10 Commits

4 changed files with 223 additions and 110 deletions
+100 -31
View File
@@ -6,7 +6,7 @@ from pathlib import Path
from pprint import PrettyPrinter from pprint import PrettyPrinter
from pydoc import pager from pydoc import pager
from dataclasses import dataclass from dataclasses import dataclass
from typing import TypeVar, Type, Optional, ClassVar, Any, Callable, Literal from typing import TypeVar, Type, Optional, ClassVar, Any, Callable, Literal, Union
from .configuration import default_config_file from .configuration import default_config_file
from .message import Message, MessageFilter, MessageError, message_in from .message import Message, MessageFilter, MessageError, message_in
from .tags import Tag from .tags import Tag
@@ -16,7 +16,7 @@ ChatDBInst = TypeVar('ChatDBInst', bound='ChatDB')
db_next_file = '.next' db_next_file = '.next'
ignored_files = [db_next_file, default_config_file] ignored_files = [db_next_file, default_config_file]
msg_place = Literal['mem', 'disk', 'cache', 'db', 'all'] msg_location = Literal['mem', 'disk', 'cache', 'db', 'all']
class ChatError(Exception): class ChatError(Exception):
@@ -107,7 +107,9 @@ def clear_dir(dir_path: Path,
""" """
file_iter = dir_path.glob(glob) if glob else dir_path.iterdir() file_iter = dir_path.glob(glob) if glob else dir_path.iterdir()
for file_path in file_iter: for file_path in file_iter:
if file_path.is_file() and file_path.suffix in Message.file_suffixes: if (file_path.is_file()
and file_path.name not in ignored_files # noqa: W503
and file_path.suffix in Message.file_suffixes): # noqa: W503
file_path.unlink(missing_ok=True) file_path.unlink(missing_ok=True)
@@ -373,11 +375,11 @@ class ChatDB(Chat):
self.msg_write(messages) self.msg_write(messages)
def msg_gather(self, def msg_gather(self,
source: msg_place, loc: msg_location,
require_file_path: bool = False, require_file_path: bool = False,
mfilter: Optional[MessageFilter] = None) -> list[Message]: mfilter: Optional[MessageFilter] = None) -> list[Message]:
""" """
Gather and return messages from the given source: Gather and return messages from the given locations:
* 'mem' : messages currently in memory * 'mem' : messages currently in memory
* 'disk' : messages on disk (cache + DB directory), but not in memory * 'disk' : messages on disk (cache + DB directory), but not in memory
* 'cache': messages in the cache directory * 'cache': messages in the cache directory
@@ -386,52 +388,63 @@ class ChatDB(Chat):
If 'require_file_path' is True, return only files with a valid file_path. If 'require_file_path' is True, return only files with a valid file_path.
""" """
source_messages: list[Message] = [] loc_messages: list[Message] = []
if source in ['mem', 'all']: if loc in ['mem', 'all']:
if require_file_path: if require_file_path:
source_messages += [m for m in self.messages if (m.file_path is not None and (mfilter is None or m.match(mfilter)))] loc_messages += [m for m in self.messages if (m.file_path is not None and (mfilter is None or m.match(mfilter)))]
else: else:
source_messages += [m for m in self.messages if (mfilter is None or m.match(mfilter))] loc_messages += [m for m in self.messages if (mfilter is None or m.match(mfilter))]
if source in ['cache', 'disk', 'all']: if loc in ['cache', 'disk', 'all']:
source_messages += read_dir(self.cache_path, mfilter=mfilter) loc_messages += read_dir(self.cache_path, mfilter=mfilter)
if source in ['db', 'disk', 'all']: if loc in ['db', 'disk', 'all']:
source_messages += read_dir(self.db_path, mfilter=mfilter) loc_messages += read_dir(self.db_path, mfilter=mfilter)
# remove_duplicates and sort the list # remove_duplicates and sort the list
unique_messages: list[Message] = [] unique_messages: list[Message] = []
for m in source_messages: for m in loc_messages:
if not message_in(m, unique_messages): if not message_in(m, unique_messages):
unique_messages.append(m) unique_messages.append(m)
try:
unique_messages.sort(key=lambda m: m.msg_id()) unique_messages.sort(key=lambda m: m.msg_id())
# messages in 'mem' can have an empty file_path
except MessageError:
pass
return unique_messages return unique_messages
def msg_find(self, def msg_find(self,
msg_names: list[str], msg_names: list[str],
source: msg_place = 'mem', loc: msg_location = 'mem',
) -> list[Message]: ) -> list[Message]:
""" """
Search and return the messages with the given names. Names can either be filenames Search and return the messages with the given names. Names can either be filenames
(with or without suffix), full paths or Message.msg_id(). Messages that can't be (with or without suffix), full paths or Message.msg_id(). Messages that can't be
found are ignored (i. e. the caller should check the result if they require all found are ignored (i. e. the caller should check the result if they require all
messages). messages).
Searches one of the following places: Searches one of the following locations:
* 'mem' : messages currently in memory * 'mem' : messages currently in memory
* 'disk' : messages on disk (cache + DB directory), but not in memory * 'disk' : messages on disk (cache + DB directory), but not in memory
* 'cache': messages in the cache directory * 'cache': messages in the cache directory
* 'db' : messages in the DB directory * 'db' : messages in the DB directory
* 'all' : all messages ('mem' + 'disk') * 'all' : all messages ('mem' + 'disk')
""" """
source_messages = self.msg_gather(source, require_file_path=True) loc_messages = self.msg_gather(loc, require_file_path=True)
return [m for m in source_messages return [m for m in loc_messages
if any((m.file_path and self.msg_name_matches(m.file_path, mn)) for mn in msg_names)] if any((m.file_path and self.msg_name_matches(m.file_path, mn)) for mn in msg_names)]
def msg_remove(self, msg_names: list[str]) -> None: def msg_remove(self, msg_names: list[str], loc: msg_location = 'mem') -> None:
""" """
Remove the messages with the given names. Names can either be filenames Remove the messages with the given names. Names can either be filenames
(with or without suffix), full paths or Message.msg_id(). Also deletes the (with or without suffix), full paths or Message.msg_id(). Also deletes the
files of all given messages with a valid file_path. files of all given messages with a valid file_path.
Delete files from one of the following locations:
* 'mem' : messages currently in memory
* 'disk' : messages on disk (cache + DB directory), but not in memory
* 'cache': messages in the cache directory
* 'db' : messages in the DB directory
* 'all' : all messages ('mem' + 'disk')
""" """
if loc != 'mem':
# delete the message files first # delete the message files first
rm_messages = self.msg_find(msg_names, source='all') rm_messages = self.msg_find(msg_names, loc=loc)
for m in rm_messages: for m in rm_messages:
if (m.file_path): if (m.file_path):
m.file_path.unlink() m.file_path.unlink()
@@ -440,11 +453,11 @@ class ChatDB(Chat):
def msg_latest(self, def msg_latest(self,
mfilter: Optional[MessageFilter] = None, mfilter: Optional[MessageFilter] = None,
source: msg_place = 'mem') -> Optional[Message]: loc: msg_location = 'mem') -> Optional[Message]:
""" """
Return the last added message (according to the file ID) that matches the given filter. Return the last added message (according to the file ID) that matches the given filter.
Only consider messages with a valid file_path (except if source is 'mem'). Only consider messages with a valid file_path (except if loc is 'mem').
Searches one of the following places: Searches one of the following locations:
* 'mem' : messages currently in memory * 'mem' : messages currently in memory
* 'disk' : messages on disk (cache + DB directory), but not in memory * 'disk' : messages on disk (cache + DB directory), but not in memory
* 'cache': messages in the cache directory * 'cache': messages in the cache directory
@@ -452,20 +465,44 @@ class ChatDB(Chat):
* 'all' : all messages ('mem' + 'disk') * 'all' : all messages ('mem' + 'disk')
""" """
# only consider messages with a valid file_path so they can be sorted # only consider messages with a valid file_path so they can be sorted
source_messages = self.msg_gather(source, require_file_path=True) loc_messages = self.msg_gather(loc, require_file_path=True)
source_messages.sort(key=lambda m: m.msg_id(), reverse=True) loc_messages.sort(key=lambda m: m.msg_id(), reverse=True)
for m in source_messages: for m in loc_messages:
if mfilter is None or m.match(mfilter): if mfilter is None or m.match(mfilter):
return m return m
return None return None
def cache_read(self) -> None: def msg_in_cache(self, message: Union[Message, str]) -> bool:
"""
Return true if the given Message (or filename or Message.msg_id())
is located in the cache directory. False otherwise.
"""
if isinstance(message, Message):
return (message.file_path is not None
and message.file_path.parent.samefile(self.cache_path) # noqa: W503
and message.file_path.exists()) # noqa: W503
else:
return len(self.msg_find([message], loc='cache')) > 0
def msg_in_db(self, message: Union[Message, str]) -> bool:
"""
Return true if the given Message (or filename or Message.msg_id())
is located in the DB directory. False otherwise.
"""
if isinstance(message, Message):
return (message.file_path is not None
and message.file_path.parent.samefile(self.db_path) # noqa: W503
and message.file_path.exists()) # noqa: W503
else:
return len(self.msg_find([message], loc='db')) > 0
def cache_read(self, glob: Optional[str] = None, mfilter: Optional[MessageFilter] = None) -> None:
""" """
Read messages from the cache directory. New ones are added to the internal list, Read messages from the cache directory. New ones are added to the internal list,
existing ones are replaced. A message is determined as 'existing' if a message existing ones are replaced. A message is determined as 'existing' if a message
with the same base filename (i. e. 'file_path.name') is already in the list. with the same base filename (i. e. 'file_path.name') is already in the list.
""" """
new_messages = read_dir(self.cache_path, self.glob, self.mfilter) new_messages = read_dir(self.cache_path, glob, mfilter)
# remove all messages from self.messages that are in the new list # remove all messages from self.messages that are in the new list
self.messages = [m for m in self.messages if not message_in(m, new_messages)] self.messages = [m for m in self.messages if not message_in(m, new_messages)]
# copy the messages from the temporary list to self.messages and sort them # copy the messages from the temporary list to self.messages and sort them
@@ -502,15 +539,31 @@ class ChatDB(Chat):
self.messages += messages self.messages += messages
self.msg_sort() self.msg_sort()
def cache_clear(self) -> None: def cache_clear(self, glob: Optional[str] = None) -> None:
""" """
Delete all message files from the cache dir and remove them from the internal list. Delete all message files from the cache dir and remove them from the internal list.
""" """
clear_dir(self.cache_path, self.glob) clear_dir(self.cache_path, glob)
# only keep messages from DB dir (or those that have not yet been written) # only keep messages from DB dir (or those that have not yet been written)
self.messages = [m for m in self.messages if not m.file_path or m.file_path.parent.samefile(self.db_path)] self.messages = [m for m in self.messages if not m.file_path or m.file_path.parent.samefile(self.db_path)]
def db_read(self) -> None: def cache_move(self, message: Message) -> None:
"""
Moves the given messages to the cache directory.
"""
# remember the old path (if any)
old_path: Optional[Path] = None
if message.file_path:
old_path = message.file_path
# write message to the new destination
self.cache_write([message])
# remove the old one (if any)
if old_path:
self.msg_remove([str(old_path)], loc='db')
# (re)add it to the internal list
self.msg_add([message])
def db_read(self, glob: Optional[str] = None, mfilter: Optional[MessageFilter] = None) -> None:
""" """
Read messages from the DB directory. New ones are added to the internal list, Read messages from the DB directory. New ones are added to the internal list,
existing ones are replaced. A message is determined as 'existing' if a message existing ones are replaced. A message is determined as 'existing' if a message
@@ -552,3 +605,19 @@ class ChatDB(Chat):
m.file_path = make_file_path(self.db_path, self.default_file_suffix, self.get_next_fid) m.file_path = make_file_path(self.db_path, self.default_file_suffix, self.get_next_fid)
self.messages += messages self.messages += messages
self.msg_sort() self.msg_sort()
def db_move(self, message: Message) -> None:
"""
Moves the given messages to the db directory.
"""
# remember the old path (if any)
old_path: Optional[Path] = None
if message.file_path:
old_path = message.file_path
# write message to the new destination
self.db_write([message])
# remove the old one (if any)
if old_path:
self.msg_remove([str(old_path)], loc='cache')
# (re)add it to the internal list
self.msg_add([message])
+21 -47
View File
@@ -1,4 +1,3 @@
import sys
import argparse import argparse
from pathlib import Path from pathlib import Path
from itertools import zip_longest from itertools import zip_longest
@@ -74,34 +73,10 @@ def create_message(chat: ChatDB, args: argparse.Namespace) -> Message:
tags=args.output_tags, # FIXME tags=args.output_tags, # FIXME
ai=args.AI, ai=args.AI,
model=args.model) model=args.model)
# only write the message (as a backup), don't add it chat.cache_add([message])
# to the current chat history
chat.cache_write([message])
return message return message
def make_request(ai: AI, chat: ChatDB, message: Message, args: argparse.Namespace) -> None:
"""
Make an AI request with the give AI, chat history, message and CLI arguments.
Print all answers.
"""
ai.print()
chat.print(paged=False)
print(message.to_str() + '\n')
response: AIResponse = ai.request(message,
chat,
args.num_answers,
args.output_tags)
# write all answers to the cache, don't add them to the chat history
chat.cache_write(response.messages)
for idx, msg in enumerate(response.messages):
print(f"=== ANSWER {idx+1} ===")
print(msg.answer)
if response.tokens:
print("===============")
print(response.tokens)
def question_cmd(args: argparse.Namespace, config: Config) -> None: def question_cmd(args: argparse.Namespace, config: Config) -> None:
""" """
Handler for the 'question' command. Handler for the 'question' command.
@@ -120,29 +95,28 @@ def question_cmd(args: argparse.Namespace, config: Config) -> None:
# create the correct AI instance # create the correct AI instance
ai: AI = create_ai(args, config) ai: AI = create_ai(args, config)
# === ASK ===
if args.ask: if args.ask:
make_request(ai, chat, message, args) ai.print()
# === REPEAT === chat.print(paged=False)
response: AIResponse = ai.request(message,
chat,
args.num_answers, # FIXME
args.output_tags) # FIXME
chat.msg_update([response.messages[0]])
chat.cache_add(response.messages[1:])
for idx, msg in enumerate(response.messages):
print(f"=== ANSWER {idx+1} ===")
print(msg.answer)
if response.tokens:
print("===============")
print(response.tokens)
elif args.repeat is not None: elif args.repeat is not None:
lmessage = chat.msg_latest(source='cache') lmessage = chat.msg_latest()
if lmessage is None: assert lmessage
print("No message found to repeat!") # TODO: repeat either the last question or the
sys.exit(1) # one(s) given in 'args.repeat' (overwrite
else: # existing ones if 'args.overwrite' is True)
print(f"Repeating message '{lmessage.msg_id()}':") pass
# overwrite the latest message if requested or empty
if lmessage.answer is None or args.overwrite is True:
lmessage.clear_answer()
make_request(ai, chat, lmessage, args)
# otherwise create a new one
else:
args.ask = [lmessage.question]
message = create_message(chat, args)
make_request(ai, chat, message, args)
# === PROCESS ===
elif args.process is not None: elif args.process is not None:
# TODO: process either all questions without an # TODO: process either all questions without an
# answer or the one(s) given in 'args.process' # answer or the one(s) given in 'args.process'
+1 -4
View File
@@ -393,9 +393,9 @@ class Message():
try: try:
data = yaml.load(fd, Loader=yaml.FullLoader) data = yaml.load(fd, Loader=yaml.FullLoader)
data[cls.file_yaml_key] = file_path data[cls.file_yaml_key] = file_path
return cls.from_dict(data)
except Exception: except Exception:
raise MessageError(f"'{file_path}' does not contain a valid message") raise MessageError(f"'{file_path}' does not contain a valid message")
return cls.from_dict(data)
def to_str(self, with_tags: bool = False, with_file: bool = False, source_code_only: bool = False) -> str: def to_str(self, with_tags: bool = False, with_file: bool = False, source_code_only: bool = False) -> str:
""" """
@@ -540,9 +540,6 @@ class Message():
if self.tags: if self.tags:
self.tags = rename_tags(self.tags, tags_rename) self.tags = rename_tags(self.tags, tags_rename)
def clear_answer(self) -> None:
self.answer = None
def msg_id(self) -> str: def msg_id(self) -> str:
""" """
Returns an ID that is unique throughout all messages in the same (DB) directory. Returns an ID that is unique throughout all messages in the same (DB) directory.
+96 -23
View File
@@ -2,6 +2,7 @@ import unittest
import pathlib import pathlib
import tempfile import tempfile
import time import time
import yaml
from io import StringIO from io import StringIO
from unittest.mock import patch from unittest.mock import patch
from chatmastermind.tags import TagLine from chatmastermind.tags import TagLine
@@ -184,6 +185,10 @@ class TestChatDB(unittest.TestCase):
for file in self.trash_files: for file in self.trash_files:
with open(pathlib.Path(self.db_path.name) / file, 'w') as f: with open(pathlib.Path(self.db_path.name) / file, 'w') as f:
f.write('test trash') f.write('test trash')
# also create a file with actual yaml content
with open(pathlib.Path(self.db_path.name) / 'content.yaml', 'w') as f:
yaml.dump({'key': 'value'}, f)
self.trash_files.append('content.yaml')
self.maxDiff = None self.maxDiff = None
def message_list(self, tmp_dir: tempfile.TemporaryDirectory) -> list[pathlib.Path]: def message_list(self, tmp_dir: tempfile.TemporaryDirectory) -> list[pathlib.Path]:
@@ -284,6 +289,25 @@ class TestChatDB(unittest.TestCase):
with open(chat_db.next_path, 'r') as f: with open(chat_db.next_path, 'r') as f:
self.assertEqual(f.read(), '7') self.assertEqual(f.read(), '7')
def test_msg_in_db_or_cache(self) -> None:
# create a new ChatDB instance
chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name),
pathlib.Path(self.db_path.name))
self.assertTrue(chat_db.msg_in_db(self.message1))
self.assertTrue(chat_db.msg_in_db(str(self.message1.file_path)))
self.assertTrue(chat_db.msg_in_db(self.message1.msg_id()))
self.assertFalse(chat_db.msg_in_cache(self.message1))
self.assertFalse(chat_db.msg_in_cache(str(self.message1.file_path)))
self.assertFalse(chat_db.msg_in_cache(self.message1.msg_id()))
# add new message to the cache dir
cache_message = Message(question=Question("Question 1"),
answer=Answer("Answer 1"))
chat_db.cache_add([cache_message])
self.assertTrue(chat_db.msg_in_cache(cache_message))
self.assertTrue(chat_db.msg_in_cache(cache_message.msg_id()))
self.assertFalse(chat_db.msg_in_db(cache_message))
self.assertFalse(chat_db.msg_in_db(str(cache_message.file_path)))
def test_db_write(self) -> None: def test_db_write(self) -> None:
# create a new ChatDB instance # create a new ChatDB instance
chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name), chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name),
@@ -526,43 +550,92 @@ class TestChatDB(unittest.TestCase):
chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name), chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name),
pathlib.Path(self.db_path.name)) pathlib.Path(self.db_path.name))
# search for a DB file in memory # search for a DB file in memory
self.assertEqual(chat_db.msg_find([str(self.message1.file_path)], source='mem'), [self.message1]) self.assertEqual(chat_db.msg_find([str(self.message1.file_path)], loc='mem'), [self.message1])
self.assertEqual(chat_db.msg_find([self.message1.msg_id()], source='mem'), [self.message1]) self.assertEqual(chat_db.msg_find([self.message1.msg_id()], loc='mem'), [self.message1])
self.assertEqual(chat_db.msg_find(['0001.txt'], source='mem'), [self.message1]) self.assertEqual(chat_db.msg_find(['0001.txt'], loc='mem'), [self.message1])
self.assertEqual(chat_db.msg_find(['0001'], source='mem'), [self.message1]) self.assertEqual(chat_db.msg_find(['0001'], loc='mem'), [self.message1])
# and on disk # and on disk
self.assertEqual(chat_db.msg_find([str(self.message2.file_path)], source='db'), [self.message2]) self.assertEqual(chat_db.msg_find([str(self.message2.file_path)], loc='db'), [self.message2])
self.assertEqual(chat_db.msg_find([self.message2.msg_id()], source='db'), [self.message2]) self.assertEqual(chat_db.msg_find([self.message2.msg_id()], loc='db'), [self.message2])
self.assertEqual(chat_db.msg_find(['0002.yaml'], source='db'), [self.message2]) self.assertEqual(chat_db.msg_find(['0002.yaml'], loc='db'), [self.message2])
self.assertEqual(chat_db.msg_find(['0002'], source='db'), [self.message2]) self.assertEqual(chat_db.msg_find(['0002'], loc='db'), [self.message2])
# now search the cache -> expect empty result # now search the cache -> expect empty result
self.assertEqual(chat_db.msg_find([str(self.message3.file_path)], source='cache'), []) self.assertEqual(chat_db.msg_find([str(self.message3.file_path)], loc='cache'), [])
self.assertEqual(chat_db.msg_find([self.message3.msg_id()], source='cache'), []) self.assertEqual(chat_db.msg_find([self.message3.msg_id()], loc='cache'), [])
self.assertEqual(chat_db.msg_find(['0003.txt'], source='cache'), []) self.assertEqual(chat_db.msg_find(['0003.txt'], loc='cache'), [])
self.assertEqual(chat_db.msg_find(['0003'], source='cache'), []) self.assertEqual(chat_db.msg_find(['0003'], loc='cache'), [])
# search for multiple messages # search for multiple messages
# -> search one twice, expect result to be unique # -> search one twice, expect result to be unique
search_names = ['0001', '0002.yaml', self.message3.msg_id(), str(self.message3.file_path)] search_names = ['0001', '0002.yaml', self.message3.msg_id(), str(self.message3.file_path)]
expected_result = [self.message1, self.message2, self.message3] expected_result = [self.message1, self.message2, self.message3]
result = chat_db.msg_find(search_names, source='all') result = chat_db.msg_find(search_names, loc='all')
self.assertSequenceEqual(result, expected_result) self.assertSequenceEqual(result, expected_result)
def test_msg_latest(self) -> None: def test_msg_latest(self) -> None:
chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name), chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name),
pathlib.Path(self.db_path.name)) pathlib.Path(self.db_path.name))
self.assertEqual(chat_db.msg_latest(source='mem'), self.message4) self.assertEqual(chat_db.msg_latest(loc='mem'), self.message4)
self.assertEqual(chat_db.msg_latest(source='db'), self.message4) self.assertEqual(chat_db.msg_latest(loc='db'), self.message4)
self.assertEqual(chat_db.msg_latest(source='disk'), self.message4) self.assertEqual(chat_db.msg_latest(loc='disk'), self.message4)
self.assertEqual(chat_db.msg_latest(source='all'), self.message4) self.assertEqual(chat_db.msg_latest(loc='all'), self.message4)
# the cache is currently empty: # the cache is currently empty:
self.assertIsNone(chat_db.msg_latest(source='cache')) self.assertIsNone(chat_db.msg_latest(loc='cache'))
# add new messages to the cache dir # add new messages to the cache dir
new_message = Message(question=Question("New Question"), new_message = Message(question=Question("New Question"),
answer=Answer("New Answer")) answer=Answer("New Answer"))
chat_db.cache_add([new_message]) chat_db.cache_add([new_message])
self.assertEqual(chat_db.msg_latest(source='cache'), new_message) self.assertEqual(chat_db.msg_latest(loc='cache'), new_message)
self.assertEqual(chat_db.msg_latest(source='mem'), new_message) self.assertEqual(chat_db.msg_latest(loc='mem'), new_message)
self.assertEqual(chat_db.msg_latest(source='disk'), new_message) self.assertEqual(chat_db.msg_latest(loc='disk'), new_message)
self.assertEqual(chat_db.msg_latest(source='all'), new_message) self.assertEqual(chat_db.msg_latest(loc='all'), new_message)
# the DB does not contain the new message # the DB does not contain the new message
self.assertEqual(chat_db.msg_latest(source='db'), self.message4) self.assertEqual(chat_db.msg_latest(loc='db'), self.message4)
def test_msg_gather(self) -> None:
chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name),
pathlib.Path(self.db_path.name))
all_messages = [self.message1, self.message2, self.message3, self.message4]
self.assertSequenceEqual(chat_db.msg_gather(loc='all'), all_messages)
self.assertSequenceEqual(chat_db.msg_gather(loc='db'), all_messages)
self.assertSequenceEqual(chat_db.msg_gather(loc='mem'), all_messages)
self.assertSequenceEqual(chat_db.msg_gather(loc='disk'), all_messages)
self.assertSequenceEqual(chat_db.msg_gather(loc='cache'), [])
# add a new message, but only to the internal list
new_message = Message(Question("What?"))
all_messages_mem = all_messages + [new_message]
chat_db.msg_add([new_message])
self.assertSequenceEqual(chat_db.msg_gather(loc='mem'), all_messages_mem)
self.assertSequenceEqual(chat_db.msg_gather(loc='all'), all_messages_mem)
# the nr. of messages on disk did not change -> expect old result
self.assertSequenceEqual(chat_db.msg_gather(loc='db'), all_messages)
self.assertSequenceEqual(chat_db.msg_gather(loc='disk'), all_messages)
self.assertSequenceEqual(chat_db.msg_gather(loc='cache'), [])
# test with MessageFilter
self.assertSequenceEqual(chat_db.msg_gather(loc='all', mfilter=MessageFilter(tags_or={Tag('tag1')})),
[self.message1])
self.assertSequenceEqual(chat_db.msg_gather(loc='disk', mfilter=MessageFilter(tags_or={Tag('tag2')})),
[self.message2])
self.assertSequenceEqual(chat_db.msg_gather(loc='cache', mfilter=MessageFilter(tags_or={Tag('tag3')})),
[])
self.assertSequenceEqual(chat_db.msg_gather(loc='mem', mfilter=MessageFilter(question_contains="What")),
[new_message])
def test_msg_move_and_gather(self) -> None:
chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name),
pathlib.Path(self.db_path.name))
all_messages = [self.message1, self.message2, self.message3, self.message4]
self.assertSequenceEqual(chat_db.msg_gather(loc='db'), all_messages)
self.assertSequenceEqual(chat_db.msg_gather(loc='cache'), [])
# move first message to the cache
chat_db.cache_move(self.message1)
self.assertSequenceEqual(chat_db.msg_gather(loc='cache'), [self.message1])
self.assertEqual(self.message1.file_path.parent, pathlib.Path(self.cache_path.name)) # type: ignore [union-attr]
self.assertSequenceEqual(chat_db.msg_gather(loc='db'), [self.message2, self.message3, self.message4])
self.assertSequenceEqual(chat_db.msg_gather(loc='all'), all_messages)
self.assertSequenceEqual(chat_db.msg_gather(loc='disk'), all_messages)
self.assertSequenceEqual(chat_db.msg_gather(loc='mem'), all_messages)
# now move first message back to the DB
chat_db.db_move(self.message1)
self.assertSequenceEqual(chat_db.msg_gather(loc='cache'), [])
self.assertEqual(self.message1.file_path.parent, pathlib.Path(self.db_path.name)) # type: ignore [union-attr]
self.assertSequenceEqual(chat_db.msg_gather(loc='db'), all_messages)