configuration: minor improvements / fixes
Could not extend the subclass of 'TypedDict' the way I wanted, so I switched to 'dataclass'.
This commit is contained in:
@@ -1,8 +1,13 @@
|
||||
import pathlib
|
||||
from typing import TypedDict, Any, Union
|
||||
import yaml
|
||||
from typing import Type, TypeVar, Any
|
||||
from dataclasses import dataclass, asdict
|
||||
|
||||
ConfigInst = TypeVar('ConfigInst', bound='Config')
|
||||
OpenAIConfigInst = TypeVar('OpenAIConfigInst', bound='OpenAIConfig')
|
||||
|
||||
|
||||
class OpenAIConfig(TypedDict):
|
||||
@dataclass
|
||||
class OpenAIConfig():
|
||||
"""
|
||||
The OpenAI section of the configuration file.
|
||||
"""
|
||||
@@ -14,27 +19,24 @@ class OpenAIConfig(TypedDict):
|
||||
frequency_penalty: float
|
||||
presence_penalty: float
|
||||
|
||||
|
||||
def openai_config_valid(conf: dict[str, Union[str, float, int]]) -> bool:
|
||||
"""
|
||||
Checks if the given Open AI configuration dict is complete
|
||||
and contains valid types and values.
|
||||
"""
|
||||
try:
|
||||
str(conf['api_key'])
|
||||
str(conf['model'])
|
||||
int(conf['max_tokens'])
|
||||
float(conf['temperature'])
|
||||
float(conf['top_p'])
|
||||
float(conf['frequency_penalty'])
|
||||
float(conf['presence_penalty'])
|
||||
return True
|
||||
except Exception as e:
|
||||
print(f"OpenAI configuration is invalid: {e}")
|
||||
return False
|
||||
@classmethod
|
||||
def from_dict(cls: Type[OpenAIConfigInst], source: dict[str, Any]) -> OpenAIConfigInst:
|
||||
"""
|
||||
Create OpenAIConfig from a dict.
|
||||
"""
|
||||
return cls(
|
||||
api_key=str(source['api_key']),
|
||||
model=str(source['model']),
|
||||
max_tokens=int(source['max_tokens']),
|
||||
temperature=float(source['temperature']),
|
||||
top_p=float(source['top_p']),
|
||||
frequency_penalty=float(source['frequency_penalty']),
|
||||
presence_penalty=float(source['presence_penalty'])
|
||||
)
|
||||
|
||||
|
||||
class Config(TypedDict):
|
||||
@dataclass
|
||||
class Config():
|
||||
"""
|
||||
The configuration file structure.
|
||||
"""
|
||||
@@ -42,22 +44,23 @@ class Config(TypedDict):
|
||||
db: str
|
||||
openai: OpenAIConfig
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls: Type[ConfigInst], source: dict[str, Any]) -> ConfigInst:
|
||||
"""
|
||||
Create OpenAIConfig from a dict.
|
||||
"""
|
||||
return cls(
|
||||
system=str(source['system']),
|
||||
db=str(source['db']),
|
||||
openai=OpenAIConfig.from_dict(source['openai'])
|
||||
)
|
||||
|
||||
def config_valid(conf: dict[str, Any]) -> bool:
|
||||
"""
|
||||
Checks if the given configuration dict is complete
|
||||
and contains valid types and values.
|
||||
"""
|
||||
try:
|
||||
str(conf['system'])
|
||||
pathlib.Path(str(conf['db']))
|
||||
return True
|
||||
except Exception as e:
|
||||
print(f"Configuration is invalid: {e}")
|
||||
return False
|
||||
if 'openai' in conf:
|
||||
return openai_config_valid(conf['openai'])
|
||||
else:
|
||||
# required as long as we only support OpenAI
|
||||
print("Section 'openai' is missing in the configuration!")
|
||||
return False
|
||||
@classmethod
|
||||
def from_file(cls: Type[ConfigInst], path: str) -> ConfigInst:
|
||||
with open(path, 'r') as f:
|
||||
source = yaml.load(f, Loader=yaml.FullLoader)
|
||||
return cls.from_dict(source)
|
||||
|
||||
def to_file(self, path: str) -> None:
|
||||
with open(path, 'w') as f:
|
||||
yaml.dump(asdict(self), f)
|
||||
|
||||
Reference in New Issue
Block a user