Change type msg_location to an Enum instead of Literal to be able to get all values easy and improve type checks.

This commit is contained in:
Oleksandr Kozachuk
2023-10-19 16:00:44 +02:00
parent 5f29f60168
commit 9b0951cb3f
7 changed files with 95 additions and 86 deletions
+23 -14
View File
@@ -6,7 +6,8 @@ from pathlib import Path
from pprint import PrettyPrinter from pprint import PrettyPrinter
from pydoc import pager from pydoc import pager
from dataclasses import dataclass from dataclasses import dataclass
from typing import TypeVar, Type, Optional, Any, Callable, Literal, Union from enum import Enum
from typing import TypeVar, Type, Optional, Any, Callable, Union
from .configuration import default_config_file from .configuration import default_config_file
from .message import Message, MessageFilter, MessageError, MessageFormat, message_in, message_valid_formats from .message import Message, MessageFilter, MessageError, MessageFormat, message_in, message_valid_formats
from .tags import Tag from .tags import Tag
@@ -16,10 +17,17 @@ ChatDBInst = TypeVar('ChatDBInst', bound='ChatDB')
db_next_file = '.next' db_next_file = '.next'
ignored_files = [db_next_file, default_config_file] ignored_files = [db_next_file, default_config_file]
msg_location = Literal['mem', 'disk', 'cache', 'db', 'all']
msg_suffix = Message.file_suffix_write msg_suffix = Message.file_suffix_write
class msg_location(Enum):
MEM = 'mem'
DISK = 'disk'
CACHE = 'cache'
DB = 'db'
ALL = 'all'
class ChatError(Exception): class ChatError(Exception):
pass pass
@@ -304,7 +312,8 @@ class ChatDB(Chat):
cache_path: Path, cache_path: Path,
db_path: Path, db_path: Path,
glob: Optional[str] = None, glob: Optional[str] = None,
mfilter: Optional[MessageFilter] = None) -> ChatDBInst: mfilter: Optional[MessageFilter] = None,
loc: msg_location = msg_location.DB) -> ChatDBInst:
""" """
Create a 'ChatDB' instance from the given directory structure. Create a 'ChatDB' instance from the given directory structure.
Reads all messages from 'db_path' into the local message list. Reads all messages from 'db_path' into the local message list.
@@ -399,14 +408,14 @@ class ChatDB(Chat):
If 'require_file_path' is True, return only files with a valid file_path. If 'require_file_path' is True, return only files with a valid file_path.
""" """
loc_messages: list[Message] = [] loc_messages: list[Message] = []
if loc in ['mem', 'all']: if loc in [msg_location.MEM, msg_location.ALL]:
if require_file_path: if require_file_path:
loc_messages += [m for m in self.messages if (m.file_path is not None and (mfilter is None or m.match(mfilter)))] loc_messages += [m for m in self.messages if (m.file_path is not None and (mfilter is None or m.match(mfilter)))]
else: else:
loc_messages += [m for m in self.messages if (mfilter is None or m.match(mfilter))] loc_messages += [m for m in self.messages if (mfilter is None or m.match(mfilter))]
if loc in ['cache', 'disk', 'all']: if loc in [msg_location.CACHE, msg_location.DISK, msg_location.ALL]:
loc_messages += read_dir(self.cache_path, glob=glob, mfilter=mfilter) loc_messages += read_dir(self.cache_path, glob=glob, mfilter=mfilter)
if loc in ['db', 'disk', 'all']: if loc in [msg_location.DB, msg_location.DISK, msg_location.ALL]:
loc_messages += read_dir(self.db_path, glob=glob, mfilter=mfilter) loc_messages += read_dir(self.db_path, glob=glob, mfilter=mfilter)
# remove_duplicates and sort the list # remove_duplicates and sort the list
unique_messages: list[Message] = [] unique_messages: list[Message] = []
@@ -422,7 +431,7 @@ class ChatDB(Chat):
def msg_find(self, def msg_find(self,
msg_names: list[str], msg_names: list[str],
loc: msg_location = 'mem', loc: msg_location = msg_location.MEM,
) -> list[Message]: ) -> list[Message]:
""" """
Search and return the messages with the given names. Names can either be filenames Search and return the messages with the given names. Names can either be filenames
@@ -440,7 +449,7 @@ class ChatDB(Chat):
return [m for m in loc_messages return [m for m in loc_messages
if any((m.file_path and self.msg_name_matches(m.file_path, mn)) for mn in msg_names)] if any((m.file_path and self.msg_name_matches(m.file_path, mn)) for mn in msg_names)]
def msg_remove(self, msg_names: list[str], loc: msg_location = 'mem') -> None: def msg_remove(self, msg_names: list[str], loc: msg_location = msg_location.MEM) -> None:
""" """
Remove the messages with the given names. Names can either be filenames Remove the messages with the given names. Names can either be filenames
(with or without suffix), full paths or Message.msg_id(). Also deletes the (with or without suffix), full paths or Message.msg_id(). Also deletes the
@@ -452,7 +461,7 @@ class ChatDB(Chat):
* 'db' : messages in the DB directory * 'db' : messages in the DB directory
* 'all' : all messages ('mem' + 'disk') * 'all' : all messages ('mem' + 'disk')
""" """
if loc != 'mem': if loc != msg_location.MEM:
# delete the message files first # delete the message files first
rm_messages = self.msg_find(msg_names, loc=loc) rm_messages = self.msg_find(msg_names, loc=loc)
for m in rm_messages: for m in rm_messages:
@@ -463,7 +472,7 @@ class ChatDB(Chat):
def msg_latest(self, def msg_latest(self,
mfilter: Optional[MessageFilter] = None, mfilter: Optional[MessageFilter] = None,
loc: msg_location = 'mem') -> Optional[Message]: loc: msg_location = msg_location.MEM) -> Optional[Message]:
""" """
Return the last added message (according to the file ID) that matches the given filter. Return the last added message (according to the file ID) that matches the given filter.
Only consider messages with a valid file_path (except if loc is 'mem'). Only consider messages with a valid file_path (except if loc is 'mem').
@@ -492,7 +501,7 @@ class ChatDB(Chat):
and message.file_path.parent.samefile(self.cache_path) # noqa: W503 and message.file_path.parent.samefile(self.cache_path) # noqa: W503
and message.file_path.exists()) # noqa: W503 and message.file_path.exists()) # noqa: W503
else: else:
return len(self.msg_find([message], loc='cache')) > 0 return len(self.msg_find([message], loc=msg_location.CACHE)) > 0
def msg_in_db(self, message: Union[Message, str]) -> bool: def msg_in_db(self, message: Union[Message, str]) -> bool:
""" """
@@ -504,7 +513,7 @@ class ChatDB(Chat):
and message.file_path.parent.samefile(self.db_path) # noqa: W503 and message.file_path.parent.samefile(self.db_path) # noqa: W503
and message.file_path.exists()) # noqa: W503 and message.file_path.exists()) # noqa: W503
else: else:
return len(self.msg_find([message], loc='db')) > 0 return len(self.msg_find([message], loc=msg_location.DB)) > 0
def cache_read(self, glob: Optional[str] = None, mfilter: Optional[MessageFilter] = None) -> None: def cache_read(self, glob: Optional[str] = None, mfilter: Optional[MessageFilter] = None) -> None:
""" """
@@ -569,7 +578,7 @@ class ChatDB(Chat):
self.cache_write([message]) self.cache_write([message])
# remove the old one (if any) # remove the old one (if any)
if old_path: if old_path:
self.msg_remove([str(old_path)], loc='db') self.msg_remove([str(old_path)], loc=msg_location.DB)
# (re)add it to the internal list # (re)add it to the internal list
self.msg_add([message]) self.msg_add([message])
@@ -628,6 +637,6 @@ class ChatDB(Chat):
self.db_write([message]) self.db_write([message])
# remove the old one (if any) # remove the old one (if any)
if old_path: if old_path:
self.msg_remove([str(old_path)], loc='cache') self.msg_remove([str(old_path)], loc=msg_location.CACHE)
# (re)add it to the internal list # (re)add it to the internal list
self.msg_add([message]) self.msg_add([message])
+4 -4
View File
@@ -2,7 +2,7 @@ import sys
import argparse import argparse
from pathlib import Path from pathlib import Path
from ..configuration import Config from ..configuration import Config
from ..chat import ChatDB from ..chat import ChatDB, msg_location
from ..message import MessageFilter, Message from ..message import MessageFilter, Message
@@ -17,7 +17,7 @@ def convert_messages(args: argparse.Namespace, config: Config) -> None:
chat = ChatDB.from_dir(Path(config.cache), chat = ChatDB.from_dir(Path(config.cache),
Path(config.db)) Path(config.db))
# read all known message files # read all known message files
msgs = chat.msg_gather(loc='disk', glob='*.*') msgs = chat.msg_gather(loc=msg_location.DISK, glob='*.*')
# make a set of all message IDs # make a set of all message IDs
msg_ids = set([m.msg_id() for m in msgs]) msg_ids = set([m.msg_id() for m in msgs])
# set requested format and write all messages # set requested format and write all messages
@@ -29,14 +29,14 @@ def convert_messages(args: argparse.Namespace, config: Config) -> None:
m.file_path = m.file_path.with_suffix('') m.file_path = m.file_path.with_suffix('')
chat.msg_write(msgs) chat.msg_write(msgs)
# read all messages with the current default suffix # read all messages with the current default suffix
msgs = chat.msg_gather(loc='disk', glob=f'*{msg_suffix}') msgs = chat.msg_gather(loc=msg_location.DISK, glob=f'*{msg_suffix}')
# make sure we converted all of the original messages # make sure we converted all of the original messages
for mid in msg_ids: for mid in msg_ids:
if not any(mid == m.msg_id() for m in msgs): if not any(mid == m.msg_id() for m in msgs):
print(f"Message '{mid}' has not been found after conversion. Aborting.") print(f"Message '{mid}' has not been found after conversion. Aborting.")
sys.exit(1) sys.exit(1)
# delete messages with old suffixes # delete messages with old suffixes
msgs = chat.msg_gather(loc='disk', glob='*.*') msgs = chat.msg_gather(loc=msg_location.DISK, glob='*.*')
for m in msgs: for m in msgs:
if m.file_path and m.file_path.suffix != msg_suffix: if m.file_path and m.file_path.suffix != msg_suffix:
m.rm_file() m.rm_file()
+2 -2
View File
@@ -3,7 +3,7 @@ import argparse
from pathlib import Path from pathlib import Path
from ..configuration import Config from ..configuration import Config
from ..message import Message, MessageError from ..message import Message, MessageError
from ..chat import ChatDB from ..chat import ChatDB, msg_location
def print_message(message: Message, args: argparse.Namespace) -> None: def print_message(message: Message, args: argparse.Namespace) -> None:
@@ -38,7 +38,7 @@ def print_cmd(args: argparse.Namespace, config: Config) -> None:
# print latest message # print latest message
elif args.latest: elif args.latest:
chat = ChatDB.from_dir(Path(config.cache), Path(config.db)) chat = ChatDB.from_dir(Path(config.cache), Path(config.db))
latest = chat.msg_latest(loc='disk') latest = chat.msg_latest(loc=msg_location.DISK)
if not latest: if not latest:
print("No message found!") print("No message found!")
sys.exit(1) sys.exit(1)
+3 -3
View File
@@ -4,7 +4,7 @@ from pathlib import Path
from itertools import zip_longest from itertools import zip_longest
from copy import deepcopy from copy import deepcopy
from ..configuration import Config from ..configuration import Config
from ..chat import ChatDB from ..chat import ChatDB, msg_location
from ..message import Message, MessageFilter, MessageError, Question, source_code from ..message import Message, MessageFilter, MessageError, Question, source_code
from ..ai_factory import create_ai from ..ai_factory import create_ai
from ..ai import AI, AIResponse from ..ai import AI, AIResponse
@@ -202,14 +202,14 @@ def question_cmd(args: argparse.Namespace, config: Config) -> None:
repeat_msgs: list[Message] = [] repeat_msgs: list[Message] = []
# repeat latest message # repeat latest message
if len(args.repeat) == 0: if len(args.repeat) == 0:
lmessage = chat.msg_latest(loc='cache') lmessage = chat.msg_latest(loc=msg_location.CACHE)
if lmessage is None: if lmessage is None:
print("No message found to repeat!") print("No message found to repeat!")
sys.exit(1) sys.exit(1)
repeat_msgs.append(lmessage) repeat_msgs.append(lmessage)
# repeat given message(s) # repeat given message(s)
else: else:
repeat_msgs = chat.msg_find(args.repeat, loc='disk') repeat_msgs = chat.msg_find(args.repeat, loc=msg_location.DISK)
repeat_messages(repeat_msgs, chat, args, config) repeat_messages(repeat_msgs, chat, args, config)
# === PROCESS === # === PROCESS ===
elif args.process is not None: elif args.process is not None:
+47 -47
View File
@@ -7,7 +7,7 @@ from io import StringIO
from unittest.mock import patch from unittest.mock import patch
from chatmastermind.tags import TagLine from chatmastermind.tags import TagLine
from chatmastermind.message import Message, Question, Answer, Tag, MessageFilter from chatmastermind.message import Message, Question, Answer, Tag, MessageFilter
from chatmastermind.chat import Chat, ChatDB, ChatError from chatmastermind.chat import Chat, ChatDB, ChatError, msg_location
msg_suffix: str = Message.file_suffix_write msg_suffix: str = Message.file_suffix_write
@@ -595,92 +595,92 @@ class TestChatDB(TestChatBase):
chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name), chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name),
pathlib.Path(self.db_path.name)) pathlib.Path(self.db_path.name))
# search for a DB file in memory # search for a DB file in memory
self.assertEqual(chat_db.msg_find([str(self.message1.file_path)], loc='mem'), [self.message1]) self.assertEqual(chat_db.msg_find([str(self.message1.file_path)], loc=msg_location.MEM), [self.message1])
self.assertEqual(chat_db.msg_find([self.message1.msg_id()], loc='mem'), [self.message1]) self.assertEqual(chat_db.msg_find([self.message1.msg_id()], loc=msg_location.MEM), [self.message1])
self.assertEqual(chat_db.msg_find(['0001.msg'], loc='mem'), [self.message1]) self.assertEqual(chat_db.msg_find(['0001.msg'], loc=msg_location.MEM), [self.message1])
self.assertEqual(chat_db.msg_find(['0001'], loc='mem'), [self.message1]) self.assertEqual(chat_db.msg_find(['0001'], loc=msg_location.MEM), [self.message1])
# and on disk # and on disk
self.assertEqual(chat_db.msg_find([str(self.message2.file_path)], loc='db'), [self.message2]) self.assertEqual(chat_db.msg_find([str(self.message2.file_path)], loc=msg_location.DB), [self.message2])
self.assertEqual(chat_db.msg_find([self.message2.msg_id()], loc='db'), [self.message2]) self.assertEqual(chat_db.msg_find([self.message2.msg_id()], loc=msg_location.DB), [self.message2])
self.assertEqual(chat_db.msg_find(['0002.msg'], loc='db'), [self.message2]) self.assertEqual(chat_db.msg_find(['0002.msg'], loc=msg_location.DB), [self.message2])
self.assertEqual(chat_db.msg_find(['0002'], loc='db'), [self.message2]) self.assertEqual(chat_db.msg_find(['0002'], loc=msg_location.DB), [self.message2])
# now search the cache -> expect empty result # now search the cache -> expect empty result
self.assertEqual(chat_db.msg_find([str(self.message3.file_path)], loc='cache'), []) self.assertEqual(chat_db.msg_find([str(self.message3.file_path)], loc=msg_location.CACHE), [])
self.assertEqual(chat_db.msg_find([self.message3.msg_id()], loc='cache'), []) self.assertEqual(chat_db.msg_find([self.message3.msg_id()], loc=msg_location.CACHE), [])
self.assertEqual(chat_db.msg_find(['0003.msg'], loc='cache'), []) self.assertEqual(chat_db.msg_find(['0003.msg'], loc=msg_location.CACHE), [])
self.assertEqual(chat_db.msg_find(['0003'], loc='cache'), []) self.assertEqual(chat_db.msg_find(['0003'], loc=msg_location.CACHE), [])
# search for multiple messages # search for multiple messages
# -> search one twice, expect result to be unique # -> search one twice, expect result to be unique
search_names = ['0001', '0002.msg', self.message3.msg_id(), str(self.message3.file_path)] search_names = ['0001', '0002.msg', self.message3.msg_id(), str(self.message3.file_path)]
expected_result = [self.message1, self.message2, self.message3] expected_result = [self.message1, self.message2, self.message3]
result = chat_db.msg_find(search_names, loc='all') result = chat_db.msg_find(search_names, loc=msg_location.ALL)
self.assert_messages_equal(result, expected_result) self.assert_messages_equal(result, expected_result)
def test_msg_latest(self) -> None: def test_msg_latest(self) -> None:
chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name), chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name),
pathlib.Path(self.db_path.name)) pathlib.Path(self.db_path.name))
self.assertEqual(chat_db.msg_latest(loc='mem'), self.message4) self.assertEqual(chat_db.msg_latest(loc=msg_location.MEM), self.message4)
self.assertEqual(chat_db.msg_latest(loc='db'), self.message4) self.assertEqual(chat_db.msg_latest(loc=msg_location.DB), self.message4)
self.assertEqual(chat_db.msg_latest(loc='disk'), self.message4) self.assertEqual(chat_db.msg_latest(loc=msg_location.DISK), self.message4)
self.assertEqual(chat_db.msg_latest(loc='all'), self.message4) self.assertEqual(chat_db.msg_latest(loc=msg_location.ALL), self.message4)
# the cache is currently empty: # the cache is currently empty:
self.assertIsNone(chat_db.msg_latest(loc='cache')) self.assertIsNone(chat_db.msg_latest(loc=msg_location.CACHE))
# add new messages to the cache dir # add new messages to the cache dir
new_message = Message(question=Question("New Question"), new_message = Message(question=Question("New Question"),
answer=Answer("New Answer")) answer=Answer("New Answer"))
chat_db.cache_add([new_message]) chat_db.cache_add([new_message])
self.assertEqual(chat_db.msg_latest(loc='cache'), new_message) self.assertEqual(chat_db.msg_latest(loc=msg_location.CACHE), new_message)
self.assertEqual(chat_db.msg_latest(loc='mem'), new_message) self.assertEqual(chat_db.msg_latest(loc=msg_location.MEM), new_message)
self.assertEqual(chat_db.msg_latest(loc='disk'), new_message) self.assertEqual(chat_db.msg_latest(loc=msg_location.DISK), new_message)
self.assertEqual(chat_db.msg_latest(loc='all'), new_message) self.assertEqual(chat_db.msg_latest(loc=msg_location.ALL), new_message)
# the DB does not contain the new message # the DB does not contain the new message
self.assertEqual(chat_db.msg_latest(loc='db'), self.message4) self.assertEqual(chat_db.msg_latest(loc=msg_location.DB), self.message4)
def test_msg_gather(self) -> None: def test_msg_gather(self) -> None:
chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name), chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name),
pathlib.Path(self.db_path.name)) pathlib.Path(self.db_path.name))
all_messages = [self.message1, self.message2, self.message3, self.message4] all_messages = [self.message1, self.message2, self.message3, self.message4]
self.assert_messages_equal(chat_db.msg_gather(loc='all'), all_messages) self.assert_messages_equal(chat_db.msg_gather(loc=msg_location.ALL), all_messages)
self.assert_messages_equal(chat_db.msg_gather(loc='db'), all_messages) self.assert_messages_equal(chat_db.msg_gather(loc=msg_location.DB), all_messages)
self.assert_messages_equal(chat_db.msg_gather(loc='mem'), all_messages) self.assert_messages_equal(chat_db.msg_gather(loc=msg_location.MEM), all_messages)
self.assert_messages_equal(chat_db.msg_gather(loc='disk'), all_messages) self.assert_messages_equal(chat_db.msg_gather(loc=msg_location.DISK), all_messages)
self.assert_messages_equal(chat_db.msg_gather(loc='cache'), []) self.assert_messages_equal(chat_db.msg_gather(loc=msg_location.CACHE), [])
# add a new message, but only to the internal list # add a new message, but only to the internal list
new_message = Message(Question("What?")) new_message = Message(Question("What?"))
all_messages_mem = all_messages + [new_message] all_messages_mem = all_messages + [new_message]
chat_db.msg_add([new_message]) chat_db.msg_add([new_message])
self.assert_messages_equal(chat_db.msg_gather(loc='mem'), all_messages_mem) self.assert_messages_equal(chat_db.msg_gather(loc=msg_location.MEM), all_messages_mem)
self.assert_messages_equal(chat_db.msg_gather(loc='all'), all_messages_mem) self.assert_messages_equal(chat_db.msg_gather(loc=msg_location.ALL), all_messages_mem)
# the nr. of messages on disk did not change -> expect old result # the nr. of messages on disk did not change -> expect old result
self.assert_messages_equal(chat_db.msg_gather(loc='db'), all_messages) self.assert_messages_equal(chat_db.msg_gather(loc=msg_location.DB), all_messages)
self.assert_messages_equal(chat_db.msg_gather(loc='disk'), all_messages) self.assert_messages_equal(chat_db.msg_gather(loc=msg_location.DISK), all_messages)
self.assert_messages_equal(chat_db.msg_gather(loc='cache'), []) self.assert_messages_equal(chat_db.msg_gather(loc=msg_location.CACHE), [])
# test with MessageFilter # test with MessageFilter
self.assert_messages_equal(chat_db.msg_gather(loc='all', mfilter=MessageFilter(tags_or={Tag('tag1')})), self.assert_messages_equal(chat_db.msg_gather(loc=msg_location.ALL, mfilter=MessageFilter(tags_or={Tag('tag1')})),
[self.message1]) [self.message1])
self.assert_messages_equal(chat_db.msg_gather(loc='disk', mfilter=MessageFilter(tags_or={Tag('tag2')})), self.assert_messages_equal(chat_db.msg_gather(loc=msg_location.DISK, mfilter=MessageFilter(tags_or={Tag('tag2')})),
[self.message2]) [self.message2])
self.assert_messages_equal(chat_db.msg_gather(loc='cache', mfilter=MessageFilter(tags_or={Tag('tag3')})), self.assert_messages_equal(chat_db.msg_gather(loc=msg_location.CACHE, mfilter=MessageFilter(tags_or={Tag('tag3')})),
[]) [])
self.assert_messages_equal(chat_db.msg_gather(loc='mem', mfilter=MessageFilter(question_contains="What")), self.assert_messages_equal(chat_db.msg_gather(loc=msg_location.MEM, mfilter=MessageFilter(question_contains="What")),
[new_message]) [new_message])
def test_msg_move_and_gather(self) -> None: def test_msg_move_and_gather(self) -> None:
chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name), chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name),
pathlib.Path(self.db_path.name)) pathlib.Path(self.db_path.name))
all_messages = [self.message1, self.message2, self.message3, self.message4] all_messages = [self.message1, self.message2, self.message3, self.message4]
self.assert_messages_equal(chat_db.msg_gather(loc='db'), all_messages) self.assert_messages_equal(chat_db.msg_gather(loc=msg_location.DB), all_messages)
self.assert_messages_equal(chat_db.msg_gather(loc='cache'), []) self.assert_messages_equal(chat_db.msg_gather(loc=msg_location.CACHE), [])
# move first message to the cache # move first message to the cache
chat_db.cache_move(self.message1) chat_db.cache_move(self.message1)
self.assert_messages_equal(chat_db.msg_gather(loc='cache'), [self.message1]) self.assert_messages_equal(chat_db.msg_gather(loc=msg_location.CACHE), [self.message1])
self.assertEqual(self.message1.file_path.parent, pathlib.Path(self.cache_path.name)) # type: ignore [union-attr] self.assertEqual(self.message1.file_path.parent, pathlib.Path(self.cache_path.name)) # type: ignore [union-attr]
self.assert_messages_equal(chat_db.msg_gather(loc='db'), [self.message2, self.message3, self.message4]) self.assert_messages_equal(chat_db.msg_gather(loc=msg_location.DB), [self.message2, self.message3, self.message4])
self.assert_messages_equal(chat_db.msg_gather(loc='all'), all_messages) self.assert_messages_equal(chat_db.msg_gather(loc=msg_location.ALL), all_messages)
self.assert_messages_equal(chat_db.msg_gather(loc='disk'), all_messages) self.assert_messages_equal(chat_db.msg_gather(loc=msg_location.DISK), all_messages)
self.assert_messages_equal(chat_db.msg_gather(loc='mem'), all_messages) self.assert_messages_equal(chat_db.msg_gather(loc=msg_location.MEM), all_messages)
# now move first message back to the DB # now move first message back to the DB
chat_db.db_move(self.message1) chat_db.db_move(self.message1)
self.assert_messages_equal(chat_db.msg_gather(loc='cache'), []) self.assert_messages_equal(chat_db.msg_gather(loc=msg_location.CACHE), [])
self.assertEqual(self.message1.file_path.parent, pathlib.Path(self.db_path.name)) # type: ignore [union-attr] self.assertEqual(self.message1.file_path.parent, pathlib.Path(self.db_path.name)) # type: ignore [union-attr]
self.assert_messages_equal(chat_db.msg_gather(loc='db'), all_messages) self.assert_messages_equal(chat_db.msg_gather(loc=msg_location.DB), all_messages)
+2 -2
View File
@@ -4,7 +4,7 @@ import tempfile
import yaml import yaml
from pathlib import Path from pathlib import Path
from chatmastermind.message import Message, Question from chatmastermind.message import Message, Question
from chatmastermind.chat import ChatDB, ChatError from chatmastermind.chat import ChatDB, ChatError, msg_location
from chatmastermind.configuration import Config from chatmastermind.configuration import Config
from chatmastermind.commands.hist import convert_messages from chatmastermind.commands.hist import convert_messages
@@ -41,7 +41,7 @@ class TestConvertMessages(unittest.TestCase):
def test_convert_messages(self) -> None: def test_convert_messages(self) -> None:
self.args.convert = 'yaml' self.args.convert = 'yaml'
convert_messages(self.args, self.config) convert_messages(self.args, self.config)
msgs = self.chat.msg_gather(loc='disk', glob='*.*') msgs = self.chat.msg_gather(loc=msg_location.DISK, glob='*.*')
# Check if the number of messages is the same as before # Check if the number of messages is the same as before
self.assertEqual(len(msgs), len(self.messages)) self.assertEqual(len(msgs), len(self.messages))
# Check if all messages have the requested suffix # Check if all messages have the requested suffix
+14 -14
View File
@@ -9,7 +9,7 @@ from chatmastermind.configuration import Config
from chatmastermind.commands.question import create_message, question_cmd from chatmastermind.commands.question import create_message, question_cmd
from chatmastermind.tags import Tag from chatmastermind.tags import Tag
from chatmastermind.message import Message, Question, Answer from chatmastermind.message import Message, Question, Answer
from chatmastermind.chat import Chat, ChatDB from chatmastermind.chat import Chat, ChatDB, msg_location
from chatmastermind.ai import AIError from chatmastermind.ai import AIError
from .test_common import TestWithFakeAI from .test_common import TestWithFakeAI
@@ -279,7 +279,7 @@ class TestQuestionCmdAsk(TestQuestionCmd):
# check for the expected message files # check for the expected message files
chat = ChatDB.from_dir(Path(self.cache_dir.name), chat = ChatDB.from_dir(Path(self.cache_dir.name),
Path(self.db_dir.name)) Path(self.db_dir.name))
cached_msg = chat.msg_gather(loc='cache') cached_msg = chat.msg_gather(loc=msg_location.CACHE)
self.assertEqual(len(self.message_list(self.cache_dir)), 1) self.assertEqual(len(self.message_list(self.cache_dir)), 1)
self.assert_msgs_equal_except_file_path(cached_msg, expected_responses) self.assert_msgs_equal_except_file_path(cached_msg, expected_responses)
@@ -337,7 +337,7 @@ class TestQuestionCmdAsk(TestQuestionCmd):
# check for the expected message files # check for the expected message files
chat = ChatDB.from_dir(Path(self.cache_dir.name), chat = ChatDB.from_dir(Path(self.cache_dir.name),
Path(self.db_dir.name)) Path(self.db_dir.name))
cached_msg = chat.msg_gather(loc='cache') cached_msg = chat.msg_gather(loc=msg_location.CACHE)
self.assertEqual(len(self.message_list(self.cache_dir)), 1) self.assertEqual(len(self.message_list(self.cache_dir)), 1)
self.assert_msgs_equal_except_file_path(cached_msg, [expected_question]) self.assert_msgs_equal_except_file_path(cached_msg, [expected_question])
@@ -375,7 +375,7 @@ class TestQuestionCmdRepeat(TestQuestionCmd):
# we expect the original message + the one with the new response # we expect the original message + the one with the new response
expected_responses = [message] + [expected_response] expected_responses = [message] + [expected_response]
question_cmd(self.args, self.config) question_cmd(self.args, self.config)
cached_msg = chat.msg_gather(loc='cache') cached_msg = chat.msg_gather(loc=msg_location.CACHE)
print(self.message_list(self.cache_dir)) print(self.message_list(self.cache_dir))
self.assertEqual(len(self.message_list(self.cache_dir)), 2) self.assertEqual(len(self.message_list(self.cache_dir)), 2)
self.assert_msgs_equal_except_file_path(cached_msg, expected_responses) self.assert_msgs_equal_except_file_path(cached_msg, expected_responses)
@@ -396,7 +396,7 @@ class TestQuestionCmdRepeat(TestQuestionCmd):
model=self.args.model, model=self.args.model,
file_path=Path(self.cache_dir.name) / f'0001{msg_suffix}') file_path=Path(self.cache_dir.name) / f'0001{msg_suffix}')
chat.msg_write([message]) chat.msg_write([message])
cached_msg = chat.msg_gather(loc='cache') cached_msg = chat.msg_gather(loc=msg_location.CACHE)
assert cached_msg[0].file_path assert cached_msg[0].file_path
cached_msg_file_id = cached_msg[0].file_path.stem cached_msg_file_id = cached_msg[0].file_path.stem
@@ -412,7 +412,7 @@ class TestQuestionCmdRepeat(TestQuestionCmd):
tags=message.tags, tags=message.tags,
file_path=Path('<NOT COMPARED>')) file_path=Path('<NOT COMPARED>'))
question_cmd(self.args, self.config) question_cmd(self.args, self.config)
cached_msg = chat.msg_gather(loc='cache') cached_msg = chat.msg_gather(loc=msg_location.CACHE)
self.assertEqual(len(self.message_list(self.cache_dir)), 1) self.assertEqual(len(self.message_list(self.cache_dir)), 1)
self.assert_msgs_equal_except_file_path(cached_msg, [expected_response]) self.assert_msgs_equal_except_file_path(cached_msg, [expected_response])
# also check that the file ID has not been changed # also check that the file ID has not been changed
@@ -435,7 +435,7 @@ class TestQuestionCmdRepeat(TestQuestionCmd):
model=self.args.model, model=self.args.model,
file_path=Path(self.cache_dir.name) / f'0001{msg_suffix}') file_path=Path(self.cache_dir.name) / f'0001{msg_suffix}')
chat.msg_write([message]) chat.msg_write([message])
cached_msg = chat.msg_gather(loc='cache') cached_msg = chat.msg_gather(loc=msg_location.CACHE)
assert cached_msg[0].file_path assert cached_msg[0].file_path
cached_msg_file_id = cached_msg[0].file_path.stem cached_msg_file_id = cached_msg[0].file_path.stem
@@ -452,7 +452,7 @@ class TestQuestionCmdRepeat(TestQuestionCmd):
tags=message.tags, tags=message.tags,
file_path=Path('<NOT COMPARED>')) file_path=Path('<NOT COMPARED>'))
question_cmd(self.args, self.config) question_cmd(self.args, self.config)
cached_msg = chat.msg_gather(loc='cache') cached_msg = chat.msg_gather(loc=msg_location.CACHE)
self.assertEqual(len(self.message_list(self.cache_dir)), 1) self.assertEqual(len(self.message_list(self.cache_dir)), 1)
self.assert_msgs_equal_except_file_path(cached_msg, [expected_response]) self.assert_msgs_equal_except_file_path(cached_msg, [expected_response])
# also check that the file ID has not been changed # also check that the file ID has not been changed
@@ -475,7 +475,7 @@ class TestQuestionCmdRepeat(TestQuestionCmd):
model=self.args.model, model=self.args.model,
file_path=Path(self.cache_dir.name) / f'0001{msg_suffix}') file_path=Path(self.cache_dir.name) / f'0001{msg_suffix}')
chat.msg_write([message]) chat.msg_write([message])
cached_msg = chat.msg_gather(loc='cache') cached_msg = chat.msg_gather(loc=msg_location.CACHE)
assert cached_msg[0].file_path assert cached_msg[0].file_path
# repeat the last question with new arguments (without overwriting) # repeat the last question with new arguments (without overwriting)
@@ -493,7 +493,7 @@ class TestQuestionCmdRepeat(TestQuestionCmd):
tags={Tag('newtag')}, tags={Tag('newtag')},
file_path=Path('<NOT COMPARED>')) file_path=Path('<NOT COMPARED>'))
question_cmd(self.args, self.config) question_cmd(self.args, self.config)
cached_msg = chat.msg_gather(loc='cache') cached_msg = chat.msg_gather(loc=msg_location.CACHE)
self.assertEqual(len(self.message_list(self.cache_dir)), 2) self.assertEqual(len(self.message_list(self.cache_dir)), 2)
self.assert_msgs_equal_except_file_path(cached_msg, [message] + [new_expected_response]) self.assert_msgs_equal_except_file_path(cached_msg, [message] + [new_expected_response])
@@ -513,7 +513,7 @@ class TestQuestionCmdRepeat(TestQuestionCmd):
model=self.args.model, model=self.args.model,
file_path=Path(self.cache_dir.name) / f'0001{msg_suffix}') file_path=Path(self.cache_dir.name) / f'0001{msg_suffix}')
chat.msg_write([message]) chat.msg_write([message])
cached_msg = chat.msg_gather(loc='cache') cached_msg = chat.msg_gather(loc=msg_location.CACHE)
assert cached_msg[0].file_path assert cached_msg[0].file_path
# repeat the last question with new arguments # repeat the last question with new arguments
@@ -530,7 +530,7 @@ class TestQuestionCmdRepeat(TestQuestionCmd):
tags={Tag('newtag')}, tags={Tag('newtag')},
file_path=Path('<NOT COMPARED>')) file_path=Path('<NOT COMPARED>'))
question_cmd(self.args, self.config) question_cmd(self.args, self.config)
cached_msg = chat.msg_gather(loc='cache') cached_msg = chat.msg_gather(loc=msg_location.CACHE)
self.assertEqual(len(self.message_list(self.cache_dir)), 1) self.assertEqual(len(self.message_list(self.cache_dir)), 1)
self.assert_msgs_equal_except_file_path(cached_msg, [new_expected_response]) self.assert_msgs_equal_except_file_path(cached_msg, [new_expected_response])
@@ -586,8 +586,8 @@ class TestQuestionCmdRepeat(TestQuestionCmd):
self.assertEqual(len(self.message_list(self.cache_dir)), 4) self.assertEqual(len(self.message_list(self.cache_dir)), 4)
self.assertEqual(len(self.message_list(self.db_dir)), 1) self.assertEqual(len(self.message_list(self.db_dir)), 1)
expected_cache_messages = [expected_responses[0], message2, expected_responses[1], expected_responses[2]] expected_cache_messages = [expected_responses[0], message2, expected_responses[1], expected_responses[2]]
cached_msg = chat.msg_gather(loc='cache') cached_msg = chat.msg_gather(loc=msg_location.CACHE)
self.assert_msgs_equal_except_file_path(cached_msg, expected_cache_messages) self.assert_msgs_equal_except_file_path(cached_msg, expected_cache_messages)
# check that the DB message has not been modified at all # check that the DB message has not been modified at all
db_msg = chat.msg_gather(loc='db') db_msg = chat.msg_gather(loc=msg_location.DB)
self.assert_msgs_all_equal(db_msg, [message3]) self.assert_msgs_all_equal(db_msg, [message3])