3 Commits

2 changed files with 234 additions and 147 deletions
+100
View File
@@ -0,0 +1,100 @@
import unittest
import argparse
from typing import Union, Optional
from chatmastermind.configuration import Config, AIConfig
from chatmastermind.tags import Tag
from chatmastermind.message import Message, Answer
from chatmastermind.chat import Chat
from chatmastermind.ai import AI, AIResponse, Tokens, AIError
class FakeAI(AI):
"""
A mocked version of the 'AI' class.
"""
ID: str
name: str
config: AIConfig
def models(self) -> list[str]:
raise NotImplementedError
def tokens(self, data: Union[Message, Chat]) -> int:
return 123
def print(self) -> None:
pass
def print_models(self) -> None:
pass
def __init__(self, ID: str, model: str, error: bool = False):
self.ID = ID
self.model = model
self.error = error
def request(self,
question: Message,
chat: Chat,
num_answers: int = 1,
otags: Optional[set[Tag]] = None) -> AIResponse:
"""
Mock the 'ai.request()' function by either returning fake
answers or raising an exception.
"""
if self.error:
raise AIError
question.answer = Answer("Answer 0")
question.tags = set(otags) if otags is not None else None
question.ai = self.ID
question.model = self.model
answers: list[Message] = [question]
for n in range(1, num_answers):
answers.append(Message(question=question.question,
answer=Answer(f"Answer {n}"),
tags=otags,
ai=self.ID,
model=self.model))
return AIResponse(answers, Tokens(10, 10, 20))
class TestWithFakeAI(unittest.TestCase):
"""
Base class for all tests that need to use the FakeAI.
"""
def assert_msgs_equal_except_file_path(self, msg1: list[Message], msg2: list[Message]) -> None:
"""
Compare messages using Question, Answer and all metadata excecot for the file_path.
"""
self.assertEqual(len(msg1), len(msg2))
for m1, m2 in zip(msg1, msg2):
# exclude the file_path, compare only Q, A and metadata
self.assertTrue(m1.equals(m2, file_path=False, verbose=True))
def assert_msgs_all_equal(self, msg1: list[Message], msg2: list[Message]) -> None:
"""
Compare messages using Question, Answer and ALL metadata.
"""
self.assertEqual(len(msg1), len(msg2))
for m1, m2 in zip(msg1, msg2):
self.assertTrue(m1.equals(m2, verbose=True))
def assert_msgs_content_equal(self, msg1: list[Message], msg2: list[Message]) -> None:
"""
Compare messages using only Question and Answer.
"""
self.assertEqual(len(msg1), len(msg2))
for m1, m2 in zip(msg1, msg2):
self.assertEqual(m1, m2)
def mock_create_ai(self, args: argparse.Namespace, config: Config) -> AI:
"""
Mocked 'create_ai' that returns a 'FakeAI' instance.
"""
return FakeAI(args.AI, args.model)
def mock_create_ai_with_error(self, args: argparse.Namespace, config: Config) -> AI:
"""
Mocked 'create_ai' that returns a 'FakeAI' instance.
"""
return FakeAI(args.AI, args.model, error=True)
+134 -147
View File
@@ -1,93 +1,20 @@
import os import os
import unittest
import argparse import argparse
import tempfile import tempfile
from copy import copy
from pathlib import Path from pathlib import Path
from unittest import mock from unittest import mock
from unittest.mock import MagicMock, call from unittest.mock import MagicMock, call
from typing import Optional, Union from chatmastermind.configuration import Config
from chatmastermind.configuration import Config, AIConfig
from chatmastermind.commands.question import create_message, question_cmd from chatmastermind.commands.question import create_message, question_cmd
from chatmastermind.tags import Tag from chatmastermind.tags import Tag
from chatmastermind.message import Message, Question, Answer from chatmastermind.message import Message, Question, Answer
from chatmastermind.chat import Chat, ChatDB from chatmastermind.chat import Chat, ChatDB
from chatmastermind.ai import AI, AIResponse, Tokens, AIError from chatmastermind.ai import AIError
from .test_common import TestWithFakeAI
class FakeAI(AI): class TestMessageCreate(TestWithFakeAI):
"""
A mocked version of the 'AI' class.
"""
ID: str
name: str
config: AIConfig
def models(self) -> list[str]:
raise NotImplementedError
def tokens(self, data: Union[Message, Chat]) -> int:
return 123
def print(self) -> None:
pass
def print_models(self) -> None:
pass
def __init__(self, ID: str, model: str, error: bool = False):
self.ID = ID
self.model = model
self.error = error
def request(self,
question: Message,
chat: Chat,
num_answers: int = 1,
otags: Optional[set[Tag]] = None) -> AIResponse:
"""
Mock the 'ai.request()' function by either returning fake
answers or raising an exception.
"""
if self.error:
raise AIError
question.answer = Answer("Answer 0")
question.tags = set(otags) if otags is not None else None
question.ai = self.ID
question.model = self.model
answers: list[Message] = [question]
for n in range(1, num_answers):
answers.append(Message(question=question.question,
answer=Answer(f"Answer {n}"),
tags=otags,
ai=self.ID,
model=self.model))
return AIResponse(answers, Tokens(10, 10, 20))
class TestQuestionCmdBase(unittest.TestCase):
def assert_messages_equal(self, msg1: list[Message], msg2: list[Message]) -> None:
"""
Compare messages using more than just Question and Answer.
"""
self.assertEqual(len(msg1), len(msg2))
for m1, m2 in zip(msg1, msg2):
# exclude the file_path, compare only Q, A and metadata
self.assertTrue(m1.equals(m2, file_path=False, verbose=True))
def mock_create_ai(self, args: argparse.Namespace, config: Config) -> AI:
"""
Mocked 'create_ai' that returns a 'FakeAI' instance.
"""
return FakeAI(args.AI, args.model)
def mock_create_ai_with_error(self, args: argparse.Namespace, config: Config) -> AI:
"""
Mocked 'create_ai' that returns a 'FakeAI' instance.
"""
return FakeAI(args.AI, args.model, error=True)
class TestMessageCreate(TestQuestionCmdBase):
""" """
Test if messages created by the 'question' command have Test if messages created by the 'question' command have
the correct format. the correct format.
@@ -274,7 +201,7 @@ It is embedded code
""")) """))
class TestQuestionCmd(TestQuestionCmdBase): class TestQuestionCmd(TestWithFakeAI):
def setUp(self) -> None: def setUp(self) -> None:
# create DB and cache # create DB and cache
@@ -302,17 +229,6 @@ class TestQuestionCmd(TestQuestionCmdBase):
overwrite=None overwrite=None
) )
def create_single_message(self, args: argparse.Namespace, with_answer: bool = True) -> Message:
message = Message(Question(args.ask[0]),
tags=set(args.output_tags) if args.output_tags is not None else None,
ai=args.AI,
model=args.model,
file_path=Path(self.cache_dir.name) / '0001.txt')
if with_answer:
message.answer = Answer('Answer 0')
message.to_file()
return message
def message_list(self, tmp_dir: tempfile.TemporaryDirectory) -> list[Path]: def message_list(self, tmp_dir: tempfile.TemporaryDirectory) -> list[Path]:
# exclude '.next' # exclude '.next'
return sorted([f for f in Path(tmp_dir.name).glob('*.[ty]*')]) return sorted([f for f in Path(tmp_dir.name).glob('*.[ty]*')])
@@ -327,9 +243,10 @@ class TestQuestionCmdAsk(TestQuestionCmd):
""" """
mock_create_ai.side_effect = self.mock_create_ai mock_create_ai.side_effect = self.mock_create_ai
expected_question = Message(Question(self.args.ask[0]), expected_question = Message(Question(self.args.ask[0]),
tags=self.args.output_tags, tags=set(self.args.output_tags),
ai=self.args.AI, ai=self.args.AI,
model=self.args.model) model=self.args.model,
file_path=Path('<NOT COMPARED>'))
fake_ai = self.mock_create_ai(self.args, self.config) fake_ai = self.mock_create_ai(self.args, self.config)
expected_responses = fake_ai.request(expected_question, expected_responses = fake_ai.request(expected_question,
Chat([]), Chat([]),
@@ -344,7 +261,7 @@ class TestQuestionCmdAsk(TestQuestionCmd):
Path(self.db_dir.name)) Path(self.db_dir.name))
cached_msg = chat.msg_gather(loc='cache') cached_msg = chat.msg_gather(loc='cache')
self.assertEqual(len(self.message_list(self.cache_dir)), 1) self.assertEqual(len(self.message_list(self.cache_dir)), 1)
self.assert_messages_equal(cached_msg, expected_responses) self.assert_msgs_equal_except_file_path(cached_msg, expected_responses)
@mock.patch('chatmastermind.commands.question.ChatDB.from_dir') @mock.patch('chatmastermind.commands.question.ChatDB.from_dir')
@mock.patch('chatmastermind.commands.question.create_ai') @mock.patch('chatmastermind.commands.question.create_ai')
@@ -357,9 +274,10 @@ class TestQuestionCmdAsk(TestQuestionCmd):
mock_create_ai.side_effect = self.mock_create_ai mock_create_ai.side_effect = self.mock_create_ai
expected_question = Message(Question(self.args.ask[0]), expected_question = Message(Question(self.args.ask[0]),
tags=self.args.output_tags, tags=set(self.args.output_tags),
ai=self.args.AI, ai=self.args.AI,
model=self.args.model) model=self.args.model,
file_path=Path('<NOT COMPARED>'))
fake_ai = self.mock_create_ai(self.args, self.config) fake_ai = self.mock_create_ai(self.args, self.config)
expected_responses = fake_ai.request(expected_question, expected_responses = fake_ai.request(expected_question,
Chat([]), Chat([]),
@@ -387,9 +305,10 @@ class TestQuestionCmdAsk(TestQuestionCmd):
""" """
mock_create_ai.side_effect = self.mock_create_ai_with_error mock_create_ai.side_effect = self.mock_create_ai_with_error
expected_question = Message(Question(self.args.ask[0]), expected_question = Message(Question(self.args.ask[0]),
tags=self.args.output_tags, tags=set(self.args.output_tags),
ai=self.args.AI, ai=self.args.AI,
model=self.args.model) model=self.args.model,
file_path=Path('<NOT COMPARED>'))
# execute the command # execute the command
with self.assertRaises(AIError): with self.assertRaises(AIError):
@@ -400,7 +319,7 @@ class TestQuestionCmdAsk(TestQuestionCmd):
Path(self.db_dir.name)) Path(self.db_dir.name))
cached_msg = chat.msg_gather(loc='cache') cached_msg = chat.msg_gather(loc='cache')
self.assertEqual(len(self.message_list(self.cache_dir)), 1) self.assertEqual(len(self.message_list(self.cache_dir)), 1)
self.assert_messages_equal(cached_msg, [expected_question]) self.assert_msgs_equal_except_file_path(cached_msg, [expected_question])
class TestQuestionCmdRepeat(TestQuestionCmd): class TestQuestionCmdRepeat(TestQuestionCmd):
@@ -412,26 +331,34 @@ class TestQuestionCmdRepeat(TestQuestionCmd):
""" """
mock_create_ai.side_effect = self.mock_create_ai mock_create_ai.side_effect = self.mock_create_ai
# create a message # create a message
message = self.create_single_message(self.args) message = Message(Question(self.args.ask[0]),
Answer('Old Answer'),
tags=set(self.args.output_tags),
ai=self.args.AI,
model=self.args.model,
file_path=Path(self.cache_dir.name) / '0001.txt')
message.to_file()
# repeat the last question (without overwriting) # repeat the last question (without overwriting)
# -> expect two identical messages (except for the file_path) # -> expect two identical messages (except for the file_path)
self.args.ask = None self.args.ask = None
self.args.repeat = [] self.args.repeat = []
self.args.overwrite = False self.args.overwrite = False
fake_ai = self.mock_create_ai(self.args, self.config) expected_response = Message(Question(message.question),
expected_response = fake_ai.request(message, Answer('Answer 0'),
Chat([]), ai=message.ai,
self.args.num_answers, model=message.model,
set(self.args.output_tags)).messages tags=message.tags,
expected_responses = expected_response + expected_response file_path=Path('<NOT COMPARED>'))
# we expect the original message + the one with the new response
expected_responses = [message] + [expected_response]
question_cmd(self.args, self.config) question_cmd(self.args, self.config)
chat = ChatDB.from_dir(Path(self.cache_dir.name), chat = ChatDB.from_dir(Path(self.cache_dir.name),
Path(self.db_dir.name)) Path(self.db_dir.name))
cached_msg = chat.msg_gather(loc='cache') cached_msg = chat.msg_gather(loc='cache')
print(self.message_list(self.cache_dir)) print(self.message_list(self.cache_dir))
self.assertEqual(len(self.message_list(self.cache_dir)), 2) self.assertEqual(len(self.message_list(self.cache_dir)), 2)
self.assert_messages_equal(cached_msg, expected_responses) self.assert_msgs_equal_except_file_path(cached_msg, expected_responses)
@mock.patch('chatmastermind.commands.question.create_ai') @mock.patch('chatmastermind.commands.question.create_ai')
def test_repeat_single_question_overwrite(self, mock_create_ai: MagicMock) -> None: def test_repeat_single_question_overwrite(self, mock_create_ai: MagicMock) -> None:
@@ -440,7 +367,13 @@ class TestQuestionCmdRepeat(TestQuestionCmd):
""" """
mock_create_ai.side_effect = self.mock_create_ai mock_create_ai.side_effect = self.mock_create_ai
# create a message # create a message
message = self.create_single_message(self.args) message = Message(Question(self.args.ask[0]),
Answer('Old Answer'),
tags=set(self.args.output_tags),
ai=self.args.AI,
model=self.args.model,
file_path=Path(self.cache_dir.name) / '0001.txt')
message.to_file()
chat = ChatDB.from_dir(Path(self.cache_dir.name), chat = ChatDB.from_dir(Path(self.cache_dir.name),
Path(self.db_dir.name)) Path(self.db_dir.name))
cached_msg = chat.msg_gather(loc='cache') cached_msg = chat.msg_gather(loc='cache')
@@ -448,19 +381,20 @@ class TestQuestionCmdRepeat(TestQuestionCmd):
cached_msg_file_id = cached_msg[0].file_path.stem cached_msg_file_id = cached_msg[0].file_path.stem
# repeat the last question (WITH overwriting) # repeat the last question (WITH overwriting)
# -> expect a single message afterwards # -> expect a single message afterwards (with a new answer)
self.args.ask = None self.args.ask = None
self.args.repeat = [] self.args.repeat = []
self.args.overwrite = True self.args.overwrite = True
fake_ai = self.mock_create_ai(self.args, self.config) expected_response = Message(Question(message.question),
expected_response = fake_ai.request(message, Answer('Answer 0'),
Chat([]), ai=message.ai,
self.args.num_answers, model=message.model,
set(self.args.output_tags)).messages tags=message.tags,
file_path=Path('<NOT COMPARED>'))
question_cmd(self.args, self.config) question_cmd(self.args, self.config)
cached_msg = chat.msg_gather(loc='cache') cached_msg = chat.msg_gather(loc='cache')
self.assertEqual(len(self.message_list(self.cache_dir)), 1) self.assertEqual(len(self.message_list(self.cache_dir)), 1)
self.assert_messages_equal(cached_msg, expected_response) self.assert_msgs_equal_except_file_path(cached_msg, [expected_response])
# also check that the file ID has not been changed # also check that the file ID has not been changed
assert cached_msg[0].file_path assert cached_msg[0].file_path
self.assertEqual(cached_msg_file_id, cached_msg[0].file_path.stem) self.assertEqual(cached_msg_file_id, cached_msg[0].file_path.stem)
@@ -473,7 +407,12 @@ class TestQuestionCmdRepeat(TestQuestionCmd):
mock_create_ai.side_effect = self.mock_create_ai mock_create_ai.side_effect = self.mock_create_ai
# create a question WITHOUT an answer # create a question WITHOUT an answer
# -> just like after an error, which is tested above # -> just like after an error, which is tested above
question = self.create_single_message(self.args, with_answer=False) message = Message(Question(self.args.ask[0]),
tags=set(self.args.output_tags),
ai=self.args.AI,
model=self.args.model,
file_path=Path(self.cache_dir.name) / '0001.txt')
message.to_file()
chat = ChatDB.from_dir(Path(self.cache_dir.name), chat = ChatDB.from_dir(Path(self.cache_dir.name),
Path(self.db_dir.name)) Path(self.db_dir.name))
cached_msg = chat.msg_gather(loc='cache') cached_msg = chat.msg_gather(loc='cache')
@@ -486,15 +425,16 @@ class TestQuestionCmdRepeat(TestQuestionCmd):
self.args.ask = None self.args.ask = None
self.args.repeat = [] self.args.repeat = []
self.args.overwrite = False self.args.overwrite = False
fake_ai = self.mock_create_ai(self.args, self.config) expected_response = Message(Question(message.question),
expected_response = fake_ai.request(question, Answer('Answer 0'),
Chat([]), ai=message.ai,
self.args.num_answers, model=message.model,
self.args.output_tags).messages tags=message.tags,
file_path=Path('<NOT COMPARED>'))
question_cmd(self.args, self.config) question_cmd(self.args, self.config)
cached_msg = chat.msg_gather(loc='cache') cached_msg = chat.msg_gather(loc='cache')
self.assertEqual(len(self.message_list(self.cache_dir)), 1) self.assertEqual(len(self.message_list(self.cache_dir)), 1)
self.assert_messages_equal(cached_msg, expected_response) self.assert_msgs_equal_except_file_path(cached_msg, [expected_response])
# also check that the file ID has not been changed # also check that the file ID has not been changed
assert cached_msg[0].file_path assert cached_msg[0].file_path
self.assertEqual(cached_msg_file_id, cached_msg[0].file_path.stem) self.assertEqual(cached_msg_file_id, cached_msg[0].file_path.stem)
@@ -506,55 +446,99 @@ class TestQuestionCmdRepeat(TestQuestionCmd):
""" """
mock_create_ai.side_effect = self.mock_create_ai mock_create_ai.side_effect = self.mock_create_ai
# create a message # create a message
message = self.create_single_message(self.args) message = Message(Question(self.args.ask[0]),
Answer('Old Answer'),
tags=set(self.args.output_tags),
ai=self.args.AI,
model=self.args.model,
file_path=Path(self.cache_dir.name) / '0001.txt')
message.to_file()
chat = ChatDB.from_dir(Path(self.cache_dir.name), chat = ChatDB.from_dir(Path(self.cache_dir.name),
Path(self.db_dir.name)) Path(self.db_dir.name))
cached_msg = chat.msg_gather(loc='cache') cached_msg = chat.msg_gather(loc='cache')
assert cached_msg[0].file_path assert cached_msg[0].file_path
# repeat the last question with new arguments (without overwriting) # repeat the last question with new arguments (without overwriting)
# -> expect two messages with identical question and answer, but different metadata # -> expect two messages with identical question but different metadata and new answer
self.args.ask = None self.args.ask = None
self.args.repeat = [] self.args.repeat = []
self.args.overwrite = False self.args.overwrite = False
self.args.output_tags = ['newtag'] self.args.output_tags = ['newtag']
self.args.AI = 'newai' self.args.AI = 'newai'
self.args.model = 'newmodel' self.args.model = 'newmodel'
new_expected_question = Message(question=Question(message.question), new_expected_response = Message(Question(message.question),
tags=set(self.args.output_tags), Answer('Answer 0'),
ai=self.args.AI, ai='newai',
model=self.args.model) model='newmodel',
fake_ai = self.mock_create_ai(self.args, self.config) tags={Tag('newtag')},
new_expected_response = fake_ai.request(new_expected_question, file_path=Path('<NOT COMPARED>'))
Chat([]),
self.args.num_answers,
set(self.args.output_tags)).messages
question_cmd(self.args, self.config) question_cmd(self.args, self.config)
cached_msg = chat.msg_gather(loc='cache') cached_msg = chat.msg_gather(loc='cache')
self.assertEqual(len(self.message_list(self.cache_dir)), 2) self.assertEqual(len(self.message_list(self.cache_dir)), 2)
self.assert_messages_equal(cached_msg, [message] + new_expected_response) self.assert_msgs_equal_except_file_path(cached_msg, [message] + [new_expected_response])
@mock.patch('chatmastermind.commands.question.create_ai')
def test_repeat_single_question_new_args_overwrite(self, mock_create_ai: MagicMock) -> None:
"""
Repeat a single question with new arguments, overwriting the old one.
"""
mock_create_ai.side_effect = self.mock_create_ai
# create a message
message = Message(Question(self.args.ask[0]),
Answer('Old Answer'),
tags=set(self.args.output_tags),
ai=self.args.AI,
model=self.args.model,
file_path=Path(self.cache_dir.name) / '0001.txt')
message.to_file()
chat = ChatDB.from_dir(Path(self.cache_dir.name),
Path(self.db_dir.name))
cached_msg = chat.msg_gather(loc='cache')
assert cached_msg[0].file_path
# repeat the last question with new arguments
self.args.ask = None
self.args.repeat = []
self.args.overwrite = True
self.args.output_tags = ['newtag']
self.args.AI = 'newai'
self.args.model = 'newmodel'
new_expected_response = Message(Question(message.question),
Answer('Answer 0'),
ai='newai',
model='newmodel',
tags={Tag('newtag')},
file_path=Path('<NOT COMPARED>'))
question_cmd(self.args, self.config)
cached_msg = chat.msg_gather(loc='cache')
self.assertEqual(len(self.message_list(self.cache_dir)), 1)
self.assert_msgs_equal_except_file_path(cached_msg, [new_expected_response])
@mock.patch('chatmastermind.commands.question.create_ai') @mock.patch('chatmastermind.commands.question.create_ai')
def test_repeat_multiple_questions(self, mock_create_ai: MagicMock) -> None: def test_repeat_multiple_questions(self, mock_create_ai: MagicMock) -> None:
""" """
Repeat multiple questions. Repeat multiple questions.
""" """
mock_create_ai.side_effect = self.mock_create_ai
# 1. === create three questions === # 1. === create three questions ===
# cached message without an answer # cached message without an answer
message1 = Message(Question('Question 1'), message1 = Message(Question(self.args.ask[0]),
ai='foo', tags=self.args.output_tags,
model='bla', ai=self.args.AI,
model=self.args.model,
file_path=Path(self.cache_dir.name) / '0001.txt') file_path=Path(self.cache_dir.name) / '0001.txt')
# cached message with an answer # cached message with an answer
message2 = Message(Question('Question 2'), message2 = Message(Question(self.args.ask[0]),
Answer('Answer 0'), Answer('Old Answer'),
ai='openai', tags=self.args.output_tags,
model='gpt-3.5-turbo', ai=self.args.AI,
model=self.args.model,
file_path=Path(self.cache_dir.name) / '0002.txt') file_path=Path(self.cache_dir.name) / '0002.txt')
# DB message without an answer # DB message without an answer
message3 = Message(Question('Question 3'), message3 = Message(Question(self.args.ask[0]),
ai='openai', tags=self.args.output_tags,
model='gpt-3.5-turbo', ai=self.args.AI,
model=self.args.model,
file_path=Path(self.db_dir.name) / '0003.txt') file_path=Path(self.db_dir.name) / '0003.txt')
message1.to_file() message1.to_file()
message2.to_file() message2.to_file()
@@ -563,7 +547,9 @@ class TestQuestionCmdRepeat(TestQuestionCmd):
expected_responses: list[Message] = [] expected_responses: list[Message] = []
fake_ai = self.mock_create_ai(self.args, self.config) fake_ai = self.mock_create_ai(self.args, self.config)
for question in questions: for question in questions:
expected_responses += fake_ai.request(question, # since the message's answer is modified, we use a copy
# -> the original is used for comparison below
expected_responses += fake_ai.request(copy(question),
Chat([]), Chat([]),
self.args.num_answers, self.args.num_answers,
set(self.args.output_tags)).messages set(self.args.output_tags)).messages
@@ -583,6 +569,7 @@ class TestQuestionCmdRepeat(TestQuestionCmd):
chat = ChatDB.from_dir(Path(self.cache_dir.name), chat = ChatDB.from_dir(Path(self.cache_dir.name),
Path(self.db_dir.name)) Path(self.db_dir.name))
cached_msg = chat.msg_gather(loc='cache') cached_msg = chat.msg_gather(loc='cache')
print(f"Cached: {cached_msg}") self.assert_msgs_equal_except_file_path(cached_msg, expected_cache_messages)
print(f"Expected: {expected_cache_messages}") # check that the DB message has not been modified at all
self.assert_messages_equal(cached_msg, expected_cache_messages) db_msg = chat.msg_gather(loc='db')
self.assert_msgs_all_equal(db_msg, [message3])