Refactoring, fixes and new features for the 'chat.py' module #12

Merged
juk0de merged 12 commits from chat_refactoring into main 2023-09-18 14:23:52 +02:00
2 changed files with 44 additions and 1 deletions
Showing only changes of commit 2fb7410b43 - Show all commits
+25 -1
View File
@@ -6,7 +6,7 @@ from pathlib import Path
from pprint import PrettyPrinter
from pydoc import pager
from dataclasses import dataclass
from typing import TypeVar, Type, Optional, ClassVar, Any, Callable, Literal
from typing import TypeVar, Type, Optional, ClassVar, Any, Callable, Literal, Union
from .configuration import default_config_file
from .message import Message, MessageFilter, MessageError, message_in
from .tags import Tag
@@ -466,6 +466,30 @@ class ChatDB(Chat):
return m
return None
def msg_in_cache(self, message: Union[Message, str]) -> bool:
"""
Return true if the given Message (or filename or Message.msg_id())
is located in the cache directory. False otherwise.
"""
if isinstance(message, Message):
return (message.file_path is not None
and message.file_path.parent.samefile(self.cache_path) # noqa: W503
and message.file_path.exists()) # noqa: W503
else:
return len(self.msg_find([message], loc='cache')) > 0
def msg_in_db(self, message: Union[Message, str]) -> bool:
"""
Return true if the given Message (or filename or Message.msg_id())
is located in the DB directory. False otherwise.
"""
if isinstance(message, Message):
return (message.file_path is not None
and message.file_path.parent.samefile(self.db_path) # noqa: W503
and message.file_path.exists()) # noqa: W503
else:
return len(self.msg_find([message], loc='db')) > 0
def cache_read(self) -> None:
"""
Read messages from the cache directory. New ones are added to the internal list,
+19
View File
@@ -289,6 +289,25 @@ class TestChatDB(unittest.TestCase):
with open(chat_db.next_path, 'r') as f:
self.assertEqual(f.read(), '7')
def test_msg_in_db_or_cache(self) -> None:
# create a new ChatDB instance
chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name),
pathlib.Path(self.db_path.name))
self.assertTrue(chat_db.msg_in_db(self.message1))
self.assertTrue(chat_db.msg_in_db(str(self.message1.file_path)))
self.assertTrue(chat_db.msg_in_db(self.message1.msg_id()))
self.assertFalse(chat_db.msg_in_cache(self.message1))
self.assertFalse(chat_db.msg_in_cache(str(self.message1.file_path)))
self.assertFalse(chat_db.msg_in_cache(self.message1.msg_id()))
# add new message to the cache dir
cache_message = Message(question=Question("Question 1"),
answer=Answer("Answer 1"))
chat_db.cache_add([cache_message])
self.assertTrue(chat_db.msg_in_cache(cache_message))
self.assertTrue(chat_db.msg_in_cache(cache_message.msg_id()))
self.assertFalse(chat_db.msg_in_db(cache_message))
self.assertFalse(chat_db.msg_in_db(str(cache_message.file_path)))
def test_db_write(self) -> None:
# create a new ChatDB instance
chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name),