15 Commits

14 changed files with 324 additions and 82 deletions
+6
View File
@@ -66,3 +66,9 @@ class AI(Protocol):
and is not implemented for all AIs. and is not implemented for all AIs.
""" """
raise NotImplementedError raise NotImplementedError
def print(self) -> None:
"""
Print some info about the current AI, like system message.
"""
pass
+16 -10
View File
@@ -4,25 +4,31 @@ Creates different AI instances, based on the given configuration.
import argparse import argparse
from typing import cast from typing import cast
from .configuration import Config, OpenAIConfig, default_ai_ID from .configuration import Config, AIConfig, OpenAIConfig
from .ai import AI, AIError from .ai import AI, AIError
from .ais.openai import OpenAI from .ais.openai import OpenAI
def create_ai(args: argparse.Namespace, config: Config) -> AI: def create_ai(args: argparse.Namespace, config: Config) -> AI: # noqa: 11
""" """
Creates an AI subclass instance from the given arguments Creates an AI subclass instance from the given arguments
and configuration file. and configuration file. If AI has not been set in the
arguments, it searches for the ID 'default'. If that
is not found, it uses the first AI in the list.
""" """
if args.ai: ai_conf: AIConfig
if args.AI:
try: try:
ai_conf = config.ais[args.ai] ai_conf = config.ais[args.AI]
except KeyError: except KeyError:
raise AIError(f"AI ID '{args.ai}' does not exist in this configuration") raise AIError(f"AI ID '{args.AI}' does not exist in this configuration")
elif default_ai_ID in config.ais: elif 'default' in config.ais:
ai_conf = config.ais[default_ai_ID] ai_conf = config.ais['default']
else: else:
raise AIError("No AI name given and no default exists") try:
ai_conf = next(iter(config.ais.values()))
except StopIteration:
raise AIError("No AI found in this configuration")
if ai_conf.name == 'openai': if ai_conf.name == 'openai':
ai = OpenAI(cast(OpenAIConfig, ai_conf)) ai = OpenAI(cast(OpenAIConfig, ai_conf))
@@ -34,4 +40,4 @@ def create_ai(args: argparse.Namespace, config: Config) -> AI:
ai.config.temperature = args.temperature ai.config.temperature = args.temperature
return ai return ai
else: else:
raise AIError(f"AI '{args.ai}' is not supported") raise AIError(f"AI '{args.AI}' is not supported")
+15 -6
View File
@@ -43,16 +43,20 @@ class OpenAI(AI):
n=num_answers, n=num_answers,
frequency_penalty=self.config.frequency_penalty, frequency_penalty=self.config.frequency_penalty,
presence_penalty=self.config.presence_penalty) presence_penalty=self.config.presence_penalty)
answers: list[Message] = [] question.answer = Answer(response['choices'][0]['message']['content'])
for choice in response['choices']: # type: ignore question.tags = otags
question.ai = self.ID
question.model = self.config.model
answers: list[Message] = [question]
for choice in response['choices'][1:]: # type: ignore
answers.append(Message(question=question.question, answers.append(Message(question=question.question,
answer=Answer(choice['message']['content']), answer=Answer(choice['message']['content']),
tags=otags, tags=otags,
ai=self.name, ai=self.ID,
model=self.config.model)) model=self.config.model))
return AIResponse(answers, Tokens(response['usage']['prompt'], return AIResponse(answers, Tokens(response['usage']['prompt_tokens'],
response['usage']['completion'], response['usage']['completion_tokens'],
response['usage']['total'])) response['usage']['total_tokens']))
def models(self) -> list[str]: def models(self) -> list[str]:
""" """
@@ -95,3 +99,8 @@ class OpenAI(AI):
def tokens(self, data: Union[Message, Chat]) -> int: def tokens(self, data: Union[Message, Chat]) -> int:
raise NotImplementedError raise NotImplementedError
def print(self) -> None:
print(f"MODEL: {self.config.model}")
print("=== SYSTEM ===")
print(self.config.system)
+20 -1
View File
@@ -62,7 +62,10 @@ def make_file_path(dir_path: Path,
Create a file_path for the given directory using the Create a file_path for the given directory using the
given file_suffix and ID generator function. given file_suffix and ID generator function.
""" """
return dir_path / f"{next_fid():04d}{file_suffix}" file_path = dir_path / f"{next_fid():04d}{file_suffix}"
while file_path.exists():
file_path = dir_path / f"{next_fid():04d}{file_suffix}"
return file_path
def write_dir(dir_path: Path, def write_dir(dir_path: Path,
@@ -386,3 +389,19 @@ class ChatDB(Chat):
msgs = iter(messages if messages else self.messages) msgs = iter(messages if messages else self.messages)
while (m := next(msgs, None)): while (m := next(msgs, None)):
m.to_file() m.to_file()
def update_messages(self, messages: list[Message], write: bool = True) -> None:
"""
Update existing messages. A message is determined as 'existing' if a message with
the same base filename (i. e. 'file_path.name') is already in the list. Only accepts
existing messages.
"""
if any(not message_in(m, self.messages) for m in messages):
raise ChatError("Can't update messages that are not in the internal list")
# remove old versions and add new ones
self.messages = [m for m in self.messages if not message_in(m, messages)]
self.messages += messages
self.sort()
# write the UPDATED messages if requested
if write:
self.write_messages(messages)
+9 -1
View File
@@ -13,7 +13,15 @@ def print_cmd(args: argparse.Namespace, config: Config) -> None:
try: try:
message = Message.from_file(fname) message = Message.from_file(fname)
if message: if message:
print(message.to_str(source_code_only=args.source_code_only)) if args.question:
print(message.question)
elif args.answer:
print(message.answer)
elif message.answer and args.only_source_code:
for code in message.answer.source_code():
print(code)
else:
print(message.to_str())
except MessageError: except MessageError:
print(f"File is not a valid message: {args.file}") print(f"File is not a valid message: {args.file}")
sys.exit(1) sys.exit(1)
+41 -23
View File
@@ -3,7 +3,7 @@ from pathlib import Path
from itertools import zip_longest from itertools import zip_longest
from ..configuration import Config from ..configuration import Config
from ..chat import ChatDB from ..chat import ChatDB
from ..message import Message, Question from ..message import Message, MessageFilter, Question, source_code
from ..ai_factory import create_ai from ..ai_factory import create_ai
from ..ai import AI, AIResponse from ..ai import AI, AIResponse
@@ -13,27 +13,36 @@ def create_message(chat: ChatDB, args: argparse.Namespace) -> Message:
Creates (and writes) a new message from the given arguments. Creates (and writes) a new message from the given arguments.
""" """
question_parts = [] question_parts = []
question_list = args.question if args.question is not None else [] question_list = args.ask if args.ask is not None else []
source_list = args.source if args.source is not None else [] text_files = args.source_text if args.source_text is not None else []
code_files = args.source_code if args.source_code is not None else []
# FIXME: don't surround all sourced files with ``` for question, source, code in zip_longest(question_list, text_files, code_files, fillvalue=None):
# -> do it only if '--source-code-only' is True and no source code if question is not None and len(question.strip()) > 0:
# could be extracted from that file
for question, source in zip_longest(question_list, source_list, fillvalue=None):
if question is not None and source is not None:
with open(source) as r:
question_parts.append(f"{question}\n\n```\n{r.read().strip()}\n```")
elif question is not None:
question_parts.append(question) question_parts.append(question)
elif source is not None: if source is not None and len(source) > 0:
with open(source) as r: with open(source) as r:
question_parts.append(f"```\n{r.read().strip()}\n```") content = r.read().strip()
if len(content) > 0:
question_parts.append(content)
if code is not None and len(code) > 0:
with open(code) as r:
content = r.read().strip()
if len(content) == 0:
continue
# try to extract and add source code
code_parts = source_code(content, include_delims=True)
if len(code_parts) > 0:
question_parts += code_parts
# if there's none, add the whole file
else:
question_parts.append(f"```\n{content}\n```")
full_question = '\n\n'.join(question_parts) full_question = '\n\n'.join(question_parts)
message = Message(question=Question(full_question), message = Message(question=Question(full_question),
tags=args.output_tags, # FIXME tags=args.output_tags, # FIXME
ai=args.ai, ai=args.AI,
model=args.model) model=args.model)
chat.add_to_cache([message]) chat.add_to_cache([message])
return message return message
@@ -43,8 +52,12 @@ def question_cmd(args: argparse.Namespace, config: Config) -> None:
""" """
Handler for the 'question' command. Handler for the 'question' command.
""" """
mfilter = MessageFilter(tags_or=args.or_tags if args.or_tags is not None else set(),
tags_and=args.and_tags if args.and_tags is not None else set(),
tags_not=args.exclude_tags if args.exclude_tags is not None else set())
chat = ChatDB.from_dir(cache_path=Path('.'), chat = ChatDB.from_dir(cache_path=Path('.'),
db_path=Path(config.db)) db_path=Path(config.db),
mfilter=mfilter)
# if it's a new question, create and store it immediately # if it's a new question, create and store it immediately
if args.ask or args.create: if args.ask or args.create:
message = create_message(chat, args) message = create_message(chat, args)
@@ -54,23 +67,28 @@ def question_cmd(args: argparse.Namespace, config: Config) -> None:
# create the correct AI instance # create the correct AI instance
ai: AI = create_ai(args, config) ai: AI = create_ai(args, config)
if args.ask: if args.ask:
ai.print()
chat.print(paged=False)
response: AIResponse = ai.request(message, response: AIResponse = ai.request(message,
chat, chat,
args.num_answers, # FIXME args.num_answers, # FIXME
args.otags) # FIXME args.output_tags) # FIXME
assert response chat.update_messages([response.messages[0]])
# TODO: chat.add_to_cache(response.messages[1:])
# * add answer to the message above (and create for idx, msg in enumerate(response.messages):
# more messages for any additional answers) print(f"=== ANSWER {idx+1} ===")
pass print(msg.answer)
elif args.repeat: if response.tokens:
print("===============")
print(response.tokens)
elif args.repeat is not None:
lmessage = chat.latest_message() lmessage = chat.latest_message()
assert lmessage assert lmessage
# TODO: repeat either the last question or the # TODO: repeat either the last question or the
# one(s) given in 'args.repeat' (overwrite # one(s) given in 'args.repeat' (overwrite
# existing ones if 'args.overwrite' is True) # existing ones if 'args.overwrite' is True)
pass pass
elif args.process: elif args.process is not None:
# TODO: process either all questions without an # TODO: process either all questions without an
# answer or the one(s) given in 'args.process' # answer or the one(s) given in 'args.process'
pass pass
+19 -6
View File
@@ -9,7 +9,6 @@ OpenAIConfigInst = TypeVar('OpenAIConfigInst', bound='OpenAIConfig')
supported_ais: list[str] = ['openai'] supported_ais: list[str] = ['openai']
default_ai_ID: str = 'default'
default_config_path = '.config.yaml' default_config_path = '.config.yaml'
@@ -17,6 +16,18 @@ class ConfigError(Exception):
pass pass
def str_presenter(dumper: yaml.Dumper, data: str) -> yaml.ScalarNode:
"""
Changes the YAML dump style to multiline syntax for multiline strings.
"""
if len(data.splitlines()) > 1:
return dumper.represent_scalar('tag:yaml.org,2002:str', data, style='|')
return dumper.represent_scalar('tag:yaml.org,2002:str', data)
yaml.add_representer(str, str_presenter)
@dataclass @dataclass
class AIConfig: class AIConfig:
""" """
@@ -46,15 +57,15 @@ class OpenAIConfig(AIConfig):
# all members have default values, so we can easily create # all members have default values, so we can easily create
# a default configuration # a default configuration
ID: str = 'default' ID: str = 'myopenai'
api_key: str = '0123456789' api_key: str = '0123456789'
system: str = 'You are an assistant'
model: str = 'gpt-3.5-turbo-16k' model: str = 'gpt-3.5-turbo-16k'
temperature: float = 1.0 temperature: float = 1.0
max_tokens: int = 4000 max_tokens: int = 4000
top_p: float = 1.0 top_p: float = 1.0
frequency_penalty: float = 0.0 frequency_penalty: float = 0.0
presence_penalty: float = 0.0 presence_penalty: float = 0.0
system: str = 'You are an assistant'
@classmethod @classmethod
def from_dict(cls: Type[OpenAIConfigInst], source: dict[str, Any]) -> OpenAIConfigInst: def from_dict(cls: Type[OpenAIConfigInst], source: dict[str, Any]) -> OpenAIConfigInst:
@@ -62,14 +73,14 @@ class OpenAIConfig(AIConfig):
Create OpenAIConfig from a dict. Create OpenAIConfig from a dict.
""" """
res = cls( res = cls(
system=str(source['system']),
api_key=str(source['api_key']), api_key=str(source['api_key']),
model=str(source['model']), model=str(source['model']),
max_tokens=int(source['max_tokens']), max_tokens=int(source['max_tokens']),
temperature=float(source['temperature']), temperature=float(source['temperature']),
top_p=float(source['top_p']), top_p=float(source['top_p']),
frequency_penalty=float(source['frequency_penalty']), frequency_penalty=float(source['frequency_penalty']),
presence_penalty=float(source['presence_penalty']) presence_penalty=float(source['presence_penalty']),
system=str(source['system'])
) )
# overwrite default ID if provided # overwrite default ID if provided
if 'ID' in source: if 'ID' in source:
@@ -148,6 +159,8 @@ class Config:
def as_dict(self) -> dict[str, Any]: def as_dict(self) -> dict[str, Any]:
res = asdict(self) res = asdict(self)
# add the AI name manually (as first element)
# (not done by 'asdict' because it's a class variable)
for ID, conf in res['ais'].items(): for ID, conf in res['ais'].items():
conf.update({'name': self.ais[ID].name}) res['ais'][ID] = {**{'name': self.ais[ID].name}, **conf}
return res return res
+6 -5
View File
@@ -67,9 +67,8 @@ def create_parser() -> argparse.ArgumentParser:
question_group.add_argument('-p', '--process', nargs='*', help='Process existing questions') question_group.add_argument('-p', '--process', nargs='*', help='Process existing questions')
question_cmd_parser.add_argument('-O', '--overwrite', help='Overwrite existing messages when repeating them', question_cmd_parser.add_argument('-O', '--overwrite', help='Overwrite existing messages when repeating them',
action='store_true') action='store_true')
question_cmd_parser.add_argument('-s', '--source', nargs='+', help='Source add content of a file to the query') question_cmd_parser.add_argument('-s', '--source-text', nargs='+', help='Add content of a file to the query')
question_cmd_parser.add_argument('-S', '--source-code-only', help='Add pure source code to the chat history', question_cmd_parser.add_argument('-S', '--source-code', nargs='+', help='Add source code file content to the chat history')
action='store_true')
# 'hist' command parser # 'hist' command parser
hist_cmd_parser = cmdparser.add_parser('hist', parents=[tag_parser], hist_cmd_parser = cmdparser.add_parser('hist', parents=[tag_parser],
@@ -114,8 +113,10 @@ def create_parser() -> argparse.ArgumentParser:
aliases=['p']) aliases=['p'])
print_cmd_parser.set_defaults(func=print_cmd) print_cmd_parser.set_defaults(func=print_cmd)
print_cmd_parser.add_argument('-f', '--file', help='File to print', required=True) print_cmd_parser.add_argument('-f', '--file', help='File to print', required=True)
print_cmd_parser.add_argument('-S', '--source-code-only', help='Print source code only (from the answer, if available)', print_cmd_modes = print_cmd_parser.add_mutually_exclusive_group()
action='store_true') print_cmd_modes.add_argument('-q', '--question', help='Print only question', action='store_true')
print_cmd_modes.add_argument('-a', '--answer', help='Print only answer', action='store_true')
print_cmd_modes.add_argument('-S', '--only-source-code', help='Print only source code', action='store_true')
argcomplete.autocomplete(parser) argcomplete.autocomplete(parser)
return parser return parser
+17 -11
View File
@@ -3,6 +3,8 @@ Module implementing message related functions and classes.
""" """
import pathlib import pathlib
import yaml import yaml
import tempfile
import shutil
from typing import Type, TypeVar, ClassVar, Optional, Any, Union, Final, Literal, Iterable from typing import Type, TypeVar, ClassVar, Optional, Any, Union, Final, Literal, Iterable
from dataclasses import dataclass, asdict, field from dataclasses import dataclass, asdict, field
from .tags import Tag, TagLine, TagError, match_tags, rename_tags from .tags import Tag, TagLine, TagError, match_tags, rename_tags
@@ -312,7 +314,7 @@ class Message():
mfilter.tags_not if mfilter else None) mfilter.tags_not if mfilter else None)
else: else:
message = cls.__from_file_yaml(file_path) message = cls.__from_file_yaml(file_path)
if message and (not mfilter or (mfilter and message.match(mfilter))): if message and (mfilter is None or message.match(mfilter)):
return message return message
else: else:
return None return None
@@ -414,7 +416,7 @@ class Message():
return '\n'.join(output) return '\n'.join(output)
def __str__(self) -> str: def __str__(self) -> str:
return self.to_str(False, False, False) return self.to_str(True, True, False)
def to_file(self, file_path: Optional[pathlib.Path]=None) -> None: # noqa: 11 def to_file(self, file_path: Optional[pathlib.Path]=None) -> None: # noqa: 11
""" """
@@ -445,16 +447,18 @@ class Message():
* Answer.txt_header * Answer.txt_header
* Answer * Answer
""" """
with open(file_path, "w") as fd: with tempfile.NamedTemporaryFile(dir=file_path.parent, prefix=file_path.name, mode="w", delete=False) as temp_fd:
temp_file_path = pathlib.Path(temp_fd.name)
if self.tags: if self.tags:
fd.write(f'{TagLine.from_set(self.tags)}\n') temp_fd.write(f'{TagLine.from_set(self.tags)}\n')
if self.ai: if self.ai:
fd.write(f'{AILine.from_ai(self.ai)}\n') temp_fd.write(f'{AILine.from_ai(self.ai)}\n')
if self.model: if self.model:
fd.write(f'{ModelLine.from_model(self.model)}\n') temp_fd.write(f'{ModelLine.from_model(self.model)}\n')
fd.write(f'{Question.txt_header}\n{self.question}\n') temp_fd.write(f'{Question.txt_header}\n{self.question}\n')
if self.answer: if self.answer:
fd.write(f'{Answer.txt_header}\n{self.answer}\n') temp_fd.write(f'{Answer.txt_header}\n{self.answer}\n')
shutil.move(temp_file_path, file_path)
def __to_file_yaml(self, file_path: pathlib.Path) -> None: def __to_file_yaml(self, file_path: pathlib.Path) -> None:
""" """
@@ -466,7 +470,8 @@ class Message():
* Message.ai_yaml_key: str [Optional] * Message.ai_yaml_key: str [Optional]
* Message.model_yaml_key: str [Optional] * Message.model_yaml_key: str [Optional]
""" """
with open(file_path, "w") as fd: with tempfile.NamedTemporaryFile(dir=file_path.parent, prefix=file_path.name, mode="w", delete=False) as temp_fd:
temp_file_path = pathlib.Path(temp_fd.name)
data: YamlDict = {Question.yaml_key: str(self.question)} data: YamlDict = {Question.yaml_key: str(self.question)}
if self.answer: if self.answer:
data[Answer.yaml_key] = str(self.answer) data[Answer.yaml_key] = str(self.answer)
@@ -476,7 +481,8 @@ class Message():
data[self.model_yaml_key] = self.model data[self.model_yaml_key] = self.model
if self.tags: if self.tags:
data[self.tags_yaml_key] = sorted([str(tag) for tag in self.tags]) data[self.tags_yaml_key] = sorted([str(tag) for tag in self.tags])
yaml.dump(data, fd, sort_keys=False) yaml.dump(data, temp_fd, sort_keys=False)
shutil.move(temp_file_path, file_path)
def filter_tags(self, prefix: Optional[str] = None, contain: Optional[str] = None) -> set[Tag]: def filter_tags(self, prefix: Optional[str] = None, contain: Optional[str] = None) -> set[Tag]:
""" """
@@ -508,7 +514,7 @@ class Message():
Return True if all attributes match, else False. Return True if all attributes match, else False.
""" """
mytags = self.tags or set() mytags = self.tags or set()
if (((mfilter.tags_or or mfilter.tags_and or mfilter.tags_not) if (((mfilter.tags_or is not None or mfilter.tags_and is not None or mfilter.tags_not is not None)
and not match_tags(mytags, mfilter.tags_or, mfilter.tags_and, mfilter.tags_not)) # noqa: W503 and not match_tags(mytags, mfilter.tags_or, mfilter.tags_and, mfilter.tags_not)) # noqa: W503
or (mfilter.ai and (not self.ai or mfilter.ai != self.ai)) # noqa: W503 or (mfilter.ai and (not self.ai or mfilter.ai != self.ai)) # noqa: W503
or (mfilter.model and (not self.model or mfilter.model != self.model)) # noqa: W503 or (mfilter.model and (not self.model or mfilter.model != self.model)) # noqa: W503
+5 -5
View File
@@ -10,7 +10,7 @@ from chatmastermind.ais.openai import OpenAI
class TestCreateAI(unittest.TestCase): class TestCreateAI(unittest.TestCase):
def setUp(self) -> None: def setUp(self) -> None:
self.args = MagicMock(spec=argparse.Namespace) self.args = MagicMock(spec=argparse.Namespace)
self.args.ai = 'default' self.args.AI = 'myopenai'
self.args.model = None self.args.model = None
self.args.max_tokens = None self.args.max_tokens = None
self.args.temperature = None self.args.temperature = None
@@ -18,19 +18,19 @@ class TestCreateAI(unittest.TestCase):
def test_create_ai_from_args(self) -> None: def test_create_ai_from_args(self) -> None:
# Create an AI with the default configuration # Create an AI with the default configuration
config = Config() config = Config()
self.args.ai = 'default' self.args.AI = 'myopenai'
ai = create_ai(self.args, config) ai = create_ai(self.args, config)
self.assertIsInstance(ai, OpenAI) self.assertIsInstance(ai, OpenAI)
def test_create_ai_from_default(self) -> None: def test_create_ai_from_default(self) -> None:
self.args.ai = None self.args.AI = None
# Create an AI with the default configuration # Create an AI with the default configuration
config = Config() config = Config()
ai = create_ai(self.args, config) ai = create_ai(self.args, config)
self.assertIsInstance(ai, OpenAI) self.assertIsInstance(ai, OpenAI)
def test_create_empty_ai_error(self) -> None: def test_create_empty_ai_error(self) -> None:
self.args.ai = None self.args.AI = None
# Create Config with empty AIs # Create Config with empty AIs
config = Config() config = Config()
config.ais = {} config.ais = {}
@@ -40,7 +40,7 @@ class TestCreateAI(unittest.TestCase):
def test_create_unsupported_ai_error(self) -> None: def test_create_unsupported_ai_error(self) -> None:
# Mock argparse.Namespace with ai='invalid_ai' # Mock argparse.Namespace with ai='invalid_ai'
self.args.ai = 'invalid_ai' self.args.AI = 'invalid_ai'
# Create default Config # Create default Config
config = Config() config = Config()
# Call create_ai function and assert that it raises AIError # Call create_ai function and assert that it raises AIError
+48 -2
View File
@@ -202,7 +202,25 @@ class TestChatDB(unittest.TestCase):
self.assertEqual(chat_db.messages[1].file_path, self.assertEqual(chat_db.messages[1].file_path,
pathlib.Path(self.db_path.name, '0003.txt')) pathlib.Path(self.db_path.name, '0003.txt'))
def test_chat_db_filter(self) -> None: def test_chat_db_from_dir_filter_tags(self) -> None:
chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name),
pathlib.Path(self.db_path.name),
mfilter=MessageFilter(tags_or={Tag('tag1')}))
self.assertEqual(len(chat_db.messages), 1)
self.assertEqual(chat_db.cache_path, pathlib.Path(self.cache_path.name))
self.assertEqual(chat_db.db_path, pathlib.Path(self.db_path.name))
self.assertEqual(chat_db.messages[0].file_path,
pathlib.Path(self.db_path.name, '0001.txt'))
def test_chat_db_from_dir_filter_tags_empty(self) -> None:
chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name),
pathlib.Path(self.db_path.name),
mfilter=MessageFilter(tags_or=set(),
tags_and=set(),
tags_not=set()))
self.assertEqual(len(chat_db.messages), 0)
def test_chat_db_from_dir_filter_answer(self) -> None:
chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name), chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name),
pathlib.Path(self.db_path.name), pathlib.Path(self.db_path.name),
mfilter=MessageFilter(answer_contains='Answer 2')) mfilter=MessageFilter(answer_contains='Answer 2'))
@@ -213,7 +231,7 @@ class TestChatDB(unittest.TestCase):
pathlib.Path(self.db_path.name, '0002.yaml')) pathlib.Path(self.db_path.name, '0002.yaml'))
self.assertEqual(chat_db.messages[0].answer, 'Answer 2') self.assertEqual(chat_db.messages[0].answer, 'Answer 2')
def test_chat_db_from_messges(self) -> None: def test_chat_db_from_messages(self) -> None:
chat_db = ChatDB.from_messages(pathlib.Path(self.cache_path.name), chat_db = ChatDB.from_messages(pathlib.Path(self.cache_path.name),
pathlib.Path(self.db_path.name), pathlib.Path(self.db_path.name),
messages=[self.message1, self.message2, messages=[self.message1, self.message2,
@@ -440,3 +458,31 @@ class TestChatDB(unittest.TestCase):
cache_dir_files = self.message_list(self.cache_path) cache_dir_files = self.message_list(self.cache_path)
self.assertEqual(len(cache_dir_files), 1) self.assertEqual(len(cache_dir_files), 1)
self.assertIn(pathlib.Path(self.cache_path.name, '123456.txt'), cache_dir_files) self.assertIn(pathlib.Path(self.cache_path.name, '123456.txt'), cache_dir_files)
def test_chat_db_update_messages(self) -> None:
# create a new ChatDB instance
chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name),
pathlib.Path(self.db_path.name))
db_dir_files = self.message_list(self.db_path)
self.assertEqual(len(db_dir_files), 4)
cache_dir_files = self.message_list(self.cache_path)
self.assertEqual(len(cache_dir_files), 0)
message = chat_db.messages[0]
message.answer = Answer("New answer")
# update message without writing
chat_db.update_messages([message], write=False)
self.assertEqual(chat_db.messages[0].answer, Answer("New answer"))
# re-read the message and check for old content
chat_db.read_db()
self.assertEqual(chat_db.messages[0].answer, Answer("Answer 1"))
# now check with writing (message should be overwritten)
chat_db.update_messages([message], write=True)
chat_db.read_db()
self.assertEqual(chat_db.messages[0].answer, Answer("New answer"))
# test without file_path -> expect error
message1 = Message(question=Question("Question 1"),
answer=Answer("Answer 1"))
with self.assertRaises(ChatError):
chat_db.update_messages([message1])
+7 -7
View File
@@ -59,7 +59,7 @@ class TestConfig(unittest.TestCase):
source_dict = { source_dict = {
'db': './test_db/', 'db': './test_db/',
'ais': { 'ais': {
'default': { 'myopenai': {
'name': 'openai', 'name': 'openai',
'system': 'Custom system', 'system': 'Custom system',
'api_key': '9876543210', 'api_key': '9876543210',
@@ -75,10 +75,10 @@ class TestConfig(unittest.TestCase):
config = Config.from_dict(source_dict) config = Config.from_dict(source_dict)
self.assertEqual(config.db, './test_db/') self.assertEqual(config.db, './test_db/')
self.assertEqual(len(config.ais), 1) self.assertEqual(len(config.ais), 1)
self.assertEqual(config.ais['default'].name, 'openai') self.assertEqual(config.ais['myopenai'].name, 'openai')
self.assertEqual(cast(OpenAIConfig, config.ais['default']).system, 'Custom system') self.assertEqual(cast(OpenAIConfig, config.ais['myopenai']).system, 'Custom system')
# check that 'ID' has been added # check that 'ID' has been added
self.assertEqual(config.ais['default'].ID, 'default') self.assertEqual(config.ais['myopenai'].ID, 'myopenai')
def test_create_default_should_create_default_config(self) -> None: def test_create_default_should_create_default_config(self) -> None:
Config.create_default(Path(self.test_file.name)) Config.create_default(Path(self.test_file.name))
@@ -117,8 +117,8 @@ class TestConfig(unittest.TestCase):
config = Config( config = Config(
db='./test_db/', db='./test_db/',
ais={ ais={
'default': OpenAIConfig( 'myopenai': OpenAIConfig(
ID='default', ID='myopenai',
system='Custom system', system='Custom system',
api_key='9876543210', api_key='9876543210',
model='custom_model', model='custom_model',
@@ -135,7 +135,7 @@ class TestConfig(unittest.TestCase):
saved_config = yaml.load(f, Loader=yaml.FullLoader) saved_config = yaml.load(f, Loader=yaml.FullLoader)
self.assertEqual(saved_config['db'], './test_db/') self.assertEqual(saved_config['db'], './test_db/')
self.assertEqual(len(saved_config['ais']), 1) self.assertEqual(len(saved_config['ais']), 1)
self.assertEqual(saved_config['ais']['default']['system'], 'Custom system') self.assertEqual(saved_config['ais']['myopenai']['system'], 'Custom system')
def test_from_file_error_unknown_ai(self) -> None: def test_from_file_error_unknown_ai(self) -> None:
source_dict = { source_dict = {
+6
View File
@@ -300,6 +300,12 @@ This is a question.
MessageFilter(tags_or={Tag('tag1')})) MessageFilter(tags_or={Tag('tag1')}))
self.assertIsNone(message) self.assertIsNone(message)
def test_from_file_txt_empty_tags_dont_match(self) -> None:
message = Message.from_file(self.file_path_min,
MessageFilter(tags_or=set(),
tags_and=set()))
self.assertIsNone(message)
def test_from_file_txt_no_tags_match_tags_not(self) -> None: def test_from_file_txt_no_tags_match_tags_not(self) -> None:
message = Message.from_file(self.file_path_min, message = Message.from_file(self.file_path_min,
MessageFilter(tags_not={Tag('tag1')})) MessageFilter(tags_not={Tag('tag1')}))
+109 -5
View File
@@ -1,3 +1,4 @@
import os
import unittest import unittest
import argparse import argparse
import tempfile import tempfile
@@ -14,23 +15,66 @@ class TestMessageCreate(unittest.TestCase):
the correct format. the correct format.
""" """
def setUp(self) -> None: def setUp(self) -> None:
# create ChatDB structure
self.db_path = tempfile.TemporaryDirectory() self.db_path = tempfile.TemporaryDirectory()
self.cache_path = tempfile.TemporaryDirectory() self.cache_path = tempfile.TemporaryDirectory()
self.chat = ChatDB.from_dir(cache_path=Path(self.cache_path.name), self.chat = ChatDB.from_dir(cache_path=Path(self.cache_path.name),
db_path=Path(self.db_path.name)) db_path=Path(self.db_path.name))
# create arguments mock # create arguments mock
self.args = MagicMock(spec=argparse.Namespace) self.args = MagicMock(spec=argparse.Namespace)
self.args.source = None self.args.source_text = None
self.args.ai = None self.args.source_code = None
self.args.AI = None
self.args.model = None self.args.model = None
self.args.output_tags = None self.args.output_tags = None
# File 1 : no source code block, only text
self.source_file1 = tempfile.NamedTemporaryFile(delete=False)
self.source_file1_content = """This is just text.
No source code.
Nope. Go look elsewhere!"""
with open(self.source_file1.name, 'w') as f:
f.write(self.source_file1_content)
# File 2 : one embedded source code block
self.source_file2 = tempfile.NamedTemporaryFile(delete=False)
self.source_file2_content = """This is just text.
```
This is embedded source code.
```
And some text again."""
with open(self.source_file2.name, 'w') as f:
f.write(self.source_file2_content)
# File 3 : all source code
self.source_file3 = tempfile.NamedTemporaryFile(delete=False)
self.source_file3_content = """This is all source code.
Yes, really.
Language is called 'brainfart'."""
with open(self.source_file3.name, 'w') as f:
f.write(self.source_file3_content)
# File 4 : two source code blocks
self.source_file4 = tempfile.NamedTemporaryFile(delete=False)
self.source_file4_content = """This is just text.
```
This is embedded source code.
```
And some text again.
```
This is embedded source code.
```
Aaaand again some text."""
with open(self.source_file4.name, 'w') as f:
f.write(self.source_file4_content)
def tearDown(self) -> None:
os.remove(self.source_file1.name)
os.remove(self.source_file2.name)
os.remove(self.source_file3.name)
def message_list(self, tmp_dir: tempfile.TemporaryDirectory) -> list[Path]: def message_list(self, tmp_dir: tempfile.TemporaryDirectory) -> list[Path]:
# exclude '.next' # exclude '.next'
return list(Path(tmp_dir.name).glob('*.[ty]*')) return list(Path(tmp_dir.name).glob('*.[ty]*'))
def test_message_file_created(self) -> None: def test_message_file_created(self) -> None:
self.args.question = ["What is this?"] self.args.ask = ["What is this?"]
cache_dir_files = self.message_list(self.cache_path) cache_dir_files = self.message_list(self.cache_path)
self.assertEqual(len(cache_dir_files), 0) self.assertEqual(len(cache_dir_files), 0)
create_message(self.chat, self.args) create_message(self.chat, self.args)
@@ -41,14 +85,14 @@ class TestMessageCreate(unittest.TestCase):
self.assertEqual(message.question, Question("What is this?")) # type: ignore [union-attr] self.assertEqual(message.question, Question("What is this?")) # type: ignore [union-attr]
def test_single_question(self) -> None: def test_single_question(self) -> None:
self.args.question = ["What is this?"] self.args.ask = ["What is this?"]
message = create_message(self.chat, self.args) message = create_message(self.chat, self.args)
self.assertIsInstance(message, Message) self.assertIsInstance(message, Message)
self.assertEqual(message.question, Question("What is this?")) self.assertEqual(message.question, Question("What is this?"))
self.assertEqual(len(message.question.source_code()), 0) self.assertEqual(len(message.question.source_code()), 0)
def test_multipart_question(self) -> None: def test_multipart_question(self) -> None:
self.args.question = ["What is this", "'bard' thing?", "Is it good?"] self.args.ask = ["What is this", "'bard' thing?", "Is it good?"]
message = create_message(self.chat, self.args) message = create_message(self.chat, self.args)
self.assertIsInstance(message, Message) self.assertIsInstance(message, Message)
self.assertEqual(message.question, Question("""What is this self.assertEqual(message.question, Question("""What is this
@@ -56,3 +100,63 @@ class TestMessageCreate(unittest.TestCase):
'bard' thing? 'bard' thing?
Is it good?""")) Is it good?"""))
def test_single_question_with_text_only_file(self) -> None:
self.args.ask = ["What is this?"]
self.args.source_text = [f"{self.source_file1.name}"]
message = create_message(self.chat, self.args)
self.assertIsInstance(message, Message)
# file contains no source code (only text)
# -> don't expect any in the question
self.assertEqual(len(message.question.source_code()), 0)
self.assertEqual(message.question, Question(f"""What is this?
{self.source_file1_content}"""))
def test_single_question_with_text_file_and_embedded_code(self) -> None:
self.args.ask = ["What is this?"]
self.args.source_code = [f"{self.source_file2.name}"]
message = create_message(self.chat, self.args)
self.assertIsInstance(message, Message)
# file contains 1 source code block
# -> expect it in the question
self.assertEqual(len(message.question.source_code()), 1)
self.assertEqual(message.question, Question("""What is this?
```
This is embedded source code.
```
"""))
def test_single_question_with_code_only_file(self) -> None:
self.args.ask = ["What is this?"]
self.args.source_code = [f"{self.source_file3.name}"]
message = create_message(self.chat, self.args)
self.assertIsInstance(message, Message)
# file is complete source code
self.assertEqual(len(message.question.source_code()), 1)
self.assertEqual(message.question, Question(f"""What is this?
```
{self.source_file3_content}
```"""))
def test_single_question_with_text_file_and_multi_embedded_code(self) -> None:
self.args.ask = ["What is this?"]
self.args.source_code = [f"{self.source_file4.name}"]
message = create_message(self.chat, self.args)
self.assertIsInstance(message, Message)
# file contains 2 source code blocks
# -> expect them in the question
self.assertEqual(len(message.question.source_code()), 2)
self.assertEqual(message.question, Question("""What is this?
```
This is embedded source code.
```
```
This is embedded source code.
```
"""))