diff --git a/README.md b/README.md index c661200..bfca190 100644 --- a/README.md +++ b/README.md @@ -65,6 +65,8 @@ cmm question [-t OTAGS]... [-k ATAGS]... [-x XTAGS]... [-o OUTTAGS]... [-A AI_ID * `-O, --overwrite`: Overwrite existing messages when repeating them * `-s, --source-text FILE`: Add content of a file to the query * `-S, --source-code FILE`: Add source code file content to the chat history +* `-l, --location {cache,db,all}`: Use given location when building the chat history (default: 'db') +* `-g, --glob GLOB`: Filter message files using the given glob pattern #### Hist @@ -83,6 +85,8 @@ cmm hist [--print | --convert FORMAT] [-t OTAGS]... [-k ATAGS]... [-x XTAGS]... * `-S, --source-code-only`: Only print embedded source code * `-A, --answer SUBSTRING`: Filter for answer substring * `-Q, --question SUBSTRING`: Filter for question substring +* `-l, --location {cache,db,all}`: Use given location when building the chat history (default: 'db') +* `-g, --glob GLOB`: Filter message files using the given glob pattern #### Tags diff --git a/chatmastermind/chat.py b/chatmastermind/chat.py index 1cc21ba..41a69bd 100644 --- a/chatmastermind/chat.py +++ b/chatmastermind/chat.py @@ -6,7 +6,8 @@ from pathlib import Path from pprint import PrettyPrinter from pydoc import pager from dataclasses import dataclass -from typing import TypeVar, Type, Optional, Any, Callable, Literal, Union +from enum import Enum +from typing import TypeVar, Type, Optional, Any, Callable, Union from .configuration import default_config_file from .message import Message, MessageFilter, MessageError, MessageFormat, message_in, message_valid_formats from .tags import Tag @@ -16,10 +17,17 @@ ChatDBInst = TypeVar('ChatDBInst', bound='ChatDB') db_next_file = '.next' ignored_files = [db_next_file, default_config_file] -msg_location = Literal['mem', 'disk', 'cache', 'db', 'all'] msg_suffix = Message.file_suffix_write +class msg_location(Enum): + MEM = 'mem' + DISK = 'disk' + CACHE = 'cache' + DB = 'db' + ALL = 'all' + + class ChatError(Exception): pass @@ -44,12 +52,12 @@ def read_dir(dir_path: Path, Parameters: * 'dir_path': source directory * 'glob': if specified, files will be filtered using 'path.glob()', - otherwise it uses 'path.iterdir()'. + otherwise it reads all files with the default message suffix * 'mfilter': use with 'Message.from_file()' to filter messages when reading them. """ messages: list[Message] = [] - file_iter = dir_path.glob(glob) if glob else dir_path.iterdir() + file_iter = dir_path.glob(glob) if glob else dir_path.glob(f'*{msg_suffix}') for file_path in sorted(file_iter): if (file_path.is_file() and file_path.name not in ignored_files # noqa: W503 @@ -287,7 +295,7 @@ class ChatDB(Chat): # a MessageFilter that all messages must match (if given) mfilter: Optional[MessageFilter] = None # the glob pattern for all messages - glob: Optional[str] = None + glob: str = f'*{msg_suffix}' # message format (for writing) mformat: MessageFormat = Message.default_format @@ -303,20 +311,28 @@ class ChatDB(Chat): def from_dir(cls: Type[ChatDBInst], cache_path: Path, db_path: Path, - glob: Optional[str] = None, - mfilter: Optional[MessageFilter] = None) -> ChatDBInst: + glob: str = f'*{msg_suffix}', + mfilter: Optional[MessageFilter] = None, + loc: msg_location = msg_location.DB) -> ChatDBInst: """ Create a 'ChatDB' instance from the given directory structure. Reads all messages from 'db_path' into the local message list. Parameters: * 'cache_path': path to the directory for temporary messages * 'db_path': path to the directory for persistent messages - * 'glob': if specified, files will be filtered using 'path.glob()', - otherwise it uses 'path.iterdir()'. + * 'glob': if specified, files will be filtered using 'path.glob()' * 'mfilter': use with 'Message.from_file()' to filter messages when reading them. + * 'loc': read messages from given location instead of 'db_path' """ - messages = read_dir(db_path, glob, mfilter) + if loc == msg_location.MEM: + raise ChatError(f"Can't build ChatDB from message location '{loc}'") + messages: list[Message] = [] + if loc in [msg_location.DB, msg_location.DISK, msg_location.ALL]: + messages.extend(read_dir(db_path, glob, mfilter)) + if loc in [msg_location.CACHE, msg_location.DISK, msg_location.ALL]: + messages.extend(read_dir(cache_path, glob, mfilter)) + messages.sort(key=lambda x: x.msg_id()) return cls(messages, cache_path, db_path, mfilter, glob) @classmethod @@ -386,7 +402,7 @@ class ChatDB(Chat): def msg_gather(self, loc: msg_location, require_file_path: bool = False, - glob: Optional[str] = None, + glob: str = f'*{msg_suffix}', mfilter: Optional[MessageFilter] = None) -> list[Message]: """ Gather and return messages from the given locations: @@ -399,14 +415,14 @@ class ChatDB(Chat): If 'require_file_path' is True, return only files with a valid file_path. """ loc_messages: list[Message] = [] - if loc in ['mem', 'all']: + if loc in [msg_location.MEM, msg_location.ALL]: if require_file_path: loc_messages += [m for m in self.messages if (m.file_path is not None and (mfilter is None or m.match(mfilter)))] else: loc_messages += [m for m in self.messages if (mfilter is None or m.match(mfilter))] - if loc in ['cache', 'disk', 'all']: + if loc in [msg_location.CACHE, msg_location.DISK, msg_location.ALL]: loc_messages += read_dir(self.cache_path, glob=glob, mfilter=mfilter) - if loc in ['db', 'disk', 'all']: + if loc in [msg_location.DB, msg_location.DISK, msg_location.ALL]: loc_messages += read_dir(self.db_path, glob=glob, mfilter=mfilter) # remove_duplicates and sort the list unique_messages: list[Message] = [] @@ -422,7 +438,7 @@ class ChatDB(Chat): def msg_find(self, msg_names: list[str], - loc: msg_location = 'mem', + loc: msg_location = msg_location.MEM, ) -> list[Message]: """ Search and return the messages with the given names. Names can either be filenames @@ -440,7 +456,7 @@ class ChatDB(Chat): return [m for m in loc_messages if any((m.file_path and self.msg_name_matches(m.file_path, mn)) for mn in msg_names)] - def msg_remove(self, msg_names: list[str], loc: msg_location = 'mem') -> None: + def msg_remove(self, msg_names: list[str], loc: msg_location = msg_location.MEM) -> None: """ Remove the messages with the given names. Names can either be filenames (with or without suffix), full paths or Message.msg_id(). Also deletes the @@ -452,7 +468,7 @@ class ChatDB(Chat): * 'db' : messages in the DB directory * 'all' : all messages ('mem' + 'disk') """ - if loc != 'mem': + if loc != msg_location.MEM: # delete the message files first rm_messages = self.msg_find(msg_names, loc=loc) for m in rm_messages: @@ -463,7 +479,7 @@ class ChatDB(Chat): def msg_latest(self, mfilter: Optional[MessageFilter] = None, - loc: msg_location = 'mem') -> Optional[Message]: + loc: msg_location = msg_location.MEM) -> Optional[Message]: """ Return the last added message (according to the file ID) that matches the given filter. Only consider messages with a valid file_path (except if loc is 'mem'). @@ -492,7 +508,7 @@ class ChatDB(Chat): and message.file_path.parent.samefile(self.cache_path) # noqa: W503 and message.file_path.exists()) # noqa: W503 else: - return len(self.msg_find([message], loc='cache')) > 0 + return len(self.msg_find([message], loc=msg_location.CACHE)) > 0 def msg_in_db(self, message: Union[Message, str]) -> bool: """ @@ -504,9 +520,9 @@ class ChatDB(Chat): and message.file_path.parent.samefile(self.db_path) # noqa: W503 and message.file_path.exists()) # noqa: W503 else: - return len(self.msg_find([message], loc='db')) > 0 + return len(self.msg_find([message], loc=msg_location.DB)) > 0 - def cache_read(self, glob: Optional[str] = None, mfilter: Optional[MessageFilter] = None) -> None: + def cache_read(self, glob: str = f'*{msg_suffix}', mfilter: Optional[MessageFilter] = None) -> None: """ Read messages from the cache directory. New ones are added to the internal list, existing ones are replaced. A message is determined as 'existing' if a message @@ -549,7 +565,7 @@ class ChatDB(Chat): self.messages += messages self.msg_sort() - def cache_clear(self, glob: Optional[str] = None) -> None: + def cache_clear(self, glob: str = f'*{msg_suffix}') -> None: """ Delete all message files from the cache dir and remove them from the internal list. """ @@ -569,11 +585,11 @@ class ChatDB(Chat): self.cache_write([message]) # remove the old one (if any) if old_path: - self.msg_remove([str(old_path)], loc='db') + self.msg_remove([str(old_path)], loc=msg_location.DB) # (re)add it to the internal list self.msg_add([message]) - def db_read(self, glob: Optional[str] = None, mfilter: Optional[MessageFilter] = None) -> None: + def db_read(self, glob: str = f'*{msg_suffix}', mfilter: Optional[MessageFilter] = None) -> None: """ Read messages from the DB directory. New ones are added to the internal list, existing ones are replaced. A message is determined as 'existing' if a message @@ -628,6 +644,6 @@ class ChatDB(Chat): self.db_write([message]) # remove the old one (if any) if old_path: - self.msg_remove([str(old_path)], loc='cache') + self.msg_remove([str(old_path)], loc=msg_location.CACHE) # (re)add it to the internal list self.msg_add([message]) diff --git a/chatmastermind/commands/hist.py b/chatmastermind/commands/hist.py index 8e68320..b065afb 100644 --- a/chatmastermind/commands/hist.py +++ b/chatmastermind/commands/hist.py @@ -2,7 +2,7 @@ import sys import argparse from pathlib import Path from ..configuration import Config -from ..chat import ChatDB +from ..chat import ChatDB, msg_location from ..message import MessageFilter, Message @@ -15,9 +15,10 @@ def convert_messages(args: argparse.Namespace, config: Config) -> None: ('.txt', '.yaml') to the latest default message file suffix ('.msg'). """ chat = ChatDB.from_dir(Path(config.cache), - Path(config.db)) + Path(config.db), + glob='*') # read all known message files - msgs = chat.msg_gather(loc='disk', glob='*.*') + msgs = chat.msg_gather(loc=msg_location.DISK, glob='*.*') # make a set of all message IDs msg_ids = set([m.msg_id() for m in msgs]) # set requested format and write all messages @@ -29,14 +30,14 @@ def convert_messages(args: argparse.Namespace, config: Config) -> None: m.file_path = m.file_path.with_suffix('') chat.msg_write(msgs) # read all messages with the current default suffix - msgs = chat.msg_gather(loc='disk', glob=f'*{msg_suffix}') + msgs = chat.msg_gather(loc=msg_location.DISK, glob=f'*{msg_suffix}') # make sure we converted all of the original messages for mid in msg_ids: if not any(mid == m.msg_id() for m in msgs): print(f"Message '{mid}' has not been found after conversion. Aborting.") sys.exit(1) # delete messages with old suffixes - msgs = chat.msg_gather(loc='disk', glob='*.*') + msgs = chat.msg_gather(loc=msg_location.DISK, glob='*.*') for m in msgs: if m.file_path and m.file_path.suffix != msg_suffix: m.rm_file() @@ -55,7 +56,9 @@ def print_chat(args: argparse.Namespace, config: Config) -> None: answer_contains=args.answer) chat = ChatDB.from_dir(Path(config.cache), Path(config.db), - mfilter=mfilter) + mfilter=mfilter, + loc=msg_location(args.location), + glob=args.glob) chat.print(args.source_code_only, args.with_metadata, paged=not args.no_paging, diff --git a/chatmastermind/commands/print.py b/chatmastermind/commands/print.py index d2a66a1..e1d66ca 100644 --- a/chatmastermind/commands/print.py +++ b/chatmastermind/commands/print.py @@ -3,7 +3,7 @@ import argparse from pathlib import Path from ..configuration import Config from ..message import Message, MessageError -from ..chat import ChatDB +from ..chat import ChatDB, msg_location def print_message(message: Message, args: argparse.Namespace) -> None: @@ -38,7 +38,7 @@ def print_cmd(args: argparse.Namespace, config: Config) -> None: # print latest message elif args.latest: chat = ChatDB.from_dir(Path(config.cache), Path(config.db)) - latest = chat.msg_latest(loc='disk') + latest = chat.msg_latest(loc=msg_location.DISK) if not latest: print("No message found!") sys.exit(1) diff --git a/chatmastermind/commands/question.py b/chatmastermind/commands/question.py index 4e7d8e0..cd31d54 100644 --- a/chatmastermind/commands/question.py +++ b/chatmastermind/commands/question.py @@ -4,7 +4,7 @@ from pathlib import Path from itertools import zip_longest from copy import deepcopy from ..configuration import Config -from ..chat import ChatDB +from ..chat import ChatDB, msg_location from ..message import Message, MessageFilter, MessageError, Question, source_code from ..ai_factory import create_ai from ..ai import AI, AIResponse @@ -186,7 +186,9 @@ def question_cmd(args: argparse.Namespace, config: Config) -> None: tags_not=args.exclude_tags) chat = ChatDB.from_dir(cache_path=Path(config.cache), db_path=Path(config.db), - mfilter=mfilter) + mfilter=mfilter, + glob=args.glob, + loc=msg_location(args.location)) # if it's a new question, create and store it immediately if args.ask or args.create: message = create_message(chat, args) @@ -202,14 +204,14 @@ def question_cmd(args: argparse.Namespace, config: Config) -> None: repeat_msgs: list[Message] = [] # repeat latest message if len(args.repeat) == 0: - lmessage = chat.msg_latest(loc='cache') + lmessage = chat.msg_latest(loc=msg_location.CACHE) if lmessage is None: print("No message found to repeat!") sys.exit(1) repeat_msgs.append(lmessage) # repeat given message(s) else: - repeat_msgs = chat.msg_find(args.repeat, loc='disk') + repeat_msgs = chat.msg_find(args.repeat, loc=msg_location.DISK) repeat_messages(repeat_msgs, chat, args, config) # === PROCESS === elif args.process is not None: diff --git a/chatmastermind/main.py b/chatmastermind/main.py index ff74d6c..3fba60f 100755 --- a/chatmastermind/main.py +++ b/chatmastermind/main.py @@ -14,6 +14,7 @@ from .commands.tags import tags_cmd from .commands.config import config_cmd from .commands.hist import hist_cmd from .commands.print import print_cmd +from .chat import msg_location def tags_completer(prefix: str, parsed_args: Any, **kwargs: Any) -> list[str]: @@ -65,6 +66,11 @@ def create_parser() -> argparse.ArgumentParser: question_group.add_argument('-c', '--create', nargs='+', help='Create a question', metavar='QUESTION') question_group.add_argument('-r', '--repeat', nargs='*', help='Repeat a question', metavar='MESSAGE') question_group.add_argument('-p', '--process', nargs='*', help='Process existing questions', metavar='MESSAGE') + question_cmd_parser.add_argument('-l', '--location', + choices=[x.value for x in msg_location if x not in [msg_location.MEM, msg_location.DISK]], + default='db', + help='Use given location when building the chat history (default: \'db\')') + question_cmd_parser.add_argument('-g', '--glob', help='Filter message files using the given glob pattern') question_cmd_parser.add_argument('-O', '--overwrite', help='Overwrite existing messages when repeating them', action='store_true') question_cmd_parser.add_argument('-s', '--source-text', nargs='+', help='Add content of a file to the query', metavar='FILE') @@ -87,6 +93,11 @@ def create_parser() -> argparse.ArgumentParser: hist_cmd_parser.add_argument('-Q', '--question', help='Print only questions with given substring', metavar='SUBSTRING') hist_cmd_parser.add_argument('-d', '--tight', help='Print without message separators', action='store_true') hist_cmd_parser.add_argument('-P', '--no-paging', help='Print without paging', action='store_true') + hist_cmd_parser.add_argument('-l', '--location', + choices=[x.value for x in msg_location if x not in [msg_location.MEM, msg_location.DISK]], + default='db', + help='Use given location when building the chat history (default: \'db\')') + hist_cmd_parser.add_argument('-g', '--glob', help='Filter message files using the given glob pattern') # 'tags' command parser tags_cmd_parser = cmdparser.add_parser('tags', diff --git a/tests/test_chat.py b/tests/test_chat.py index 802616d..7f68113 100644 --- a/tests/test_chat.py +++ b/tests/test_chat.py @@ -7,7 +7,7 @@ from io import StringIO from unittest.mock import patch from chatmastermind.tags import TagLine from chatmastermind.message import Message, Question, Answer, Tag, MessageFilter -from chatmastermind.chat import Chat, ChatDB, ChatError +from chatmastermind.chat import Chat, ChatDB, ChatError, msg_location msg_suffix: str = Message.file_suffix_write @@ -240,7 +240,8 @@ class TestChatDB(TestChatBase): msg_to_file_force_suffix(duplicate_message) with self.assertRaises(ChatError) as cm: ChatDB.from_dir(pathlib.Path(self.cache_path.name), - pathlib.Path(self.db_path.name)) + pathlib.Path(self.db_path.name), + glob='*') self.assertEqual(str(cm.exception), "Validation failed") def test_file_path_ID_exists(self) -> None: @@ -595,92 +596,92 @@ class TestChatDB(TestChatBase): chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name), pathlib.Path(self.db_path.name)) # search for a DB file in memory - self.assertEqual(chat_db.msg_find([str(self.message1.file_path)], loc='mem'), [self.message1]) - self.assertEqual(chat_db.msg_find([self.message1.msg_id()], loc='mem'), [self.message1]) - self.assertEqual(chat_db.msg_find(['0001.msg'], loc='mem'), [self.message1]) - self.assertEqual(chat_db.msg_find(['0001'], loc='mem'), [self.message1]) + self.assertEqual(chat_db.msg_find([str(self.message1.file_path)], loc=msg_location.MEM), [self.message1]) + self.assertEqual(chat_db.msg_find([self.message1.msg_id()], loc=msg_location.MEM), [self.message1]) + self.assertEqual(chat_db.msg_find(['0001.msg'], loc=msg_location.MEM), [self.message1]) + self.assertEqual(chat_db.msg_find(['0001'], loc=msg_location.MEM), [self.message1]) # and on disk - self.assertEqual(chat_db.msg_find([str(self.message2.file_path)], loc='db'), [self.message2]) - self.assertEqual(chat_db.msg_find([self.message2.msg_id()], loc='db'), [self.message2]) - self.assertEqual(chat_db.msg_find(['0002.msg'], loc='db'), [self.message2]) - self.assertEqual(chat_db.msg_find(['0002'], loc='db'), [self.message2]) + self.assertEqual(chat_db.msg_find([str(self.message2.file_path)], loc=msg_location.DB), [self.message2]) + self.assertEqual(chat_db.msg_find([self.message2.msg_id()], loc=msg_location.DB), [self.message2]) + self.assertEqual(chat_db.msg_find(['0002.msg'], loc=msg_location.DB), [self.message2]) + self.assertEqual(chat_db.msg_find(['0002'], loc=msg_location.DB), [self.message2]) # now search the cache -> expect empty result - self.assertEqual(chat_db.msg_find([str(self.message3.file_path)], loc='cache'), []) - self.assertEqual(chat_db.msg_find([self.message3.msg_id()], loc='cache'), []) - self.assertEqual(chat_db.msg_find(['0003.msg'], loc='cache'), []) - self.assertEqual(chat_db.msg_find(['0003'], loc='cache'), []) + self.assertEqual(chat_db.msg_find([str(self.message3.file_path)], loc=msg_location.CACHE), []) + self.assertEqual(chat_db.msg_find([self.message3.msg_id()], loc=msg_location.CACHE), []) + self.assertEqual(chat_db.msg_find(['0003.msg'], loc=msg_location.CACHE), []) + self.assertEqual(chat_db.msg_find(['0003'], loc=msg_location.CACHE), []) # search for multiple messages # -> search one twice, expect result to be unique search_names = ['0001', '0002.msg', self.message3.msg_id(), str(self.message3.file_path)] expected_result = [self.message1, self.message2, self.message3] - result = chat_db.msg_find(search_names, loc='all') + result = chat_db.msg_find(search_names, loc=msg_location.ALL) self.assert_messages_equal(result, expected_result) def test_msg_latest(self) -> None: chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name), pathlib.Path(self.db_path.name)) - self.assertEqual(chat_db.msg_latest(loc='mem'), self.message4) - self.assertEqual(chat_db.msg_latest(loc='db'), self.message4) - self.assertEqual(chat_db.msg_latest(loc='disk'), self.message4) - self.assertEqual(chat_db.msg_latest(loc='all'), self.message4) + self.assertEqual(chat_db.msg_latest(loc=msg_location.MEM), self.message4) + self.assertEqual(chat_db.msg_latest(loc=msg_location.DB), self.message4) + self.assertEqual(chat_db.msg_latest(loc=msg_location.DISK), self.message4) + self.assertEqual(chat_db.msg_latest(loc=msg_location.ALL), self.message4) # the cache is currently empty: - self.assertIsNone(chat_db.msg_latest(loc='cache')) + self.assertIsNone(chat_db.msg_latest(loc=msg_location.CACHE)) # add new messages to the cache dir new_message = Message(question=Question("New Question"), answer=Answer("New Answer")) chat_db.cache_add([new_message]) - self.assertEqual(chat_db.msg_latest(loc='cache'), new_message) - self.assertEqual(chat_db.msg_latest(loc='mem'), new_message) - self.assertEqual(chat_db.msg_latest(loc='disk'), new_message) - self.assertEqual(chat_db.msg_latest(loc='all'), new_message) + self.assertEqual(chat_db.msg_latest(loc=msg_location.CACHE), new_message) + self.assertEqual(chat_db.msg_latest(loc=msg_location.MEM), new_message) + self.assertEqual(chat_db.msg_latest(loc=msg_location.DISK), new_message) + self.assertEqual(chat_db.msg_latest(loc=msg_location.ALL), new_message) # the DB does not contain the new message - self.assertEqual(chat_db.msg_latest(loc='db'), self.message4) + self.assertEqual(chat_db.msg_latest(loc=msg_location.DB), self.message4) def test_msg_gather(self) -> None: chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name), pathlib.Path(self.db_path.name)) all_messages = [self.message1, self.message2, self.message3, self.message4] - self.assert_messages_equal(chat_db.msg_gather(loc='all'), all_messages) - self.assert_messages_equal(chat_db.msg_gather(loc='db'), all_messages) - self.assert_messages_equal(chat_db.msg_gather(loc='mem'), all_messages) - self.assert_messages_equal(chat_db.msg_gather(loc='disk'), all_messages) - self.assert_messages_equal(chat_db.msg_gather(loc='cache'), []) + self.assert_messages_equal(chat_db.msg_gather(loc=msg_location.ALL), all_messages) + self.assert_messages_equal(chat_db.msg_gather(loc=msg_location.DB), all_messages) + self.assert_messages_equal(chat_db.msg_gather(loc=msg_location.MEM), all_messages) + self.assert_messages_equal(chat_db.msg_gather(loc=msg_location.DISK), all_messages) + self.assert_messages_equal(chat_db.msg_gather(loc=msg_location.CACHE), []) # add a new message, but only to the internal list new_message = Message(Question("What?")) all_messages_mem = all_messages + [new_message] chat_db.msg_add([new_message]) - self.assert_messages_equal(chat_db.msg_gather(loc='mem'), all_messages_mem) - self.assert_messages_equal(chat_db.msg_gather(loc='all'), all_messages_mem) + self.assert_messages_equal(chat_db.msg_gather(loc=msg_location.MEM), all_messages_mem) + self.assert_messages_equal(chat_db.msg_gather(loc=msg_location.ALL), all_messages_mem) # the nr. of messages on disk did not change -> expect old result - self.assert_messages_equal(chat_db.msg_gather(loc='db'), all_messages) - self.assert_messages_equal(chat_db.msg_gather(loc='disk'), all_messages) - self.assert_messages_equal(chat_db.msg_gather(loc='cache'), []) + self.assert_messages_equal(chat_db.msg_gather(loc=msg_location.DB), all_messages) + self.assert_messages_equal(chat_db.msg_gather(loc=msg_location.DISK), all_messages) + self.assert_messages_equal(chat_db.msg_gather(loc=msg_location.CACHE), []) # test with MessageFilter - self.assert_messages_equal(chat_db.msg_gather(loc='all', mfilter=MessageFilter(tags_or={Tag('tag1')})), + self.assert_messages_equal(chat_db.msg_gather(loc=msg_location.ALL, mfilter=MessageFilter(tags_or={Tag('tag1')})), [self.message1]) - self.assert_messages_equal(chat_db.msg_gather(loc='disk', mfilter=MessageFilter(tags_or={Tag('tag2')})), + self.assert_messages_equal(chat_db.msg_gather(loc=msg_location.DISK, mfilter=MessageFilter(tags_or={Tag('tag2')})), [self.message2]) - self.assert_messages_equal(chat_db.msg_gather(loc='cache', mfilter=MessageFilter(tags_or={Tag('tag3')})), + self.assert_messages_equal(chat_db.msg_gather(loc=msg_location.CACHE, mfilter=MessageFilter(tags_or={Tag('tag3')})), []) - self.assert_messages_equal(chat_db.msg_gather(loc='mem', mfilter=MessageFilter(question_contains="What")), + self.assert_messages_equal(chat_db.msg_gather(loc=msg_location.MEM, mfilter=MessageFilter(question_contains="What")), [new_message]) def test_msg_move_and_gather(self) -> None: chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name), pathlib.Path(self.db_path.name)) all_messages = [self.message1, self.message2, self.message3, self.message4] - self.assert_messages_equal(chat_db.msg_gather(loc='db'), all_messages) - self.assert_messages_equal(chat_db.msg_gather(loc='cache'), []) + self.assert_messages_equal(chat_db.msg_gather(loc=msg_location.DB), all_messages) + self.assert_messages_equal(chat_db.msg_gather(loc=msg_location.CACHE), []) # move first message to the cache chat_db.cache_move(self.message1) - self.assert_messages_equal(chat_db.msg_gather(loc='cache'), [self.message1]) + self.assert_messages_equal(chat_db.msg_gather(loc=msg_location.CACHE), [self.message1]) self.assertEqual(self.message1.file_path.parent, pathlib.Path(self.cache_path.name)) # type: ignore [union-attr] - self.assert_messages_equal(chat_db.msg_gather(loc='db'), [self.message2, self.message3, self.message4]) - self.assert_messages_equal(chat_db.msg_gather(loc='all'), all_messages) - self.assert_messages_equal(chat_db.msg_gather(loc='disk'), all_messages) - self.assert_messages_equal(chat_db.msg_gather(loc='mem'), all_messages) + self.assert_messages_equal(chat_db.msg_gather(loc=msg_location.DB), [self.message2, self.message3, self.message4]) + self.assert_messages_equal(chat_db.msg_gather(loc=msg_location.ALL), all_messages) + self.assert_messages_equal(chat_db.msg_gather(loc=msg_location.DISK), all_messages) + self.assert_messages_equal(chat_db.msg_gather(loc=msg_location.MEM), all_messages) # now move first message back to the DB chat_db.db_move(self.message1) - self.assert_messages_equal(chat_db.msg_gather(loc='cache'), []) + self.assert_messages_equal(chat_db.msg_gather(loc=msg_location.CACHE), []) self.assertEqual(self.message1.file_path.parent, pathlib.Path(self.db_path.name)) # type: ignore [union-attr] - self.assert_messages_equal(chat_db.msg_gather(loc='db'), all_messages) + self.assert_messages_equal(chat_db.msg_gather(loc=msg_location.DB), all_messages) diff --git a/tests/test_hist_cmd.py b/tests/test_hist_cmd.py index cfd1c1f..2a8cbc1 100644 --- a/tests/test_hist_cmd.py +++ b/tests/test_hist_cmd.py @@ -4,7 +4,7 @@ import tempfile import yaml from pathlib import Path from chatmastermind.message import Message, Question -from chatmastermind.chat import ChatDB, ChatError +from chatmastermind.chat import ChatDB, ChatError, msg_location from chatmastermind.configuration import Config from chatmastermind.commands.hist import convert_messages @@ -41,7 +41,7 @@ class TestConvertMessages(unittest.TestCase): def test_convert_messages(self) -> None: self.args.convert = 'yaml' convert_messages(self.args, self.config) - msgs = self.chat.msg_gather(loc='disk', glob='*.*') + msgs = self.chat.msg_gather(loc=msg_location.DISK, glob='*.*') # Check if the number of messages is the same as before self.assertEqual(len(msgs), len(self.messages)) # Check if all messages have the requested suffix diff --git a/tests/test_question_cmd.py b/tests/test_question_cmd.py index be31b58..b809e4b 100644 --- a/tests/test_question_cmd.py +++ b/tests/test_question_cmd.py @@ -9,7 +9,7 @@ from chatmastermind.configuration import Config from chatmastermind.commands.question import create_message, question_cmd from chatmastermind.tags import Tag from chatmastermind.message import Message, Question, Answer -from chatmastermind.chat import Chat, ChatDB +from chatmastermind.chat import Chat, ChatDB, msg_location from chatmastermind.ai import AIError from .test_common import TestWithFakeAI @@ -234,6 +234,8 @@ class TestQuestionCmd(TestWithFakeAI): # create a mock argparse.Namespace self.args = argparse.Namespace( ask=['What is the meaning of life?'], + glob=None, + location='db', num_answers=1, output_tags=['science'], AI='FakeAI', @@ -279,7 +281,7 @@ class TestQuestionCmdAsk(TestQuestionCmd): # check for the expected message files chat = ChatDB.from_dir(Path(self.cache_dir.name), Path(self.db_dir.name)) - cached_msg = chat.msg_gather(loc='cache') + cached_msg = chat.msg_gather(loc=msg_location.CACHE) self.assertEqual(len(self.message_list(self.cache_dir)), 1) self.assert_msgs_equal_except_file_path(cached_msg, expected_responses) @@ -337,7 +339,7 @@ class TestQuestionCmdAsk(TestQuestionCmd): # check for the expected message files chat = ChatDB.from_dir(Path(self.cache_dir.name), Path(self.db_dir.name)) - cached_msg = chat.msg_gather(loc='cache') + cached_msg = chat.msg_gather(loc=msg_location.CACHE) self.assertEqual(len(self.message_list(self.cache_dir)), 1) self.assert_msgs_equal_except_file_path(cached_msg, [expected_question]) @@ -375,7 +377,7 @@ class TestQuestionCmdRepeat(TestQuestionCmd): # we expect the original message + the one with the new response expected_responses = [message] + [expected_response] question_cmd(self.args, self.config) - cached_msg = chat.msg_gather(loc='cache') + cached_msg = chat.msg_gather(loc=msg_location.CACHE) print(self.message_list(self.cache_dir)) self.assertEqual(len(self.message_list(self.cache_dir)), 2) self.assert_msgs_equal_except_file_path(cached_msg, expected_responses) @@ -396,7 +398,7 @@ class TestQuestionCmdRepeat(TestQuestionCmd): model=self.args.model, file_path=Path(self.cache_dir.name) / f'0001{msg_suffix}') chat.msg_write([message]) - cached_msg = chat.msg_gather(loc='cache') + cached_msg = chat.msg_gather(loc=msg_location.CACHE) assert cached_msg[0].file_path cached_msg_file_id = cached_msg[0].file_path.stem @@ -412,7 +414,7 @@ class TestQuestionCmdRepeat(TestQuestionCmd): tags=message.tags, file_path=Path('')) question_cmd(self.args, self.config) - cached_msg = chat.msg_gather(loc='cache') + cached_msg = chat.msg_gather(loc=msg_location.CACHE) self.assertEqual(len(self.message_list(self.cache_dir)), 1) self.assert_msgs_equal_except_file_path(cached_msg, [expected_response]) # also check that the file ID has not been changed @@ -435,7 +437,7 @@ class TestQuestionCmdRepeat(TestQuestionCmd): model=self.args.model, file_path=Path(self.cache_dir.name) / f'0001{msg_suffix}') chat.msg_write([message]) - cached_msg = chat.msg_gather(loc='cache') + cached_msg = chat.msg_gather(loc=msg_location.CACHE) assert cached_msg[0].file_path cached_msg_file_id = cached_msg[0].file_path.stem @@ -452,7 +454,7 @@ class TestQuestionCmdRepeat(TestQuestionCmd): tags=message.tags, file_path=Path('')) question_cmd(self.args, self.config) - cached_msg = chat.msg_gather(loc='cache') + cached_msg = chat.msg_gather(loc=msg_location.CACHE) self.assertEqual(len(self.message_list(self.cache_dir)), 1) self.assert_msgs_equal_except_file_path(cached_msg, [expected_response]) # also check that the file ID has not been changed @@ -475,7 +477,7 @@ class TestQuestionCmdRepeat(TestQuestionCmd): model=self.args.model, file_path=Path(self.cache_dir.name) / f'0001{msg_suffix}') chat.msg_write([message]) - cached_msg = chat.msg_gather(loc='cache') + cached_msg = chat.msg_gather(loc=msg_location.CACHE) assert cached_msg[0].file_path # repeat the last question with new arguments (without overwriting) @@ -493,7 +495,7 @@ class TestQuestionCmdRepeat(TestQuestionCmd): tags={Tag('newtag')}, file_path=Path('')) question_cmd(self.args, self.config) - cached_msg = chat.msg_gather(loc='cache') + cached_msg = chat.msg_gather(loc=msg_location.CACHE) self.assertEqual(len(self.message_list(self.cache_dir)), 2) self.assert_msgs_equal_except_file_path(cached_msg, [message] + [new_expected_response]) @@ -513,7 +515,7 @@ class TestQuestionCmdRepeat(TestQuestionCmd): model=self.args.model, file_path=Path(self.cache_dir.name) / f'0001{msg_suffix}') chat.msg_write([message]) - cached_msg = chat.msg_gather(loc='cache') + cached_msg = chat.msg_gather(loc=msg_location.CACHE) assert cached_msg[0].file_path # repeat the last question with new arguments @@ -530,7 +532,7 @@ class TestQuestionCmdRepeat(TestQuestionCmd): tags={Tag('newtag')}, file_path=Path('')) question_cmd(self.args, self.config) - cached_msg = chat.msg_gather(loc='cache') + cached_msg = chat.msg_gather(loc=msg_location.CACHE) self.assertEqual(len(self.message_list(self.cache_dir)), 1) self.assert_msgs_equal_except_file_path(cached_msg, [new_expected_response]) @@ -586,8 +588,8 @@ class TestQuestionCmdRepeat(TestQuestionCmd): self.assertEqual(len(self.message_list(self.cache_dir)), 4) self.assertEqual(len(self.message_list(self.db_dir)), 1) expected_cache_messages = [expected_responses[0], message2, expected_responses[1], expected_responses[2]] - cached_msg = chat.msg_gather(loc='cache') + cached_msg = chat.msg_gather(loc=msg_location.CACHE) self.assert_msgs_equal_except_file_path(cached_msg, expected_cache_messages) # check that the DB message has not been modified at all - db_msg = chat.msg_gather(loc='db') + db_msg = chat.msg_gather(loc=msg_location.DB) self.assert_msgs_all_equal(db_msg, [message3])