test_question_cmd: introduced 'FakeAI' class

This commit is contained in:
2023-09-22 13:38:24 +02:00
parent 80c5dcc801
commit b50caa345c
+163 -136
View File
@@ -4,9 +4,9 @@ import argparse
import tempfile import tempfile
from pathlib import Path from pathlib import Path
from unittest import mock from unittest import mock
from unittest.mock import MagicMock, call, ANY from unittest.mock import MagicMock, call
from typing import Optional from typing import Optional, Union
from chatmastermind.configuration import Config from chatmastermind.configuration import Config, AIConfig
from chatmastermind.commands.question import create_message, question_cmd from chatmastermind.commands.question import create_message, question_cmd
from chatmastermind.tags import Tag from chatmastermind.tags import Tag
from chatmastermind.message import Message, Question, Answer from chatmastermind.message import Message, Question, Answer
@@ -14,6 +14,56 @@ from chatmastermind.chat import Chat, ChatDB
from chatmastermind.ai import AI, AIResponse, Tokens, AIError from chatmastermind.ai import AI, AIResponse, Tokens, AIError
class FakeAI(AI):
"""
A mocked version of the 'AI' class.
"""
ID: str
name: str
config: AIConfig
def models(self) -> list[str]:
raise NotImplementedError
def tokens(self, data: Union[Message, Chat]) -> int:
return 123
def print(self) -> None:
pass
def print_models(self) -> None:
pass
def __init__(self, ID: str, model: str, error: bool = False):
self.ID = ID
self.model = model
self.error = error
def request(self,
question: Message,
chat: Chat,
num_answers: int = 1,
otags: Optional[set[Tag]] = None) -> AIResponse:
"""
Mock the 'ai.request()' function by either returning fake
answers or raising an exception.
"""
if self.error:
raise AIError
question.answer = Answer("Answer 0")
question.tags = set(otags) if otags is not None else None
question.ai = self.ID
question.model = self.model
answers: list[Message] = [question]
for n in range(1, num_answers):
answers.append(Message(question=question.question,
answer=Answer(f"Answer {n}"),
tags=otags,
ai=self.ID,
model=self.model))
return AIResponse(answers, Tokens(10, 10, 20))
class TestQuestionCmdBase(unittest.TestCase): class TestQuestionCmdBase(unittest.TestCase):
def assert_messages_equal(self, msg1: list[Message], msg2: list[Message]) -> None: def assert_messages_equal(self, msg1: list[Message], msg2: list[Message]) -> None:
""" """
@@ -24,6 +74,18 @@ class TestQuestionCmdBase(unittest.TestCase):
# exclude the file_path, compare only Q, A and metadata # exclude the file_path, compare only Q, A and metadata
self.assertTrue(m1.equals(m2, file_path=False, verbose=True)) self.assertTrue(m1.equals(m2, file_path=False, verbose=True))
def mock_create_ai(self, args: argparse.Namespace, config: Config) -> AI:
"""
Mocked 'create_ai' that returns a 'FakeAI' instance.
"""
return FakeAI(args.AI, args.model)
def mock_create_ai_with_error(self, args: argparse.Namespace, config: Config) -> AI:
"""
Mocked 'create_ai' that returns a 'FakeAI' instance.
"""
return FakeAI(args.AI, args.model, error=True)
class TestMessageCreate(TestQuestionCmdBase): class TestMessageCreate(TestQuestionCmdBase):
""" """
@@ -227,8 +289,8 @@ class TestQuestionCmd(TestQuestionCmdBase):
ask=['What is the meaning of life?'], ask=['What is the meaning of life?'],
num_answers=1, num_answers=1,
output_tags=['science'], output_tags=['science'],
AI='openai', AI='FakeAI',
model='gpt-3.5-turbo', model='FakeModel',
or_tags=None, or_tags=None,
and_tags=None, and_tags=None,
exclude_tags=None, exclude_tags=None,
@@ -239,9 +301,39 @@ class TestQuestionCmd(TestQuestionCmdBase):
process=None, process=None,
overwrite=None overwrite=None
) )
# create a mock AI instance
self.ai = MagicMock(spec=AI) def create_single_message(self, args: argparse.Namespace, with_answer: bool = True) -> Message:
self.ai.request.side_effect = self.mock_request message = Message(Question(args.ask[0]),
tags=set(args.output_tags) if args.output_tags is not None else None,
ai=args.AI,
model=args.model,
file_path=Path(self.cache_dir.name) / '0001.txt')
if with_answer:
message.answer = Answer('Answer 0')
message.to_file()
return message
def create_multiple_messages(self) -> list[Message]:
# cached message without an answer
message1 = Message(Question('Question 1'),
ai='foo',
model='bla',
file_path=Path(self.cache_dir.name) / '0001.txt')
# cached message with an answer
message2 = Message(Question('Question 2'),
Answer('Answer 0'),
ai='openai',
model='gpt-3.5-turbo',
file_path=Path(self.cache_dir.name) / '0002.txt')
# DB message without an answer
message3 = Message(Question('Question 3'),
ai='openai',
model='gpt-3.5-turbo',
file_path=Path(self.db_dir.name) / '0003.txt')
message1.to_file()
message2.to_file()
message3.to_file()
return [message1, message2, message3]
def input_message(self, args: argparse.Namespace) -> Message: def input_message(self, args: argparse.Namespace) -> Message:
""" """
@@ -257,27 +349,6 @@ class TestQuestionCmd(TestQuestionCmdBase):
ai=args.AI, ai=args.AI,
model=args.model) model=args.model)
def mock_request(self,
question: Message,
chat: Chat,
num_answers: int = 1,
otags: Optional[set[Tag]] = None) -> AIResponse:
"""
Mock the 'ai.request()' function
"""
question.answer = Answer("Answer 0")
question.tags = set(otags) if otags else None
question.ai = 'FakeAI'
question.model = 'FakeModel'
answers: list[Message] = [question]
for n in range(1, num_answers):
answers.append(Message(question=question.question,
answer=Answer(f"Answer {n}"),
tags=otags,
ai='FakeAI',
model='FakeModel'))
return AIResponse(answers, Tokens(10, 10, 20))
def message_list(self, tmp_dir: tempfile.TemporaryDirectory) -> list[Path]: def message_list(self, tmp_dir: tempfile.TemporaryDirectory) -> list[Path]:
# exclude '.next' # exclude '.next'
return sorted([f for f in Path(tmp_dir.name).glob('*.[ty]*')]) return sorted([f for f in Path(tmp_dir.name).glob('*.[ty]*')])
@@ -290,21 +361,17 @@ class TestQuestionCmdAsk(TestQuestionCmd):
""" """
Test single answer with no errors. Test single answer with no errors.
""" """
mock_create_ai.return_value = self.ai mock_create_ai.side_effect = self.mock_create_ai
expected_question = self.input_message(self.args) expected_question = self.input_message(self.args)
expected_responses = self.mock_request(expected_question, fake_ai = self.mock_create_ai(self.args, self.config)
Chat([]), expected_responses = fake_ai.request(expected_question,
self.args.num_answers, Chat([]),
self.args.output_tags).messages self.args.num_answers,
self.args.output_tags).messages
# execute the command # execute the command
question_cmd(self.args, self.config) question_cmd(self.args, self.config)
# check for correct request call
self.ai.request.assert_called_once_with(expected_question,
ANY,
self.args.num_answers,
self.args.output_tags)
# check for the expected message files # check for the expected message files
chat = ChatDB.from_dir(Path(self.cache_dir.name), chat = ChatDB.from_dir(Path(self.cache_dir.name),
Path(self.db_dir.name)) Path(self.db_dir.name))
@@ -321,22 +388,17 @@ class TestQuestionCmdAsk(TestQuestionCmd):
chat = MagicMock(spec=ChatDB) chat = MagicMock(spec=ChatDB)
mock_from_dir.return_value = chat mock_from_dir.return_value = chat
mock_create_ai.return_value = self.ai mock_create_ai.side_effect = self.mock_create_ai
expected_question = self.input_message(self.args) expected_question = self.input_message(self.args)
expected_responses = self.mock_request(expected_question, fake_ai = self.mock_create_ai(self.args, self.config)
Chat([]), expected_responses = fake_ai.request(expected_question,
self.args.num_answers, Chat([]),
self.args.output_tags).messages self.args.num_answers,
self.args.output_tags).messages
# execute the command # execute the command
question_cmd(self.args, self.config) question_cmd(self.args, self.config)
# check for correct request call
self.ai.request.assert_called_once_with(expected_question,
chat,
self.args.num_answers,
self.args.output_tags)
# check for the correct ChatDB calls: # check for the correct ChatDB calls:
# - initial question has been written (prior to the actual request) # - initial question has been written (prior to the actual request)
# - responses have been written (after the request) # - responses have been written (after the request)
@@ -353,19 +415,13 @@ class TestQuestionCmdAsk(TestQuestionCmd):
Provoke an error during the AI request and verify that the question Provoke an error during the AI request and verify that the question
has been correctly stored in the cache. has been correctly stored in the cache.
""" """
mock_create_ai.return_value = self.ai mock_create_ai.side_effect = self.mock_create_ai_with_error
expected_question = self.input_message(self.args) expected_question = self.input_message(self.args)
self.ai.request.side_effect = AIError
# execute the command # execute the command
with self.assertRaises(AIError): with self.assertRaises(AIError):
question_cmd(self.args, self.config) question_cmd(self.args, self.config)
# check for correct request call
self.ai.request.assert_called_once_with(expected_question,
ANY,
self.args.num_answers,
self.args.output_tags)
# check for the expected message files # check for the expected message files
chat = ChatDB.from_dir(Path(self.cache_dir.name), chat = ChatDB.from_dir(Path(self.cache_dir.name),
Path(self.db_dir.name)) Path(self.db_dir.name))
@@ -381,28 +437,27 @@ class TestQuestionCmdRepeat(TestQuestionCmd):
""" """
Repeat a single question. Repeat a single question.
""" """
# 1. ask a question mock_create_ai.side_effect = self.mock_create_ai
mock_create_ai.return_value = self.ai # create a message
expected_question = self.input_message(self.args) message = self.create_single_message(self.args)
expected_responses = self.mock_request(expected_question,
Chat([]), # repeat the last question (without overwriting)
self.args.num_answers, # -> expect two identical messages (except for the file_path)
self.args.output_tags).messages self.args.ask = None
self.args.repeat = []
self.args.output_tags = []
self.args.overwrite = False
fake_ai = self.mock_create_ai(self.args, self.config)
expected_response = fake_ai.request(message,
Chat([]),
self.args.num_answers,
set(self.args.output_tags)).messages
expected_responses = expected_response + expected_response
question_cmd(self.args, self.config) question_cmd(self.args, self.config)
chat = ChatDB.from_dir(Path(self.cache_dir.name), chat = ChatDB.from_dir(Path(self.cache_dir.name),
Path(self.db_dir.name)) Path(self.db_dir.name))
cached_msg = chat.msg_gather(loc='cache') cached_msg = chat.msg_gather(loc='cache')
self.assertEqual(len(self.message_list(self.cache_dir)), 1) print(self.message_list(self.cache_dir))
self.assert_messages_equal(cached_msg, expected_responses)
# 2. repeat the last question (without overwriting)
# -> expect two identical messages (except for the file_path)
self.args.ask = None
self.args.repeat = []
self.args.overwrite = False
expected_responses += expected_responses
question_cmd(self.args, self.config)
cached_msg = chat.msg_gather(loc='cache')
self.assertEqual(len(self.message_list(self.cache_dir)), 2) self.assertEqual(len(self.message_list(self.cache_dir)), 2)
self.assert_messages_equal(cached_msg, expected_responses) self.assert_messages_equal(cached_msg, expected_responses)
@@ -411,31 +466,29 @@ class TestQuestionCmdRepeat(TestQuestionCmd):
""" """
Repeat a single question and overwrite the old one. Repeat a single question and overwrite the old one.
""" """
# 1. ask a question mock_create_ai.side_effect = self.mock_create_ai
mock_create_ai.return_value = self.ai # create a message
expected_question = self.input_message(self.args) message = self.create_single_message(self.args)
expected_responses = self.mock_request(expected_question,
Chat([]),
self.args.num_answers,
self.args.output_tags).messages
question_cmd(self.args, self.config)
chat = ChatDB.from_dir(Path(self.cache_dir.name), chat = ChatDB.from_dir(Path(self.cache_dir.name),
Path(self.db_dir.name)) Path(self.db_dir.name))
cached_msg = chat.msg_gather(loc='cache') cached_msg = chat.msg_gather(loc='cache')
assert cached_msg[0].file_path assert cached_msg[0].file_path
cached_msg_file_id = cached_msg[0].file_path.stem cached_msg_file_id = cached_msg[0].file_path.stem
self.assertEqual(len(self.message_list(self.cache_dir)), 1)
self.assert_messages_equal(cached_msg, expected_responses)
# 2. repeat the last question (WITH overwriting) # repeat the last question (WITH overwriting)
# -> expect a single message afterwards # -> expect a single message afterwards
self.args.ask = None self.args.ask = None
self.args.repeat = [] self.args.repeat = []
self.args.overwrite = True self.args.overwrite = True
fake_ai = self.mock_create_ai(self.args, self.config)
expected_response = fake_ai.request(message,
Chat([]),
self.args.num_answers,
set(self.args.output_tags)).messages
question_cmd(self.args, self.config) question_cmd(self.args, self.config)
cached_msg = chat.msg_gather(loc='cache') cached_msg = chat.msg_gather(loc='cache')
self.assertEqual(len(self.message_list(self.cache_dir)), 1) self.assertEqual(len(self.message_list(self.cache_dir)), 1)
self.assert_messages_equal(cached_msg, expected_responses) self.assert_messages_equal(cached_msg, expected_response)
# also check that the file ID has not been changed # also check that the file ID has not been changed
assert cached_msg[0].file_path assert cached_msg[0].file_path
self.assertEqual(cached_msg_file_id, cached_msg[0].file_path.stem) self.assertEqual(cached_msg_file_id, cached_msg[0].file_path.stem)
@@ -445,35 +498,31 @@ class TestQuestionCmdRepeat(TestQuestionCmd):
""" """
Repeat a single question after an error. Repeat a single question after an error.
""" """
# 1. ask a question and provoke an error mock_create_ai.side_effect = self.mock_create_ai
mock_create_ai.return_value = self.ai # create a question WITHOUT an answer
expected_question = self.input_message(self.args) # -> just like after an error, which is tested above
self.ai.request.side_effect = AIError question = self.create_single_message(self.args, with_answer=False)
with self.assertRaises(AIError):
question_cmd(self.args, self.config)
chat = ChatDB.from_dir(Path(self.cache_dir.name), chat = ChatDB.from_dir(Path(self.cache_dir.name),
Path(self.db_dir.name)) Path(self.db_dir.name))
cached_msg = chat.msg_gather(loc='cache') cached_msg = chat.msg_gather(loc='cache')
assert cached_msg[0].file_path assert cached_msg[0].file_path
cached_msg_file_id = cached_msg[0].file_path.stem cached_msg_file_id = cached_msg[0].file_path.stem
self.assertEqual(len(self.message_list(self.cache_dir)), 1)
self.assert_messages_equal(cached_msg, [expected_question])
# 2. repeat the last question (without overwriting) # repeat the last question (without overwriting)
# -> expect a single message because if the original has # -> expect a single message because if the original has
# no answer, it should be overwritten by default # no answer, it should be overwritten by default
self.args.ask = None self.args.ask = None
self.args.repeat = [] self.args.repeat = []
self.args.overwrite = False self.args.overwrite = False
self.ai.request.side_effect = self.mock_request fake_ai = self.mock_create_ai(self.args, self.config)
expected_responses = self.mock_request(expected_question, expected_response = fake_ai.request(question,
Chat([]), Chat([]),
self.args.num_answers, self.args.num_answers,
self.args.output_tags).messages self.args.output_tags).messages
question_cmd(self.args, self.config) question_cmd(self.args, self.config)
cached_msg = chat.msg_gather(loc='cache') cached_msg = chat.msg_gather(loc='cache')
self.assertEqual(len(self.message_list(self.cache_dir)), 1) self.assertEqual(len(self.message_list(self.cache_dir)), 1)
self.assert_messages_equal(cached_msg, expected_responses) self.assert_messages_equal(cached_msg, expected_response)
# also check that the file ID has not been changed # also check that the file ID has not been changed
assert cached_msg[0].file_path assert cached_msg[0].file_path
self.assertEqual(cached_msg_file_id, cached_msg[0].file_path.stem) self.assertEqual(cached_msg_file_id, cached_msg[0].file_path.stem)
@@ -483,21 +532,15 @@ class TestQuestionCmdRepeat(TestQuestionCmd):
""" """
Repeat a single question with new arguments. Repeat a single question with new arguments.
""" """
# 1. ask a question mock_create_ai.side_effect = self.mock_create_ai
mock_create_ai.return_value = self.ai # create a message
expected_question = self.input_message(self.args) message = self.create_single_message(self.args)
expected_responses = self.mock_request(expected_question,
Chat([]),
self.args.num_answers,
self.args.output_tags).messages
question_cmd(self.args, self.config)
chat = ChatDB.from_dir(Path(self.cache_dir.name), chat = ChatDB.from_dir(Path(self.cache_dir.name),
Path(self.db_dir.name)) Path(self.db_dir.name))
cached_msg = chat.msg_gather(loc='cache') cached_msg = chat.msg_gather(loc='cache')
self.assertEqual(len(self.message_list(self.cache_dir)), 1) assert cached_msg[0].file_path
self.assert_messages_equal(cached_msg, expected_responses)
# 2. repeat the last question with new arguments (without overwriting) # repeat the last question with new arguments (without overwriting)
# -> expect two messages with identical question and answer, but different metadata # -> expect two messages with identical question and answer, but different metadata
self.args.ask = None self.args.ask = None
self.args.repeat = [] self.args.repeat = []
@@ -505,44 +548,28 @@ class TestQuestionCmdRepeat(TestQuestionCmd):
self.args.output_tags = ['newtag'] self.args.output_tags = ['newtag']
self.args.AI = 'newai' self.args.AI = 'newai'
self.args.model = 'newmodel' self.args.model = 'newmodel'
new_expected_question = Message(question=Question(expected_question.question), new_expected_question = Message(question=Question(message.question),
tags=set(self.args.output_tags), tags=set(self.args.output_tags),
ai=self.args.AI, ai=self.args.AI,
model=self.args.model) model=self.args.model)
expected_responses += self.mock_request(new_expected_question, fake_ai = self.mock_create_ai(self.args, self.config)
new_expected_response = fake_ai.request(new_expected_question,
Chat([]), Chat([]),
self.args.num_answers, self.args.num_answers,
set(self.args.output_tags)).messages set(self.args.output_tags)).messages
question_cmd(self.args, self.config) question_cmd(self.args, self.config)
cached_msg = chat.msg_gather(loc='cache') cached_msg = chat.msg_gather(loc='cache')
self.assertEqual(len(self.message_list(self.cache_dir)), 2) self.assertEqual(len(self.message_list(self.cache_dir)), 2)
self.assert_messages_equal(cached_msg, expected_responses) self.assert_messages_equal(cached_msg, [message] + new_expected_response)
print(cached_msg)
print(message)
print(new_expected_question)
@mock.patch('chatmastermind.commands.question.create_ai') @mock.patch('chatmastermind.commands.question.create_ai')
def test_repeat_multiple_questions(self, mock_create_ai: MagicMock) -> None: def test_repeat_multiple_questions(self, mock_create_ai: MagicMock) -> None:
""" """
Repeat multiple questions. Repeat multiple questions.
""" """
# 1. create some questions / messages
# cached message without an answer
message1 = Message(Question('Question 1'),
ai='foo',
model='bla',
file_path=Path(self.cache_dir.name) / '0001.txt')
# cached message with an answer
message2 = Message(Question('Question 2'),
Answer('Answer 2'),
ai='openai',
model='gpt-3.5-turbo',
file_path=Path(self.cache_dir.name) / '0002.txt')
# DB message without an answer
message3 = Message(Question('Question 3'),
ai='openai',
model='gpt-3.5-turbo',
file_path=Path(self.db_dir.name) / '0003.txt')
message1.to_file()
message2.to_file()
message3.to_file()
# chat = ChatDB.from_dir(Path(self.cache_dir.name), # chat = ChatDB.from_dir(Path(self.cache_dir.name),
# Path(self.db_dir.name)) # Path(self.db_dir.name))