configuration: minor improvements / fixes

Could not extend the subclass of 'TypedDict' the way I wanted, so I switched to 'dataclass'.
This commit is contained in:
2023-08-16 23:22:20 +02:00
parent 380b7c1b67
commit a5c91adc41
5 changed files with 80 additions and 95 deletions
+43 -40
View File
@@ -1,8 +1,13 @@
import pathlib
from typing import TypedDict, Any, Union
import yaml
from typing import Type, TypeVar, Any
from dataclasses import dataclass, asdict
ConfigInst = TypeVar('ConfigInst', bound='Config')
OpenAIConfigInst = TypeVar('OpenAIConfigInst', bound='OpenAIConfig')
class OpenAIConfig(TypedDict):
@dataclass
class OpenAIConfig():
"""
The OpenAI section of the configuration file.
"""
@@ -14,27 +19,24 @@ class OpenAIConfig(TypedDict):
frequency_penalty: float
presence_penalty: float
def openai_config_valid(conf: dict[str, Union[str, float, int]]) -> bool:
"""
Checks if the given Open AI configuration dict is complete
and contains valid types and values.
"""
try:
str(conf['api_key'])
str(conf['model'])
int(conf['max_tokens'])
float(conf['temperature'])
float(conf['top_p'])
float(conf['frequency_penalty'])
float(conf['presence_penalty'])
return True
except Exception as e:
print(f"OpenAI configuration is invalid: {e}")
return False
@classmethod
def from_dict(cls: Type[OpenAIConfigInst], source: dict[str, Any]) -> OpenAIConfigInst:
"""
Create OpenAIConfig from a dict.
"""
return cls(
api_key=str(source['api_key']),
model=str(source['model']),
max_tokens=int(source['max_tokens']),
temperature=float(source['temperature']),
top_p=float(source['top_p']),
frequency_penalty=float(source['frequency_penalty']),
presence_penalty=float(source['presence_penalty'])
)
class Config(TypedDict):
@dataclass
class Config():
"""
The configuration file structure.
"""
@@ -42,22 +44,23 @@ class Config(TypedDict):
db: str
openai: OpenAIConfig
@classmethod
def from_dict(cls: Type[ConfigInst], source: dict[str, Any]) -> ConfigInst:
"""
Create OpenAIConfig from a dict.
"""
return cls(
system=str(source['system']),
db=str(source['db']),
openai=OpenAIConfig.from_dict(source['openai'])
)
def config_valid(conf: dict[str, Any]) -> bool:
"""
Checks if the given configuration dict is complete
and contains valid types and values.
"""
try:
str(conf['system'])
pathlib.Path(str(conf['db']))
return True
except Exception as e:
print(f"Configuration is invalid: {e}")
return False
if 'openai' in conf:
return openai_config_valid(conf['openai'])
else:
# required as long as we only support OpenAI
print("Section 'openai' is missing in the configuration!")
return False
@classmethod
def from_file(cls: Type[ConfigInst], path: str) -> ConfigInst:
with open(path, 'r') as f:
source = yaml.load(f, Loader=yaml.FullLoader)
return cls.from_dict(source)
def to_file(self, path: str) -> None:
with open(path, 'w') as f:
yaml.dump(asdict(self), f)