diff --git a/chatmastermind/chat.py b/chatmastermind/chat.py index dd18293..f3637de 100644 --- a/chatmastermind/chat.py +++ b/chatmastermind/chat.py @@ -6,7 +6,7 @@ from pathlib import Path from pprint import PrettyPrinter from pydoc import pager from dataclasses import dataclass -from typing import TypeVar, Type, Optional, ClassVar, Any, Callable +from typing import TypeVar, Type, Optional, ClassVar, Any, Callable, Literal from .message import Message, MessageFilter, MessageError, message_in from .tags import Tag @@ -142,15 +142,18 @@ class Chat: self.messages += messages self.sort() - def latest_message(self) -> Optional[Message]: + def latest_message(self, mfilter: Optional[MessageFilter] = None) -> Optional[Message]: """ - Returns the last added message (according to the file ID). + Return the last added message (according to the file ID) that matches the given filter. + When containing messages without a valid file_path, it returns the latest message in + the internal list. """ if len(self.messages) > 0: self.sort() - return self.messages[-1] - else: - return None + for m in reversed(self.messages): + if mfilter is None or m.match(mfilter): + return m + return None def find_messages(self, msg_names: list[str]) -> list[Message]: """ @@ -404,3 +407,32 @@ class ChatDB(Chat): # write the UPDATED messages if requested if write: self.write_messages(messages) + + def latest_message(self, + mfilter: Optional[MessageFilter] = None, + source: Literal['mem', 'disk', 'cache', 'db', 'all'] = 'mem') -> Optional[Message]: + """ + Return the last added message (according to the file ID) that matches the given filter. + Only consider messages with a valid file_path (except if source is 'mem'). + Searches one of the following sources: + * 'mem' : only search messages currently in memory + * 'disk' : search messages on disk (cache + DB directory), but not in memory + * 'cache': only search messages in the cache directory + * 'db' : only search messages in the DB directory + * 'all' : search all messages ('mem' + 'disk') + """ + source_messages: list[Message] = [] + if source == 'mem': + return super().latest_message(mfilter) + if source in ['cache', 'disk', 'all']: + source_messages += read_dir(self.cache_path, mfilter=mfilter) + if source in ['db', 'disk', 'all']: + source_messages += read_dir(self.db_path, mfilter=mfilter) + if source in ['all']: + # only consider messages with a valid file_path so they can be sorted + source_messages += [m for m in self.messages if (m.file_path is not None and (mfilter is None or m.match(mfilter)))] + source_messages.sort(key=lambda m: m.msg_id(), reverse=True) + for m in source_messages: + if mfilter is None or m.match(mfilter): + return m + return None diff --git a/tests/test_chat.py b/tests/test_chat.py index f34cb24..ca74725 100644 --- a/tests/test_chat.py +++ b/tests/test_chat.py @@ -84,6 +84,13 @@ class TestChat(unittest.TestCase): self.chat.remove_messages(['0003.txt']) self.assertListEqual(self.chat.messages, [self.message1, self.message2]) + def test_latest_message(self) -> None: + self.assertIsNone(self.chat.latest_message()) + self.chat.add_messages([self.message1]) + self.assertEqual(self.chat.latest_message(), self.message1) + self.chat.add_messages([self.message2]) + self.assertEqual(self.chat.latest_message(), self.message2) + @patch('sys.stdout', new_callable=StringIO) def test_print(self, mock_stdout: StringIO) -> None: self.chat.add_messages([self.message1, self.message2]) @@ -474,3 +481,23 @@ class TestChatDB(unittest.TestCase): answer=Answer("Answer 1")) with self.assertRaises(ChatError): chat_db.update_messages([message1]) + + def test_chat_db_latest_message(self) -> None: + chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name), + pathlib.Path(self.db_path.name)) + self.assertEqual(chat_db.latest_message(source='mem'), self.message4) + self.assertEqual(chat_db.latest_message(source='db'), self.message4) + self.assertEqual(chat_db.latest_message(source='disk'), self.message4) + self.assertEqual(chat_db.latest_message(source='all'), self.message4) + # the cache is currently empty: + self.assertIsNone(chat_db.latest_message(source='cache')) + # add new messages to the cache dir + new_message = Message(question=Question("New Question"), + answer=Answer("New Answer")) + chat_db.add_to_cache([new_message]) + self.assertEqual(chat_db.latest_message(source='cache'), new_message) + self.assertEqual(chat_db.latest_message(source='mem'), new_message) + self.assertEqual(chat_db.latest_message(source='disk'), new_message) + self.assertEqual(chat_db.latest_message(source='all'), new_message) + # the DB does not contain the new message + self.assertEqual(chat_db.latest_message(source='db'), self.message4)