From 9ca9a23569dc767cc9194648ab15ba5e2d627bf3 Mon Sep 17 00:00:00 2001 From: juk0de Date: Sun, 24 Sep 2023 18:20:38 +0200 Subject: [PATCH 01/16] message: introduced file suffix '.msg' - '.msg' suffix is always used for writing - 'Message.to_file()' will set the file suffix if the given file_path has none - added 'mformat' argument to 'Message.to_file()' for choosing the file format - '.txt' and '.yaml' suffixes are only supported for reading --- chatmastermind/message.py | 75 +++++++++++++++++++++------------------ 1 file changed, 41 insertions(+), 34 deletions(-) diff --git a/chatmastermind/message.py b/chatmastermind/message.py index 443455e..04def6d 100644 --- a/chatmastermind/message.py +++ b/chatmastermind/message.py @@ -5,7 +5,8 @@ import pathlib import yaml import tempfile import shutil -from typing import Type, TypeVar, ClassVar, Optional, Any, Union, Final, Literal, Iterable +from typing import Type, TypeVar, ClassVar, Optional, Any, Union, Final, Literal, Iterable, Tuple +from typing import get_args as typing_get_args from dataclasses import dataclass, asdict, field from .tags import Tag, TagLine, TagError, match_tags, rename_tags @@ -15,6 +16,9 @@ MessageInst = TypeVar('MessageInst', bound='Message') AILineInst = TypeVar('AILineInst', bound='AILine') ModelLineInst = TypeVar('ModelLineInst', bound='ModelLine') YamlDict = dict[str, Union[QuestionInst, AnswerInst, set[Tag]]] +MessageFormat = Literal['txt', 'yaml'] +message_valid_formats: Final[Tuple[MessageFormat, ...]] = typing_get_args(MessageFormat) +message_default_format: Final[MessageFormat] = 'txt' class MessageError(Exception): @@ -92,7 +96,7 @@ class MessageFilter: class AILine(str): """ - A line that represents the AI name in a '.txt' file.. + A line that represents the AI name in the 'txt' format. """ prefix: Final[str] = 'AI:' @@ -112,7 +116,7 @@ class AILine(str): class ModelLine(str): """ - A line that represents the model name in a '.txt' file.. + A line that represents the model name in the 'txt' format. """ prefix: Final[str] = 'MODEL:' @@ -216,7 +220,9 @@ class Message(): model: Optional[str] = field(default=None, compare=False) file_path: Optional[pathlib.Path] = field(default=None, compare=False) # class variables - file_suffixes: ClassVar[list[str]] = ['.txt', '.yaml'] + file_suffixes_read: ClassVar[list[str]] = ['.msg', '.txt', '.yaml'] + file_suffix_write: ClassVar[str] = '.msg' + default_format: ClassVar[MessageFormat] = message_default_format tags_yaml_key: ClassVar[str] = 'tags' file_yaml_key: ClassVar[str] = 'file_path' ai_yaml_key: ClassVar[str] = 'ai' @@ -276,24 +282,16 @@ class Message(): tags: set[Tag] = set() if not file_path.exists(): raise MessageError(f"Message file '{file_path}' does not exist") - if file_path.suffix not in cls.file_suffixes: + if file_path.suffix not in cls.file_suffixes_read: raise MessageError(f"File type '{file_path.suffix}' is not supported") - # for TXT, it's enough to read the TagLine - if file_path.suffix == '.txt': - with open(file_path, "r") as fd: - try: - tags = TagLine(fd.readline()).tags(prefix, contain) - except TagError: - pass # message without tags - else: # '.yaml' - try: - message = cls.from_file(file_path) - if message: - msg_tags = message.filter_tags(prefix=prefix, contain=contain) - except MessageError as e: - print(f"Error processing message in '{file_path}': {str(e)}") - if msg_tags: - tags = msg_tags + try: + message = cls.from_file(file_path) + if message: + msg_tags = message.filter_tags(prefix=prefix, contain=contain) + except MessageError as e: + print(f"Error processing message in '{file_path}': {str(e)}") + if msg_tags: + tags = msg_tags return tags @classmethod @@ -328,15 +326,16 @@ class Message(): """ if not file_path.exists(): raise MessageError(f"Message file '{file_path}' does not exist") - if file_path.suffix not in cls.file_suffixes: + if file_path.suffix not in cls.file_suffixes_read: raise MessageError(f"File type '{file_path.suffix}' is not supported") - - if file_path.suffix == '.txt': + # try TXT first + try: message = cls.__from_file_txt(file_path, mfilter.tags_or if mfilter else None, mfilter.tags_and if mfilter else None, mfilter.tags_not if mfilter else None) - else: + # then YAML + except MessageError: message = cls.__from_file_yaml(file_path) if message and (mfilter is None or message.match(mfilter)): return message @@ -442,21 +441,29 @@ class Message(): output.append(self.answer) return '\n'.join(output) - def to_file(self, file_path: Optional[pathlib.Path]=None) -> None: # noqa: 11 + def to_file(self, file_path: Optional[pathlib.Path]=None, mformat: MessageFormat = message_default_format) -> None: # noqa: 11 """ - Write a Message to the given file. Type is determined based on the suffix. - Currently supported suffixes: ['.txt', '.yaml'] + Write a Message to the given file. Supported message file formats are 'txt' and 'yaml'. + Suffix is always '.msg'. """ if file_path: self.file_path = file_path if not self.file_path: raise MessageError("Got no valid path to write message") - if self.file_path.suffix not in self.file_suffixes: - raise MessageError(f"File type '{self.file_path.suffix}' is not supported") + if mformat not in message_valid_formats: + raise MessageError(f"File format '{mformat}' is not supported") + # check for valid suffix + # -> add one if it's empty + # -> refuse old or otherwise unsupported suffixes + if not self.file_path.suffix: + self.file_path = self.file_path.with_suffix(self.file_suffix_write) + elif self.file_path.suffix != self.file_suffix_write: + raise MessageError(f"File suffix '{self.file_path.suffix}' is not supported") # TXT - if self.file_path.suffix == '.txt': + if mformat == 'txt': return self.__to_file_txt(self.file_path) - elif self.file_path.suffix == '.yaml': + # YAML + elif mformat == 'yaml': return self.__to_file_yaml(self.file_path) def __to_file_txt(self, file_path: pathlib.Path) -> None: @@ -468,8 +475,8 @@ class Message(): * Model [Optional] * Question.txt_header * Question - * Answer.txt_header - * Answer + * Answer.txt_header [Optional] + * Answer [Optional] """ with tempfile.NamedTemporaryFile(dir=file_path.parent, prefix=file_path.name, mode="w", delete=False) as temp_fd: temp_file_path = pathlib.Path(temp_fd.name) -- 2.36.6 From d2be53aeab003f2cb425c31361cc6532f70caca4 Mon Sep 17 00:00:00 2001 From: juk0de Date: Mon, 25 Sep 2023 09:18:19 +0200 Subject: [PATCH 02/16] chat: switched to new message suffix and formats - no longer using file suffix to choose the format - added 'mformat' argument to 'write_xxx()' functions - file suffix is now set by 'Message.to_file()' per default --- chatmastermind/chat.py | 48 ++++++++++++++++++------------------------ 1 file changed, 20 insertions(+), 28 deletions(-) diff --git a/chatmastermind/chat.py b/chatmastermind/chat.py index 17e5c38..63a5e7f 100644 --- a/chatmastermind/chat.py +++ b/chatmastermind/chat.py @@ -6,9 +6,9 @@ from pathlib import Path from pprint import PrettyPrinter from pydoc import pager from dataclasses import dataclass -from typing import TypeVar, Type, Optional, ClassVar, Any, Callable, Literal, Union +from typing import TypeVar, Type, Optional, Any, Callable, Literal, Union from .configuration import default_config_file -from .message import Message, MessageFilter, MessageError, message_in +from .message import Message, MessageFilter, MessageError, MessageFormat, message_in from .tags import Tag ChatInst = TypeVar('ChatInst', bound='Chat') @@ -17,6 +17,7 @@ ChatDBInst = TypeVar('ChatDBInst', bound='ChatDB') db_next_file = '.next' ignored_files = [db_next_file, default_config_file] msg_location = Literal['mem', 'disk', 'cache', 'db', 'all'] +msg_suffix = Message.file_suffix_write class ChatError(Exception): @@ -52,7 +53,7 @@ def read_dir(dir_path: Path, for file_path in sorted(file_iter): if (file_path.is_file() and file_path.name not in ignored_files # noqa: W503 - and file_path.suffix in Message.file_suffixes): # noqa: W503 + and file_path.suffix in Message.file_suffixes_read): # noqa: W503 try: message = Message.from_file(file_path, mfilter) if message: @@ -63,22 +64,20 @@ def read_dir(dir_path: Path, def make_file_path(dir_path: Path, - file_suffix: str, next_fid: Callable[[], int]) -> Path: """ - Create a file_path for the given directory using the - given file_suffix and ID generator function. + Create a file_path for the given directory using the given ID generator function. """ - file_path = dir_path / f"{next_fid():04d}{file_suffix}" + file_path = dir_path / f"{next_fid():04d}{msg_suffix}" while file_path.exists(): - file_path = dir_path / f"{next_fid():04d}{file_suffix}" + file_path = dir_path / f"{next_fid():04d}{msg_suffix}" return file_path def write_dir(dir_path: Path, messages: list[Message], - file_suffix: str, - next_fid: Callable[[], int]) -> None: + next_fid: Callable[[], int], + mformat: MessageFormat = Message.default_format) -> None: """ Write all messages to the given directory. If a message has no file_path, a new one will be created. If message.file_path exists, it will be modified @@ -86,18 +85,17 @@ def write_dir(dir_path: Path, Parameters: * 'dir_path': destination directory * 'messages': list of messages to write - * 'file_suffix': suffix for the message files ['.txt'|'.yaml'] * 'next_fid': callable that returns the next file ID """ for message in messages: file_path = message.file_path # message has no file_path: create one if not file_path: - file_path = make_file_path(dir_path, file_suffix, next_fid) + file_path = make_file_path(dir_path, next_fid) # file_path does not point to given directory: modify it elif not file_path.parent.samefile(dir_path): file_path = dir_path / file_path.name - message.to_file(file_path) + message.to_file(file_path, mformat=mformat) def clear_dir(dir_path: Path, @@ -109,7 +107,7 @@ def clear_dir(dir_path: Path, for file_path in file_iter: if (file_path.is_file() and file_path.name not in ignored_files # noqa: W503 - and file_path.suffix in Message.file_suffixes): # noqa: W503 + and file_path.suffix in Message.file_suffixes_read): # noqa: W503 file_path.unlink(missing_ok=True) @@ -146,7 +144,7 @@ class Chat: Matching is True if: * 'name' matches the full 'file_path' * 'name' matches 'file_path.name' (i. e. including the suffix) - * 'name' matches 'file_path.stem' (i. e. without a suffix) + * 'name' matches 'file_path.stem' (i. e. without the suffix) """ return Path(name) == file_path or name == file_path.name or name == file_path.stem @@ -281,13 +279,10 @@ class ChatDB(Chat): persistently. """ - default_file_suffix: ClassVar[str] = '.txt' - cache_path: Path db_path: Path # a MessageFilter that all messages must match (if given) mfilter: Optional[MessageFilter] = None - file_suffix: str = default_file_suffix # the glob pattern for all messages glob: Optional[str] = None @@ -317,8 +312,7 @@ class ChatDB(Chat): when reading them. """ messages = read_dir(db_path, glob, mfilter) - return cls(messages, cache_path, db_path, mfilter, - cls.default_file_suffix, glob) + return cls(messages, cache_path, db_path, mfilter, glob) @classmethod def from_messages(cls: Type[ChatDBInst], @@ -345,7 +339,9 @@ class ChatDB(Chat): with open(self.next_path, 'w') as f: f.write(f'{fid}') - def msg_write(self, messages: Optional[list[Message]] = None) -> None: + def msg_write(self, + messages: Optional[list[Message]] = None, + mformat: MessageFormat = Message.default_format) -> None: """ Write either the given messages or the internal ones to their CURRENT file_path. If messages are given, they all must have a valid file_path. When writing the @@ -356,7 +352,7 @@ class ChatDB(Chat): raise ChatError("Can't write files without a valid file_path") msgs = iter(messages if messages else self.messages) while (m := next(msgs, None)): - m.to_file() + m.to_file(mformat=mformat) def msg_update(self, messages: list[Message], write: bool = True) -> None: """ @@ -518,7 +514,6 @@ class ChatDB(Chat): """ write_dir(self.cache_path, messages if messages else self.messages, - self.file_suffix, self.get_next_fid) def cache_add(self, messages: list[Message], write: bool = True) -> None: @@ -531,11 +526,10 @@ class ChatDB(Chat): if write: write_dir(self.cache_path, messages, - self.file_suffix, self.get_next_fid) else: for m in messages: - m.file_path = make_file_path(self.cache_path, self.default_file_suffix, self.get_next_fid) + m.file_path = make_file_path(self.cache_path, self.get_next_fid) self.messages += messages self.msg_sort() @@ -585,7 +579,6 @@ class ChatDB(Chat): """ write_dir(self.db_path, messages if messages else self.messages, - self.file_suffix, self.get_next_fid) def db_add(self, messages: list[Message], write: bool = True) -> None: @@ -598,11 +591,10 @@ class ChatDB(Chat): if write: write_dir(self.db_path, messages, - self.file_suffix, self.get_next_fid) else: for m in messages: - m.file_path = make_file_path(self.db_path, self.default_file_suffix, self.get_next_fid) + m.file_path = make_file_path(self.db_path, self.get_next_fid) self.messages += messages self.msg_sort() -- 2.36.6 From b8681e82741b0e45e56a8c18c4b5fb94462ad2c9 Mon Sep 17 00:00:00 2001 From: juk0de Date: Tue, 26 Sep 2023 10:11:27 +0200 Subject: [PATCH 03/16] message: fixed tag matching for YAML file format --- chatmastermind/message.py | 32 +++++++++++++++++--------------- 1 file changed, 17 insertions(+), 15 deletions(-) diff --git a/chatmastermind/message.py b/chatmastermind/message.py index 04def6d..d88ac5c 100644 --- a/chatmastermind/message.py +++ b/chatmastermind/message.py @@ -372,10 +372,6 @@ class Message(): tags = TagLine(fd.readline()).tags() except TagError: fd.seek(pos) - if tags_or or tags_and or tags_not: - # match with an empty set if the file has no tags - if not match_tags(tags, tags_or, tags_and, tags_not): - return None # AILine (Optional) try: pos = fd.tell() @@ -390,17 +386,23 @@ class Message(): fd.seek(pos) # Question and Answer text = fd.read().strip().split('\n') - try: - question_idx = text.index(Question.txt_header) + 1 - except ValueError: - raise MessageError(f"'{file_path}' does not contain a valid message") - try: - answer_idx = text.index(Answer.txt_header) - question = Question.from_list(text[question_idx:answer_idx]) - answer = Answer.from_list(text[answer_idx + 1:]) - except ValueError: - question = Question.from_list(text[question_idx:]) - return cls(question, answer, tags, ai, model, file_path) + try: + question_idx = text.index(Question.txt_header) + 1 + except ValueError: + raise MessageError(f"'{file_path}' does not contain a valid message") + try: + answer_idx = text.index(Answer.txt_header) + question = Question.from_list(text[question_idx:answer_idx]) + answer = Answer.from_list(text[answer_idx + 1:]) + except ValueError: + question = Question.from_list(text[question_idx:]) + # match tags AFTER reading the whole file + # -> make sure it's a valid 'txt' file format + if tags_or or tags_and or tags_not: + # match with an empty set if the file has no tags + if not match_tags(tags, tags_or, tags_and, tags_not): + return None + return cls(question, answer, tags, ai, model, file_path) @classmethod def __from_file_yaml(cls: Type[MessageInst], file_path: pathlib.Path) -> MessageInst: -- 2.36.6 From d07fd13e8e04188de9746ad18853f1e913413b8c Mon Sep 17 00:00:00 2001 From: juk0de Date: Tue, 26 Sep 2023 10:12:14 +0200 Subject: [PATCH 04/16] test_message: changed all tests to use the new '.msg' suffix --- tests/test_message.py | 182 ++++++++++++++++++++++++++---------------- 1 file changed, 111 insertions(+), 71 deletions(-) diff --git a/tests/test_message.py b/tests/test_message.py index 5c7997f..e486ce1 100644 --- a/tests/test_message.py +++ b/tests/test_message.py @@ -1,11 +1,16 @@ import unittest import pathlib import tempfile +import itertools from typing import cast -from chatmastermind.message import source_code, Message, MessageError, Question, Answer, AILine, ModelLine, MessageFilter, message_in +from chatmastermind.message import source_code, Message, MessageError, Question, Answer, AILine, ModelLine,\ + MessageFilter, message_in, message_valid_formats from chatmastermind.tags import Tag, TagLine +msg_suffix: str = Message.file_suffix_write + + class SourceCodeTestCase(unittest.TestCase): def test_source_code_with_include_delims(self) -> None: text = """ @@ -101,7 +106,7 @@ class AnswerTestCase(unittest.TestCase): class MessageToFileTxtTestCase(unittest.TestCase): def setUp(self) -> None: - self.file = tempfile.NamedTemporaryFile(delete=False, suffix='.txt') + self.file = tempfile.NamedTemporaryFile(delete=False, suffix=msg_suffix) self.file_path = pathlib.Path(self.file.name) self.message_complete = Message(Question('This is a question.'), Answer('This is an answer.'), @@ -117,7 +122,7 @@ class MessageToFileTxtTestCase(unittest.TestCase): self.file_path.unlink() def test_to_file_txt_complete(self) -> None: - self.message_complete.to_file(self.file_path) + self.message_complete.to_file(self.file_path, mformat='txt') with open(self.file_path, "r") as fd: content = fd.read() @@ -132,7 +137,7 @@ This is an answer. self.assertEqual(content, expected_content) def test_to_file_txt_min(self) -> None: - self.message_min.to_file(self.file_path) + self.message_min.to_file(self.file_path, mformat='txt') with open(self.file_path, "r") as fd: content = fd.read() @@ -141,11 +146,17 @@ This is a question. """ self.assertEqual(content, expected_content) - def test_to_file_unsupported_file_type(self) -> None: + def test_to_file_unsupported_file_suffix(self) -> None: unsupported_file_path = pathlib.Path("example.doc") with self.assertRaises(MessageError) as cm: self.message_complete.to_file(unsupported_file_path) - self.assertEqual(str(cm.exception), "File type '.doc' is not supported") + self.assertEqual(str(cm.exception), "File suffix '.doc' is not supported") + + def test_to_file_unsupported_file_format(self) -> None: + unsupported_file_format = pathlib.Path(f"example{msg_suffix}") + with self.assertRaises(MessageError) as cm: + self.message_complete.to_file(unsupported_file_format, mformat='doc') # type: ignore [arg-type] + self.assertEqual(str(cm.exception), "File format 'doc' is not supported") def test_to_file_no_file_path(self) -> None: """ @@ -159,10 +170,24 @@ This is a question. # reset the internal file_path self.message_complete.file_path = self.file_path + def test_to_file_txt_auto_suffix(self) -> None: + """ + Test if suffix is auto-generated if omitted. + """ + file_path_no_suffix = self.file_path.with_suffix('') + # test with file_path member + self.message_min.file_path = file_path_no_suffix + self.message_min.to_file(mformat='txt') + self.assertEqual(self.message_min.file_path.suffix, msg_suffix) + # test with explicit file_path + self.message_min.file_path = file_path_no_suffix + self.message_min.to_file(file_path=file_path_no_suffix, mformat='txt') + self.assertEqual(self.message_min.file_path.suffix, msg_suffix) + class MessageToFileYamlTestCase(unittest.TestCase): def setUp(self) -> None: - self.file = tempfile.NamedTemporaryFile(delete=False, suffix='.yaml') + self.file = tempfile.NamedTemporaryFile(delete=False, suffix=msg_suffix) self.file_path = pathlib.Path(self.file.name) self.message_complete = Message(Question('This is a question.'), Answer('This is an answer.'), @@ -184,7 +209,7 @@ class MessageToFileYamlTestCase(unittest.TestCase): self.file_path.unlink() def test_to_file_yaml_complete(self) -> None: - self.message_complete.to_file(self.file_path) + self.message_complete.to_file(self.file_path, mformat='yaml') with open(self.file_path, "r") as fd: content = fd.read() @@ -199,7 +224,7 @@ class MessageToFileYamlTestCase(unittest.TestCase): self.assertEqual(content, expected_content) def test_to_file_yaml_multiline(self) -> None: - self.message_multiline.to_file(self.file_path) + self.message_multiline.to_file(self.file_path, mformat='yaml') with open(self.file_path, "r") as fd: content = fd.read() @@ -218,17 +243,31 @@ class MessageToFileYamlTestCase(unittest.TestCase): self.assertEqual(content, expected_content) def test_to_file_yaml_min(self) -> None: - self.message_min.to_file(self.file_path) + self.message_min.to_file(self.file_path, mformat='yaml') with open(self.file_path, "r") as fd: content = fd.read() expected_content = f"{Question.yaml_key}: This is a question.\n" self.assertEqual(content, expected_content) + def test_to_file_yaml_auto_suffix(self) -> None: + """ + Test if suffix is auto-generated if omitted. + """ + file_path_no_suffix = self.file_path.with_suffix('') + # test with file_path member + self.message_min.file_path = file_path_no_suffix + self.message_min.to_file(mformat='yaml') + self.assertEqual(self.message_min.file_path.suffix, msg_suffix) + # test with explicit file_path + self.message_min.file_path = file_path_no_suffix + self.message_min.to_file(file_path=file_path_no_suffix, mformat='yaml') + self.assertEqual(self.message_min.file_path.suffix, msg_suffix) + class MessageFromFileTxtTestCase(unittest.TestCase): def setUp(self) -> None: - self.file = tempfile.NamedTemporaryFile(delete=False, suffix='.txt') + self.file = tempfile.NamedTemporaryFile(delete=False, suffix=msg_suffix) self.file_path = pathlib.Path(self.file.name) with open(self.file_path, "w") as fd: fd.write(f"""{TagLine.prefix} tag1 tag2 @@ -239,7 +278,7 @@ This is a question. {Answer.txt_header} This is an answer. """) - self.file_min = tempfile.NamedTemporaryFile(delete=False, suffix='.txt') + self.file_min = tempfile.NamedTemporaryFile(delete=False, suffix=msg_suffix) self.file_path_min = pathlib.Path(self.file_min.name) with open(self.file_path_min, "w") as fd: fd.write(f"""{Question.txt_header} @@ -259,13 +298,13 @@ This is a question. message = Message.from_file(self.file_path) self.assertIsNotNone(message) self.assertIsInstance(message, Message) - if message: # mypy bug - self.assertEqual(message.question, 'This is a question.') - self.assertEqual(message.answer, 'This is an answer.') - self.assertSetEqual(cast(set[Tag], message.tags), {Tag('tag1'), Tag('tag2')}) - self.assertEqual(message.ai, 'ChatGPT') - self.assertEqual(message.model, 'gpt-3.5-turbo') - self.assertEqual(message.file_path, self.file_path) + assert message + self.assertEqual(message.question, 'This is a question.') + self.assertEqual(message.answer, 'This is an answer.') + self.assertSetEqual(cast(set[Tag], message.tags), {Tag('tag1'), Tag('tag2')}) + self.assertEqual(message.ai, 'ChatGPT') + self.assertEqual(message.model, 'gpt-3.5-turbo') + self.assertEqual(message.file_path, self.file_path) def test_from_file_txt_min(self) -> None: """ @@ -274,21 +313,21 @@ This is a question. message = Message.from_file(self.file_path_min) self.assertIsNotNone(message) self.assertIsInstance(message, Message) - if message: # mypy bug - self.assertEqual(message.question, 'This is a question.') - self.assertEqual(message.file_path, self.file_path_min) - self.assertIsNone(message.answer) + assert message + self.assertEqual(message.question, 'This is a question.') + self.assertEqual(message.file_path, self.file_path_min) + self.assertIsNone(message.answer) def test_from_file_txt_tags_match(self) -> None: message = Message.from_file(self.file_path, MessageFilter(tags_or={Tag('tag1')})) self.assertIsNotNone(message) self.assertIsInstance(message, Message) - if message: # mypy bug - self.assertEqual(message.question, 'This is a question.') - self.assertEqual(message.answer, 'This is an answer.') - self.assertSetEqual(cast(set[Tag], message.tags), {Tag('tag1'), Tag('tag2')}) - self.assertEqual(message.file_path, self.file_path) + assert message + self.assertEqual(message.question, 'This is a question.') + self.assertEqual(message.answer, 'This is an answer.') + self.assertSetEqual(cast(set[Tag], message.tags), {Tag('tag1'), Tag('tag2')}) + self.assertEqual(message.file_path, self.file_path) def test_from_file_txt_tags_dont_match(self) -> None: message = Message.from_file(self.file_path, @@ -311,13 +350,13 @@ This is a question. MessageFilter(tags_not={Tag('tag1')})) self.assertIsNotNone(message) self.assertIsInstance(message, Message) - if message: # mypy bug - self.assertEqual(message.question, 'This is a question.') - self.assertSetEqual(cast(set[Tag], message.tags), set()) - self.assertEqual(message.file_path, self.file_path_min) + assert message + self.assertEqual(message.question, 'This is a question.') + self.assertSetEqual(cast(set[Tag], message.tags), set()) + self.assertEqual(message.file_path, self.file_path_min) def test_from_file_not_exists(self) -> None: - file_not_exists = pathlib.Path("example.txt") + file_not_exists = pathlib.Path(f"example{msg_suffix}") with self.assertRaises(MessageError) as cm: Message.from_file(file_not_exists) self.assertEqual(str(cm.exception), f"Message file '{file_not_exists}' does not exist") @@ -396,7 +435,7 @@ This is a question. class MessageFromFileYamlTestCase(unittest.TestCase): def setUp(self) -> None: - self.file = tempfile.NamedTemporaryFile(delete=False, suffix='.yaml') + self.file = tempfile.NamedTemporaryFile(delete=False, suffix=msg_suffix) self.file_path = pathlib.Path(self.file.name) with open(self.file_path, "w") as fd: fd.write(f""" @@ -410,7 +449,7 @@ class MessageFromFileYamlTestCase(unittest.TestCase): - tag1 - tag2 """) - self.file_min = tempfile.NamedTemporaryFile(delete=False, suffix='.yaml') + self.file_min = tempfile.NamedTemporaryFile(delete=False, suffix=msg_suffix) self.file_path_min = pathlib.Path(self.file_min.name) with open(self.file_path_min, "w") as fd: fd.write(f""" @@ -431,13 +470,13 @@ class MessageFromFileYamlTestCase(unittest.TestCase): message = Message.from_file(self.file_path) self.assertIsInstance(message, Message) self.assertIsNotNone(message) - if message: # mypy bug - self.assertEqual(message.question, 'This is a question.') - self.assertEqual(message.answer, 'This is an answer.') - self.assertSetEqual(cast(set[Tag], message.tags), {Tag('tag1'), Tag('tag2')}) - self.assertEqual(message.ai, 'ChatGPT') - self.assertEqual(message.model, 'gpt-3.5-turbo') - self.assertEqual(message.file_path, self.file_path) + assert message + self.assertEqual(message.question, 'This is a question.') + self.assertEqual(message.answer, 'This is an answer.') + self.assertSetEqual(cast(set[Tag], message.tags), {Tag('tag1'), Tag('tag2')}) + self.assertEqual(message.ai, 'ChatGPT') + self.assertEqual(message.model, 'gpt-3.5-turbo') + self.assertEqual(message.file_path, self.file_path) def test_from_file_yaml_min(self) -> None: """ @@ -446,14 +485,14 @@ class MessageFromFileYamlTestCase(unittest.TestCase): message = Message.from_file(self.file_path_min) self.assertIsInstance(message, Message) self.assertIsNotNone(message) - if message: # mypy bug - self.assertEqual(message.question, 'This is a question.') - self.assertSetEqual(cast(set[Tag], message.tags), set()) - self.assertEqual(message.file_path, self.file_path_min) - self.assertIsNone(message.answer) + assert message + self.assertEqual(message.question, 'This is a question.') + self.assertSetEqual(cast(set[Tag], message.tags), set()) + self.assertEqual(message.file_path, self.file_path_min) + self.assertIsNone(message.answer) def test_from_file_not_exists(self) -> None: - file_not_exists = pathlib.Path("example.yaml") + file_not_exists = pathlib.Path(f"example{msg_suffix}") with self.assertRaises(MessageError) as cm: Message.from_file(file_not_exists) self.assertEqual(str(cm.exception), f"Message file '{file_not_exists}' does not exist") @@ -463,11 +502,11 @@ class MessageFromFileYamlTestCase(unittest.TestCase): MessageFilter(tags_or={Tag('tag1')})) self.assertIsNotNone(message) self.assertIsInstance(message, Message) - if message: # mypy bug - self.assertEqual(message.question, 'This is a question.') - self.assertEqual(message.answer, 'This is an answer.') - self.assertSetEqual(cast(set[Tag], message.tags), {Tag('tag1'), Tag('tag2')}) - self.assertEqual(message.file_path, self.file_path) + assert message + self.assertEqual(message.question, 'This is a question.') + self.assertEqual(message.answer, 'This is an answer.') + self.assertSetEqual(cast(set[Tag], message.tags), {Tag('tag1'), Tag('tag2')}) + self.assertEqual(message.file_path, self.file_path) def test_from_file_yaml_tags_dont_match(self) -> None: message = Message.from_file(self.file_path, @@ -484,10 +523,10 @@ class MessageFromFileYamlTestCase(unittest.TestCase): MessageFilter(tags_not={Tag('tag1')})) self.assertIsNotNone(message) self.assertIsInstance(message, Message) - if message: # mypy bug - self.assertEqual(message.question, 'This is a question.') - self.assertSetEqual(cast(set[Tag], message.tags), set()) - self.assertEqual(message.file_path, self.file_path_min) + assert message + self.assertEqual(message.question, 'This is a question.') + self.assertSetEqual(cast(set[Tag], message.tags), set()) + self.assertEqual(message.file_path, self.file_path_min) def test_from_file_yaml_question_match(self) -> None: message = Message.from_file(self.file_path, @@ -563,7 +602,7 @@ class MessageFromFileYamlTestCase(unittest.TestCase): class TagsFromFileTestCase(unittest.TestCase): def setUp(self) -> None: - self.file_txt = tempfile.NamedTemporaryFile(delete=False, suffix='.txt') + self.file_txt = tempfile.NamedTemporaryFile(delete=False, suffix=msg_suffix) self.file_path_txt = pathlib.Path(self.file_txt.name) with open(self.file_path_txt, "w") as fd: fd.write(f"""{TagLine.prefix} tag1 tag2 ptag3 @@ -572,7 +611,7 @@ This is a question. {Answer.txt_header} This is an answer. """) - self.file_txt_no_tags = tempfile.NamedTemporaryFile(delete=False, suffix='.txt') + self.file_txt_no_tags = tempfile.NamedTemporaryFile(delete=False, suffix=msg_suffix) self.file_path_txt_no_tags = pathlib.Path(self.file_txt_no_tags.name) with open(self.file_path_txt_no_tags, "w") as fd: fd.write(f"""{Question.txt_header} @@ -580,7 +619,7 @@ This is a question. {Answer.txt_header} This is an answer. """) - self.file_txt_tags_empty = tempfile.NamedTemporaryFile(delete=False, suffix='.txt') + self.file_txt_tags_empty = tempfile.NamedTemporaryFile(delete=False, suffix=msg_suffix) self.file_path_txt_tags_empty = pathlib.Path(self.file_txt_tags_empty.name) with open(self.file_path_txt_tags_empty, "w") as fd: fd.write(f"""TAGS: @@ -589,7 +628,7 @@ This is a question. {Answer.txt_header} This is an answer. """) - self.file_yaml = tempfile.NamedTemporaryFile(delete=False, suffix='.yaml') + self.file_yaml = tempfile.NamedTemporaryFile(delete=False, suffix=msg_suffix) self.file_path_yaml = pathlib.Path(self.file_yaml.name) with open(self.file_path_yaml, "w") as fd: fd.write(f""" @@ -602,7 +641,7 @@ This is an answer. - tag2 - ptag3 """) - self.file_yaml_no_tags = tempfile.NamedTemporaryFile(delete=False, suffix='.yaml') + self.file_yaml_no_tags = tempfile.NamedTemporaryFile(delete=False, suffix=msg_suffix) self.file_path_yaml_no_tags = pathlib.Path(self.file_yaml_no_tags.name) with open(self.file_path_yaml_no_tags, "w") as fd: fd.write(f""" @@ -679,24 +718,25 @@ class TagsFromDirTestCase(unittest.TestCase): {Tag('ctag5'), Tag('ctag6')} ] self.files = [ - pathlib.Path(self.temp_dir.name, 'file1.txt'), - pathlib.Path(self.temp_dir.name, 'file2.yaml'), - pathlib.Path(self.temp_dir.name, 'file3.txt') + pathlib.Path(self.temp_dir.name, f'file1{msg_suffix}'), + pathlib.Path(self.temp_dir.name, f'file2{msg_suffix}'), + pathlib.Path(self.temp_dir.name, f'file3{msg_suffix}') ] self.files_no_tags = [ - pathlib.Path(self.temp_dir_no_tags.name, 'file4.txt'), - pathlib.Path(self.temp_dir_no_tags.name, 'file5.yaml'), - pathlib.Path(self.temp_dir_no_tags.name, 'file6.txt') + pathlib.Path(self.temp_dir_no_tags.name, f'file4{msg_suffix}'), + pathlib.Path(self.temp_dir_no_tags.name, f'file5{msg_suffix}'), + pathlib.Path(self.temp_dir_no_tags.name, f'file6{msg_suffix}') ] + mformats = itertools.cycle(message_valid_formats) for file, tags in zip(self.files, self.tag_sets): message = Message(Question('This is a question.'), Answer('This is an answer.'), tags) - message.to_file(file) + message.to_file(file, next(mformats)) for file in self.files_no_tags: message = Message(Question('This is a question.'), Answer('This is an answer.')) - message.to_file(file) + message.to_file(file, next(mformats)) def tearDown(self) -> None: self.temp_dir.cleanup() @@ -719,7 +759,7 @@ class TagsFromDirTestCase(unittest.TestCase): class MessageIDTestCase(unittest.TestCase): def setUp(self) -> None: - self.file = tempfile.NamedTemporaryFile(delete=False, suffix='.txt') + self.file = tempfile.NamedTemporaryFile(delete=False, suffix=msg_suffix) self.file_path = pathlib.Path(self.file.name) self.message = Message(Question('This is a question.'), file_path=self.file_path) -- 2.36.6 From e34eab651982a5c452a17f0989a2c8a920aa03e1 Mon Sep 17 00:00:00 2001 From: juk0de Date: Tue, 26 Sep 2023 10:12:24 +0200 Subject: [PATCH 05/16] test_chat: changed all tests to use the new '.msg' suffix --- tests/test_chat.py | 192 +++++++++++++++++++++++---------------------- 1 file changed, 100 insertions(+), 92 deletions(-) diff --git a/tests/test_chat.py b/tests/test_chat.py index a69f92c..0d4f672 100644 --- a/tests/test_chat.py +++ b/tests/test_chat.py @@ -10,6 +10,20 @@ from chatmastermind.message import Message, Question, Answer, Tag, MessageFilter from chatmastermind.chat import Chat, ChatDB, ChatError +msg_suffix: str = Message.file_suffix_write + + +def msg_to_file_force_suffix(msg: Message) -> None: + """ + Force writing a message file with illegal suffixes. + """ + def_suffix = Message.file_suffix_write + assert msg.file_path + Message.file_suffix_write = msg.file_path.suffix + msg.to_file() + Message.file_suffix_write = def_suffix + + class TestChatBase(unittest.TestCase): def assert_messages_equal(self, msg1: list[Message], msg2: list[Message]) -> None: """ @@ -27,11 +41,11 @@ class TestChat(TestChatBase): self.message1 = Message(Question('Question 1'), Answer('Answer 1'), {Tag('atag1'), Tag('btag2')}, - file_path=pathlib.Path('0001.txt')) + file_path=pathlib.Path(f'0001{msg_suffix}')) self.message2 = Message(Question('Question 2'), Answer('Answer 2'), {Tag('btag2')}, - file_path=pathlib.Path('0002.txt')) + file_path=pathlib.Path(f'0002{msg_suffix}')) self.maxDiff = None def test_unique_id(self) -> None: @@ -99,24 +113,24 @@ class TestChat(TestChatBase): def test_find_remove_messages(self) -> None: self.chat.msg_add([self.message1, self.message2]) - msgs = self.chat.msg_find(['0001.txt']) + msgs = self.chat.msg_find(['0001']) self.assertListEqual(msgs, [self.message1]) - msgs = self.chat.msg_find(['0001.txt', '0002.txt']) + msgs = self.chat.msg_find(['0001', '0002']) self.assertListEqual(msgs, [self.message1, self.message2]) # add new Message with full path message3 = Message(Question('Question 2'), Answer('Answer 2'), {Tag('btag2')}, - file_path=pathlib.Path('/foo/bla/0003.txt')) + file_path=pathlib.Path(f'/foo/bla/0003{msg_suffix}')) self.chat.msg_add([message3]) # find new Message by full path - msgs = self.chat.msg_find(['/foo/bla/0003.txt']) + msgs = self.chat.msg_find([f'/foo/bla/0003{msg_suffix}']) self.assertListEqual(msgs, [message3]) # find Message with full path only by filename - msgs = self.chat.msg_find(['0003.txt']) + msgs = self.chat.msg_find([f'0003{msg_suffix}']) self.assertListEqual(msgs, [message3]) # remove last message - self.chat.msg_remove(['0003.txt']) + self.chat.msg_remove(['0003']) self.assertListEqual(self.chat.messages, [self.message1, self.message2]) def test_latest_message(self) -> None: @@ -146,13 +160,13 @@ Answer 2 self.chat.msg_add([self.message1, self.message2]) self.chat.print(paged=False, with_tags=True, with_files=True) expected_output = f"""{TagLine.prefix} atag1 btag2 -FILE: 0001.txt +FILE: 0001{msg_suffix} {Question.txt_header} Question 1 {Answer.txt_header} Answer 1 {TagLine.prefix} btag2 -FILE: 0002.txt +FILE: 0002{msg_suffix} {Question.txt_header} Question 2 {Answer.txt_header} @@ -168,31 +182,27 @@ class TestChatDB(TestChatBase): self.message1 = Message(Question('Question 1'), Answer('Answer 1'), - {Tag('tag1')}, - file_path=pathlib.Path('0001.txt')) + {Tag('tag1')}) self.message2 = Message(Question('Question 2'), Answer('Answer 2'), - {Tag('tag2')}, - file_path=pathlib.Path('0002.yaml')) + {Tag('tag2')}) self.message3 = Message(Question('Question 3'), Answer('Answer 3'), - {Tag('tag3')}, - file_path=pathlib.Path('0003.txt')) + {Tag('tag3')}) self.message4 = Message(Question('Question 4'), Answer('Answer 4'), - {Tag('tag4')}, - file_path=pathlib.Path('0004.yaml')) + {Tag('tag4')}) - self.message1.to_file(pathlib.Path(self.db_path.name, '0001.txt')) - self.message2.to_file(pathlib.Path(self.db_path.name, '0002.yaml')) - self.message3.to_file(pathlib.Path(self.db_path.name, '0003.txt')) - self.message4.to_file(pathlib.Path(self.db_path.name, '0004.yaml')) + self.message1.to_file(pathlib.Path(self.db_path.name, '0001'), mformat='txt') + self.message2.to_file(pathlib.Path(self.db_path.name, '0002'), mformat='yaml') + self.message3.to_file(pathlib.Path(self.db_path.name, '0003'), mformat='txt') + self.message4.to_file(pathlib.Path(self.db_path.name, '0004'), mformat='yaml') # make the next FID match the current state next_fname = pathlib.Path(self.db_path.name) / '.next' with open(next_fname, 'w') as f: f.write('4') # add some "trash" in order to test if it's correctly handled / ignored - self.trash_files = ['.config.yaml', 'foo.yaml', 'bla.txt'] + self.trash_files = ['.config.yaml', 'foo.yaml', 'bla.txt', 'fubar.msg'] for file in self.trash_files: with open(pathlib.Path(self.db_path.name) / file, 'w') as f: f.write('test trash') @@ -207,7 +217,7 @@ class TestChatDB(TestChatBase): List all Message files in the given TemporaryDirectory. """ # exclude '.next' - return [f for f in pathlib.Path(tmp_dir.name).glob('*.[ty]*') if f.name not in self.trash_files] + return [f for f in pathlib.Path(tmp_dir.name).glob('*.[tym]*') if f.name not in self.trash_files] def tearDown(self) -> None: self.db_path.cleanup() @@ -218,8 +228,8 @@ class TestChatDB(TestChatBase): duplicate_message = Message(Question('Question 4'), Answer('Answer 4'), {Tag('tag4')}, - file_path=pathlib.Path('0004.txt')) - duplicate_message.to_file(pathlib.Path(self.db_path.name, '0004.txt')) + file_path=pathlib.Path(self.db_path.name, '0004.txt')) + msg_to_file_force_suffix(duplicate_message) with self.assertRaises(ChatError) as cm: ChatDB.from_dir(pathlib.Path(self.cache_path.name), pathlib.Path(self.db_path.name)) @@ -233,25 +243,23 @@ class TestChatDB(TestChatBase): self.assertEqual(chat_db.db_path, pathlib.Path(self.db_path.name)) # check that the files are sorted self.assertEqual(chat_db.messages[0].file_path, - pathlib.Path(self.db_path.name, '0001.txt')) + pathlib.Path(self.db_path.name, f'0001{msg_suffix}')) self.assertEqual(chat_db.messages[1].file_path, - pathlib.Path(self.db_path.name, '0002.yaml')) + pathlib.Path(self.db_path.name, f'0002{msg_suffix}')) self.assertEqual(chat_db.messages[2].file_path, - pathlib.Path(self.db_path.name, '0003.txt')) + pathlib.Path(self.db_path.name, f'0003{msg_suffix}')) self.assertEqual(chat_db.messages[3].file_path, - pathlib.Path(self.db_path.name, '0004.yaml')) + pathlib.Path(self.db_path.name, f'0004{msg_suffix}')) def test_from_dir_glob(self) -> None: chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name), pathlib.Path(self.db_path.name), - glob='*.txt') - self.assertEqual(len(chat_db.messages), 2) + glob='*1.*') + self.assertEqual(len(chat_db.messages), 1) self.assertEqual(chat_db.cache_path, pathlib.Path(self.cache_path.name)) self.assertEqual(chat_db.db_path, pathlib.Path(self.db_path.name)) self.assertEqual(chat_db.messages[0].file_path, - pathlib.Path(self.db_path.name, '0001.txt')) - self.assertEqual(chat_db.messages[1].file_path, - pathlib.Path(self.db_path.name, '0003.txt')) + pathlib.Path(self.db_path.name, f'0001{msg_suffix}')) def test_from_dir_filter_tags(self) -> None: chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name), @@ -261,7 +269,7 @@ class TestChatDB(TestChatBase): self.assertEqual(chat_db.cache_path, pathlib.Path(self.cache_path.name)) self.assertEqual(chat_db.db_path, pathlib.Path(self.db_path.name)) self.assertEqual(chat_db.messages[0].file_path, - pathlib.Path(self.db_path.name, '0001.txt')) + pathlib.Path(self.db_path.name, f'0001{msg_suffix}')) def test_from_dir_filter_tags_empty(self) -> None: chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name), @@ -279,7 +287,7 @@ class TestChatDB(TestChatBase): self.assertEqual(chat_db.cache_path, pathlib.Path(self.cache_path.name)) self.assertEqual(chat_db.db_path, pathlib.Path(self.db_path.name)) self.assertEqual(chat_db.messages[0].file_path, - pathlib.Path(self.db_path.name, '0002.yaml')) + pathlib.Path(self.db_path.name, f'0002{msg_suffix}')) self.assertEqual(chat_db.messages[0].answer, 'Answer 2') def test_from_messages(self) -> None: @@ -324,25 +332,25 @@ class TestChatDB(TestChatBase): chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name), pathlib.Path(self.db_path.name)) # check that Message.file_path is correct - self.assertEqual(chat_db.messages[0].file_path, pathlib.Path(self.db_path.name, '0001.txt')) - self.assertEqual(chat_db.messages[1].file_path, pathlib.Path(self.db_path.name, '0002.yaml')) - self.assertEqual(chat_db.messages[2].file_path, pathlib.Path(self.db_path.name, '0003.txt')) - self.assertEqual(chat_db.messages[3].file_path, pathlib.Path(self.db_path.name, '0004.yaml')) + self.assertEqual(chat_db.messages[0].file_path, pathlib.Path(self.db_path.name, f'0001{msg_suffix}')) + self.assertEqual(chat_db.messages[1].file_path, pathlib.Path(self.db_path.name, f'0002{msg_suffix}')) + self.assertEqual(chat_db.messages[2].file_path, pathlib.Path(self.db_path.name, f'0003{msg_suffix}')) + self.assertEqual(chat_db.messages[3].file_path, pathlib.Path(self.db_path.name, f'0004{msg_suffix}')) # write the messages to the cache directory chat_db.cache_write() # check if the written files are in the cache directory cache_dir_files = self.message_list(self.cache_path) self.assertEqual(len(cache_dir_files), 4) - self.assertIn(pathlib.Path(self.cache_path.name, '0001.txt'), cache_dir_files) - self.assertIn(pathlib.Path(self.cache_path.name, '0002.yaml'), cache_dir_files) - self.assertIn(pathlib.Path(self.cache_path.name, '0003.txt'), cache_dir_files) - self.assertIn(pathlib.Path(self.cache_path.name, '0004.yaml'), cache_dir_files) + self.assertIn(pathlib.Path(self.cache_path.name, f'0001{msg_suffix}'), cache_dir_files) + self.assertIn(pathlib.Path(self.cache_path.name, f'0002{msg_suffix}'), cache_dir_files) + self.assertIn(pathlib.Path(self.cache_path.name, f'0003{msg_suffix}'), cache_dir_files) + self.assertIn(pathlib.Path(self.cache_path.name, f'0004{msg_suffix}'), cache_dir_files) # check that Message.file_path has been correctly updated - self.assertEqual(chat_db.messages[0].file_path, pathlib.Path(self.cache_path.name, '0001.txt')) - self.assertEqual(chat_db.messages[1].file_path, pathlib.Path(self.cache_path.name, '0002.yaml')) - self.assertEqual(chat_db.messages[2].file_path, pathlib.Path(self.cache_path.name, '0003.txt')) - self.assertEqual(chat_db.messages[3].file_path, pathlib.Path(self.cache_path.name, '0004.yaml')) + self.assertEqual(chat_db.messages[0].file_path, pathlib.Path(self.cache_path.name, f'0001{msg_suffix}')) + self.assertEqual(chat_db.messages[1].file_path, pathlib.Path(self.cache_path.name, f'0002{msg_suffix}')) + self.assertEqual(chat_db.messages[2].file_path, pathlib.Path(self.cache_path.name, f'0003{msg_suffix}')) + self.assertEqual(chat_db.messages[3].file_path, pathlib.Path(self.cache_path.name, f'0004{msg_suffix}')) # check the timestamp of the files in the DB directory db_dir_files = self.message_list(self.db_path) @@ -354,18 +362,18 @@ class TestChatDB(TestChatBase): # check if the written files are in the DB directory db_dir_files = self.message_list(self.db_path) self.assertEqual(len(db_dir_files), 4) - self.assertIn(pathlib.Path(self.db_path.name, '0001.txt'), db_dir_files) - self.assertIn(pathlib.Path(self.db_path.name, '0002.yaml'), db_dir_files) - self.assertIn(pathlib.Path(self.db_path.name, '0003.txt'), db_dir_files) - self.assertIn(pathlib.Path(self.db_path.name, '0004.yaml'), db_dir_files) + self.assertIn(pathlib.Path(self.db_path.name, f'0001{msg_suffix}'), db_dir_files) + self.assertIn(pathlib.Path(self.db_path.name, f'0002{msg_suffix}'), db_dir_files) + self.assertIn(pathlib.Path(self.db_path.name, f'0003{msg_suffix}'), db_dir_files) + self.assertIn(pathlib.Path(self.db_path.name, f'0004{msg_suffix}'), db_dir_files) # check if all files in the DB dir have actually been overwritten for file in db_dir_files: self.assertGreater(file.stat().st_mtime, old_timestamps[file]) # check that Message.file_path has been correctly updated (again) - self.assertEqual(chat_db.messages[0].file_path, pathlib.Path(self.db_path.name, '0001.txt')) - self.assertEqual(chat_db.messages[1].file_path, pathlib.Path(self.db_path.name, '0002.yaml')) - self.assertEqual(chat_db.messages[2].file_path, pathlib.Path(self.db_path.name, '0003.txt')) - self.assertEqual(chat_db.messages[3].file_path, pathlib.Path(self.db_path.name, '0004.yaml')) + self.assertEqual(chat_db.messages[0].file_path, pathlib.Path(self.db_path.name, f'0001{msg_suffix}')) + self.assertEqual(chat_db.messages[1].file_path, pathlib.Path(self.db_path.name, f'0002{msg_suffix}')) + self.assertEqual(chat_db.messages[2].file_path, pathlib.Path(self.db_path.name, f'0003{msg_suffix}')) + self.assertEqual(chat_db.messages[3].file_path, pathlib.Path(self.db_path.name, f'0004{msg_suffix}')) def test_db_read(self) -> None: # create a new ChatDB instance @@ -380,65 +388,65 @@ class TestChatDB(TestChatBase): new_message2 = Message(Question('Question 6'), Answer('Answer 6'), {Tag('tag6')}) - new_message1.to_file(pathlib.Path(self.db_path.name, '0005.txt')) - new_message2.to_file(pathlib.Path(self.db_path.name, '0006.yaml')) + new_message1.to_file(pathlib.Path(self.db_path.name, f'0005{msg_suffix}'), mformat='txt') + new_message2.to_file(pathlib.Path(self.db_path.name, f'0006{msg_suffix}'), mformat='yaml') # read and check them chat_db.db_read() self.assertEqual(len(chat_db.messages), 6) - self.assertEqual(chat_db.messages[4].file_path, pathlib.Path(self.db_path.name, '0005.txt')) - self.assertEqual(chat_db.messages[5].file_path, pathlib.Path(self.db_path.name, '0006.yaml')) + self.assertEqual(chat_db.messages[4].file_path, pathlib.Path(self.db_path.name, f'0005{msg_suffix}')) + self.assertEqual(chat_db.messages[5].file_path, pathlib.Path(self.db_path.name, f'0006{msg_suffix}')) # create 2 new files in the cache directory new_message3 = Message(Question('Question 7'), - Answer('Answer 5'), + Answer('Answer 7'), {Tag('tag7')}) new_message4 = Message(Question('Question 8'), - Answer('Answer 6'), + Answer('Answer 8'), {Tag('tag8')}) - new_message3.to_file(pathlib.Path(self.cache_path.name, '0007.txt')) - new_message4.to_file(pathlib.Path(self.cache_path.name, '0008.yaml')) + new_message3.to_file(pathlib.Path(self.cache_path.name, f'0007{msg_suffix}'), mformat='txt') + new_message4.to_file(pathlib.Path(self.cache_path.name, f'0008{msg_suffix}'), mformat='yaml') # read and check them chat_db.cache_read() self.assertEqual(len(chat_db.messages), 8) # check that the new message have the cache dir path - self.assertEqual(chat_db.messages[6].file_path, pathlib.Path(self.cache_path.name, '0007.txt')) - self.assertEqual(chat_db.messages[7].file_path, pathlib.Path(self.cache_path.name, '0008.yaml')) + self.assertEqual(chat_db.messages[6].file_path, pathlib.Path(self.cache_path.name, f'0007{msg_suffix}')) + self.assertEqual(chat_db.messages[7].file_path, pathlib.Path(self.cache_path.name, f'0008{msg_suffix}')) # an the old ones keep their path (since they have not been replaced) - self.assertEqual(chat_db.messages[4].file_path, pathlib.Path(self.db_path.name, '0005.txt')) - self.assertEqual(chat_db.messages[5].file_path, pathlib.Path(self.db_path.name, '0006.yaml')) + self.assertEqual(chat_db.messages[4].file_path, pathlib.Path(self.db_path.name, f'0005{msg_suffix}')) + self.assertEqual(chat_db.messages[5].file_path, pathlib.Path(self.db_path.name, f'0006{msg_suffix}')) # now overwrite two messages in the DB directory new_message1.question = Question('New Question 1') new_message2.question = Question('New Question 2') - new_message1.to_file(pathlib.Path(self.db_path.name, '0005.txt')) - new_message2.to_file(pathlib.Path(self.db_path.name, '0006.yaml')) + new_message1.to_file(pathlib.Path(self.db_path.name, f'0005{msg_suffix}'), mformat='txt') + new_message2.to_file(pathlib.Path(self.db_path.name, f'0006{msg_suffix}'), mformat='yaml') # read from the DB dir and check if the modified messages have been updated chat_db.db_read() self.assertEqual(len(chat_db.messages), 8) self.assertEqual(chat_db.messages[4].question, 'New Question 1') self.assertEqual(chat_db.messages[5].question, 'New Question 2') - self.assertEqual(chat_db.messages[4].file_path, pathlib.Path(self.db_path.name, '0005.txt')) - self.assertEqual(chat_db.messages[5].file_path, pathlib.Path(self.db_path.name, '0006.yaml')) + self.assertEqual(chat_db.messages[4].file_path, pathlib.Path(self.db_path.name, f'0005{msg_suffix}')) + self.assertEqual(chat_db.messages[5].file_path, pathlib.Path(self.db_path.name, f'0006{msg_suffix}')) # now write the messages from the cache to the DB directory - new_message3.to_file(pathlib.Path(self.db_path.name, '0007.txt')) - new_message4.to_file(pathlib.Path(self.db_path.name, '0008.yaml')) + new_message3.to_file(pathlib.Path(self.db_path.name, f'0007{msg_suffix}')) + new_message4.to_file(pathlib.Path(self.db_path.name, f'0008{msg_suffix}')) # read and check them chat_db.db_read() self.assertEqual(len(chat_db.messages), 8) # check that they now have the DB path - self.assertEqual(chat_db.messages[6].file_path, pathlib.Path(self.db_path.name, '0007.txt')) - self.assertEqual(chat_db.messages[7].file_path, pathlib.Path(self.db_path.name, '0008.yaml')) + self.assertEqual(chat_db.messages[6].file_path, pathlib.Path(self.db_path.name, f'0007{msg_suffix}')) + self.assertEqual(chat_db.messages[7].file_path, pathlib.Path(self.db_path.name, f'0008{msg_suffix}')) def test_cache_clear(self) -> None: # create a new ChatDB instance chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name), pathlib.Path(self.db_path.name)) # check that Message.file_path is correct - self.assertEqual(chat_db.messages[0].file_path, pathlib.Path(self.db_path.name, '0001.txt')) - self.assertEqual(chat_db.messages[1].file_path, pathlib.Path(self.db_path.name, '0002.yaml')) - self.assertEqual(chat_db.messages[2].file_path, pathlib.Path(self.db_path.name, '0003.txt')) - self.assertEqual(chat_db.messages[3].file_path, pathlib.Path(self.db_path.name, '0004.yaml')) + self.assertEqual(chat_db.messages[0].file_path, pathlib.Path(self.db_path.name, f'0001{msg_suffix}')) + self.assertEqual(chat_db.messages[1].file_path, pathlib.Path(self.db_path.name, f'0002{msg_suffix}')) + self.assertEqual(chat_db.messages[2].file_path, pathlib.Path(self.db_path.name, f'0003{msg_suffix}')) + self.assertEqual(chat_db.messages[3].file_path, pathlib.Path(self.db_path.name, f'0004{msg_suffix}')) # write the messages to the cache directory chat_db.cache_write() @@ -450,10 +458,10 @@ class TestChatDB(TestChatBase): chat_db.db_write() db_dir_files = self.message_list(self.db_path) self.assertEqual(len(db_dir_files), 4) - self.assertIn(pathlib.Path(self.db_path.name, '0001.txt'), db_dir_files) - self.assertIn(pathlib.Path(self.db_path.name, '0002.yaml'), db_dir_files) - self.assertIn(pathlib.Path(self.db_path.name, '0003.txt'), db_dir_files) - self.assertIn(pathlib.Path(self.db_path.name, '0004.yaml'), db_dir_files) + self.assertIn(pathlib.Path(self.db_path.name, f'0001{msg_suffix}'), db_dir_files) + self.assertIn(pathlib.Path(self.db_path.name, f'0002{msg_suffix}'), db_dir_files) + self.assertIn(pathlib.Path(self.db_path.name, f'0003{msg_suffix}'), db_dir_files) + self.assertIn(pathlib.Path(self.db_path.name, f'0004{msg_suffix}'), db_dir_files) # add a new message with empty file_path message_empty = Message(question=Question("What the hell am I doing here?"), @@ -461,7 +469,7 @@ class TestChatDB(TestChatBase): # and one for the cache dir message_cache = Message(question=Question("What the hell am I doing here?"), answer=Answer("You're a creep!"), - file_path=pathlib.Path(self.cache_path.name, '0005.txt')) + file_path=pathlib.Path(self.cache_path.name, '0005')) chat_db.msg_add([message_empty, message_cache]) # clear the cache and check the cache dir @@ -523,11 +531,11 @@ class TestChatDB(TestChatBase): chat_db.msg_write([message]) # write a message with a valid file_path - message.file_path = pathlib.Path(self.cache_path.name) / '123456.txt' + message.file_path = pathlib.Path(self.cache_path.name) / '123456' chat_db.msg_write([message]) cache_dir_files = self.message_list(self.cache_path) self.assertEqual(len(cache_dir_files), 1) - self.assertIn(pathlib.Path(self.cache_path.name, '123456.txt'), cache_dir_files) + self.assertIn(pathlib.Path(self.cache_path.name, f'123456{msg_suffix}'), cache_dir_files) def test_msg_update(self) -> None: # create a new ChatDB instance @@ -563,21 +571,21 @@ class TestChatDB(TestChatBase): # search for a DB file in memory self.assertEqual(chat_db.msg_find([str(self.message1.file_path)], loc='mem'), [self.message1]) self.assertEqual(chat_db.msg_find([self.message1.msg_id()], loc='mem'), [self.message1]) - self.assertEqual(chat_db.msg_find(['0001.txt'], loc='mem'), [self.message1]) + self.assertEqual(chat_db.msg_find(['0001.msg'], loc='mem'), [self.message1]) self.assertEqual(chat_db.msg_find(['0001'], loc='mem'), [self.message1]) # and on disk self.assertEqual(chat_db.msg_find([str(self.message2.file_path)], loc='db'), [self.message2]) self.assertEqual(chat_db.msg_find([self.message2.msg_id()], loc='db'), [self.message2]) - self.assertEqual(chat_db.msg_find(['0002.yaml'], loc='db'), [self.message2]) + self.assertEqual(chat_db.msg_find(['0002.msg'], loc='db'), [self.message2]) self.assertEqual(chat_db.msg_find(['0002'], loc='db'), [self.message2]) # now search the cache -> expect empty result self.assertEqual(chat_db.msg_find([str(self.message3.file_path)], loc='cache'), []) self.assertEqual(chat_db.msg_find([self.message3.msg_id()], loc='cache'), []) - self.assertEqual(chat_db.msg_find(['0003.txt'], loc='cache'), []) + self.assertEqual(chat_db.msg_find(['0003.msg'], loc='cache'), []) self.assertEqual(chat_db.msg_find(['0003'], loc='cache'), []) # search for multiple messages # -> search one twice, expect result to be unique - search_names = ['0001', '0002.yaml', self.message3.msg_id(), str(self.message3.file_path)] + search_names = ['0001', '0002.msg', self.message3.msg_id(), str(self.message3.file_path)] expected_result = [self.message1, self.message2, self.message3] result = chat_db.msg_find(search_names, loc='all') self.assert_messages_equal(result, expected_result) -- 2.36.6 From df42bcee09e0930fa042cb05a6a536a19043172d Mon Sep 17 00:00:00 2001 From: juk0de Date: Tue, 26 Sep 2023 18:18:11 +0200 Subject: [PATCH 06/16] test_chat: added test for file_path collision detection --- tests/test_chat.py | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/tests/test_chat.py b/tests/test_chat.py index 0d4f672..fca61e5 100644 --- a/tests/test_chat.py +++ b/tests/test_chat.py @@ -235,6 +235,24 @@ class TestChatDB(TestChatBase): pathlib.Path(self.db_path.name)) self.assertEqual(str(cm.exception), "Validation failed") + def test_file_path_ID_exists(self) -> None: + """ + Tests if the CacheDB chooses another ID if a file path with + the given one exists. + """ + # create a new and empty CacheDB + db_path = tempfile.TemporaryDirectory() + cache_path = tempfile.TemporaryDirectory() + chat_db = ChatDB.from_dir(pathlib.Path(cache_path.name), + pathlib.Path(db_path.name)) + # add a message file + message = Message(Question('What?'), + file_path=pathlib.Path(cache_path.name) / f'0001{msg_suffix}') + message.to_file() + message1 = Message(Question('Where?')) + chat_db.cache_write([message1]) + self.assertEqual(message1.msg_id(), '0002') + def test_from_dir(self) -> None: chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name), pathlib.Path(self.db_path.name)) -- 2.36.6 From 01860ace2ce14bbb43468859fd2114bb4959b362 Mon Sep 17 00:00:00 2001 From: juk0de Date: Tue, 26 Sep 2023 18:24:36 +0200 Subject: [PATCH 07/16] test_question_cmd: modified tests to use '.msg' file suffix --- tests/test_question_cmd.py | 61 +++++++++++++++++++------------------- 1 file changed, 31 insertions(+), 30 deletions(-) diff --git a/tests/test_question_cmd.py b/tests/test_question_cmd.py index 5033e9f..be31b58 100644 --- a/tests/test_question_cmd.py +++ b/tests/test_question_cmd.py @@ -14,6 +14,9 @@ from chatmastermind.ai import AIError from .test_common import TestWithFakeAI +msg_suffix = Message.file_suffix_write + + class TestMessageCreate(TestWithFakeAI): """ Test if messages created by the 'question' command have @@ -85,7 +88,7 @@ Aaaand again some text.""" def message_list(self, tmp_dir: tempfile.TemporaryDirectory) -> list[Path]: # exclude '.next' - return list(Path(tmp_dir.name).glob('*.[ty]*')) + return list(Path(tmp_dir.name).glob(f'*{msg_suffix}')) def test_message_file_created(self) -> None: self.args.ask = ["What is this?"] @@ -248,7 +251,7 @@ class TestQuestionCmd(TestWithFakeAI): def message_list(self, tmp_dir: tempfile.TemporaryDirectory) -> list[Path]: # exclude '.next' - return sorted([f for f in Path(tmp_dir.name).glob('*.[ty]*')]) + return sorted([f for f in Path(tmp_dir.name).glob(f'*{msg_suffix}')]) class TestQuestionCmdAsk(TestQuestionCmd): @@ -347,14 +350,16 @@ class TestQuestionCmdRepeat(TestQuestionCmd): Repeat a single question. """ mock_create_ai.side_effect = self.mock_create_ai + chat = ChatDB.from_dir(Path(self.cache_dir.name), + Path(self.db_dir.name)) # create a message message = Message(Question(self.args.ask[0]), Answer('Old Answer'), tags=set(self.args.output_tags), ai=self.args.AI, model=self.args.model, - file_path=Path(self.cache_dir.name) / '0001.txt') - message.to_file() + file_path=Path(self.cache_dir.name) / f'0001{msg_suffix}') + chat.msg_write([message]) # repeat the last question (without overwriting) # -> expect two identical messages (except for the file_path) @@ -370,8 +375,6 @@ class TestQuestionCmdRepeat(TestQuestionCmd): # we expect the original message + the one with the new response expected_responses = [message] + [expected_response] question_cmd(self.args, self.config) - chat = ChatDB.from_dir(Path(self.cache_dir.name), - Path(self.db_dir.name)) cached_msg = chat.msg_gather(loc='cache') print(self.message_list(self.cache_dir)) self.assertEqual(len(self.message_list(self.cache_dir)), 2) @@ -383,16 +386,16 @@ class TestQuestionCmdRepeat(TestQuestionCmd): Repeat a single question and overwrite the old one. """ mock_create_ai.side_effect = self.mock_create_ai + chat = ChatDB.from_dir(Path(self.cache_dir.name), + Path(self.db_dir.name)) # create a message message = Message(Question(self.args.ask[0]), Answer('Old Answer'), tags=set(self.args.output_tags), ai=self.args.AI, model=self.args.model, - file_path=Path(self.cache_dir.name) / '0001.txt') - message.to_file() - chat = ChatDB.from_dir(Path(self.cache_dir.name), - Path(self.db_dir.name)) + file_path=Path(self.cache_dir.name) / f'0001{msg_suffix}') + chat.msg_write([message]) cached_msg = chat.msg_gather(loc='cache') assert cached_msg[0].file_path cached_msg_file_id = cached_msg[0].file_path.stem @@ -422,16 +425,16 @@ class TestQuestionCmdRepeat(TestQuestionCmd): Repeat a single question after an error. """ mock_create_ai.side_effect = self.mock_create_ai + chat = ChatDB.from_dir(Path(self.cache_dir.name), + Path(self.db_dir.name)) # create a question WITHOUT an answer # -> just like after an error, which is tested above message = Message(Question(self.args.ask[0]), tags=set(self.args.output_tags), ai=self.args.AI, model=self.args.model, - file_path=Path(self.cache_dir.name) / '0001.txt') - message.to_file() - chat = ChatDB.from_dir(Path(self.cache_dir.name), - Path(self.db_dir.name)) + file_path=Path(self.cache_dir.name) / f'0001{msg_suffix}') + chat.msg_write([message]) cached_msg = chat.msg_gather(loc='cache') assert cached_msg[0].file_path cached_msg_file_id = cached_msg[0].file_path.stem @@ -462,16 +465,16 @@ class TestQuestionCmdRepeat(TestQuestionCmd): Repeat a single question with new arguments. """ mock_create_ai.side_effect = self.mock_create_ai + chat = ChatDB.from_dir(Path(self.cache_dir.name), + Path(self.db_dir.name)) # create a message message = Message(Question(self.args.ask[0]), Answer('Old Answer'), tags=set(self.args.output_tags), ai=self.args.AI, model=self.args.model, - file_path=Path(self.cache_dir.name) / '0001.txt') - message.to_file() - chat = ChatDB.from_dir(Path(self.cache_dir.name), - Path(self.db_dir.name)) + file_path=Path(self.cache_dir.name) / f'0001{msg_suffix}') + chat.msg_write([message]) cached_msg = chat.msg_gather(loc='cache') assert cached_msg[0].file_path @@ -500,16 +503,16 @@ class TestQuestionCmdRepeat(TestQuestionCmd): Repeat a single question with new arguments, overwriting the old one. """ mock_create_ai.side_effect = self.mock_create_ai + chat = ChatDB.from_dir(Path(self.cache_dir.name), + Path(self.db_dir.name)) # create a message message = Message(Question(self.args.ask[0]), Answer('Old Answer'), tags=set(self.args.output_tags), ai=self.args.AI, model=self.args.model, - file_path=Path(self.cache_dir.name) / '0001.txt') - message.to_file() - chat = ChatDB.from_dir(Path(self.cache_dir.name), - Path(self.db_dir.name)) + file_path=Path(self.cache_dir.name) / f'0001{msg_suffix}') + chat.msg_write([message]) cached_msg = chat.msg_gather(loc='cache') assert cached_msg[0].file_path @@ -537,29 +540,29 @@ class TestQuestionCmdRepeat(TestQuestionCmd): Repeat multiple questions. """ mock_create_ai.side_effect = self.mock_create_ai + chat = ChatDB.from_dir(Path(self.cache_dir.name), + Path(self.db_dir.name)) # 1. === create three questions === # cached message without an answer message1 = Message(Question(self.args.ask[0]), tags=self.args.output_tags, ai=self.args.AI, model=self.args.model, - file_path=Path(self.cache_dir.name) / '0001.txt') + file_path=Path(self.cache_dir.name) / f'0001{msg_suffix}') # cached message with an answer message2 = Message(Question(self.args.ask[0]), Answer('Old Answer'), tags=self.args.output_tags, ai=self.args.AI, model=self.args.model, - file_path=Path(self.cache_dir.name) / '0002.txt') + file_path=Path(self.cache_dir.name) / f'0002{msg_suffix}') # DB message without an answer message3 = Message(Question(self.args.ask[0]), tags=self.args.output_tags, ai=self.args.AI, model=self.args.model, - file_path=Path(self.db_dir.name) / '0003.txt') - message1.to_file() - message2.to_file() - message3.to_file() + file_path=Path(self.db_dir.name) / f'0003{msg_suffix}') + chat.msg_write([message1, message2, message3]) questions = [message1, message2, message3] expected_responses: list[Message] = [] fake_ai = self.mock_create_ai(self.args, self.config) @@ -583,8 +586,6 @@ class TestQuestionCmdRepeat(TestQuestionCmd): self.assertEqual(len(self.message_list(self.cache_dir)), 4) self.assertEqual(len(self.message_list(self.db_dir)), 1) expected_cache_messages = [expected_responses[0], message2, expected_responses[1], expected_responses[2]] - chat = ChatDB.from_dir(Path(self.cache_dir.name), - Path(self.db_dir.name)) cached_msg = chat.msg_gather(loc='cache') self.assert_msgs_equal_except_file_path(cached_msg, expected_cache_messages) # check that the DB message has not been modified at all -- 2.36.6 From 140dbed8094dbbe5e4c65e46f20879d16a6be311 Mon Sep 17 00:00:00 2001 From: juk0de Date: Wed, 27 Sep 2023 08:14:56 +0200 Subject: [PATCH 08/16] message: added function 'rm_file()' and test --- chatmastermind/message.py | 7 +++++++ tests/test_message.py | 19 +++++++++++++++++++ 2 files changed, 26 insertions(+) diff --git a/chatmastermind/message.py b/chatmastermind/message.py index d88ac5c..ced11bb 100644 --- a/chatmastermind/message.py +++ b/chatmastermind/message.py @@ -517,6 +517,13 @@ class Message(): yaml.dump(data, temp_fd, sort_keys=False) shutil.move(temp_file_path, file_path) + def rm_file(self) -> None: + """ + Delete the message file. Ignore empty file_path and not existing files. + """ + if self.file_path is not None: + self.file_path.unlink(missing_ok=True) + def filter_tags(self, prefix: Optional[str] = None, contain: Optional[str] = None) -> set[Tag]: """ Filter tags based on their prefix (i. e. the tag starts with a given string) diff --git a/tests/test_message.py b/tests/test_message.py index e486ce1..6e39143 100644 --- a/tests/test_message.py +++ b/tests/test_message.py @@ -874,3 +874,22 @@ This is a question. {Answer.txt_header} This is an answer.""" self.assertEqual(self.message.to_str(with_tags=True, with_file=True), expected_output) + + +class MessageRmFileTestCase(unittest.TestCase): + def setUp(self) -> None: + self.file = tempfile.NamedTemporaryFile(delete=False, suffix=msg_suffix) + self.file_path = pathlib.Path(self.file.name) + self.message = Message(Question('This is a question.'), + file_path=self.file_path) + self.message.to_file() + + def tearDown(self) -> None: + self.file.close() + self.file_path.unlink(missing_ok=True) + + def test_rm_file(self) -> None: + assert self.message.file_path + self.assertTrue(self.message.file_path.exists()) + self.message.rm_file() + self.assertFalse(self.message.file_path.exists()) -- 2.36.6 From aecfd1088db825b953ff0b9d02d04a217f8a849a Mon Sep 17 00:00:00 2001 From: juk0de Date: Wed, 27 Sep 2023 08:15:35 +0200 Subject: [PATCH 09/16] chat: added message file format as ChatDB class member --- chatmastermind/chat.py | 24 ++++++++++++++++++------ 1 file changed, 18 insertions(+), 6 deletions(-) diff --git a/chatmastermind/chat.py b/chatmastermind/chat.py index 63a5e7f..ad1cece 100644 --- a/chatmastermind/chat.py +++ b/chatmastermind/chat.py @@ -285,6 +285,8 @@ class ChatDB(Chat): mfilter: Optional[MessageFilter] = None # the glob pattern for all messages glob: Optional[str] = None + # message format (for writing) + mformat: MessageFormat = Message.default_format def __post_init__(self) -> None: # contains the latest message ID @@ -339,9 +341,15 @@ class ChatDB(Chat): with open(self.next_path, 'w') as f: f.write(f'{fid}') + def set_msg_format(self, mformat: MessageFormat) -> None: + """ + Set message format for writing messages. + """ + self.mformat = mformat + def msg_write(self, messages: Optional[list[Message]] = None, - mformat: MessageFormat = Message.default_format) -> None: + mformat: Optional[MessageFormat] = None) -> None: """ Write either the given messages or the internal ones to their CURRENT file_path. If messages are given, they all must have a valid file_path. When writing the @@ -352,7 +360,7 @@ class ChatDB(Chat): raise ChatError("Can't write files without a valid file_path") msgs = iter(messages if messages else self.messages) while (m := next(msgs, None)): - m.to_file(mformat=mformat) + m.to_file(mformat=mformat if mformat else self.mformat) def msg_update(self, messages: list[Message], write: bool = True) -> None: """ @@ -514,7 +522,8 @@ class ChatDB(Chat): """ write_dir(self.cache_path, messages if messages else self.messages, - self.get_next_fid) + self.get_next_fid, + self.mformat) def cache_add(self, messages: list[Message], write: bool = True) -> None: """ @@ -526,7 +535,8 @@ class ChatDB(Chat): if write: write_dir(self.cache_path, messages, - self.get_next_fid) + self.get_next_fid, + self.mformat) else: for m in messages: m.file_path = make_file_path(self.cache_path, self.get_next_fid) @@ -579,7 +589,8 @@ class ChatDB(Chat): """ write_dir(self.db_path, messages if messages else self.messages, - self.get_next_fid) + self.get_next_fid, + self.mformat) def db_add(self, messages: list[Message], write: bool = True) -> None: """ @@ -591,7 +602,8 @@ class ChatDB(Chat): if write: write_dir(self.db_path, messages, - self.get_next_fid) + self.get_next_fid, + self.mformat) else: for m in messages: m.file_path = make_file_path(self.db_path, self.get_next_fid) -- 2.36.6 From efdb3cae2f6605d114dfb3af5c602b0e38cf5219 Mon Sep 17 00:00:00 2001 From: juk0de Date: Thu, 28 Sep 2023 07:19:00 +0200 Subject: [PATCH 10/16] question: moved around some code --- chatmastermind/commands/question.py | 46 ++++++++++++++--------------- 1 file changed, 23 insertions(+), 23 deletions(-) diff --git a/chatmastermind/commands/question.py b/chatmastermind/commands/question.py index 785349b..4e7d8e0 100644 --- a/chatmastermind/commands/question.py +++ b/chatmastermind/commands/question.py @@ -55,6 +55,29 @@ def add_file_as_code(question_parts: list[str], file: str) -> None: question_parts.append(f"```\n{content}\n```") +def create_msg_args(msg: Message, args: argparse.Namespace) -> argparse.Namespace: + """ + Takes an existing message and CLI arguments, and returns modified args based + on the members of the given message. Used e.g. when repeating messages, where + it's necessary to determine the correct AI, module and output tags to use + (either from the existing message or the given args). + """ + msg_args = args + # if AI, model or output tags have not been specified, + # use those from the original message + if (args.AI is None + or args.model is None # noqa: W503 + or args.output_tags is None): # noqa: W503 + msg_args = deepcopy(args) + if args.AI is None and msg.ai is not None: + msg_args.AI = msg.ai + if args.model is None and msg.model is not None: + msg_args.model = msg.model + if args.output_tags is None and msg.tags is not None: + msg_args.output_tags = msg.tags + return msg_args + + def create_message(chat: ChatDB, args: argparse.Namespace) -> Message: """ Create a new message from the given arguments and write it @@ -115,29 +138,6 @@ def make_request(ai: AI, chat: ChatDB, message: Message, args: argparse.Namespac print(response.tokens) -def create_msg_args(msg: Message, args: argparse.Namespace) -> argparse.Namespace: - """ - Takes an existing message and CLI arguments, and returns modified args based - on the members of the given message. Used e.g. when repeating messages, where - it's necessary to determine the correct AI, module and output tags to use - (either from the existing message or the given args). - """ - msg_args = args - # if AI, model or output tags have not been specified, - # use those from the original message - if (args.AI is None - or args.model is None # noqa: W503 - or args.output_tags is None): # noqa: W503 - msg_args = deepcopy(args) - if args.AI is None and msg.ai is not None: - msg_args.AI = msg.ai - if args.model is None and msg.model is not None: - msg_args.model = msg.model - if args.output_tags is None and msg.tags is not None: - msg_args.output_tags = msg.tags - return msg_args - - def repeat_messages(messages: list[Message], chat: ChatDB, args: argparse.Namespace, config: Config) -> None: """ Repeat the given messages using the given arguments. -- 2.36.6 From 2a8f01aee429a9748aa70901ecec656a3953680a Mon Sep 17 00:00:00 2001 From: juk0de Date: Thu, 28 Sep 2023 07:51:56 +0200 Subject: [PATCH 11/16] chat: 'msg_gather()' now supports globbing --- chatmastermind/chat.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/chatmastermind/chat.py b/chatmastermind/chat.py index ad1cece..2640c8b 100644 --- a/chatmastermind/chat.py +++ b/chatmastermind/chat.py @@ -8,7 +8,7 @@ from pydoc import pager from dataclasses import dataclass from typing import TypeVar, Type, Optional, Any, Callable, Literal, Union from .configuration import default_config_file -from .message import Message, MessageFilter, MessageError, MessageFormat, message_in +from .message import Message, MessageFilter, MessageError, MessageFormat, message_in, message_valid_formats from .tags import Tag ChatInst = TypeVar('ChatInst', bound='Chat') @@ -345,6 +345,8 @@ class ChatDB(Chat): """ Set message format for writing messages. """ + if mformat not in message_valid_formats: + raise ChatError(f"Message format '{mformat}' is not supported") self.mformat = mformat def msg_write(self, @@ -381,6 +383,7 @@ class ChatDB(Chat): def msg_gather(self, loc: msg_location, require_file_path: bool = False, + glob: Optional[str] = None, mfilter: Optional[MessageFilter] = None) -> list[Message]: """ Gather and return messages from the given locations: @@ -399,9 +402,9 @@ class ChatDB(Chat): else: loc_messages += [m for m in self.messages if (mfilter is None or m.match(mfilter))] if loc in ['cache', 'disk', 'all']: - loc_messages += read_dir(self.cache_path, mfilter=mfilter) + loc_messages += read_dir(self.cache_path, glob=glob, mfilter=mfilter) if loc in ['db', 'disk', 'all']: - loc_messages += read_dir(self.db_path, mfilter=mfilter) + loc_messages += read_dir(self.db_path, glob=glob, mfilter=mfilter) # remove_duplicates and sort the list unique_messages: list[Message] = [] for m in loc_messages: -- 2.36.6 From 811b2e6830186bb080c66d1904d5cae223a924b4 Mon Sep 17 00:00:00 2001 From: juk0de Date: Fri, 29 Sep 2023 06:59:46 +0200 Subject: [PATCH 12/16] hist_cmd: implemented '--convert' option --- chatmastermind/commands/hist.py | 56 +++++++++++++++++++++++++++++++-- chatmastermind/main.py | 9 ++++-- 2 files changed, 59 insertions(+), 6 deletions(-) diff --git a/chatmastermind/commands/hist.py b/chatmastermind/commands/hist.py index 5b14bd2..104ae2f 100644 --- a/chatmastermind/commands/hist.py +++ b/chatmastermind/commands/hist.py @@ -1,13 +1,53 @@ +import sys import argparse from pathlib import Path from ..configuration import Config from ..chat import ChatDB -from ..message import MessageFilter +from ..message import MessageFilter, Message -def hist_cmd(args: argparse.Namespace, config: Config) -> None: +msg_suffix = Message.file_suffix_write # currently '.msg' + + +def convert_messages(args: argparse.Namespace, config: Config) -> None: """ - Handler for the 'hist' command. + Convert messages to a new format. Also used to change old suffixes + ('.txt', '.yaml') to the latest default message file suffix ('.msg'). + """ + chat = ChatDB.from_dir(Path(config.cache), + Path(config.db)) + # read all known message files + msgs = chat.msg_gather(loc='disk', glob='*.*') + # make a set of all message IDs + msg_ids = set([m.msg_id() for m in msgs]) + # set requested format and write all messages + chat.set_msg_format(args.convert) + # delete the current suffix + # -> a new one will automatically be created + for m in msgs: + if m.file_path: + m.file_path = m.file_path.with_suffix('') + chat.msg_write(msgs) + # read all messages with the current default suffix + msgs = chat.msg_gather(loc='disk', glob='*{msg_suffix}') + # make sure we converted all of the original messages + for mid in msg_ids: + if not any(mid == m.msg_id() for m in msgs): + print(f"Message '{mid}' has not been found after conversion. Aborting.") + sys.exit(1) + # delete messages with old suffixes + msgs = chat.msg_gather(loc='disk', glob='*.*') + for m in msgs: + if m.file_path and m.file_path.suffix != msg_suffix: + m.rm_file() + print(f"Successfully converted {len(msg_ids)} messages.") + if len(msgs): + print(f"Deleted {len(msgs)} messages with deprecated suffixes.") + + +def print_chat(args: argparse.Namespace, config: Config) -> None: + """ + Print the DB chat history. """ mfilter = MessageFilter(tags_or=args.or_tags, @@ -21,3 +61,13 @@ def hist_cmd(args: argparse.Namespace, config: Config) -> None: chat.print(args.source_code_only, args.with_tags, args.with_files) + + +def hist_cmd(args: argparse.Namespace, config: Config) -> None: + """ + Handler for the 'hist' command. + """ + if args.print: + print_chat(args, config) + elif args.convert: + convert_messages(args, config) diff --git a/chatmastermind/main.py b/chatmastermind/main.py index ac4f7cc..a803d0e 100755 --- a/chatmastermind/main.py +++ b/chatmastermind/main.py @@ -73,17 +73,20 @@ def create_parser() -> argparse.ArgumentParser: # 'hist' command parser hist_cmd_parser = cmdparser.add_parser('hist', parents=[tag_parser], - help="Print chat history.", + help="Print and manage chat history.", aliases=['h']) hist_cmd_parser.set_defaults(func=hist_cmd) + hist_group = hist_cmd_parser.add_mutually_exclusive_group(required=True) + hist_group.add_argument('-p', '--print', help='Print the DB chat history', action='store_true') + hist_group.add_argument('-c', '--convert', help='Convert all message files to the given format [txt|yaml]', metavar='FORMAT') hist_cmd_parser.add_argument('-w', '--with-tags', help="Print chat history with tags.", action='store_true') hist_cmd_parser.add_argument('-W', '--with-files', help="Print chat history with filenames.", action='store_true') hist_cmd_parser.add_argument('-S', '--source-code-only', help='Only print embedded source code', action='store_true') - hist_cmd_parser.add_argument('-A', '--answer', help='Search for answer substring', metavar='SUBSTRING') - hist_cmd_parser.add_argument('-Q', '--question', help='Search for question substring', metavar='SUBSTRING') + hist_cmd_parser.add_argument('-A', '--answer', help='Print only answers with given substring', metavar='SUBSTRING') + hist_cmd_parser.add_argument('-Q', '--question', help='Print only questions with given substring', metavar='SUBSTRING') # 'tags' command parser tags_cmd_parser = cmdparser.add_parser('tags', -- 2.36.6 From e19c6bb1ea06be626fcd84ab61b614a316fa2a24 Mon Sep 17 00:00:00 2001 From: juk0de Date: Fri, 29 Sep 2023 18:53:02 +0200 Subject: [PATCH 13/16] hist_cmd: added module 'test_hist_cmd.py' --- chatmastermind/commands/hist.py | 4 +-- tests/test_hist_cmd.py | 62 +++++++++++++++++++++++++++++++++ 2 files changed, 63 insertions(+), 3 deletions(-) create mode 100644 tests/test_hist_cmd.py diff --git a/chatmastermind/commands/hist.py b/chatmastermind/commands/hist.py index 104ae2f..e84e761 100644 --- a/chatmastermind/commands/hist.py +++ b/chatmastermind/commands/hist.py @@ -29,7 +29,7 @@ def convert_messages(args: argparse.Namespace, config: Config) -> None: m.file_path = m.file_path.with_suffix('') chat.msg_write(msgs) # read all messages with the current default suffix - msgs = chat.msg_gather(loc='disk', glob='*{msg_suffix}') + msgs = chat.msg_gather(loc='disk', glob=f'*{msg_suffix}') # make sure we converted all of the original messages for mid in msg_ids: if not any(mid == m.msg_id() for m in msgs): @@ -41,8 +41,6 @@ def convert_messages(args: argparse.Namespace, config: Config) -> None: if m.file_path and m.file_path.suffix != msg_suffix: m.rm_file() print(f"Successfully converted {len(msg_ids)} messages.") - if len(msgs): - print(f"Deleted {len(msgs)} messages with deprecated suffixes.") def print_chat(args: argparse.Namespace, config: Config) -> None: diff --git a/tests/test_hist_cmd.py b/tests/test_hist_cmd.py new file mode 100644 index 0000000..cfd1c1f --- /dev/null +++ b/tests/test_hist_cmd.py @@ -0,0 +1,62 @@ +import unittest +import argparse +import tempfile +import yaml +from pathlib import Path +from chatmastermind.message import Message, Question +from chatmastermind.chat import ChatDB, ChatError +from chatmastermind.configuration import Config +from chatmastermind.commands.hist import convert_messages + + +msg_suffix = Message.file_suffix_write + + +class TestConvertMessages(unittest.TestCase): + def setUp(self) -> None: + self.db_dir = tempfile.TemporaryDirectory() + self.cache_dir = tempfile.TemporaryDirectory() + self.db_path = Path(self.db_dir.name) + self.cache_path = Path(self.cache_dir.name) + self.args = argparse.Namespace() + self.config = Config() + self.config.cache = self.cache_dir.name + self.config.db = self.db_dir.name + # Prepare some messages + self.chat = ChatDB.from_dir(Path(self.cache_path), + Path(self.db_path)) + self.messages = [Message(Question(f'Question {i}')) for i in range(0, 6)] + self.chat.db_write(self.messages[0:2]) + self.chat.cache_write(self.messages[2:]) + # Change some of the suffixes + assert self.messages[0].file_path + assert self.messages[1].file_path + self.messages[0].file_path.rename(self.messages[0].file_path.with_suffix('.txt')) + self.messages[1].file_path.rename(self.messages[1].file_path.with_suffix('.yaml')) + + def tearDown(self) -> None: + self.db_dir.cleanup() + self.cache_dir.cleanup() + + def test_convert_messages(self) -> None: + self.args.convert = 'yaml' + convert_messages(self.args, self.config) + msgs = self.chat.msg_gather(loc='disk', glob='*.*') + # Check if the number of messages is the same as before + self.assertEqual(len(msgs), len(self.messages)) + # Check if all messages have the requested suffix + for msg in msgs: + assert msg.file_path + self.assertEqual(msg.file_path.suffix, msg_suffix) + # Check if the message IDs are correctly maintained + for m_new, m_old in zip(msgs, self.messages): + self.assertEqual(m_new.msg_id(), m_old.msg_id()) + # check if all messages have the new format + for m in msgs: + with open(str(m.file_path), "r") as fd: + yaml.load(fd, Loader=yaml.FullLoader) + + def test_convert_messages_wrong_format(self) -> None: + self.args.convert = 'foo' + with self.assertRaises(ChatError): + convert_messages(self.args, self.config) -- 2.36.6 From e4cb6eb22b659c1f745e5fecc693aee5896739d9 Mon Sep 17 00:00:00 2001 From: juk0de Date: Sun, 1 Oct 2023 09:27:40 +0200 Subject: [PATCH 14/16] README: updated 'hist' command description --- README.md | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/README.md b/README.md index 223ea85..4fb1d08 100644 --- a/README.md +++ b/README.md @@ -68,20 +68,22 @@ cmm question [-t OTAGS]... [-k ATAGS]... [-x XTAGS]... [-o OUTTAGS]... [-A AI_ID #### Hist -The `hist` command is used to print the chat history. +The `hist` command is used to print and manage the chat history. ```bash -cmm hist [-t OTAGS]... [-k ATAGS]... [-x XTAGS]... [-w] [-W] [-S] [-A SUBSTRING] [-Q SUBSTRING] +cmm hist [--print | --convert FORMAT] [-t OTAGS]... [-k ATAGS]... [-x XTAGS]... [-w] [-W] [-S] [-A SUBSTRING] [-Q SUBSTRING] ``` +* `-p, --print`: Print the DB chat history +* `-c, --convert FORMAT`: Convert all messages to the given format * `-t, --or-tags OTAGS`: List of tags (one must match) * `-k, --and-tags ATAGS`: List of tags (all must match) * `-x, --exclude-tags XTAGS`: List of tags to exclude * `-w, --with-tags`: Print chat history with tags * `-W, --with-files`: Print chat history with filenames * `-S, --source-code-only`: Only print embedded source code -* `-A, --answer SUBSTRING`: Search for answer substring -* `-Q, --question SUBSTRING`: Search for question substring +* `-A, --answer SUBSTRING`: Filter for answer substring +* `-Q, --question SUBSTRING`: Filter for question substring #### Tags -- 2.36.6 From 8f563998443268ed9aab521397e1cf91314faa8f Mon Sep 17 00:00:00 2001 From: juk0de Date: Sun, 1 Oct 2023 10:11:16 +0200 Subject: [PATCH 15/16] cmm: replaced options '--with-tags' and '--with-file' with '--with-metadata' --- README.md | 3 +-- chatmastermind/chat.py | 4 ++-- chatmastermind/commands/hist.py | 3 +-- chatmastermind/main.py | 4 +--- chatmastermind/message.py | 7 ++++--- tests/test_chat.py | 12 ++++++++++-- tests/test_message.py | 6 +++++- 7 files changed, 24 insertions(+), 15 deletions(-) diff --git a/README.md b/README.md index 4fb1d08..c661200 100644 --- a/README.md +++ b/README.md @@ -79,8 +79,7 @@ cmm hist [--print | --convert FORMAT] [-t OTAGS]... [-k ATAGS]... [-x XTAGS]... * `-t, --or-tags OTAGS`: List of tags (one must match) * `-k, --and-tags ATAGS`: List of tags (all must match) * `-x, --exclude-tags XTAGS`: List of tags to exclude -* `-w, --with-tags`: Print chat history with tags -* `-W, --with-files`: Print chat history with filenames +* `-w, --with-metadata`: Print chat history with metadata (tags, filenames, AI, etc.) * `-S, --source-code-only`: Only print embedded source code * `-A, --answer SUBSTRING`: Filter for answer substring * `-Q, --question SUBSTRING`: Filter for question substring diff --git a/chatmastermind/chat.py b/chatmastermind/chat.py index 2640c8b..eea78c6 100644 --- a/chatmastermind/chat.py +++ b/chatmastermind/chat.py @@ -255,14 +255,14 @@ class Chat: return sum(m.tokens() for m in self.messages) def print(self, source_code_only: bool = False, - with_tags: bool = False, with_files: bool = False, + with_metadata: bool = False, paged: bool = True) -> None: output: list[str] = [] for message in self.messages: if source_code_only: output.append(message.to_str(source_code_only=True)) continue - output.append(message.to_str(with_tags, with_files)) + output.append(message.to_str(with_metadata)) if paged: print_paged('\n'.join(output)) else: diff --git a/chatmastermind/commands/hist.py b/chatmastermind/commands/hist.py index e84e761..de7dd66 100644 --- a/chatmastermind/commands/hist.py +++ b/chatmastermind/commands/hist.py @@ -57,8 +57,7 @@ def print_chat(args: argparse.Namespace, config: Config) -> None: Path(config.db), mfilter=mfilter) chat.print(args.source_code_only, - args.with_tags, - args.with_files) + args.with_metadata) def hist_cmd(args: argparse.Namespace, config: Config) -> None: diff --git a/chatmastermind/main.py b/chatmastermind/main.py index a803d0e..482806c 100755 --- a/chatmastermind/main.py +++ b/chatmastermind/main.py @@ -79,9 +79,7 @@ def create_parser() -> argparse.ArgumentParser: hist_group = hist_cmd_parser.add_mutually_exclusive_group(required=True) hist_group.add_argument('-p', '--print', help='Print the DB chat history', action='store_true') hist_group.add_argument('-c', '--convert', help='Convert all message files to the given format [txt|yaml]', metavar='FORMAT') - hist_cmd_parser.add_argument('-w', '--with-tags', help="Print chat history with tags.", - action='store_true') - hist_cmd_parser.add_argument('-W', '--with-files', help="Print chat history with filenames.", + hist_cmd_parser.add_argument('-w', '--with-metadata', help="Print chat history with metadata (tags, filename, AI, etc.).", action='store_true') hist_cmd_parser.add_argument('-S', '--source-code-only', help='Only print embedded source code', action='store_true') diff --git a/chatmastermind/message.py b/chatmastermind/message.py index ced11bb..8e7a55d 100644 --- a/chatmastermind/message.py +++ b/chatmastermind/message.py @@ -422,7 +422,7 @@ class Message(): except Exception: raise MessageError(f"'{file_path}' does not contain a valid message") - def to_str(self, with_tags: bool = False, with_file: bool = False, source_code_only: bool = False) -> str: + def to_str(self, with_metadata: bool = False, source_code_only: bool = False) -> str: """ Return the current Message as a string. """ @@ -432,10 +432,11 @@ class Message(): if self.answer: output.extend(self.answer.source_code(include_delims=True)) return '\n'.join(output) if len(output) > 0 else '' - if with_tags: + if with_metadata: output.append(self.tags_str()) - if with_file: output.append('FILE: ' + str(self.file_path)) + output.append('AI: ' + str(self.ai)) + output.append('MODEL: ' + str(self.model)) output.append(Question.txt_header) output.append(self.question) if self.answer: diff --git a/tests/test_chat.py b/tests/test_chat.py index fca61e5..7616dde 100644 --- a/tests/test_chat.py +++ b/tests/test_chat.py @@ -41,10 +41,14 @@ class TestChat(TestChatBase): self.message1 = Message(Question('Question 1'), Answer('Answer 1'), {Tag('atag1'), Tag('btag2')}, + ai='FakeAI', + model='FakeModel', file_path=pathlib.Path(f'0001{msg_suffix}')) self.message2 = Message(Question('Question 2'), Answer('Answer 2'), {Tag('btag2')}, + ai='FakeAI', + model='FakeModel', file_path=pathlib.Path(f'0002{msg_suffix}')) self.maxDiff = None @@ -156,17 +160,21 @@ Answer 2 self.assertEqual(mock_stdout.getvalue(), expected_output) @patch('sys.stdout', new_callable=StringIO) - def test_print_with_tags_and_file(self, mock_stdout: StringIO) -> None: + def test_print_with_metadata(self, mock_stdout: StringIO) -> None: self.chat.msg_add([self.message1, self.message2]) - self.chat.print(paged=False, with_tags=True, with_files=True) + self.chat.print(paged=False, with_metadata=True) expected_output = f"""{TagLine.prefix} atag1 btag2 FILE: 0001{msg_suffix} +AI: FakeAI +MODEL: FakeModel {Question.txt_header} Question 1 {Answer.txt_header} Answer 1 {TagLine.prefix} btag2 FILE: 0002{msg_suffix} +AI: FakeAI +MODEL: FakeModel {Question.txt_header} Question 2 {Answer.txt_header} diff --git a/tests/test_message.py b/tests/test_message.py index 6e39143..b79bcae 100644 --- a/tests/test_message.py +++ b/tests/test_message.py @@ -856,6 +856,8 @@ class MessageToStrTestCase(unittest.TestCase): def setUp(self) -> None: self.message = Message(Question('This is a question.'), Answer('This is an answer.'), + ai=('FakeAI'), + model=('FakeModel'), tags={Tag('atag1'), Tag('btag2')}, file_path=pathlib.Path('/tmp/foo/bla')) @@ -869,11 +871,13 @@ This is an answer.""" def test_to_str_with_tags_and_file(self) -> None: expected_output = f"""{TagLine.prefix} atag1 btag2 FILE: /tmp/foo/bla +AI: FakeAI +MODEL: FakeModel {Question.txt_header} This is a question. {Answer.txt_header} This is an answer.""" - self.assertEqual(self.message.to_str(with_tags=True, with_file=True), expected_output) + self.assertEqual(self.message.to_str(with_metadata=True), expected_output) class MessageRmFileTestCase(unittest.TestCase): -- 2.36.6 From 3ea1f4902793479d58cdb040c43247c6ac3ce8f5 Mon Sep 17 00:00:00 2001 From: juk0de Date: Sun, 1 Oct 2023 10:22:26 +0200 Subject: [PATCH 16/16] cmm: added options '--tight' and '--no-paging' to the 'hist --print' cmd --- chatmastermind/chat.py | 5 ++++- chatmastermind/commands/hist.py | 4 +++- chatmastermind/main.py | 2 ++ tests/test_chat.py | 4 ++-- 4 files changed, 11 insertions(+), 4 deletions(-) diff --git a/chatmastermind/chat.py b/chatmastermind/chat.py index eea78c6..1cc21ba 100644 --- a/chatmastermind/chat.py +++ b/chatmastermind/chat.py @@ -256,13 +256,16 @@ class Chat: def print(self, source_code_only: bool = False, with_metadata: bool = False, - paged: bool = True) -> None: + paged: bool = True, + tight: bool = False) -> None: output: list[str] = [] for message in self.messages: if source_code_only: output.append(message.to_str(source_code_only=True)) continue output.append(message.to_str(with_metadata)) + if not tight: + output.append('\n' + ('-' * terminal_width()) + '\n') if paged: print_paged('\n'.join(output)) else: diff --git a/chatmastermind/commands/hist.py b/chatmastermind/commands/hist.py index de7dd66..8e68320 100644 --- a/chatmastermind/commands/hist.py +++ b/chatmastermind/commands/hist.py @@ -57,7 +57,9 @@ def print_chat(args: argparse.Namespace, config: Config) -> None: Path(config.db), mfilter=mfilter) chat.print(args.source_code_only, - args.with_metadata) + args.with_metadata, + paged=not args.no_paging, + tight=args.tight) def hist_cmd(args: argparse.Namespace, config: Config) -> None: diff --git a/chatmastermind/main.py b/chatmastermind/main.py index 482806c..ff74d6c 100755 --- a/chatmastermind/main.py +++ b/chatmastermind/main.py @@ -85,6 +85,8 @@ def create_parser() -> argparse.ArgumentParser: action='store_true') hist_cmd_parser.add_argument('-A', '--answer', help='Print only answers with given substring', metavar='SUBSTRING') hist_cmd_parser.add_argument('-Q', '--question', help='Print only questions with given substring', metavar='SUBSTRING') + hist_cmd_parser.add_argument('-d', '--tight', help='Print without message separators', action='store_true') + hist_cmd_parser.add_argument('-P', '--no-paging', help='Print without paging', action='store_true') # 'tags' command parser tags_cmd_parser = cmdparser.add_parser('tags', diff --git a/tests/test_chat.py b/tests/test_chat.py index 7616dde..802616d 100644 --- a/tests/test_chat.py +++ b/tests/test_chat.py @@ -147,7 +147,7 @@ class TestChat(TestChatBase): @patch('sys.stdout', new_callable=StringIO) def test_print(self, mock_stdout: StringIO) -> None: self.chat.msg_add([self.message1, self.message2]) - self.chat.print(paged=False) + self.chat.print(paged=False, tight=True) expected_output = f"""{Question.txt_header} Question 1 {Answer.txt_header} @@ -162,7 +162,7 @@ Answer 2 @patch('sys.stdout', new_callable=StringIO) def test_print_with_metadata(self, mock_stdout: StringIO) -> None: self.chat.msg_add([self.message1, self.message2]) - self.chat.print(paged=False, with_metadata=True) + self.chat.print(paged=False, with_metadata=True, tight=True) expected_output = f"""{TagLine.prefix} atag1 btag2 FILE: 0001{msg_suffix} AI: FakeAI -- 2.36.6