9 Commits

7 changed files with 26 additions and 45 deletions
+6 -12
View File
@@ -4,31 +4,25 @@ Creates different AI instances, based on the given configuration.
import argparse import argparse
from typing import cast from typing import cast
from .configuration import Config, AIConfig, OpenAIConfig from .configuration import Config, OpenAIConfig, default_ai_ID
from .ai import AI, AIError from .ai import AI, AIError
from .ais.openai import OpenAI from .ais.openai import OpenAI
def create_ai(args: argparse.Namespace, config: Config) -> AI: # noqa: 11 def create_ai(args: argparse.Namespace, config: Config) -> AI:
""" """
Creates an AI subclass instance from the given arguments Creates an AI subclass instance from the given arguments
and configuration file. If AI has not been set in the and configuration file.
arguments, it searches for the ID 'default'. If that
is not found, it uses the first AI in the list.
""" """
ai_conf: AIConfig
if args.AI: if args.AI:
try: try:
ai_conf = config.ais[args.AI] ai_conf = config.ais[args.AI]
except KeyError: except KeyError:
raise AIError(f"AI ID '{args.AI}' does not exist in this configuration") raise AIError(f"AI ID '{args.AI}' does not exist in this configuration")
elif 'default' in config.ais: elif default_ai_ID in config.ais:
ai_conf = config.ais['default'] ai_conf = config.ais[default_ai_ID]
else: else:
try: raise AIError("No AI name given and no default exists")
ai_conf = next(iter(config.ais.values()))
except StopIteration:
raise AIError("No AI found in this configuration")
if ai_conf.name == 'openai': if ai_conf.name == 'openai':
ai = OpenAI(cast(OpenAIConfig, ai_conf)) ai = OpenAI(cast(OpenAIConfig, ai_conf))
+1 -4
View File
@@ -62,10 +62,7 @@ def make_file_path(dir_path: Path,
Create a file_path for the given directory using the Create a file_path for the given directory using the
given file_suffix and ID generator function. given file_suffix and ID generator function.
""" """
file_path = dir_path / f"{next_fid():04d}{file_suffix}" return dir_path / f"{next_fid():04d}{file_suffix}"
while file_path.exists():
file_path = dir_path / f"{next_fid():04d}{file_suffix}"
return file_path
def write_dir(dir_path: Path, def write_dir(dir_path: Path,
+2 -1
View File
@@ -9,6 +9,7 @@ OpenAIConfigInst = TypeVar('OpenAIConfigInst', bound='OpenAIConfig')
supported_ais: list[str] = ['openai'] supported_ais: list[str] = ['openai']
default_ai_ID: str = 'default'
default_config_path = '.config.yaml' default_config_path = '.config.yaml'
@@ -57,7 +58,7 @@ class OpenAIConfig(AIConfig):
# all members have default values, so we can easily create # all members have default values, so we can easily create
# a default configuration # a default configuration
ID: str = 'myopenai' ID: str = 'default'
api_key: str = '0123456789' api_key: str = '0123456789'
model: str = 'gpt-3.5-turbo-16k' model: str = 'gpt-3.5-turbo-16k'
temperature: float = 1.0 temperature: float = 1.0
+8 -14
View File
@@ -3,8 +3,6 @@ Module implementing message related functions and classes.
""" """
import pathlib import pathlib
import yaml import yaml
import tempfile
import shutil
from typing import Type, TypeVar, ClassVar, Optional, Any, Union, Final, Literal, Iterable from typing import Type, TypeVar, ClassVar, Optional, Any, Union, Final, Literal, Iterable
from dataclasses import dataclass, asdict, field from dataclasses import dataclass, asdict, field
from .tags import Tag, TagLine, TagError, match_tags, rename_tags from .tags import Tag, TagLine, TagError, match_tags, rename_tags
@@ -447,18 +445,16 @@ class Message():
* Answer.txt_header * Answer.txt_header
* Answer * Answer
""" """
with tempfile.NamedTemporaryFile(dir=file_path.parent, prefix=file_path.name, mode="w", delete=False) as temp_fd: with open(file_path, "w") as fd:
temp_file_path = pathlib.Path(temp_fd.name)
if self.tags: if self.tags:
temp_fd.write(f'{TagLine.from_set(self.tags)}\n') fd.write(f'{TagLine.from_set(self.tags)}\n')
if self.ai: if self.ai:
temp_fd.write(f'{AILine.from_ai(self.ai)}\n') fd.write(f'{AILine.from_ai(self.ai)}\n')
if self.model: if self.model:
temp_fd.write(f'{ModelLine.from_model(self.model)}\n') fd.write(f'{ModelLine.from_model(self.model)}\n')
temp_fd.write(f'{Question.txt_header}\n{self.question}\n') fd.write(f'{Question.txt_header}\n{self.question}\n')
if self.answer: if self.answer:
temp_fd.write(f'{Answer.txt_header}\n{self.answer}\n') fd.write(f'{Answer.txt_header}\n{self.answer}\n')
shutil.move(temp_file_path, file_path)
def __to_file_yaml(self, file_path: pathlib.Path) -> None: def __to_file_yaml(self, file_path: pathlib.Path) -> None:
""" """
@@ -470,8 +466,7 @@ class Message():
* Message.ai_yaml_key: str [Optional] * Message.ai_yaml_key: str [Optional]
* Message.model_yaml_key: str [Optional] * Message.model_yaml_key: str [Optional]
""" """
with tempfile.NamedTemporaryFile(dir=file_path.parent, prefix=file_path.name, mode="w", delete=False) as temp_fd: with open(file_path, "w") as fd:
temp_file_path = pathlib.Path(temp_fd.name)
data: YamlDict = {Question.yaml_key: str(self.question)} data: YamlDict = {Question.yaml_key: str(self.question)}
if self.answer: if self.answer:
data[Answer.yaml_key] = str(self.answer) data[Answer.yaml_key] = str(self.answer)
@@ -481,8 +476,7 @@ class Message():
data[self.model_yaml_key] = self.model data[self.model_yaml_key] = self.model
if self.tags: if self.tags:
data[self.tags_yaml_key] = sorted([str(tag) for tag in self.tags]) data[self.tags_yaml_key] = sorted([str(tag) for tag in self.tags])
yaml.dump(data, temp_fd, sort_keys=False) yaml.dump(data, fd, sort_keys=False)
shutil.move(temp_file_path, file_path)
def filter_tags(self, prefix: Optional[str] = None, contain: Optional[str] = None) -> set[Tag]: def filter_tags(self, prefix: Optional[str] = None, contain: Optional[str] = None) -> set[Tag]:
""" """
+2 -2
View File
@@ -10,7 +10,7 @@ from chatmastermind.ais.openai import OpenAI
class TestCreateAI(unittest.TestCase): class TestCreateAI(unittest.TestCase):
def setUp(self) -> None: def setUp(self) -> None:
self.args = MagicMock(spec=argparse.Namespace) self.args = MagicMock(spec=argparse.Namespace)
self.args.AI = 'myopenai' self.args.AI = 'default'
self.args.model = None self.args.model = None
self.args.max_tokens = None self.args.max_tokens = None
self.args.temperature = None self.args.temperature = None
@@ -18,7 +18,7 @@ class TestCreateAI(unittest.TestCase):
def test_create_ai_from_args(self) -> None: def test_create_ai_from_args(self) -> None:
# Create an AI with the default configuration # Create an AI with the default configuration
config = Config() config = Config()
self.args.AI = 'myopenai' self.args.AI = 'default'
ai = create_ai(self.args, config) ai = create_ai(self.args, config)
self.assertIsInstance(ai, OpenAI) self.assertIsInstance(ai, OpenAI)
-5
View File
@@ -481,8 +481,3 @@ class TestChatDB(unittest.TestCase):
chat_db.update_messages([message], write=True) chat_db.update_messages([message], write=True)
chat_db.read_db() chat_db.read_db()
self.assertEqual(chat_db.messages[0].answer, Answer("New answer")) self.assertEqual(chat_db.messages[0].answer, Answer("New answer"))
# test without file_path -> expect error
message1 = Message(question=Question("Question 1"),
answer=Answer("Answer 1"))
with self.assertRaises(ChatError):
chat_db.update_messages([message1])
+7 -7
View File
@@ -59,7 +59,7 @@ class TestConfig(unittest.TestCase):
source_dict = { source_dict = {
'db': './test_db/', 'db': './test_db/',
'ais': { 'ais': {
'myopenai': { 'default': {
'name': 'openai', 'name': 'openai',
'system': 'Custom system', 'system': 'Custom system',
'api_key': '9876543210', 'api_key': '9876543210',
@@ -75,10 +75,10 @@ class TestConfig(unittest.TestCase):
config = Config.from_dict(source_dict) config = Config.from_dict(source_dict)
self.assertEqual(config.db, './test_db/') self.assertEqual(config.db, './test_db/')
self.assertEqual(len(config.ais), 1) self.assertEqual(len(config.ais), 1)
self.assertEqual(config.ais['myopenai'].name, 'openai') self.assertEqual(config.ais['default'].name, 'openai')
self.assertEqual(cast(OpenAIConfig, config.ais['myopenai']).system, 'Custom system') self.assertEqual(cast(OpenAIConfig, config.ais['default']).system, 'Custom system')
# check that 'ID' has been added # check that 'ID' has been added
self.assertEqual(config.ais['myopenai'].ID, 'myopenai') self.assertEqual(config.ais['default'].ID, 'default')
def test_create_default_should_create_default_config(self) -> None: def test_create_default_should_create_default_config(self) -> None:
Config.create_default(Path(self.test_file.name)) Config.create_default(Path(self.test_file.name))
@@ -117,8 +117,8 @@ class TestConfig(unittest.TestCase):
config = Config( config = Config(
db='./test_db/', db='./test_db/',
ais={ ais={
'myopenai': OpenAIConfig( 'default': OpenAIConfig(
ID='myopenai', ID='default',
system='Custom system', system='Custom system',
api_key='9876543210', api_key='9876543210',
model='custom_model', model='custom_model',
@@ -135,7 +135,7 @@ class TestConfig(unittest.TestCase):
saved_config = yaml.load(f, Loader=yaml.FullLoader) saved_config = yaml.load(f, Loader=yaml.FullLoader)
self.assertEqual(saved_config['db'], './test_db/') self.assertEqual(saved_config['db'], './test_db/')
self.assertEqual(len(saved_config['ais']), 1) self.assertEqual(len(saved_config['ais']), 1)
self.assertEqual(saved_config['ais']['myopenai']['system'], 'Custom system') self.assertEqual(saved_config['ais']['default']['system'], 'Custom system')
def test_from_file_error_unknown_ai(self) -> None: def test_from_file_error_unknown_ai(self) -> None:
source_dict = { source_dict = {