chat: implemented special version of 'latest_message()' for the ChatDB class
This commit is contained in:
+38
-6
@@ -6,7 +6,7 @@ from pathlib import Path
|
||||
from pprint import PrettyPrinter
|
||||
from pydoc import pager
|
||||
from dataclasses import dataclass
|
||||
from typing import TypeVar, Type, Optional, ClassVar, Any, Callable
|
||||
from typing import TypeVar, Type, Optional, ClassVar, Any, Callable, Literal
|
||||
from .message import Message, MessageFilter, MessageError, message_in
|
||||
from .tags import Tag
|
||||
|
||||
@@ -142,15 +142,18 @@ class Chat:
|
||||
self.messages += messages
|
||||
self.sort()
|
||||
|
||||
def latest_message(self) -> Optional[Message]:
|
||||
def latest_message(self, mfilter: Optional[MessageFilter] = None) -> Optional[Message]:
|
||||
"""
|
||||
Returns the last added message (according to the file ID).
|
||||
Return the last added message (according to the file ID) that matches the given filter.
|
||||
When containing messages without a valid file_path, it returns the latest message in
|
||||
the internal list.
|
||||
"""
|
||||
if len(self.messages) > 0:
|
||||
self.sort()
|
||||
return self.messages[-1]
|
||||
else:
|
||||
return None
|
||||
for m in reversed(self.messages):
|
||||
if mfilter is None or m.match(mfilter):
|
||||
return m
|
||||
return None
|
||||
|
||||
def find_messages(self, msg_names: list[str]) -> list[Message]:
|
||||
"""
|
||||
@@ -404,3 +407,32 @@ class ChatDB(Chat):
|
||||
# write the UPDATED messages if requested
|
||||
if write:
|
||||
self.write_messages(messages)
|
||||
|
||||
def latest_message(self,
|
||||
mfilter: Optional[MessageFilter] = None,
|
||||
source: Literal['mem', 'disk', 'cache', 'db', 'all'] = 'mem') -> Optional[Message]:
|
||||
"""
|
||||
Return the last added message (according to the file ID) that matches the given filter.
|
||||
Only consider messages with a valid file_path (except if source is 'mem').
|
||||
Searches one of the following sources:
|
||||
* 'mem' : only search messages currently in memory
|
||||
* 'disk' : search messages on disk (cache + DB directory), but not in memory
|
||||
* 'cache': only search messages in the cache directory
|
||||
* 'db' : only search messages in the DB directory
|
||||
* 'all' : search all messages ('mem' + 'disk')
|
||||
"""
|
||||
source_messages: list[Message] = []
|
||||
if source == 'mem':
|
||||
return super().latest_message(mfilter)
|
||||
if source in ['cache', 'disk', 'all']:
|
||||
source_messages += read_dir(self.cache_path, mfilter=mfilter)
|
||||
if source in ['db', 'disk', 'all']:
|
||||
source_messages += read_dir(self.db_path, mfilter=mfilter)
|
||||
if source in ['all']:
|
||||
# only consider messages with a valid file_path so they can be sorted
|
||||
source_messages += [m for m in self.messages if (m.file_path is not None and (mfilter is None or m.match(mfilter)))]
|
||||
source_messages.sort(key=lambda m: m.msg_id(), reverse=True)
|
||||
for m in source_messages:
|
||||
if mfilter is None or m.match(mfilter):
|
||||
return m
|
||||
return None
|
||||
|
||||
@@ -84,6 +84,13 @@ class TestChat(unittest.TestCase):
|
||||
self.chat.remove_messages(['0003.txt'])
|
||||
self.assertListEqual(self.chat.messages, [self.message1, self.message2])
|
||||
|
||||
def test_latest_message(self) -> None:
|
||||
self.assertIsNone(self.chat.latest_message())
|
||||
self.chat.add_messages([self.message1])
|
||||
self.assertEqual(self.chat.latest_message(), self.message1)
|
||||
self.chat.add_messages([self.message2])
|
||||
self.assertEqual(self.chat.latest_message(), self.message2)
|
||||
|
||||
@patch('sys.stdout', new_callable=StringIO)
|
||||
def test_print(self, mock_stdout: StringIO) -> None:
|
||||
self.chat.add_messages([self.message1, self.message2])
|
||||
@@ -474,3 +481,23 @@ class TestChatDB(unittest.TestCase):
|
||||
answer=Answer("Answer 1"))
|
||||
with self.assertRaises(ChatError):
|
||||
chat_db.update_messages([message1])
|
||||
|
||||
def test_chat_db_latest_message(self) -> None:
|
||||
chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name),
|
||||
pathlib.Path(self.db_path.name))
|
||||
self.assertEqual(chat_db.latest_message(source='mem'), self.message4)
|
||||
self.assertEqual(chat_db.latest_message(source='db'), self.message4)
|
||||
self.assertEqual(chat_db.latest_message(source='disk'), self.message4)
|
||||
self.assertEqual(chat_db.latest_message(source='all'), self.message4)
|
||||
# the cache is currently empty:
|
||||
self.assertIsNone(chat_db.latest_message(source='cache'))
|
||||
# add new messages to the cache dir
|
||||
new_message = Message(question=Question("New Question"),
|
||||
answer=Answer("New Answer"))
|
||||
chat_db.add_to_cache([new_message])
|
||||
self.assertEqual(chat_db.latest_message(source='cache'), new_message)
|
||||
self.assertEqual(chat_db.latest_message(source='mem'), new_message)
|
||||
self.assertEqual(chat_db.latest_message(source='disk'), new_message)
|
||||
self.assertEqual(chat_db.latest_message(source='all'), new_message)
|
||||
# the DB does not contain the new message
|
||||
self.assertEqual(chat_db.latest_message(source='db'), self.message4)
|
||||
|
||||
Reference in New Issue
Block a user