Make cmm q -a "test" completely work, similar as before, tests are not fixed yet.

This commit is contained in:
Oleksandr Kozachuk
2023-09-09 19:24:45 +02:00
parent 6395941516
commit 458ec20cbd
4 changed files with 30 additions and 13 deletions
+6
View File
@@ -66,3 +66,9 @@ class AI(Protocol):
and is not implemented for all AIs. and is not implemented for all AIs.
""" """
raise NotImplementedError raise NotImplementedError
def print(self) -> None:
"""
Print some info about the current AI, like system message.
"""
pass
+14 -5
View File
@@ -43,16 +43,20 @@ class OpenAI(AI):
n=num_answers, n=num_answers,
frequency_penalty=self.config.frequency_penalty, frequency_penalty=self.config.frequency_penalty,
presence_penalty=self.config.presence_penalty) presence_penalty=self.config.presence_penalty)
answers: list[Message] = [] question.answer = Answer(response['choices'][0]['message']['content'])
for choice in response['choices']: # type: ignore question.tags = otags
question.ai = self.name
question.model = self.config.model
answers: list[Message] = [question]
for choice in response['choices'][1:]: # type: ignore
answers.append(Message(question=question.question, answers.append(Message(question=question.question,
answer=Answer(choice['message']['content']), answer=Answer(choice['message']['content']),
tags=otags, tags=otags,
ai=self.name, ai=self.name,
model=self.config.model)) model=self.config.model))
return AIResponse(answers, Tokens(response['usage']['prompt'], return AIResponse(answers, Tokens(response['usage']['prompt_tokens'],
response['usage']['completion'], response['usage']['completion_tokens'],
response['usage']['total'])) response['usage']['total_tokens']))
def models(self) -> list[str]: def models(self) -> list[str]:
""" """
@@ -95,3 +99,8 @@ class OpenAI(AI):
def tokens(self, data: Union[Message, Chat]) -> int: def tokens(self, data: Union[Message, Chat]) -> int:
raise NotImplementedError raise NotImplementedError
def print(self) -> None:
print(f"MODEL: {self.config.model}")
print("=== SYSTEM ===")
print(self.config.system)
+1 -3
View File
@@ -201,7 +201,7 @@ class Chat:
output.append(message.to_str(source_code_only=True)) output.append(message.to_str(source_code_only=True))
continue continue
output.append(message.to_str(with_tags, with_files)) output.append(message.to_str(with_tags, with_files))
output.append('\n' + ('-' * terminal_width()) + '\n') # output.append('\n' + ('-' * terminal_width()) + '\n')
if paged: if paged:
print_paged('\n'.join(output)) print_paged('\n'.join(output))
else: else:
@@ -361,8 +361,6 @@ class ChatDB(Chat):
Add the given new messages and set the file_path to the cache directory. Add the given new messages and set the file_path to the cache directory.
Only accepts messages without a file_path. Only accepts messages without a file_path.
""" """
if any(m.file_path is not None for m in messages):
raise ChatError("Can't add new messages with existing file_path")
if write: if write:
write_dir(self.cache_path, write_dir(self.cache_path,
messages, messages,
+9 -5
View File
@@ -63,15 +63,19 @@ def question_cmd(args: argparse.Namespace, config: Config) -> None:
# create the correct AI instance # create the correct AI instance
ai: AI = create_ai(args, config) ai: AI = create_ai(args, config)
if args.ask: if args.ask:
ai.print()
chat.print(paged=False)
response: AIResponse = ai.request(message, response: AIResponse = ai.request(message,
chat, chat,
args.num_answers, # FIXME args.num_answers, # FIXME
args.output_tags) # FIXME args.output_tags) # FIXME
assert response chat.add_to_cache(response.messages)
# TODO: for idx, msg in enumerate(response.messages):
# * add answer to the message above (and create print(f"=== ANSWER {idx+1} ===")
# more messages for any additional answers) print(msg.answer)
pass if response.tokens:
print("===============")
print(response.tokens)
elif args.repeat: elif args.repeat:
lmessage = chat.latest_message() lmessage = chat.latest_message()
assert lmessage assert lmessage