From fc82f85b7ce469ffffc26fb4b5a43dd4abc056fb Mon Sep 17 00:00:00 2001 From: juk0de Date: Fri, 15 Sep 2023 10:17:20 +0200 Subject: [PATCH] chat: added new functions: msg_unique_id(), msg_unique_content() and tests --- chatmastermind/chat.py | 29 ++++++++++++++++- tests/test_chat.py | 73 +++++++++++++++++++++++++++++++++--------- 2 files changed, 86 insertions(+), 16 deletions(-) diff --git a/chatmastermind/chat.py b/chatmastermind/chat.py index 0aee2fe..083b91e 100644 --- a/chatmastermind/chat.py +++ b/chatmastermind/chat.py @@ -146,6 +146,25 @@ class Chat: except MessageError: pass + def msg_unique_id(self) -> None: + """ + Remove duplicates from the internal messages, based on the msg_id (i. e. file_path). + Messages without a file_path are kept. + """ + old_msgs = self.messages.copy() + self.messages = [] + for m in old_msgs: + if not message_in(m, self.messages): + self.messages.append(m) + self.msg_sort() + + def msg_unique_content(self) -> None: + """ + Remove duplicates from the internal messages, based on the content (i. e. question + answer). + """ + self.messages = list(set(self.messages)) + self.msg_sort() + def msg_clear(self) -> None: """ Delete all messages. @@ -356,7 +375,13 @@ class ChatDB(Chat): source_messages += read_dir(self.cache_path, mfilter=mfilter) if source in ['db', 'disk', 'all']: source_messages += read_dir(self.db_path, mfilter=mfilter) - return source_messages + # remove_duplicates and sort the list + unique_messages: list[Message] = [] + for m in source_messages: + if not message_in(m, unique_messages): + unique_messages.append(m) + unique_messages.sort(key=lambda m: m.msg_id()) + return unique_messages def msg_find(self, msg_names: list[str], @@ -430,6 +455,7 @@ class ChatDB(Chat): Write messages to the cache directory. If a message has no file_path, a new one will be created. If message.file_path exists, it will be modified to point to the cache directory. + Does NOT add the messages to the internal list (use 'cache_add()' for that)! """ write_dir(self.cache_path, messages if messages else self.messages, @@ -480,6 +506,7 @@ class ChatDB(Chat): Write messages to the DB directory. If a message has no file_path, a new one will be created. If message.file_path exists, it will be modified to point to the DB directory. + Does NOT add the messages to the internal list (use 'db_add()' for that)! """ write_dir(self.db_path, messages if messages else self.messages, diff --git a/tests/test_chat.py b/tests/test_chat.py index 6aac1b5..7ea5b0c 100644 --- a/tests/test_chat.py +++ b/tests/test_chat.py @@ -21,6 +21,29 @@ class TestChat(unittest.TestCase): Answer('Answer 2'), {Tag('btag2')}, file_path=pathlib.Path('0002.txt')) + self.maxDiff = None + + def test_unique_id(self) -> None: + # test with two identical messages + self.chat.msg_add([self.message1, self.message1]) + self.assertSequenceEqual(self.chat.messages, [self.message1, self.message1]) + self.chat.msg_unique_id() + self.assertSequenceEqual(self.chat.messages, [self.message1]) + # test with two different messages + self.chat.msg_add([self.message2]) + self.chat.msg_unique_id() + self.assertSequenceEqual(self.chat.messages, [self.message1, self.message2]) + + def test_unique_content(self) -> None: + # test with two identical messages + self.chat.msg_add([self.message1, self.message1]) + self.assertSequenceEqual(self.chat.messages, [self.message1, self.message1]) + self.chat.msg_unique_content() + self.assertSequenceEqual(self.chat.messages, [self.message1]) + # test with two different messages + self.chat.msg_add([self.message2]) + self.chat.msg_unique_content() + self.assertSequenceEqual(self.chat.messages, [self.message1, self.message2]) def test_filter(self) -> None: self.chat.msg_add([self.message1, self.message2]) @@ -166,7 +189,6 @@ class TestChatDB(unittest.TestCase): with open(pathlib.Path(self.db_path.name) / 'content.yaml', 'w') as f: yaml.dump({'key': 'value'}, f) self.trash_files.append('content.yaml') - self.maxDiff = None def message_list(self, tmp_dir: tempfile.TemporaryDirectory) -> list[pathlib.Path]: @@ -181,7 +203,7 @@ class TestChatDB(unittest.TestCase): self.cache_path.cleanup() pass - def test_chat_db_from_dir(self) -> None: + def test_from_dir(self) -> None: chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name), pathlib.Path(self.db_path.name)) self.assertEqual(len(chat_db.messages), 4) @@ -197,7 +219,7 @@ class TestChatDB(unittest.TestCase): self.assertEqual(chat_db.messages[3].file_path, pathlib.Path(self.db_path.name, '0004.yaml')) - def test_chat_db_from_dir_glob(self) -> None: + def test_from_dir_glob(self) -> None: chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name), pathlib.Path(self.db_path.name), glob='*.txt') @@ -209,7 +231,7 @@ class TestChatDB(unittest.TestCase): self.assertEqual(chat_db.messages[1].file_path, pathlib.Path(self.db_path.name, '0003.txt')) - def test_chat_db_from_dir_filter_tags(self) -> None: + def test_from_dir_filter_tags(self) -> None: chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name), pathlib.Path(self.db_path.name), mfilter=MessageFilter(tags_or={Tag('tag1')})) @@ -219,7 +241,7 @@ class TestChatDB(unittest.TestCase): self.assertEqual(chat_db.messages[0].file_path, pathlib.Path(self.db_path.name, '0001.txt')) - def test_chat_db_from_dir_filter_tags_empty(self) -> None: + def test_from_dir_filter_tags_empty(self) -> None: chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name), pathlib.Path(self.db_path.name), mfilter=MessageFilter(tags_or=set(), @@ -227,7 +249,7 @@ class TestChatDB(unittest.TestCase): tags_not=set())) self.assertEqual(len(chat_db.messages), 0) - def test_chat_db_from_dir_filter_answer(self) -> None: + def test_from_dir_filter_answer(self) -> None: chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name), pathlib.Path(self.db_path.name), mfilter=MessageFilter(answer_contains='Answer 2')) @@ -238,7 +260,7 @@ class TestChatDB(unittest.TestCase): pathlib.Path(self.db_path.name, '0002.yaml')) self.assertEqual(chat_db.messages[0].answer, 'Answer 2') - def test_chat_db_from_messages(self) -> None: + def test_from_messages(self) -> None: chat_db = ChatDB.from_messages(pathlib.Path(self.cache_path.name), pathlib.Path(self.db_path.name), messages=[self.message1, self.message2, @@ -247,7 +269,7 @@ class TestChatDB(unittest.TestCase): self.assertEqual(chat_db.cache_path, pathlib.Path(self.cache_path.name)) self.assertEqual(chat_db.db_path, pathlib.Path(self.db_path.name)) - def test_chat_db_fids(self) -> None: + def test_fids(self) -> None: chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name), pathlib.Path(self.db_path.name)) self.assertEqual(chat_db.get_next_fid(), 5) @@ -256,7 +278,7 @@ class TestChatDB(unittest.TestCase): with open(chat_db.next_path, 'r') as f: self.assertEqual(f.read(), '7') - def test_chat_db_write(self) -> None: + def test_db_write(self) -> None: # create a new ChatDB instance chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name), pathlib.Path(self.db_path.name)) @@ -304,7 +326,7 @@ class TestChatDB(unittest.TestCase): self.assertEqual(chat_db.messages[2].file_path, pathlib.Path(self.db_path.name, '0003.txt')) self.assertEqual(chat_db.messages[3].file_path, pathlib.Path(self.db_path.name, '0004.yaml')) - def test_chat_db_read(self) -> None: + def test_db_read(self) -> None: # create a new ChatDB instance chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name), pathlib.Path(self.db_path.name)) @@ -367,7 +389,7 @@ class TestChatDB(unittest.TestCase): self.assertEqual(chat_db.messages[6].file_path, pathlib.Path(self.db_path.name, '0007.txt')) self.assertEqual(chat_db.messages[7].file_path, pathlib.Path(self.db_path.name, '0008.yaml')) - def test_chat_db_clear(self) -> None: + def test_cache_clear(self) -> None: # create a new ChatDB instance chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name), pathlib.Path(self.db_path.name)) @@ -412,7 +434,7 @@ class TestChatDB(unittest.TestCase): # but not the message with the cache dir path self.assertFalse(any(m.file_path == message_cache.file_path for m in chat_db.messages)) - def test_chat_db_add(self) -> None: + def test_add(self) -> None: # create a new ChatDB instance chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name), pathlib.Path(self.db_path.name)) @@ -443,7 +465,7 @@ class TestChatDB(unittest.TestCase): with self.assertRaises(ChatError): chat_db.cache_add([Message(Question("?"), file_path=pathlib.Path("foo"))]) - def test_chat_db_write_messages(self) -> None: + def test_msg_write(self) -> None: # create a new ChatDB instance chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name), pathlib.Path(self.db_path.name)) @@ -466,7 +488,7 @@ class TestChatDB(unittest.TestCase): self.assertEqual(len(cache_dir_files), 1) self.assertIn(pathlib.Path(self.cache_path.name, '123456.txt'), cache_dir_files) - def test_chat_db_update_messages(self) -> None: + def test_msg_update(self) -> None: # create a new ChatDB instance chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name), pathlib.Path(self.db_path.name)) @@ -494,7 +516,28 @@ class TestChatDB(unittest.TestCase): with self.assertRaises(ChatError): chat_db.msg_update([message1]) - def test_chat_db_latest_message(self) -> None: + def test_msg_find(self) -> None: + chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name), + pathlib.Path(self.db_path.name)) + # search for a DB file in memory + self.assertEqual(chat_db.msg_find([str(self.message1.file_path)], source='mem'), [self.message1]) + self.assertEqual(chat_db.msg_find(['0001.txt'], source='mem'), [self.message1]) + self.assertEqual(chat_db.msg_find(['0001'], source='mem'), [self.message1]) + # and on disk + self.assertEqual(chat_db.msg_find([str(self.message2.file_path)], source='db'), [self.message2]) + self.assertEqual(chat_db.msg_find(['0002.yaml'], source='db'), [self.message2]) + self.assertEqual(chat_db.msg_find(['0002'], source='db'), [self.message2]) + # now search the cache -> expect empty result + self.assertEqual(chat_db.msg_find([str(self.message3.file_path)], source='cache'), []) + self.assertEqual(chat_db.msg_find(['0003.txt'], source='cache'), []) + self.assertEqual(chat_db.msg_find(['0003'], source='cache'), []) + # search for multiple messages + search_names = ['0001', '0002.yaml', str(self.message3.file_path)] + expected_result = [self.message1, self.message2, self.message3] + result = chat_db.msg_find(search_names, source='all') + self.assertSequenceEqual(result, expected_result) + + def test_msg_latest(self) -> None: chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name), pathlib.Path(self.db_path.name)) self.assertEqual(chat_db.msg_latest(source='mem'), self.message4)