Fix the max_tokens, temperature, and model setup.
This commit is contained in:
+18
-8
@@ -21,7 +21,7 @@ def tags_completer(prefix, parsed_args, **kwargs):
|
||||
return get_tags_unique(config, prefix)
|
||||
|
||||
|
||||
def read_config(path: str):
|
||||
def read_config(path: str) -> ConfigType:
|
||||
with open(path, 'r') as f:
|
||||
config = yaml.load(f, Loader=yaml.FullLoader)
|
||||
return config
|
||||
@@ -81,6 +81,15 @@ def ask_cmd(args: argparse.Namespace, config: ConfigType) -> None:
|
||||
"""
|
||||
Handler for the 'ask' command.
|
||||
"""
|
||||
if type(config['openai']) is not dict:
|
||||
raise RuntimeError('Configuration openai is not a dict.')
|
||||
config_openai = config['openai']
|
||||
if args.max_tokens:
|
||||
config_openai['max_tokens'] = args.max_tokens
|
||||
if args.temperature:
|
||||
config_openai['temperature'] = args.temperature
|
||||
if args.model:
|
||||
config_openai['model'] = args.model
|
||||
chat, question, tags = create_question_with_hist(args, config)
|
||||
print_chat_hist(chat, False, args.only_source_code)
|
||||
otags = args.output_tags or []
|
||||
@@ -211,13 +220,14 @@ def main() -> int:
|
||||
config = read_config(args.config)
|
||||
|
||||
# modify config according to args
|
||||
openai_api_key(config['openai']['api_key'])
|
||||
if args.max_tokens:
|
||||
config['openai']['max_tokens'] = args.max_tokens
|
||||
if args.temperature:
|
||||
config['openai']['temperature'] = args.temperature
|
||||
if args.model:
|
||||
config['openai']['model'] = args.model
|
||||
if type(config['openai']) is dict:
|
||||
config_openai = config['openai']
|
||||
else:
|
||||
RuntimeError("Configuration openai is not a dict.")
|
||||
if type(config_openai['api_key']) is str:
|
||||
openai_api_key(config_openai['api_key'])
|
||||
else:
|
||||
raise RuntimeError("Configuration openai.api_key is not a string.")
|
||||
|
||||
command.func(command, config)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user