22 Commits

Author SHA1 Message Date
Oleksandr Kozachuk dca1bc621f Add Jinja2 as preparation for pipes. 2023-10-17 11:57:52 +02:00
Oleksandr Kozachuk 5f29f60168 Add .old/ to git ignore, I use that dir ofter for old files, I do not want to delete. 2023-10-17 11:53:49 +02:00
juk0de 3ea1f49027 cmm: added options '--tight' and '--no-paging' to the 'hist --print' cmd 2023-10-02 08:35:19 +02:00
juk0de 8f56399844 cmm: replaced options '--with-tags' and '--with-file' with '--with-metadata' 2023-10-01 10:11:16 +02:00
juk0de e4cb6eb22b README: updated 'hist' command description 2023-10-01 09:27:40 +02:00
juk0de e19c6bb1ea hist_cmd: added module 'test_hist_cmd.py' 2023-09-30 08:25:33 +02:00
juk0de 811b2e6830 hist_cmd: implemented '--convert' option 2023-09-29 18:53:12 +02:00
juk0de 2a8f01aee4 chat: 'msg_gather()' now supports globbing 2023-09-29 07:16:20 +02:00
juk0de efdb3cae2f question: moved around some code 2023-09-29 07:16:20 +02:00
juk0de aecfd1088d chat: added message file format as ChatDB class member 2023-09-29 07:16:20 +02:00
juk0de 140dbed809 message: added function 'rm_file()' and test 2023-09-29 07:16:20 +02:00
juk0de 01860ace2c test_question_cmd: modified tests to use '.msg' file suffix 2023-09-29 07:16:20 +02:00
juk0de df42bcee09 test_chat: added test for file_path collision detection 2023-09-29 07:16:20 +02:00
juk0de e34eab6519 test_chat: changed all tests to use the new '.msg' suffix 2023-09-29 07:16:20 +02:00
juk0de d07fd13e8e test_message: changed all tests to use the new '.msg' suffix 2023-09-29 07:16:20 +02:00
juk0de b8681e8274 message: fixed tag matching for YAML file format 2023-09-29 07:16:20 +02:00
juk0de d2be53aeab chat: switched to new message suffix and formats
- no longer using file suffix to choose the format
- added 'mformat' argument to 'write_xxx()' functions
- file suffix is now set by 'Message.to_file()' per default
2023-09-29 07:16:20 +02:00
juk0de 9ca9a23569 message: introduced file suffix '.msg'
- '.msg' suffix is always used for writing
- 'Message.to_file()' will set the file suffix if the given file_path has none
- added 'mformat' argument to 'Message.to_file()' for choosing the file format
- '.txt' and '.yaml' suffixes are only supported for reading
2023-09-29 07:16:20 +02:00
juk0de 6f3758e12e question_cmd: fixed '--create' option 2023-09-29 07:15:46 +02:00
juk0de dd836cd72d Merge pull request 'cmm question --repeat supports multiple questions, added tests and fixes' (#15) from repeat_multi into main
This PR primarily modifies the `cmm question --repeat` command to allow repeating multiple questions, instead of only the last one.

Additionally, this PR includes the following changes:

- In `ai_factory.py`, added optional parameters 'def_ai' and 'def_model' to the `create_ai` function which allows specifying a default AI and model.
- In `openai.py`, a potential bug was fixed where the 'tags' attribute was updated to ensure it is always a set, even when 'otags' is None.
- In `question.py`, a significant amount of new code was added to facilitate the 'repeat' functionality. This includes functions to create modified args based on an existing message (`create_msg_args`), to repeat a given list of messages (`repeat_messages`), and to invert the semantics of the INPUT tags for this command (`invert_input_tag_args`).
- In `main.py`, the 'nargs' parameter was changed from `+` to `*` in the 'or-tags', 'and-tags', and 'exclude-tags' arguments to accommodate the updated handling of tags in `question.py`.
- A new `test_common.py` file was added which includes a `FakeAI` class for testing purposes, and a `TestWithFakeAI` class which includes a number of methods for asserting various conditions about messages.

This PR also includes additional tests to verify the correct operation of the new 'repeat' functionality.
2023-09-26 18:04:27 +02:00
juk0de 4538624247 Merge pull request 'Implemented the 'question --repeat' command and other improvements' (#14) from repeat into main
Reviewed-on: #14
2023-09-21 07:25:47 +02:00
juk0de f964c5471e Merge pull request 'Refactoring, fixes and new features for the 'chat.py' module' (#12) from chat_refactoring into main
Reviewed-on: #12
2023-09-18 14:23:51 +02:00
13 changed files with 263 additions and 66 deletions
+1
View File
@@ -106,6 +106,7 @@ celerybeat.pid
.venv .venv
env/ env/
venv/ venv/
.old/
ENV/ ENV/
env.bak/ env.bak/
venv.bak/ venv.bak/
+7 -6
View File
@@ -68,20 +68,21 @@ cmm question [-t OTAGS]... [-k ATAGS]... [-x XTAGS]... [-o OUTTAGS]... [-A AI_ID
#### Hist #### Hist
The `hist` command is used to print the chat history. The `hist` command is used to print and manage the chat history.
```bash ```bash
cmm hist [-t OTAGS]... [-k ATAGS]... [-x XTAGS]... [-w] [-W] [-S] [-A SUBSTRING] [-Q SUBSTRING] cmm hist [--print | --convert FORMAT] [-t OTAGS]... [-k ATAGS]... [-x XTAGS]... [-w] [-W] [-S] [-A SUBSTRING] [-Q SUBSTRING]
``` ```
* `-p, --print`: Print the DB chat history
* `-c, --convert FORMAT`: Convert all messages to the given format
* `-t, --or-tags OTAGS`: List of tags (one must match) * `-t, --or-tags OTAGS`: List of tags (one must match)
* `-k, --and-tags ATAGS`: List of tags (all must match) * `-k, --and-tags ATAGS`: List of tags (all must match)
* `-x, --exclude-tags XTAGS`: List of tags to exclude * `-x, --exclude-tags XTAGS`: List of tags to exclude
* `-w, --with-tags`: Print chat history with tags * `-w, --with-metadata`: Print chat history with metadata (tags, filenames, AI, etc.)
* `-W, --with-files`: Print chat history with filenames
* `-S, --source-code-only`: Only print embedded source code * `-S, --source-code-only`: Only print embedded source code
* `-A, --answer SUBSTRING`: Search for answer substring * `-A, --answer SUBSTRING`: Filter for answer substring
* `-Q, --question SUBSTRING`: Search for question substring * `-Q, --question SUBSTRING`: Filter for question substring
#### Tags #### Tags
+30 -12
View File
@@ -8,7 +8,7 @@ from pydoc import pager
from dataclasses import dataclass from dataclasses import dataclass
from typing import TypeVar, Type, Optional, Any, Callable, Literal, Union from typing import TypeVar, Type, Optional, Any, Callable, Literal, Union
from .configuration import default_config_file from .configuration import default_config_file
from .message import Message, MessageFilter, MessageError, MessageFormat, message_in from .message import Message, MessageFilter, MessageError, MessageFormat, message_in, message_valid_formats
from .tags import Tag from .tags import Tag
ChatInst = TypeVar('ChatInst', bound='Chat') ChatInst = TypeVar('ChatInst', bound='Chat')
@@ -255,14 +255,17 @@ class Chat:
return sum(m.tokens() for m in self.messages) return sum(m.tokens() for m in self.messages)
def print(self, source_code_only: bool = False, def print(self, source_code_only: bool = False,
with_tags: bool = False, with_files: bool = False, with_metadata: bool = False,
paged: bool = True) -> None: paged: bool = True,
tight: bool = False) -> None:
output: list[str] = [] output: list[str] = []
for message in self.messages: for message in self.messages:
if source_code_only: if source_code_only:
output.append(message.to_str(source_code_only=True)) output.append(message.to_str(source_code_only=True))
continue continue
output.append(message.to_str(with_tags, with_files)) output.append(message.to_str(with_metadata))
if not tight:
output.append('\n' + ('-' * terminal_width()) + '\n')
if paged: if paged:
print_paged('\n'.join(output)) print_paged('\n'.join(output))
else: else:
@@ -285,6 +288,8 @@ class ChatDB(Chat):
mfilter: Optional[MessageFilter] = None mfilter: Optional[MessageFilter] = None
# the glob pattern for all messages # the glob pattern for all messages
glob: Optional[str] = None glob: Optional[str] = None
# message format (for writing)
mformat: MessageFormat = Message.default_format
def __post_init__(self) -> None: def __post_init__(self) -> None:
# contains the latest message ID # contains the latest message ID
@@ -339,9 +344,17 @@ class ChatDB(Chat):
with open(self.next_path, 'w') as f: with open(self.next_path, 'w') as f:
f.write(f'{fid}') f.write(f'{fid}')
def set_msg_format(self, mformat: MessageFormat) -> None:
"""
Set message format for writing messages.
"""
if mformat not in message_valid_formats:
raise ChatError(f"Message format '{mformat}' is not supported")
self.mformat = mformat
def msg_write(self, def msg_write(self,
messages: Optional[list[Message]] = None, messages: Optional[list[Message]] = None,
mformat: MessageFormat = Message.default_format) -> None: mformat: Optional[MessageFormat] = None) -> None:
""" """
Write either the given messages or the internal ones to their CURRENT file_path. Write either the given messages or the internal ones to their CURRENT file_path.
If messages are given, they all must have a valid file_path. When writing the If messages are given, they all must have a valid file_path. When writing the
@@ -352,7 +365,7 @@ class ChatDB(Chat):
raise ChatError("Can't write files without a valid file_path") raise ChatError("Can't write files without a valid file_path")
msgs = iter(messages if messages else self.messages) msgs = iter(messages if messages else self.messages)
while (m := next(msgs, None)): while (m := next(msgs, None)):
m.to_file(mformat=mformat) m.to_file(mformat=mformat if mformat else self.mformat)
def msg_update(self, messages: list[Message], write: bool = True) -> None: def msg_update(self, messages: list[Message], write: bool = True) -> None:
""" """
@@ -373,6 +386,7 @@ class ChatDB(Chat):
def msg_gather(self, def msg_gather(self,
loc: msg_location, loc: msg_location,
require_file_path: bool = False, require_file_path: bool = False,
glob: Optional[str] = None,
mfilter: Optional[MessageFilter] = None) -> list[Message]: mfilter: Optional[MessageFilter] = None) -> list[Message]:
""" """
Gather and return messages from the given locations: Gather and return messages from the given locations:
@@ -391,9 +405,9 @@ class ChatDB(Chat):
else: else:
loc_messages += [m for m in self.messages if (mfilter is None or m.match(mfilter))] loc_messages += [m for m in self.messages if (mfilter is None or m.match(mfilter))]
if loc in ['cache', 'disk', 'all']: if loc in ['cache', 'disk', 'all']:
loc_messages += read_dir(self.cache_path, mfilter=mfilter) loc_messages += read_dir(self.cache_path, glob=glob, mfilter=mfilter)
if loc in ['db', 'disk', 'all']: if loc in ['db', 'disk', 'all']:
loc_messages += read_dir(self.db_path, mfilter=mfilter) loc_messages += read_dir(self.db_path, glob=glob, mfilter=mfilter)
# remove_duplicates and sort the list # remove_duplicates and sort the list
unique_messages: list[Message] = [] unique_messages: list[Message] = []
for m in loc_messages: for m in loc_messages:
@@ -514,7 +528,8 @@ class ChatDB(Chat):
""" """
write_dir(self.cache_path, write_dir(self.cache_path,
messages if messages else self.messages, messages if messages else self.messages,
self.get_next_fid) self.get_next_fid,
self.mformat)
def cache_add(self, messages: list[Message], write: bool = True) -> None: def cache_add(self, messages: list[Message], write: bool = True) -> None:
""" """
@@ -526,7 +541,8 @@ class ChatDB(Chat):
if write: if write:
write_dir(self.cache_path, write_dir(self.cache_path,
messages, messages,
self.get_next_fid) self.get_next_fid,
self.mformat)
else: else:
for m in messages: for m in messages:
m.file_path = make_file_path(self.cache_path, self.get_next_fid) m.file_path = make_file_path(self.cache_path, self.get_next_fid)
@@ -579,7 +595,8 @@ class ChatDB(Chat):
""" """
write_dir(self.db_path, write_dir(self.db_path,
messages if messages else self.messages, messages if messages else self.messages,
self.get_next_fid) self.get_next_fid,
self.mformat)
def db_add(self, messages: list[Message], write: bool = True) -> None: def db_add(self, messages: list[Message], write: bool = True) -> None:
""" """
@@ -591,7 +608,8 @@ class ChatDB(Chat):
if write: if write:
write_dir(self.db_path, write_dir(self.db_path,
messages, messages,
self.get_next_fid) self.get_next_fid,
self.mformat)
else: else:
for m in messages: for m in messages:
m.file_path = make_file_path(self.db_path, self.get_next_fid) m.file_path = make_file_path(self.db_path, self.get_next_fid)
+54 -5
View File
@@ -1,13 +1,51 @@
import sys
import argparse import argparse
from pathlib import Path from pathlib import Path
from ..configuration import Config from ..configuration import Config
from ..chat import ChatDB from ..chat import ChatDB
from ..message import MessageFilter from ..message import MessageFilter, Message
def hist_cmd(args: argparse.Namespace, config: Config) -> None: msg_suffix = Message.file_suffix_write # currently '.msg'
def convert_messages(args: argparse.Namespace, config: Config) -> None:
""" """
Handler for the 'hist' command. Convert messages to a new format. Also used to change old suffixes
('.txt', '.yaml') to the latest default message file suffix ('.msg').
"""
chat = ChatDB.from_dir(Path(config.cache),
Path(config.db))
# read all known message files
msgs = chat.msg_gather(loc='disk', glob='*.*')
# make a set of all message IDs
msg_ids = set([m.msg_id() for m in msgs])
# set requested format and write all messages
chat.set_msg_format(args.convert)
# delete the current suffix
# -> a new one will automatically be created
for m in msgs:
if m.file_path:
m.file_path = m.file_path.with_suffix('')
chat.msg_write(msgs)
# read all messages with the current default suffix
msgs = chat.msg_gather(loc='disk', glob=f'*{msg_suffix}')
# make sure we converted all of the original messages
for mid in msg_ids:
if not any(mid == m.msg_id() for m in msgs):
print(f"Message '{mid}' has not been found after conversion. Aborting.")
sys.exit(1)
# delete messages with old suffixes
msgs = chat.msg_gather(loc='disk', glob='*.*')
for m in msgs:
if m.file_path and m.file_path.suffix != msg_suffix:
m.rm_file()
print(f"Successfully converted {len(msg_ids)} messages.")
def print_chat(args: argparse.Namespace, config: Config) -> None:
"""
Print the DB chat history.
""" """
mfilter = MessageFilter(tags_or=args.or_tags, mfilter = MessageFilter(tags_or=args.or_tags,
@@ -19,5 +57,16 @@ def hist_cmd(args: argparse.Namespace, config: Config) -> None:
Path(config.db), Path(config.db),
mfilter=mfilter) mfilter=mfilter)
chat.print(args.source_code_only, chat.print(args.source_code_only,
args.with_tags, args.with_metadata,
args.with_files) paged=not args.no_paging,
tight=args.tight)
def hist_cmd(args: argparse.Namespace, config: Config) -> None:
"""
Handler for the 'hist' command.
"""
if args.print:
print_chat(args, config)
elif args.convert:
convert_messages(args, config)
+33 -24
View File
@@ -10,6 +10,10 @@ from ..ai_factory import create_ai
from ..ai import AI, AIResponse from ..ai import AI, AIResponse
class QuestionCmdError(Exception):
pass
def add_file_as_text(question_parts: list[str], file: str) -> None: def add_file_as_text(question_parts: list[str], file: str) -> None:
""" """
Add the given file as plain text to the question part list. Add the given file as plain text to the question part list.
@@ -51,13 +55,41 @@ def add_file_as_code(question_parts: list[str], file: str) -> None:
question_parts.append(f"```\n{content}\n```") question_parts.append(f"```\n{content}\n```")
def create_msg_args(msg: Message, args: argparse.Namespace) -> argparse.Namespace:
"""
Takes an existing message and CLI arguments, and returns modified args based
on the members of the given message. Used e.g. when repeating messages, where
it's necessary to determine the correct AI, module and output tags to use
(either from the existing message or the given args).
"""
msg_args = args
# if AI, model or output tags have not been specified,
# use those from the original message
if (args.AI is None
or args.model is None # noqa: W503
or args.output_tags is None): # noqa: W503
msg_args = deepcopy(args)
if args.AI is None and msg.ai is not None:
msg_args.AI = msg.ai
if args.model is None and msg.model is not None:
msg_args.model = msg.model
if args.output_tags is None and msg.tags is not None:
msg_args.output_tags = msg.tags
return msg_args
def create_message(chat: ChatDB, args: argparse.Namespace) -> Message: def create_message(chat: ChatDB, args: argparse.Namespace) -> Message:
""" """
Create a new message from the given arguments and write it Create a new message from the given arguments and write it
to the cache directory. to the cache directory.
""" """
question_parts = [] question_parts = []
question_list = args.ask if args.ask is not None else [] if args.create is not None:
question_list = args.create
elif args.ask is not None:
question_list = args.ask
else:
raise QuestionCmdError("No question found")
text_files = args.source_text if args.source_text is not None else [] text_files = args.source_text if args.source_text is not None else []
code_files = args.source_code if args.source_code is not None else [] code_files = args.source_code if args.source_code is not None else []
@@ -106,29 +138,6 @@ def make_request(ai: AI, chat: ChatDB, message: Message, args: argparse.Namespac
print(response.tokens) print(response.tokens)
def create_msg_args(msg: Message, args: argparse.Namespace) -> argparse.Namespace:
"""
Takes an existing message and CLI arguments, and returns modified args based
on the members of the given message. Used e.g. when repeating messages, where
it's necessary to determine the correct AI, module and output tags to use
(either from the existing message or the given args).
"""
msg_args = args
# if AI, model or output tags have not been specified,
# use those from the original message
if (args.AI is None
or args.model is None # noqa: W503
or args.output_tags is None): # noqa: W503
msg_args = deepcopy(args)
if args.AI is None and msg.ai is not None:
msg_args.AI = msg.ai
if args.model is None and msg.model is not None:
msg_args.model = msg.model
if args.output_tags is None and msg.tags is not None:
msg_args.output_tags = msg.tags
return msg_args
def repeat_messages(messages: list[Message], chat: ChatDB, args: argparse.Namespace, config: Config) -> None: def repeat_messages(messages: list[Message], chat: ChatDB, args: argparse.Namespace, config: Config) -> None:
""" """
Repeat the given messages using the given arguments. Repeat the given messages using the given arguments.
+9 -6
View File
@@ -73,17 +73,20 @@ def create_parser() -> argparse.ArgumentParser:
# 'hist' command parser # 'hist' command parser
hist_cmd_parser = cmdparser.add_parser('hist', parents=[tag_parser], hist_cmd_parser = cmdparser.add_parser('hist', parents=[tag_parser],
help="Print chat history.", help="Print and manage chat history.",
aliases=['h']) aliases=['h'])
hist_cmd_parser.set_defaults(func=hist_cmd) hist_cmd_parser.set_defaults(func=hist_cmd)
hist_cmd_parser.add_argument('-w', '--with-tags', help="Print chat history with tags.", hist_group = hist_cmd_parser.add_mutually_exclusive_group(required=True)
action='store_true') hist_group.add_argument('-p', '--print', help='Print the DB chat history', action='store_true')
hist_cmd_parser.add_argument('-W', '--with-files', help="Print chat history with filenames.", hist_group.add_argument('-c', '--convert', help='Convert all message files to the given format [txt|yaml]', metavar='FORMAT')
hist_cmd_parser.add_argument('-w', '--with-metadata', help="Print chat history with metadata (tags, filename, AI, etc.).",
action='store_true') action='store_true')
hist_cmd_parser.add_argument('-S', '--source-code-only', help='Only print embedded source code', hist_cmd_parser.add_argument('-S', '--source-code-only', help='Only print embedded source code',
action='store_true') action='store_true')
hist_cmd_parser.add_argument('-A', '--answer', help='Search for answer substring', metavar='SUBSTRING') hist_cmd_parser.add_argument('-A', '--answer', help='Print only answers with given substring', metavar='SUBSTRING')
hist_cmd_parser.add_argument('-Q', '--question', help='Search for question substring', metavar='SUBSTRING') hist_cmd_parser.add_argument('-Q', '--question', help='Print only questions with given substring', metavar='SUBSTRING')
hist_cmd_parser.add_argument('-d', '--tight', help='Print without message separators', action='store_true')
hist_cmd_parser.add_argument('-P', '--no-paging', help='Print without paging', action='store_true')
# 'tags' command parser # 'tags' command parser
tags_cmd_parser = cmdparser.add_parser('tags', tags_cmd_parser = cmdparser.add_parser('tags',
+11 -3
View File
@@ -422,7 +422,7 @@ class Message():
except Exception: except Exception:
raise MessageError(f"'{file_path}' does not contain a valid message") raise MessageError(f"'{file_path}' does not contain a valid message")
def to_str(self, with_tags: bool = False, with_file: bool = False, source_code_only: bool = False) -> str: def to_str(self, with_metadata: bool = False, source_code_only: bool = False) -> str:
""" """
Return the current Message as a string. Return the current Message as a string.
""" """
@@ -432,10 +432,11 @@ class Message():
if self.answer: if self.answer:
output.extend(self.answer.source_code(include_delims=True)) output.extend(self.answer.source_code(include_delims=True))
return '\n'.join(output) if len(output) > 0 else '' return '\n'.join(output) if len(output) > 0 else ''
if with_tags: if with_metadata:
output.append(self.tags_str()) output.append(self.tags_str())
if with_file:
output.append('FILE: ' + str(self.file_path)) output.append('FILE: ' + str(self.file_path))
output.append('AI: ' + str(self.ai))
output.append('MODEL: ' + str(self.model))
output.append(Question.txt_header) output.append(Question.txt_header)
output.append(self.question) output.append(self.question)
if self.answer: if self.answer:
@@ -517,6 +518,13 @@ class Message():
yaml.dump(data, temp_fd, sort_keys=False) yaml.dump(data, temp_fd, sort_keys=False)
shutil.move(temp_file_path, file_path) shutil.move(temp_file_path, file_path)
def rm_file(self) -> None:
"""
Delete the message file. Ignore empty file_path and not existing files.
"""
if self.file_path is not None:
self.file_path.unlink(missing_ok=True)
def filter_tags(self, prefix: Optional[str] = None, contain: Optional[str] = None) -> set[Tag]: def filter_tags(self, prefix: Optional[str] = None, contain: Optional[str] = None) -> set[Tag]:
""" """
Filter tags based on their prefix (i. e. the tag starts with a given string) Filter tags based on their prefix (i. e. the tag starts with a given string)
+1
View File
@@ -2,3 +2,4 @@ openai
PyYAML PyYAML
argcomplete argcomplete
pytest pytest
Jinja2
+3 -6
View File
@@ -2,6 +2,8 @@ from setuptools import setup, find_packages
with open("README.md", "r", encoding="utf-8") as fh: with open("README.md", "r", encoding="utf-8") as fh:
long_description = fh.read() long_description = fh.read()
with open("requirements.txt", "r") as fh:
install_requirements = [line.strip() for line in fh]
setup( setup(
name="ChatMastermind", name="ChatMastermind",
@@ -28,12 +30,7 @@ setup(
"Topic :: Utilities", "Topic :: Utilities",
"Topic :: Text Processing", "Topic :: Text Processing",
], ],
install_requires=[ install_requires=install_requirements,
"openai",
"PyYAML",
"argcomplete",
"pytest",
],
python_requires=">=3.9", python_requires=">=3.9",
test_suite="tests", test_suite="tests",
entry_points={ entry_points={
+11 -3
View File
@@ -41,10 +41,14 @@ class TestChat(TestChatBase):
self.message1 = Message(Question('Question 1'), self.message1 = Message(Question('Question 1'),
Answer('Answer 1'), Answer('Answer 1'),
{Tag('atag1'), Tag('btag2')}, {Tag('atag1'), Tag('btag2')},
ai='FakeAI',
model='FakeModel',
file_path=pathlib.Path(f'0001{msg_suffix}')) file_path=pathlib.Path(f'0001{msg_suffix}'))
self.message2 = Message(Question('Question 2'), self.message2 = Message(Question('Question 2'),
Answer('Answer 2'), Answer('Answer 2'),
{Tag('btag2')}, {Tag('btag2')},
ai='FakeAI',
model='FakeModel',
file_path=pathlib.Path(f'0002{msg_suffix}')) file_path=pathlib.Path(f'0002{msg_suffix}'))
self.maxDiff = None self.maxDiff = None
@@ -143,7 +147,7 @@ class TestChat(TestChatBase):
@patch('sys.stdout', new_callable=StringIO) @patch('sys.stdout', new_callable=StringIO)
def test_print(self, mock_stdout: StringIO) -> None: def test_print(self, mock_stdout: StringIO) -> None:
self.chat.msg_add([self.message1, self.message2]) self.chat.msg_add([self.message1, self.message2])
self.chat.print(paged=False) self.chat.print(paged=False, tight=True)
expected_output = f"""{Question.txt_header} expected_output = f"""{Question.txt_header}
Question 1 Question 1
{Answer.txt_header} {Answer.txt_header}
@@ -156,17 +160,21 @@ Answer 2
self.assertEqual(mock_stdout.getvalue(), expected_output) self.assertEqual(mock_stdout.getvalue(), expected_output)
@patch('sys.stdout', new_callable=StringIO) @patch('sys.stdout', new_callable=StringIO)
def test_print_with_tags_and_file(self, mock_stdout: StringIO) -> None: def test_print_with_metadata(self, mock_stdout: StringIO) -> None:
self.chat.msg_add([self.message1, self.message2]) self.chat.msg_add([self.message1, self.message2])
self.chat.print(paged=False, with_tags=True, with_files=True) self.chat.print(paged=False, with_metadata=True, tight=True)
expected_output = f"""{TagLine.prefix} atag1 btag2 expected_output = f"""{TagLine.prefix} atag1 btag2
FILE: 0001{msg_suffix} FILE: 0001{msg_suffix}
AI: FakeAI
MODEL: FakeModel
{Question.txt_header} {Question.txt_header}
Question 1 Question 1
{Answer.txt_header} {Answer.txt_header}
Answer 1 Answer 1
{TagLine.prefix} btag2 {TagLine.prefix} btag2
FILE: 0002{msg_suffix} FILE: 0002{msg_suffix}
AI: FakeAI
MODEL: FakeModel
{Question.txt_header} {Question.txt_header}
Question 2 Question 2
{Answer.txt_header} {Answer.txt_header}
+62
View File
@@ -0,0 +1,62 @@
import unittest
import argparse
import tempfile
import yaml
from pathlib import Path
from chatmastermind.message import Message, Question
from chatmastermind.chat import ChatDB, ChatError
from chatmastermind.configuration import Config
from chatmastermind.commands.hist import convert_messages
msg_suffix = Message.file_suffix_write
class TestConvertMessages(unittest.TestCase):
def setUp(self) -> None:
self.db_dir = tempfile.TemporaryDirectory()
self.cache_dir = tempfile.TemporaryDirectory()
self.db_path = Path(self.db_dir.name)
self.cache_path = Path(self.cache_dir.name)
self.args = argparse.Namespace()
self.config = Config()
self.config.cache = self.cache_dir.name
self.config.db = self.db_dir.name
# Prepare some messages
self.chat = ChatDB.from_dir(Path(self.cache_path),
Path(self.db_path))
self.messages = [Message(Question(f'Question {i}')) for i in range(0, 6)]
self.chat.db_write(self.messages[0:2])
self.chat.cache_write(self.messages[2:])
# Change some of the suffixes
assert self.messages[0].file_path
assert self.messages[1].file_path
self.messages[0].file_path.rename(self.messages[0].file_path.with_suffix('.txt'))
self.messages[1].file_path.rename(self.messages[1].file_path.with_suffix('.yaml'))
def tearDown(self) -> None:
self.db_dir.cleanup()
self.cache_dir.cleanup()
def test_convert_messages(self) -> None:
self.args.convert = 'yaml'
convert_messages(self.args, self.config)
msgs = self.chat.msg_gather(loc='disk', glob='*.*')
# Check if the number of messages is the same as before
self.assertEqual(len(msgs), len(self.messages))
# Check if all messages have the requested suffix
for msg in msgs:
assert msg.file_path
self.assertEqual(msg.file_path.suffix, msg_suffix)
# Check if the message IDs are correctly maintained
for m_new, m_old in zip(msgs, self.messages):
self.assertEqual(m_new.msg_id(), m_old.msg_id())
# check if all messages have the new format
for m in msgs:
with open(str(m.file_path), "r") as fd:
yaml.load(fd, Loader=yaml.FullLoader)
def test_convert_messages_wrong_format(self) -> None:
self.args.convert = 'foo'
with self.assertRaises(ChatError):
convert_messages(self.args, self.config)
+24 -1
View File
@@ -856,6 +856,8 @@ class MessageToStrTestCase(unittest.TestCase):
def setUp(self) -> None: def setUp(self) -> None:
self.message = Message(Question('This is a question.'), self.message = Message(Question('This is a question.'),
Answer('This is an answer.'), Answer('This is an answer.'),
ai=('FakeAI'),
model=('FakeModel'),
tags={Tag('atag1'), Tag('btag2')}, tags={Tag('atag1'), Tag('btag2')},
file_path=pathlib.Path('/tmp/foo/bla')) file_path=pathlib.Path('/tmp/foo/bla'))
@@ -869,8 +871,29 @@ This is an answer."""
def test_to_str_with_tags_and_file(self) -> None: def test_to_str_with_tags_and_file(self) -> None:
expected_output = f"""{TagLine.prefix} atag1 btag2 expected_output = f"""{TagLine.prefix} atag1 btag2
FILE: /tmp/foo/bla FILE: /tmp/foo/bla
AI: FakeAI
MODEL: FakeModel
{Question.txt_header} {Question.txt_header}
This is a question. This is a question.
{Answer.txt_header} {Answer.txt_header}
This is an answer.""" This is an answer."""
self.assertEqual(self.message.to_str(with_tags=True, with_file=True), expected_output) self.assertEqual(self.message.to_str(with_metadata=True), expected_output)
class MessageRmFileTestCase(unittest.TestCase):
def setUp(self) -> None:
self.file = tempfile.NamedTemporaryFile(delete=False, suffix=msg_suffix)
self.file_path = pathlib.Path(self.file.name)
self.message = Message(Question('This is a question.'),
file_path=self.file_path)
self.message.to_file()
def tearDown(self) -> None:
self.file.close()
self.file_path.unlink(missing_ok=True)
def test_rm_file(self) -> None:
assert self.message.file_path
self.assertTrue(self.message.file_path.exists())
self.message.rm_file()
self.assertFalse(self.message.file_path.exists())
+17
View File
@@ -41,6 +41,8 @@ class TestMessageCreate(TestWithFakeAI):
self.args.AI = None self.args.AI = None
self.args.model = None self.args.model = None
self.args.output_tags = None self.args.output_tags = None
self.args.ask = None
self.args.create = None
# File 1 : no source code block, only text # File 1 : no source code block, only text
self.source_file1 = tempfile.NamedTemporaryFile(delete=False) self.source_file1 = tempfile.NamedTemporaryFile(delete=False)
self.source_file1_content = """This is just text. self.source_file1_content = """This is just text.
@@ -204,6 +206,21 @@ It is embedded code
""")) """))
class TestCreateOption(TestMessageCreate):
def test_message_file_created(self) -> None:
self.args.create = ["How does question --create work?"]
self.args.ask = None
cache_dir_files = self.message_list(self.cache_dir)
self.assertEqual(len(cache_dir_files), 0)
create_message(self.chat, self.args)
cache_dir_files = self.message_list(self.cache_dir)
self.assertEqual(len(cache_dir_files), 1)
message = Message.from_file(cache_dir_files[0])
self.assertIsInstance(message, Message)
self.assertEqual(message.question, Question("How does question --create work?")) # type: ignore [union-attr]
class TestQuestionCmd(TestWithFakeAI): class TestQuestionCmd(TestWithFakeAI):
def setUp(self) -> None: def setUp(self) -> None: