Fix the max_tokens, temperature, and model setup.

This commit is contained in:
Oleksandr Kozachuk
2023-08-12 12:20:49 +02:00
parent bc5e6228a6
commit e4d055b900
2 changed files with 23 additions and 14 deletions
+18 -8
View File
@@ -21,7 +21,7 @@ def tags_completer(prefix, parsed_args, **kwargs):
return get_tags_unique(config, prefix)
def read_config(path: str):
def read_config(path: str) -> ConfigType:
with open(path, 'r') as f:
config = yaml.load(f, Loader=yaml.FullLoader)
return config
@@ -81,6 +81,15 @@ def ask_cmd(args: argparse.Namespace, config: ConfigType) -> None:
"""
Handler for the 'ask' command.
"""
if type(config['openai']) is not dict:
raise RuntimeError('Configuration openai is not a dict.')
config_openai = config['openai']
if args.max_tokens:
config_openai['max_tokens'] = args.max_tokens
if args.temperature:
config_openai['temperature'] = args.temperature
if args.model:
config_openai['model'] = args.model
chat, question, tags = create_question_with_hist(args, config)
print_chat_hist(chat, False, args.only_source_code)
otags = args.output_tags or []
@@ -211,13 +220,14 @@ def main() -> int:
config = read_config(args.config)
# modify config according to args
openai_api_key(config['openai']['api_key'])
if args.max_tokens:
config['openai']['max_tokens'] = args.max_tokens
if args.temperature:
config['openai']['temperature'] = args.temperature
if args.model:
config['openai']['model'] = args.model
if type(config['openai']) is dict:
config_openai = config['openai']
else:
RuntimeError("Configuration openai is not a dict.")
if type(config_openai['api_key']) is str:
openai_api_key(config_openai['api_key'])
else:
raise RuntimeError("Configuration openai.api_key is not a string.")
command.func(command, config)