11 Commits

10 changed files with 156 additions and 42 deletions
+12 -6
View File
@@ -4,25 +4,31 @@ Creates different AI instances, based on the given configuration.
import argparse
from typing import cast
from .configuration import Config, OpenAIConfig, default_ai_ID
from .configuration import Config, AIConfig, OpenAIConfig
from .ai import AI, AIError
from .ais.openai import OpenAI
def create_ai(args: argparse.Namespace, config: Config) -> AI:
def create_ai(args: argparse.Namespace, config: Config) -> AI: # noqa: 11
"""
Creates an AI subclass instance from the given arguments
and configuration file.
and configuration file. If AI has not been set in the
arguments, it searches for the ID 'default'. If that
is not found, it uses the first AI in the list.
"""
ai_conf: AIConfig
if args.AI:
try:
ai_conf = config.ais[args.AI]
except KeyError:
raise AIError(f"AI ID '{args.AI}' does not exist in this configuration")
elif default_ai_ID in config.ais:
ai_conf = config.ais[default_ai_ID]
elif 'default' in config.ais:
ai_conf = config.ais['default']
else:
raise AIError("No AI name given and no default exists")
try:
ai_conf = next(iter(config.ais.values()))
except StopIteration:
raise AIError("No AI found in this configuration")
if ai_conf.name == 'openai':
ai = OpenAI(cast(OpenAIConfig, ai_conf))
+2 -2
View File
@@ -45,14 +45,14 @@ class OpenAI(AI):
presence_penalty=self.config.presence_penalty)
question.answer = Answer(response['choices'][0]['message']['content'])
question.tags = otags
question.ai = self.name
question.ai = self.ID
question.model = self.config.model
answers: list[Message] = [question]
for choice in response['choices'][1:]: # type: ignore
answers.append(Message(question=question.question,
answer=Answer(choice['message']['content']),
tags=otags,
ai=self.name,
ai=self.ID,
model=self.config.model))
return AIResponse(answers, Tokens(response['usage']['prompt_tokens'],
response['usage']['completion_tokens'],
+23 -2
View File
@@ -62,7 +62,10 @@ def make_file_path(dir_path: Path,
Create a file_path for the given directory using the
given file_suffix and ID generator function.
"""
return dir_path / f"{next_fid():04d}{file_suffix}"
file_path = dir_path / f"{next_fid():04d}{file_suffix}"
while file_path.exists():
file_path = dir_path / f"{next_fid():04d}{file_suffix}"
return file_path
def write_dir(dir_path: Path,
@@ -201,7 +204,7 @@ class Chat:
output.append(message.to_str(source_code_only=True))
continue
output.append(message.to_str(with_tags, with_files))
# output.append('\n' + ('-' * terminal_width()) + '\n')
output.append('\n' + ('-' * terminal_width()) + '\n')
if paged:
print_paged('\n'.join(output))
else:
@@ -361,6 +364,8 @@ class ChatDB(Chat):
Add the given new messages and set the file_path to the cache directory.
Only accepts messages without a file_path.
"""
if any(m.file_path is not None for m in messages):
raise ChatError("Can't add new messages with existing file_path")
if write:
write_dir(self.cache_path,
messages,
@@ -384,3 +389,19 @@ class ChatDB(Chat):
msgs = iter(messages if messages else self.messages)
while (m := next(msgs, None)):
m.to_file()
def update_messages(self, messages: list[Message], write: bool = True) -> None:
"""
Update existing messages. A message is determined as 'existing' if a message with
the same base filename (i. e. 'file_path.name') is already in the list. Only accepts
existing messages.
"""
if any(not message_in(m, self.messages) for m in messages):
raise ChatError("Can't update messages that are not in the internal list")
# remove old versions and add new ones
self.messages = [m for m in self.messages if not message_in(m, messages)]
self.messages += messages
self.sort()
# write the UPDATED messages if requested
if write:
self.write_messages(messages)
+5 -4
View File
@@ -52,9 +52,9 @@ def question_cmd(args: argparse.Namespace, config: Config) -> None:
"""
Handler for the 'question' command.
"""
mfilter = MessageFilter(tags_or=args.or_tags,
tags_and=args.and_tags,
tags_not=args.exclude_tags)
mfilter = MessageFilter(tags_or=args.or_tags if args.or_tags is not None else set(),
tags_and=args.and_tags if args.and_tags is not None else set(),
tags_not=args.exclude_tags if args.exclude_tags is not None else set())
chat = ChatDB.from_dir(cache_path=Path('.'),
db_path=Path(config.db),
mfilter=mfilter)
@@ -73,7 +73,8 @@ def question_cmd(args: argparse.Namespace, config: Config) -> None:
chat,
args.num_answers, # FIXME
args.output_tags) # FIXME
chat.add_to_cache(response.messages)
chat.update_messages([response.messages[0]])
chat.add_to_cache(response.messages[1:])
for idx, msg in enumerate(response.messages):
print(f"=== ANSWER {idx+1} ===")
print(msg.answer)
+19 -6
View File
@@ -9,7 +9,6 @@ OpenAIConfigInst = TypeVar('OpenAIConfigInst', bound='OpenAIConfig')
supported_ais: list[str] = ['openai']
default_ai_ID: str = 'default'
default_config_path = '.config.yaml'
@@ -17,6 +16,18 @@ class ConfigError(Exception):
pass
def str_presenter(dumper: yaml.Dumper, data: str) -> yaml.ScalarNode:
"""
Changes the YAML dump style to multiline syntax for multiline strings.
"""
if len(data.splitlines()) > 1:
return dumper.represent_scalar('tag:yaml.org,2002:str', data, style='|')
return dumper.represent_scalar('tag:yaml.org,2002:str', data)
yaml.add_representer(str, str_presenter)
@dataclass
class AIConfig:
"""
@@ -46,15 +57,15 @@ class OpenAIConfig(AIConfig):
# all members have default values, so we can easily create
# a default configuration
ID: str = 'default'
ID: str = 'myopenai'
api_key: str = '0123456789'
system: str = 'You are an assistant'
model: str = 'gpt-3.5-turbo-16k'
temperature: float = 1.0
max_tokens: int = 4000
top_p: float = 1.0
frequency_penalty: float = 0.0
presence_penalty: float = 0.0
system: str = 'You are an assistant'
@classmethod
def from_dict(cls: Type[OpenAIConfigInst], source: dict[str, Any]) -> OpenAIConfigInst:
@@ -62,14 +73,14 @@ class OpenAIConfig(AIConfig):
Create OpenAIConfig from a dict.
"""
res = cls(
system=str(source['system']),
api_key=str(source['api_key']),
model=str(source['model']),
max_tokens=int(source['max_tokens']),
temperature=float(source['temperature']),
top_p=float(source['top_p']),
frequency_penalty=float(source['frequency_penalty']),
presence_penalty=float(source['presence_penalty'])
presence_penalty=float(source['presence_penalty']),
system=str(source['system'])
)
# overwrite default ID if provided
if 'ID' in source:
@@ -148,6 +159,8 @@ class Config:
def as_dict(self) -> dict[str, Any]:
res = asdict(self)
# add the AI name manually (as first element)
# (not done by 'asdict' because it's a class variable)
for ID, conf in res['ais'].items():
conf.update({'name': self.ais[ID].name})
res['ais'][ID] = {**{'name': self.ais[ID].name}, **conf}
return res
+16 -10
View File
@@ -3,6 +3,8 @@ Module implementing message related functions and classes.
"""
import pathlib
import yaml
import tempfile
import shutil
from typing import Type, TypeVar, ClassVar, Optional, Any, Union, Final, Literal, Iterable
from dataclasses import dataclass, asdict, field
from .tags import Tag, TagLine, TagError, match_tags, rename_tags
@@ -312,7 +314,7 @@ class Message():
mfilter.tags_not if mfilter else None)
else:
message = cls.__from_file_yaml(file_path)
if message and (not mfilter or (mfilter and message.match(mfilter))):
if message and (mfilter is None or message.match(mfilter)):
return message
else:
return None
@@ -445,16 +447,18 @@ class Message():
* Answer.txt_header
* Answer
"""
with open(file_path, "w") as fd:
with tempfile.NamedTemporaryFile(dir=file_path.parent, prefix=file_path.name, mode="w", delete=False) as temp_fd:
temp_file_path = pathlib.Path(temp_fd.name)
if self.tags:
fd.write(f'{TagLine.from_set(self.tags)}\n')
temp_fd.write(f'{TagLine.from_set(self.tags)}\n')
if self.ai:
fd.write(f'{AILine.from_ai(self.ai)}\n')
temp_fd.write(f'{AILine.from_ai(self.ai)}\n')
if self.model:
fd.write(f'{ModelLine.from_model(self.model)}\n')
fd.write(f'{Question.txt_header}\n{self.question}\n')
temp_fd.write(f'{ModelLine.from_model(self.model)}\n')
temp_fd.write(f'{Question.txt_header}\n{self.question}\n')
if self.answer:
fd.write(f'{Answer.txt_header}\n{self.answer}\n')
temp_fd.write(f'{Answer.txt_header}\n{self.answer}\n')
shutil.move(temp_file_path, file_path)
def __to_file_yaml(self, file_path: pathlib.Path) -> None:
"""
@@ -466,7 +470,8 @@ class Message():
* Message.ai_yaml_key: str [Optional]
* Message.model_yaml_key: str [Optional]
"""
with open(file_path, "w") as fd:
with tempfile.NamedTemporaryFile(dir=file_path.parent, prefix=file_path.name, mode="w", delete=False) as temp_fd:
temp_file_path = pathlib.Path(temp_fd.name)
data: YamlDict = {Question.yaml_key: str(self.question)}
if self.answer:
data[Answer.yaml_key] = str(self.answer)
@@ -476,7 +481,8 @@ class Message():
data[self.model_yaml_key] = self.model
if self.tags:
data[self.tags_yaml_key] = sorted([str(tag) for tag in self.tags])
yaml.dump(data, fd, sort_keys=False)
yaml.dump(data, temp_fd, sort_keys=False)
shutil.move(temp_file_path, file_path)
def filter_tags(self, prefix: Optional[str] = None, contain: Optional[str] = None) -> set[Tag]:
"""
@@ -508,7 +514,7 @@ class Message():
Return True if all attributes match, else False.
"""
mytags = self.tags or set()
if (((mfilter.tags_or or mfilter.tags_and or mfilter.tags_not)
if (((mfilter.tags_or is not None or mfilter.tags_and is not None or mfilter.tags_not is not None)
and not match_tags(mytags, mfilter.tags_or, mfilter.tags_and, mfilter.tags_not)) # noqa: W503
or (mfilter.ai and (not self.ai or mfilter.ai != self.ai)) # noqa: W503
or (mfilter.model and (not self.model or mfilter.model != self.model)) # noqa: W503
+2 -2
View File
@@ -10,7 +10,7 @@ from chatmastermind.ais.openai import OpenAI
class TestCreateAI(unittest.TestCase):
def setUp(self) -> None:
self.args = MagicMock(spec=argparse.Namespace)
self.args.AI = 'default'
self.args.AI = 'myopenai'
self.args.model = None
self.args.max_tokens = None
self.args.temperature = None
@@ -18,7 +18,7 @@ class TestCreateAI(unittest.TestCase):
def test_create_ai_from_args(self) -> None:
# Create an AI with the default configuration
config = Config()
self.args.AI = 'default'
self.args.AI = 'myopenai'
ai = create_ai(self.args, config)
self.assertIsInstance(ai, OpenAI)
+64 -3
View File
@@ -6,7 +6,7 @@ from io import StringIO
from unittest.mock import patch
from chatmastermind.tags import TagLine
from chatmastermind.message import Message, Question, Answer, Tag, MessageFilter
from chatmastermind.chat import Chat, ChatDB, ChatError
from chatmastermind.chat import Chat, ChatDB, terminal_width, ChatError
class TestChat(unittest.TestCase):
@@ -92,10 +92,16 @@ class TestChat(unittest.TestCase):
Question 1
{Answer.txt_header}
Answer 1
{'-'*terminal_width()}
{Question.txt_header}
Question 2
{Answer.txt_header}
Answer 2
{'-'*terminal_width()}
"""
self.assertEqual(mock_stdout.getvalue(), expected_output)
@@ -109,12 +115,18 @@ FILE: 0001.txt
Question 1
{Answer.txt_header}
Answer 1
{'-'*terminal_width()}
{TagLine.prefix} btag2
FILE: 0002.txt
{Question.txt_header}
Question 2
{Answer.txt_header}
Answer 2
{'-'*terminal_width()}
"""
self.assertEqual(mock_stdout.getvalue(), expected_output)
@@ -190,7 +202,25 @@ class TestChatDB(unittest.TestCase):
self.assertEqual(chat_db.messages[1].file_path,
pathlib.Path(self.db_path.name, '0003.txt'))
def test_chat_db_filter(self) -> None:
def test_chat_db_from_dir_filter_tags(self) -> None:
chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name),
pathlib.Path(self.db_path.name),
mfilter=MessageFilter(tags_or={Tag('tag1')}))
self.assertEqual(len(chat_db.messages), 1)
self.assertEqual(chat_db.cache_path, pathlib.Path(self.cache_path.name))
self.assertEqual(chat_db.db_path, pathlib.Path(self.db_path.name))
self.assertEqual(chat_db.messages[0].file_path,
pathlib.Path(self.db_path.name, '0001.txt'))
def test_chat_db_from_dir_filter_tags_empty(self) -> None:
chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name),
pathlib.Path(self.db_path.name),
mfilter=MessageFilter(tags_or=set(),
tags_and=set(),
tags_not=set()))
self.assertEqual(len(chat_db.messages), 0)
def test_chat_db_from_dir_filter_answer(self) -> None:
chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name),
pathlib.Path(self.db_path.name),
mfilter=MessageFilter(answer_contains='Answer 2'))
@@ -201,7 +231,7 @@ class TestChatDB(unittest.TestCase):
pathlib.Path(self.db_path.name, '0002.yaml'))
self.assertEqual(chat_db.messages[0].answer, 'Answer 2')
def test_chat_db_from_messges(self) -> None:
def test_chat_db_from_messages(self) -> None:
chat_db = ChatDB.from_messages(pathlib.Path(self.cache_path.name),
pathlib.Path(self.db_path.name),
messages=[self.message1, self.message2,
@@ -403,6 +433,9 @@ class TestChatDB(unittest.TestCase):
db_dir_files = self.message_list(self.db_path)
self.assertEqual(len(db_dir_files), 5)
with self.assertRaises(ChatError):
chat_db.add_to_cache([Message(Question("?"), file_path=pathlib.Path("foo"))])
def test_chat_db_write_messages(self) -> None:
# create a new ChatDB instance
chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name),
@@ -425,3 +458,31 @@ class TestChatDB(unittest.TestCase):
cache_dir_files = self.message_list(self.cache_path)
self.assertEqual(len(cache_dir_files), 1)
self.assertIn(pathlib.Path(self.cache_path.name, '123456.txt'), cache_dir_files)
def test_chat_db_update_messages(self) -> None:
# create a new ChatDB instance
chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name),
pathlib.Path(self.db_path.name))
db_dir_files = self.message_list(self.db_path)
self.assertEqual(len(db_dir_files), 4)
cache_dir_files = self.message_list(self.cache_path)
self.assertEqual(len(cache_dir_files), 0)
message = chat_db.messages[0]
message.answer = Answer("New answer")
# update message without writing
chat_db.update_messages([message], write=False)
self.assertEqual(chat_db.messages[0].answer, Answer("New answer"))
# re-read the message and check for old content
chat_db.read_db()
self.assertEqual(chat_db.messages[0].answer, Answer("Answer 1"))
# now check with writing (message should be overwritten)
chat_db.update_messages([message], write=True)
chat_db.read_db()
self.assertEqual(chat_db.messages[0].answer, Answer("New answer"))
# test without file_path -> expect error
message1 = Message(question=Question("Question 1"),
answer=Answer("Answer 1"))
with self.assertRaises(ChatError):
chat_db.update_messages([message1])
+7 -7
View File
@@ -59,7 +59,7 @@ class TestConfig(unittest.TestCase):
source_dict = {
'db': './test_db/',
'ais': {
'default': {
'myopenai': {
'name': 'openai',
'system': 'Custom system',
'api_key': '9876543210',
@@ -75,10 +75,10 @@ class TestConfig(unittest.TestCase):
config = Config.from_dict(source_dict)
self.assertEqual(config.db, './test_db/')
self.assertEqual(len(config.ais), 1)
self.assertEqual(config.ais['default'].name, 'openai')
self.assertEqual(cast(OpenAIConfig, config.ais['default']).system, 'Custom system')
self.assertEqual(config.ais['myopenai'].name, 'openai')
self.assertEqual(cast(OpenAIConfig, config.ais['myopenai']).system, 'Custom system')
# check that 'ID' has been added
self.assertEqual(config.ais['default'].ID, 'default')
self.assertEqual(config.ais['myopenai'].ID, 'myopenai')
def test_create_default_should_create_default_config(self) -> None:
Config.create_default(Path(self.test_file.name))
@@ -117,8 +117,8 @@ class TestConfig(unittest.TestCase):
config = Config(
db='./test_db/',
ais={
'default': OpenAIConfig(
ID='default',
'myopenai': OpenAIConfig(
ID='myopenai',
system='Custom system',
api_key='9876543210',
model='custom_model',
@@ -135,7 +135,7 @@ class TestConfig(unittest.TestCase):
saved_config = yaml.load(f, Loader=yaml.FullLoader)
self.assertEqual(saved_config['db'], './test_db/')
self.assertEqual(len(saved_config['ais']), 1)
self.assertEqual(saved_config['ais']['default']['system'], 'Custom system')
self.assertEqual(saved_config['ais']['myopenai']['system'], 'Custom system')
def test_from_file_error_unknown_ai(self) -> None:
source_dict = {
+6
View File
@@ -300,6 +300,12 @@ This is a question.
MessageFilter(tags_or={Tag('tag1')}))
self.assertIsNone(message)
def test_from_file_txt_empty_tags_dont_match(self) -> None:
message = Message.from_file(self.file_path_min,
MessageFilter(tags_or=set(),
tags_and=set()))
self.assertIsNone(message)
def test_from_file_txt_no_tags_match_tags_not(self) -> None:
message = Message.from_file(self.file_path_min,
MessageFilter(tags_not={Tag('tag1')}))