chat: implemented special version of 'latest_message()' for the ChatDB class

This commit is contained in:
2023-09-14 11:45:47 +02:00
parent 17a0264025
commit 5cb88dad1b
2 changed files with 65 additions and 6 deletions
+38 -6
View File
@@ -6,7 +6,7 @@ from pathlib import Path
from pprint import PrettyPrinter
from pydoc import pager
from dataclasses import dataclass
from typing import TypeVar, Type, Optional, ClassVar, Any, Callable
from typing import TypeVar, Type, Optional, ClassVar, Any, Callable, Literal
from .message import Message, MessageFilter, MessageError, message_in
from .tags import Tag
@@ -142,15 +142,18 @@ class Chat:
self.messages += messages
self.sort()
def latest_message(self) -> Optional[Message]:
def latest_message(self, mfilter: Optional[MessageFilter] = None) -> Optional[Message]:
"""
Returns the last added message (according to the file ID).
Return the last added message (according to the file ID) that matches the given filter.
When containing messages without a valid file_path, it returns the latest message in
the internal list.
"""
if len(self.messages) > 0:
self.sort()
return self.messages[-1]
else:
return None
for m in reversed(self.messages):
if mfilter is None or m.match(mfilter):
return m
return None
def find_messages(self, msg_names: list[str]) -> list[Message]:
"""
@@ -404,3 +407,32 @@ class ChatDB(Chat):
# write the UPDATED messages if requested
if write:
self.write_messages(messages)
def latest_message(self,
mfilter: Optional[MessageFilter] = None,
source: Literal['mem', 'disk', 'cache', 'db', 'all'] = 'mem') -> Optional[Message]:
"""
Return the last added message (according to the file ID) that matches the given filter.
Only consider messages with a valid file_path (except if source is 'mem').
Searches one of the following sources:
* 'mem' : only search messages currently in memory
* 'disk' : search messages on disk (cache + DB directory), but not in memory
* 'cache': only search messages in the cache directory
* 'db' : only search messages in the DB directory
* 'all' : search all messages ('mem' + 'disk')
"""
source_messages: list[Message] = []
if source == 'mem':
return super().latest_message(mfilter)
if source in ['cache', 'disk', 'all']:
source_messages += read_dir(self.cache_path, mfilter=mfilter)
if source in ['db', 'disk', 'all']:
source_messages += read_dir(self.db_path, mfilter=mfilter)
if source in ['all']:
# only consider messages with a valid file_path so they can be sorted
source_messages += [m for m in self.messages if (m.file_path is not None and (mfilter is None or m.match(mfilter)))]
source_messages.sort(key=lambda m: m.msg_id(), reverse=True)
for m in source_messages:
if mfilter is None or m.match(mfilter):
return m
return None