2 Commits

Author SHA1 Message Date
juk0de fe4d28278d main: removed old code, cleaned up new code 2023-09-08 10:02:13 +02:00
juk0de 8dbe8b5f25 configuration et al: implemented new Config format 2023-09-08 10:01:48 +02:00
2 changed files with 8 additions and 98 deletions
+8 -10
View File
@@ -1,6 +1,6 @@
import yaml import yaml
from pathlib import Path from pathlib import Path
from typing import Type, TypeVar, Any, Optional, Final from typing import Type, TypeVar, Any, Optional
from dataclasses import dataclass, asdict, field from dataclasses import dataclass, asdict, field
ConfigInst = TypeVar('ConfigInst', bound='Config') ConfigInst = TypeVar('ConfigInst', bound='Config')
@@ -34,6 +34,7 @@ class OpenAIConfig(AIConfig):
# all members have default values, so we can easily create # all members have default values, so we can easily create
# a default configuration # a default configuration
ID: str = 'default' ID: str = 'default'
name: str = 'openai'
api_key: str = '0123456789' api_key: str = '0123456789'
system: str = 'You are an assistant' system: str = 'You are an assistant'
model: str = 'gpt-3.5-turbo-16k' model: str = 'gpt-3.5-turbo-16k'
@@ -42,15 +43,15 @@ class OpenAIConfig(AIConfig):
top_p: float = 1.0 top_p: float = 1.0
frequency_penalty: float = 0.0 frequency_penalty: float = 0.0
presence_penalty: float = 0.0 presence_penalty: float = 0.0
# the name should not be changed
name: Final[str] = 'openai'
@classmethod @classmethod
def from_dict(cls: Type[OpenAIConfigInst], source: dict[str, Any]) -> OpenAIConfigInst: def from_dict(cls: Type[OpenAIConfigInst], source: dict[str, Any]) -> OpenAIConfigInst:
""" """
Create OpenAIConfig from a dict. Create OpenAIConfig from a dict.
""" """
res = cls( return cls(
ID='openai',
name='openai',
system=str(source['system']), system=str(source['system']),
api_key=str(source['api_key']), api_key=str(source['api_key']),
model=str(source['model']), model=str(source['model']),
@@ -60,10 +61,6 @@ class OpenAIConfig(AIConfig):
frequency_penalty=float(source['frequency_penalty']), frequency_penalty=float(source['frequency_penalty']),
presence_penalty=float(source['presence_penalty']) presence_penalty=float(source['presence_penalty'])
) )
# overwrite default ID if provided
if 'ID' in source:
res.ID = source['ID']
return res
def as_dict(self) -> dict[str, Any]: def as_dict(self) -> dict[str, Any]:
return asdict(self) return asdict(self)
@@ -107,8 +104,6 @@ class Config:
# create the correct AI type instances # create the correct AI type instances
ais: dict[str, AIConfig] = {} ais: dict[str, AIConfig] = {}
for ID, conf in source['ais'].items(): for ID, conf in source['ais'].items():
# add the AI ID to the config (for easy internal access)
conf['ID'] = ID
ai_conf = ai_config_instance(conf['name'], conf) ai_conf = ai_config_instance(conf['name'], conf)
ais[ID] = ai_conf ais[ID] = ai_conf
return cls( return cls(
@@ -128,6 +123,9 @@ class Config:
def from_file(cls: Type[ConfigInst], path: str) -> ConfigInst: def from_file(cls: Type[ConfigInst], path: str) -> ConfigInst:
with open(path, 'r') as f: with open(path, 'r') as f:
source = yaml.load(f, Loader=yaml.FullLoader) source = yaml.load(f, Loader=yaml.FullLoader)
# add the AI ID to the config (for easy internal access)
for ID, conf in source['ais'].items():
conf['ID'] = ID
return cls.from_dict(source) return cls.from_dict(source)
def to_file(self, file_path: Path) -> None: def to_file(self, file_path: Path) -> None:
-88
View File
@@ -1,88 +0,0 @@
import os
import unittest
import yaml
from tempfile import NamedTemporaryFile
from pathlib import Path
from typing import cast
from chatmastermind.configuration import OpenAIConfig, ConfigError, ai_config_instance, Config
class TestAIConfigInstance(unittest.TestCase):
def test_ai_config_instance_with_valid_name_should_return_instance_with_default_values(self) -> None:
ai_config = cast(OpenAIConfig, ai_config_instance('openai'))
ai_reference = OpenAIConfig()
self.assertEqual(ai_config.ID, ai_reference.ID)
self.assertEqual(ai_config.name, ai_reference.name)
self.assertEqual(ai_config.api_key, ai_reference.api_key)
self.assertEqual(ai_config.system, ai_reference.system)
self.assertEqual(ai_config.model, ai_reference.model)
self.assertEqual(ai_config.temperature, ai_reference.temperature)
self.assertEqual(ai_config.max_tokens, ai_reference.max_tokens)
self.assertEqual(ai_config.top_p, ai_reference.top_p)
self.assertEqual(ai_config.frequency_penalty, ai_reference.frequency_penalty)
self.assertEqual(ai_config.presence_penalty, ai_reference.presence_penalty)
def test_ai_config_instance_with_valid_name_and_configuration_should_return_instance_with_custom_values(self) -> None:
conf_dict = {
'system': 'Custom system',
'api_key': '9876543210',
'model': 'custom_model',
'max_tokens': 5000,
'temperature': 0.5,
'top_p': 0.8,
'frequency_penalty': 0.7,
'presence_penalty': 0.2
}
ai_config = cast(OpenAIConfig, ai_config_instance('openai', conf_dict))
self.assertEqual(ai_config.system, 'Custom system')
self.assertEqual(ai_config.api_key, '9876543210')
self.assertEqual(ai_config.model, 'custom_model')
self.assertEqual(ai_config.max_tokens, 5000)
self.assertAlmostEqual(ai_config.temperature, 0.5)
self.assertAlmostEqual(ai_config.top_p, 0.8)
self.assertAlmostEqual(ai_config.frequency_penalty, 0.7)
self.assertAlmostEqual(ai_config.presence_penalty, 0.2)
def test_ai_config_instance_with_invalid_name_should_raise_config_error(self) -> None:
with self.assertRaises(ConfigError):
ai_config_instance('invalid_name')
class TestConfig(unittest.TestCase):
def setUp(self) -> None:
self.test_file = NamedTemporaryFile(delete=False)
def tearDown(self) -> None:
os.remove(self.test_file.name)
def test_from_dict_should_create_config_from_dict(self) -> None:
source_dict = {
'db': './test_db/',
'ais': {
'default': {
'name': 'openai',
'system': 'Custom system',
'api_key': '9876543210',
'model': 'custom_model',
'max_tokens': 5000,
'temperature': 0.5,
'top_p': 0.8,
'frequency_penalty': 0.7,
'presence_penalty': 0.2
}
}
}
config = Config.from_dict(source_dict)
self.assertEqual(config.db, './test_db/')
self.assertEqual(len(config.ais), 1)
self.assertEqual(config.ais['default'].name, 'openai')
self.assertEqual(cast(OpenAIConfig, config.ais['default']).system, 'Custom system')
# check that 'ID' has been added
self.assertEqual(config.ais['default'].ID, 'default')
def test_create_default_should_create_default_config(self) -> None:
Config.create_default(Path(self.test_file.name))
with open(self.test_file.name, 'r') as f:
default_config = yaml.load(f, Loader=yaml.FullLoader)
config_reference = Config()
self.assertEqual(default_config['db'], config_reference.db)