Working question -a
#7
@@ -66,3 +66,9 @@ class AI(Protocol):
|
||||
and is not implemented for all AIs.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def print(self) -> None:
|
||||
"""
|
||||
Print some info about the current AI, like system message.
|
||||
"""
|
||||
pass
|
||||
|
||||
@@ -43,16 +43,20 @@ class OpenAI(AI):
|
||||
n=num_answers,
|
||||
frequency_penalty=self.config.frequency_penalty,
|
||||
presence_penalty=self.config.presence_penalty)
|
||||
answers: list[Message] = []
|
||||
for choice in response['choices']: # type: ignore
|
||||
question.answer = Answer(response['choices'][0]['message']['content'])
|
||||
question.tags = otags
|
||||
question.ai = self.name
|
||||
question.model = self.config.model
|
||||
answers: list[Message] = [question]
|
||||
for choice in response['choices'][1:]: # type: ignore
|
||||
answers.append(Message(question=question.question,
|
||||
answer=Answer(choice['message']['content']),
|
||||
tags=otags,
|
||||
ai=self.name,
|
||||
model=self.config.model))
|
||||
return AIResponse(answers, Tokens(response['usage']['prompt'],
|
||||
response['usage']['completion'],
|
||||
response['usage']['total']))
|
||||
return AIResponse(answers, Tokens(response['usage']['prompt_tokens'],
|
||||
response['usage']['completion_tokens'],
|
||||
response['usage']['total_tokens']))
|
||||
|
||||
def models(self) -> list[str]:
|
||||
"""
|
||||
@@ -95,3 +99,8 @@ class OpenAI(AI):
|
||||
|
||||
def tokens(self, data: Union[Message, Chat]) -> int:
|
||||
raise NotImplementedError
|
||||
|
||||
def print(self) -> None:
|
||||
print(f"MODEL: {self.config.model}")
|
||||
print("=== SYSTEM ===")
|
||||
print(self.config.system)
|
||||
|
||||
@@ -201,7 +201,7 @@ class Chat:
|
||||
output.append(message.to_str(source_code_only=True))
|
||||
continue
|
||||
output.append(message.to_str(with_tags, with_files))
|
||||
output.append('\n' + ('-' * terminal_width()) + '\n')
|
||||
# output.append('\n' + ('-' * terminal_width()) + '\n')
|
||||
if paged:
|
||||
print_paged('\n'.join(output))
|
||||
else:
|
||||
@@ -361,8 +361,6 @@ class ChatDB(Chat):
|
||||
Add the given new messages and set the file_path to the cache directory.
|
||||
Only accepts messages without a file_path.
|
||||
"""
|
||||
if any(m.file_path is not None for m in messages):
|
||||
raise ChatError("Can't add new messages with existing file_path")
|
||||
if write:
|
||||
write_dir(self.cache_path,
|
||||
messages,
|
||||
|
||||
@@ -63,15 +63,19 @@ def question_cmd(args: argparse.Namespace, config: Config) -> None:
|
||||
# create the correct AI instance
|
||||
ai: AI = create_ai(args, config)
|
||||
if args.ask:
|
||||
ai.print()
|
||||
chat.print(paged=False)
|
||||
response: AIResponse = ai.request(message,
|
||||
chat,
|
||||
args.num_answers, # FIXME
|
||||
args.output_tags) # FIXME
|
||||
assert response
|
||||
# TODO:
|
||||
# * add answer to the message above (and create
|
||||
# more messages for any additional answers)
|
||||
pass
|
||||
chat.add_to_cache(response.messages)
|
||||
for idx, msg in enumerate(response.messages):
|
||||
print(f"=== ANSWER {idx+1} ===")
|
||||
print(msg.answer)
|
||||
if response.tokens:
|
||||
print("===============")
|
||||
print(response.tokens)
|
||||
elif args.repeat:
|
||||
lmessage = chat.latest_message()
|
||||
assert lmessage
|
||||
|
||||
Reference in New Issue
Block a user