'cmm question --repeat' supports multiple questions, added tests and fixes #15

Merged
juk0de merged 11 commits from repeat_multi into main 2023-09-26 18:04:29 +02:00
Showing only changes of commit a478408449 - Show all commits
+90 -66
View File
@@ -2,6 +2,7 @@ import os
import unittest import unittest
import argparse import argparse
import tempfile import tempfile
from copy import copy
from pathlib import Path from pathlib import Path
from unittest import mock from unittest import mock
from unittest.mock import MagicMock, call from unittest.mock import MagicMock, call
@@ -302,53 +303,6 @@ class TestQuestionCmd(TestQuestionCmdBase):
overwrite=None overwrite=None
) )
def create_single_message(self, args: argparse.Namespace, with_answer: bool = True) -> Message:
message = Message(Question(args.ask[0]),
tags=set(args.output_tags) if args.output_tags is not None else None,
ai=args.AI,
model=args.model,
file_path=Path(self.cache_dir.name) / '0001.txt')
if with_answer:
message.answer = Answer('Answer 0')
message.to_file()
return message
def create_multiple_messages(self) -> list[Message]:
# cached message without an answer
message1 = Message(Question('Question 1'),
ai='foo',
model='bla',
file_path=Path(self.cache_dir.name) / '0001.txt')
# cached message with an answer
message2 = Message(Question('Question 2'),
Answer('Answer 0'),
ai='openai',
model='gpt-3.5-turbo',
file_path=Path(self.cache_dir.name) / '0002.txt')
# DB message without an answer
message3 = Message(Question('Question 3'),
ai='openai',
model='gpt-3.5-turbo',
file_path=Path(self.db_dir.name) / '0003.txt')
message1.to_file()
message2.to_file()
message3.to_file()
return [message1, message2, message3]
def input_message(self, args: argparse.Namespace) -> Message:
"""
Create the expected input message for a question using the
given arguments.
"""
# NOTE: we only use the first question from the "ask" list
# -> message creation using "question.create_message()" is
# tested above
# the answer is always empty for the input message
return Message(Question(args.ask[0]),
tags=args.output_tags,
ai=args.AI,
model=args.model)
def message_list(self, tmp_dir: tempfile.TemporaryDirectory) -> list[Path]: def message_list(self, tmp_dir: tempfile.TemporaryDirectory) -> list[Path]:
# exclude '.next' # exclude '.next'
return sorted([f for f in Path(tmp_dir.name).glob('*.[ty]*')]) return sorted([f for f in Path(tmp_dir.name).glob('*.[ty]*')])
@@ -362,7 +316,11 @@ class TestQuestionCmdAsk(TestQuestionCmd):
Test single answer with no errors. Test single answer with no errors.
""" """
mock_create_ai.side_effect = self.mock_create_ai mock_create_ai.side_effect = self.mock_create_ai
expected_question = self.input_message(self.args) expected_question = Message(Question(self.args.ask[0]),
tags=set(self.args.output_tags),
ai=self.args.AI,
model=self.args.model,
file_path=Path('<NOT COMPARED>'))
fake_ai = self.mock_create_ai(self.args, self.config) fake_ai = self.mock_create_ai(self.args, self.config)
expected_responses = fake_ai.request(expected_question, expected_responses = fake_ai.request(expected_question,
Chat([]), Chat([]),
@@ -389,7 +347,11 @@ class TestQuestionCmdAsk(TestQuestionCmd):
mock_from_dir.return_value = chat mock_from_dir.return_value = chat
mock_create_ai.side_effect = self.mock_create_ai mock_create_ai.side_effect = self.mock_create_ai
expected_question = self.input_message(self.args) expected_question = Message(Question(self.args.ask[0]),
tags=set(self.args.output_tags),
ai=self.args.AI,
model=self.args.model,
file_path=Path('<NOT COMPARED>'))
fake_ai = self.mock_create_ai(self.args, self.config) fake_ai = self.mock_create_ai(self.args, self.config)
expected_responses = fake_ai.request(expected_question, expected_responses = fake_ai.request(expected_question,
Chat([]), Chat([]),
@@ -416,7 +378,11 @@ class TestQuestionCmdAsk(TestQuestionCmd):
has been correctly stored in the cache. has been correctly stored in the cache.
""" """
mock_create_ai.side_effect = self.mock_create_ai_with_error mock_create_ai.side_effect = self.mock_create_ai_with_error
expected_question = self.input_message(self.args) expected_question = Message(Question(self.args.ask[0]),
tags=set(self.args.output_tags),
ai=self.args.AI,
model=self.args.model,
file_path=Path('<NOT COMPARED>'))
# execute the command # execute the command
with self.assertRaises(AIError): with self.assertRaises(AIError):
@@ -439,20 +405,28 @@ class TestQuestionCmdRepeat(TestQuestionCmd):
""" """
mock_create_ai.side_effect = self.mock_create_ai mock_create_ai.side_effect = self.mock_create_ai
# create a message # create a message
message = self.create_single_message(self.args) message = Message(Question(self.args.ask[0]),
Answer('Old Answer'),
tags=set(self.args.output_tags),
ai=self.args.AI,
model=self.args.model,
file_path=Path(self.cache_dir.name) / '0001.txt')
message.to_file()
# repeat the last question (without overwriting) # repeat the last question (without overwriting)
# -> expect two identical messages (except for the file_path) # -> expect two identical messages (except for the file_path)
self.args.ask = None self.args.ask = None
self.args.repeat = [] self.args.repeat = []
self.args.output_tags = []
self.args.overwrite = False self.args.overwrite = False
fake_ai = self.mock_create_ai(self.args, self.config) fake_ai = self.mock_create_ai(self.args, self.config)
expected_response = fake_ai.request(message, # since the message's answer is modified, we use a copy here
# -> the original is used for comparison below
expected_response = fake_ai.request(copy(message),
Chat([]), Chat([]),
self.args.num_answers, self.args.num_answers,
set(self.args.output_tags)).messages set(self.args.output_tags)).messages
expected_responses = expected_response + expected_response # we expect the original message + the one with the new response
expected_responses = [message] + expected_response
question_cmd(self.args, self.config) question_cmd(self.args, self.config)
chat = ChatDB.from_dir(Path(self.cache_dir.name), chat = ChatDB.from_dir(Path(self.cache_dir.name),
Path(self.db_dir.name)) Path(self.db_dir.name))
@@ -468,7 +442,13 @@ class TestQuestionCmdRepeat(TestQuestionCmd):
""" """
mock_create_ai.side_effect = self.mock_create_ai mock_create_ai.side_effect = self.mock_create_ai
# create a message # create a message
message = self.create_single_message(self.args) message = Message(Question(self.args.ask[0]),
Answer('Old Answer'),
tags=set(self.args.output_tags),
ai=self.args.AI,
model=self.args.model,
file_path=Path(self.cache_dir.name) / '0001.txt')
message.to_file()
chat = ChatDB.from_dir(Path(self.cache_dir.name), chat = ChatDB.from_dir(Path(self.cache_dir.name),
Path(self.db_dir.name)) Path(self.db_dir.name))
cached_msg = chat.msg_gather(loc='cache') cached_msg = chat.msg_gather(loc='cache')
@@ -501,7 +481,12 @@ class TestQuestionCmdRepeat(TestQuestionCmd):
mock_create_ai.side_effect = self.mock_create_ai mock_create_ai.side_effect = self.mock_create_ai
# create a question WITHOUT an answer # create a question WITHOUT an answer
# -> just like after an error, which is tested above # -> just like after an error, which is tested above
question = self.create_single_message(self.args, with_answer=False) message = Message(Question(self.args.ask[0]),
tags=set(self.args.output_tags),
ai=self.args.AI,
model=self.args.model,
file_path=Path(self.cache_dir.name) / '0001.txt')
message.to_file()
chat = ChatDB.from_dir(Path(self.cache_dir.name), chat = ChatDB.from_dir(Path(self.cache_dir.name),
Path(self.db_dir.name)) Path(self.db_dir.name))
cached_msg = chat.msg_gather(loc='cache') cached_msg = chat.msg_gather(loc='cache')
@@ -515,7 +500,7 @@ class TestQuestionCmdRepeat(TestQuestionCmd):
self.args.repeat = [] self.args.repeat = []
self.args.overwrite = False self.args.overwrite = False
fake_ai = self.mock_create_ai(self.args, self.config) fake_ai = self.mock_create_ai(self.args, self.config)
expected_response = fake_ai.request(question, expected_response = fake_ai.request(message,
Chat([]), Chat([]),
self.args.num_answers, self.args.num_answers,
self.args.output_tags).messages self.args.output_tags).messages
@@ -534,7 +519,13 @@ class TestQuestionCmdRepeat(TestQuestionCmd):
""" """
mock_create_ai.side_effect = self.mock_create_ai mock_create_ai.side_effect = self.mock_create_ai
# create a message # create a message
message = self.create_single_message(self.args) message = Message(Question(self.args.ask[0]),
Answer('Old Answer'),
tags=set(self.args.output_tags),
ai=self.args.AI,
model=self.args.model,
file_path=Path(self.cache_dir.name) / '0001.txt')
message.to_file()
chat = ChatDB.from_dir(Path(self.cache_dir.name), chat = ChatDB.from_dir(Path(self.cache_dir.name),
Path(self.db_dir.name)) Path(self.db_dir.name))
cached_msg = chat.msg_gather(loc='cache') cached_msg = chat.msg_gather(loc='cache')
@@ -561,19 +552,48 @@ class TestQuestionCmdRepeat(TestQuestionCmd):
cached_msg = chat.msg_gather(loc='cache') cached_msg = chat.msg_gather(loc='cache')
self.assertEqual(len(self.message_list(self.cache_dir)), 2) self.assertEqual(len(self.message_list(self.cache_dir)), 2)
self.assert_messages_equal(cached_msg, [message] + new_expected_response) self.assert_messages_equal(cached_msg, [message] + new_expected_response)
print(cached_msg)
print(message)
print(new_expected_question)
@mock.patch('chatmastermind.commands.question.create_ai') @mock.patch('chatmastermind.commands.question.create_ai')
def test_repeat_multiple_questions(self, mock_create_ai: MagicMock) -> None: def test_repeat_multiple_questions(self, mock_create_ai: MagicMock) -> None:
""" """
Repeat multiple questions. Repeat multiple questions.
""" """
# chat = ChatDB.from_dir(Path(self.cache_dir.name), mock_create_ai.side_effect = self.mock_create_ai
# Path(self.db_dir.name)) # 1. === create three questions ===
# cached message without an answer
message1 = Message(Question(self.args.ask[0]),
tags=self.args.output_tags,
ai=self.args.AI,
model=self.args.model,
file_path=Path(self.cache_dir.name) / '0001.txt')
# cached message with an answer
message2 = Message(Question(self.args.ask[0]),
Answer('Old Answer'),
tags=self.args.output_tags,
ai=self.args.AI,
model=self.args.model,
file_path=Path(self.cache_dir.name) / '0002.txt')
# DB message without an answer
message3 = Message(Question(self.args.ask[0]),
tags=self.args.output_tags,
ai=self.args.AI,
model=self.args.model,
file_path=Path(self.db_dir.name) / '0003.txt')
message1.to_file()
message2.to_file()
message3.to_file()
questions = [message1, message2, message3]
expected_responses: list[Message] = []
fake_ai = self.mock_create_ai(self.args, self.config)
for question in questions:
# since the message's answer is modified, we use a copy
# -> the original is used for comparison below
expected_responses += fake_ai.request(copy(question),
Chat([]),
self.args.num_answers,
set(self.args.output_tags)).messages
# 2. repeat all three questions (without overwriting) # 2. === repeat all three questions (without overwriting) ===
self.args.ask = None self.args.ask = None
self.args.repeat = ['0001', '0002', '0003'] self.args.repeat = ['0001', '0002', '0003']
self.args.overwrite = False self.args.overwrite = False
@@ -581,7 +601,11 @@ class TestQuestionCmdRepeat(TestQuestionCmd):
# two new files should be in the cache directory # two new files should be in the cache directory
# * the repeated cached message with answer # * the repeated cached message with answer
# * the repeated DB message # * the repeated DB message
# -> the cached message wihtout answer should be overwritten # -> the cached message without answer should be overwritten
self.assertEqual(len(self.message_list(self.cache_dir)), 4) self.assertEqual(len(self.message_list(self.cache_dir)), 4)
self.assertEqual(len(self.message_list(self.db_dir)), 1) self.assertEqual(len(self.message_list(self.db_dir)), 1)
# FIXME: also compare actual content! expected_cache_messages = [expected_responses[0], message2, expected_responses[1], expected_responses[2]]
chat = ChatDB.from_dir(Path(self.cache_dir.name),
Path(self.db_dir.name))
cached_msg = chat.msg_gather(loc='cache')
self.assert_messages_equal(cached_msg, expected_cache_messages)