3 Commits

7 changed files with 162 additions and 400 deletions
+43 -51
View File
@@ -46,32 +46,32 @@ cmm [global options] command [command options]
The `question` command is used to ask, create, and process questions. The `question` command is used to ask, create, and process questions.
```bash ```bash
cmm question [-t OTAGS]... [-k ATAGS]... [-x XTAGS]... [-o OUTTAGS]... [-A AI_ID] [-M MODEL] [-n NUM] [-m MAX] [-T TEMP] (-a QUESTION | -c QUESTION | -r [MESSAGE ...] | -p [MESSAGE ...]) [-O] [-s FILE]... [-S FILE]... cmm question [-t OTAGS]... [-k ATAGS]... [-x XTAGS]... [-o OUTTAGS]... [-A AI] [-M MODEL] [-n NUM] [-m MAX] [-T TEMP] (-a ASK | -c CREATE | -r REPEAT | -p PROCESS) [-O] [-s SOURCE]... [-S SOURCE]...
``` ```
* `-t, --or-tags OTAGS` : List of tags (one must match) * `-t, --or-tags OTAGS` : List of tags (one must match)
* `-k, --and-tags ATAGS` : List of tags (all must match) * `-k, --and-tags ATAGS` : List of tags (all must match)
* `-x, --exclude-tags XTAGS` : List of tags to exclude * `-x, --exclude-tags XTAGS` : List of tags to exclude
* `-o, --output-tags OUTTAGS` : List of output tags (default: use input tags) * `-o, --output-tags OUTTAGS` : List of output tags (default: use input tags)
* `-A, --AI AI_ID`: AI ID to use * `-A, --AI AI` : AI ID to use
* `-M, --model MODEL` : Model to use * `-M, --model MODEL` : Model to use
* `-n, --num-answers NUM` : Number of answers to request * `-n, --num-answers NUM` : Number of answers to request
* `-m, --max-tokens MAX` : Max. number of tokens * `-m, --max-tokens MAX` : Max. number of tokens
* `-T, --temperature TEMP` : Temperature value * `-T, --temperature TEMP` : Temperature value
* `-a, --ask QUESTION`: Ask a question * `-a, --ask ASK` : Ask a question
* `-c, --create QUESTION`: Create a question * `-c, --create CREATE` : Create a question
* `-r, --repeat [MESSAGE ...]`: Repeat a question * `-r, --repeat REPEAT` : Repeat a question
* `-p, --process [MESSAGE ...]`: Process existing questions * `-p, --process PROCESS` : Process existing questions
* `-O, --overwrite` : Overwrite existing messages when repeating them * `-O, --overwrite` : Overwrite existing messages when repeating them
* `-s, --source-text FILE`: Add content of a file to the query * `-s, --source-text SOURCE` : Add content of a file to the query
* `-S, --source-code FILE`: Add source code file content to the chat history * `-S, --source-code SOURCE` : Add source code file content to the chat history
#### Hist #### Hist
The `hist` command is used to print the chat history. The `hist` command is used to print the chat history.
```bash ```bash
cmm hist [-t OTAGS]... [-k ATAGS]... [-x XTAGS]... [-w] [-W] [-S] [-A SUBSTRING] [-Q SUBSTRING] cmm hist [-t OTAGS]... [-k ATAGS]... [-x XTAGS]... [-w] [-W] [-S] [-A ANSWER] [-Q QUESTION]
``` ```
* `-t, --or-tags OTAGS` : List of tags (one must match) * `-t, --or-tags OTAGS` : List of tags (one must match)
@@ -79,47 +79,46 @@ cmm hist [-t OTAGS]... [-k ATAGS]... [-x XTAGS]... [-w] [-W] [-S] [-A SUBSTRING]
* `-x, --exclude-tags XTAGS` : List of tags to exclude * `-x, --exclude-tags XTAGS` : List of tags to exclude
* `-w, --with-tags` : Print chat history with tags * `-w, --with-tags` : Print chat history with tags
* `-W, --with-files` : Print chat history with filenames * `-W, --with-files` : Print chat history with filenames
* `-S, --source-code-only`: Only print embedded source code * `-S, --source-code-only` : Print only source code
* `-A, --answer SUBSTRING`: Search for answer substring * `-A, --answer ANSWER` : Search for answer substring
* `-Q, --question SUBSTRING`: Search for question substring * `-Q, --question QUESTION` : Search for question substring
#### Tags #### Tags
The `tags` command is used to manage tags. The `tags` command is used to manage tags.
```bash ```bash
cmm tags (-l | -p PREFIX | -c SUBSTRING) cmm tags (-l | -p PREFIX | -c CONTENT)
``` ```
* `-l, --list` : List all tags and their frequency * `-l, --list` : List all tags and their frequency
* `-p, --prefix PREFIX` : Filter tags by prefix * `-p, --prefix PREFIX` : Filter tags by prefix
* `-c, --contain SUBSTRING`: Filter tags by contained substring * `-c, --contain CONTENT` : Filter tags by contained substring
#### Config #### Config
The `config` command is used to manage the configuration. The `config` command is used to manage the configuration.
```bash ```bash
cmm config (-l | -m | -c FILE) cmm config (-l | -m | -c CREATE)
``` ```
* `-l, --list-models` : List all available models * `-l, --list-models` : List all available models
* `-m, --print-model` : Print the currently configured model * `-m, --print-model` : Print the currently configured model
* `-c, --create FILE`: Create config with default settings in the given file * `-c, --create CREATE` : Create config with default settings in the given file
#### Print #### Print
The `print` command is used to print message files. The `print` command is used to print message files.
```bash ```bash
cmm print (-f FILE | -l) [-q | -a | -S] cmm print -f FILE [-q | -a | -S]
``` ```
* `-f, --file FILE`: Print given file * `-f, --file FILE` : File to print
* `-l, --latest`: Print latest message * `-q, --question` : Print only question
* `-q, --question`: Only print the question * `-a, --answer` : Print only answer
* `-a, --answer`: Only print the answer * `-S, --only-source-code` : Print only source code
* `-S, --only-source-code`: Only print embedded source code
### Examples ### Examples
@@ -161,27 +160,18 @@ cmm print -f example.yaml
## Configuration ## Configuration
The default configuration filename is `.config.yaml` (it is searched in the current working directory). The configuration file (`.config.yaml`) should contain the following fields:
Use the command `cmm config --create <FILENAME>` to create a default configuration:
``` - `openai`:
cache: . - `api_key`: Your OpenAI API key.
db: ./db/ - `model`: The name of the OpenAI model to use (e.g. "text-davinci-002").
ais: - `temperature`: The temperature value for the model.
myopenai: - `max_tokens`: The maximum number of tokens for the model.
name: openai - `top_p`: The top P value for the model.
model: gpt-3.5-turbo-16k - `frequency_penalty`: The frequency penalty value.
api_key: 0123456789 - `presence_penalty`: The presence penalty value.
temperature: 1.0 - `system`: The system message used to set the behavior of the AI.
max_tokens: 4000 - `db`: The directory where the question-answer pairs are stored in YAML files.
top_p: 1.0
frequency_penalty: 0.0
presence_penalty: 0.0
system: You are an assistant
```
Each AI has its own section and the name of that section is called the 'AI ID' (in the example above it is `myopenai`).
The AI ID can be any string, as long as it's unique within the `ais` section. The AI ID is used for all commands that support the `AI` parameter and it's also stored within each message file.
## Autocompletion ## Autocompletion
@@ -196,33 +186,33 @@ After adding this line, restart your shell or run `source <your-shell-config-fil
## Contributing ## Contributing
### Enable commit hooks ### Enable commit hooks
```bash ```
pip install pre-commit pip install pre-commit
pre-commit install pre-commit install
``` ```
### Execute tests before opening a PR ### Execute tests before opening a PR
```bash ```
pytest pytest
``` ```
### Consider using `pyenv` / `pyenv-virtualenv` ### Consider using `pyenv` / `pyenv-virtualenv`
Short installation instructions: Short installation instructions:
* install `pyenv`: * install `pyenv`:
```bash ```
cd ~ cd ~
git clone https://github.com/pyenv/pyenv .pyenv git clone https://github.com/pyenv/pyenv .pyenv
cd ~/.pyenv && src/configure && make -C src cd ~/.pyenv && src/configure && make -C src
``` ```
* make sure that `~/.pyenv/shims` and `~/.pyenv/bin` are the first entries in your `PATH`, e.g., by setting it in `~/.bashrc` * make sure that `~/.pyenv/shims` and `~/.pyenv/bin` are the first entries in your `PATH`, e. g. by setting it in `~/.bashrc`
* add the following to your `~/.bashrc` (after setting `PATH`): `eval "$(pyenv init -)"` * add the following to your `~/.bashrc` (after setting `PATH`): `eval "$(pyenv init -)"`
* create a new terminal or source the changes (e.g., `source ~/.bashrc`) * create a new terminal or source the changes (e. g. `source ~/.bashrc`)
* install `virtualenv` * install `virtualenv`
```bash ```
git clone https://github.com/pyenv/pyenv-virtualenv.git $(pyenv root)/plugins/pyenv-virtualenv git clone https://github.com/pyenv/pyenv-virtualenv.git $(pyenv root)/plugins/pyenv-virtualenv
``` ```
* add the following to your `~/.bashrc` (after the commands above): `eval "$(pyenv virtualenv-init -)` * add the following to your `~/.bashrc` (after the commands above): `eval "$(pyenv virtualenv-init -)`
* create a new terminal or source the changes (e.g., `source ~/.bashrc`) * create a new terminal or source the changes (e. g. `source ~/.bashrc`)
* go back to the `ChatMasterMind` repo and create a virtual environment with the latest `Python`, e.g., `3.11.4`: * go back to the `ChatMasterMind` repo and create a virtual environment with the latest `Python`, e. g. `3.11.4`:
```bash ```
cd <CMM_REPO_PATH> cd <CMM_REPO_PATH>
pyenv install 3.11.4 pyenv install 3.11.4
pyenv virtualenv 3.11.4 py311 pyenv virtualenv 3.11.4 py311
@@ -233,3 +223,5 @@ pyenv activate py311
## License ## License
This project is licensed under the terms of the WTFPL License. This project is licensed under the terms of the WTFPL License.
+6 -24
View File
@@ -3,13 +3,16 @@ import argparse
from pathlib import Path from pathlib import Path
from ..configuration import Config from ..configuration import Config
from ..message import Message, MessageError from ..message import Message, MessageError
from ..chat import ChatDB
def print_message(message: Message, args: argparse.Namespace) -> None: def print_cmd(args: argparse.Namespace, config: Config) -> None:
""" """
Print given message according to give arguments. Handler for the 'print' command.
""" """
fname = Path(args.file)
try:
message = Message.from_file(fname)
if message:
if args.question: if args.question:
print(message.question) print(message.question)
elif args.answer: elif args.answer:
@@ -19,27 +22,6 @@ def print_message(message: Message, args: argparse.Namespace) -> None:
print(code) print(code)
else: else:
print(message.to_str()) print(message.to_str())
def print_cmd(args: argparse.Namespace, config: Config) -> None:
"""
Handler for the 'print' command.
"""
# print given file
if args.file is not None:
fname = Path(args.file)
try:
message = Message.from_file(fname)
if message:
print_message(message, args)
except MessageError: except MessageError:
print(f"File is not a valid message: {args.file}") print(f"File is not a valid message: {args.file}")
sys.exit(1) sys.exit(1)
# print latest message
elif args.latest:
chat = ChatDB.from_dir(Path(config.cache), Path(config.db))
latest = chat.msg_latest(loc='disk')
if not latest:
print("No message found!")
sys.exit(1)
print_message(latest, args)
+3 -3
View File
@@ -71,7 +71,7 @@ def create_message(chat: ChatDB, args: argparse.Namespace) -> Message:
full_question = '\n\n'.join(question_parts) full_question = '\n\n'.join(question_parts)
message = Message(question=Question(full_question), message = Message(question=Question(full_question),
tags=args.output_tags, tags=args.output_tags, # FIXME
ai=args.AI, ai=args.AI,
model=args.model) model=args.model)
# only write the new message to the cache, # only write the new message to the cache,
@@ -92,8 +92,8 @@ def make_request(ai: AI, chat: ChatDB, message: Message, args: argparse.Namespac
print(message.to_str()) print(message.to_str())
response: AIResponse = ai.request(message, response: AIResponse = ai.request(message,
chat, chat,
args.num_answers, args.num_answers, # FIXME
args.output_tags) args.output_tags) # FIXME
# only write the response messages to the cache, # only write the response messages to the cache,
# don't add them to the internal list # don't add them to the internal list
chat.cache_write(response.messages) chat.cache_write(response.messages)
+19 -22
View File
@@ -44,13 +44,13 @@ def create_parser() -> argparse.ArgumentParser:
help='List of tags to exclude', metavar='XTAGS') help='List of tags to exclude', metavar='XTAGS')
etag_arg.completer = tags_completer # type: ignore etag_arg.completer = tags_completer # type: ignore
otag_arg = tag_parser.add_argument('-o', '--output-tags', nargs='+', otag_arg = tag_parser.add_argument('-o', '--output-tags', nargs='+',
help='List of output tags (default: use input tags)', metavar='OUTAGS') help='List of output tags (default: use input tags)', metavar='OUTTAGS')
otag_arg.completer = tags_completer # type: ignore otag_arg.completer = tags_completer # type: ignore
# a parent parser for all commands that support AI configuration # a parent parser for all commands that support AI configuration
ai_parser = argparse.ArgumentParser(add_help=False) ai_parser = argparse.ArgumentParser(add_help=False)
ai_parser.add_argument('-A', '--AI', help='AI ID to use', metavar='AI_ID') ai_parser.add_argument('-A', '--AI', help='AI ID to use')
ai_parser.add_argument('-M', '--model', help='Model to use', metavar='MODEL') ai_parser.add_argument('-M', '--model', help='Model to use')
ai_parser.add_argument('-n', '--num-answers', help='Number of answers to request', type=int, default=1) ai_parser.add_argument('-n', '--num-answers', help='Number of answers to request', type=int, default=1)
ai_parser.add_argument('-m', '--max-tokens', help='Max. nr. of tokens', type=int) ai_parser.add_argument('-m', '--max-tokens', help='Max. nr. of tokens', type=int)
ai_parser.add_argument('-T', '--temperature', help='Temperature value', type=float) ai_parser.add_argument('-T', '--temperature', help='Temperature value', type=float)
@@ -61,15 +61,14 @@ def create_parser() -> argparse.ArgumentParser:
aliases=['q']) aliases=['q'])
question_cmd_parser.set_defaults(func=question_cmd) question_cmd_parser.set_defaults(func=question_cmd)
question_group = question_cmd_parser.add_mutually_exclusive_group(required=True) question_group = question_cmd_parser.add_mutually_exclusive_group(required=True)
question_group.add_argument('-a', '--ask', nargs='+', help='Ask a question', metavar='QUESTION') question_group.add_argument('-a', '--ask', nargs='+', help='Ask a question')
question_group.add_argument('-c', '--create', nargs='+', help='Create a question', metavar='QUESTION') question_group.add_argument('-c', '--create', nargs='+', help='Create a question')
question_group.add_argument('-r', '--repeat', nargs='*', help='Repeat a question', metavar='MESSAGE') question_group.add_argument('-r', '--repeat', nargs='*', help='Repeat a question')
question_group.add_argument('-p', '--process', nargs='*', help='Process existing questions', metavar='MESSAGE') question_group.add_argument('-p', '--process', nargs='*', help='Process existing questions')
question_cmd_parser.add_argument('-O', '--overwrite', help='Overwrite existing messages when repeating them', question_cmd_parser.add_argument('-O', '--overwrite', help='Overwrite existing messages when repeating them',
action='store_true') action='store_true')
question_cmd_parser.add_argument('-s', '--source-text', nargs='+', help='Add content of a file to the query', metavar='FILE') question_cmd_parser.add_argument('-s', '--source-text', nargs='+', help='Add content of a file to the query')
question_cmd_parser.add_argument('-S', '--source-code', nargs='+', help='Add source code file content to the chat history', question_cmd_parser.add_argument('-S', '--source-code', nargs='+', help='Add source code file content to the chat history')
metavar='FILE')
# 'hist' command parser # 'hist' command parser
hist_cmd_parser = cmdparser.add_parser('hist', parents=[tag_parser], hist_cmd_parser = cmdparser.add_parser('hist', parents=[tag_parser],
@@ -80,10 +79,10 @@ def create_parser() -> argparse.ArgumentParser:
action='store_true') action='store_true')
hist_cmd_parser.add_argument('-W', '--with-files', help="Print chat history with filenames.", hist_cmd_parser.add_argument('-W', '--with-files', help="Print chat history with filenames.",
action='store_true') action='store_true')
hist_cmd_parser.add_argument('-S', '--source-code-only', help='Only print embedded source code', hist_cmd_parser.add_argument('-S', '--source-code-only', help='Print only source code',
action='store_true') action='store_true')
hist_cmd_parser.add_argument('-A', '--answer', help='Search for answer substring', metavar='SUBSTRING') hist_cmd_parser.add_argument('-A', '--answer', help='Search for answer substring')
hist_cmd_parser.add_argument('-Q', '--question', help='Search for question substring', metavar='SUBSTRING') hist_cmd_parser.add_argument('-Q', '--question', help='Search for question substring')
# 'tags' command parser # 'tags' command parser
tags_cmd_parser = cmdparser.add_parser('tags', tags_cmd_parser = cmdparser.add_parser('tags',
@@ -93,8 +92,8 @@ def create_parser() -> argparse.ArgumentParser:
tags_group = tags_cmd_parser.add_mutually_exclusive_group(required=True) tags_group = tags_cmd_parser.add_mutually_exclusive_group(required=True)
tags_group.add_argument('-l', '--list', help="List all tags and their frequency", tags_group.add_argument('-l', '--list', help="List all tags and their frequency",
action='store_true') action='store_true')
tags_cmd_parser.add_argument('-p', '--prefix', help="Filter tags by prefix", metavar='PREFIX') tags_cmd_parser.add_argument('-p', '--prefix', help="Filter tags by prefix")
tags_cmd_parser.add_argument('-c', '--contain', help="Filter tags by contained substring", metavar='SUBSTRING') tags_cmd_parser.add_argument('-c', '--contain', help="Filter tags by contained substring")
# 'config' command parser # 'config' command parser
config_cmd_parser = cmdparser.add_parser('config', config_cmd_parser = cmdparser.add_parser('config',
@@ -107,20 +106,18 @@ def create_parser() -> argparse.ArgumentParser:
action='store_true') action='store_true')
config_group.add_argument('-m', '--print-model', help="Print the currently configured model", config_group.add_argument('-m', '--print-model', help="Print the currently configured model",
action='store_true') action='store_true')
config_group.add_argument('-c', '--create', help="Create config with default settings in the given file", metavar='FILE') config_group.add_argument('-c', '--create', help="Create config with default settings in the given file")
# 'print' command parser # 'print' command parser
print_cmd_parser = cmdparser.add_parser('print', print_cmd_parser = cmdparser.add_parser('print',
help="Print message files.", help="Print message files.",
aliases=['p']) aliases=['p'])
print_cmd_parser.set_defaults(func=print_cmd) print_cmd_parser.set_defaults(func=print_cmd)
print_group = print_cmd_parser.add_mutually_exclusive_group(required=True) print_cmd_parser.add_argument('-f', '--file', help='File to print', required=True)
print_group.add_argument('-f', '--file', help='Print given message file', metavar='FILE')
print_group.add_argument('-l', '--latest', help='Print latest message', action='store_true')
print_cmd_modes = print_cmd_parser.add_mutually_exclusive_group() print_cmd_modes = print_cmd_parser.add_mutually_exclusive_group()
print_cmd_modes.add_argument('-q', '--question', help='Only print the question', action='store_true') print_cmd_modes.add_argument('-q', '--question', help='Print only question', action='store_true')
print_cmd_modes.add_argument('-a', '--answer', help='Only print the answer', action='store_true') print_cmd_modes.add_argument('-a', '--answer', help='Print only answer', action='store_true')
print_cmd_modes.add_argument('-S', '--only-source-code', help='Only print embedded source code', action='store_true') print_cmd_modes.add_argument('-S', '--only-source-code', help='Print only source code', action='store_true')
argcomplete.autocomplete(parser) argcomplete.autocomplete(parser)
return parser return parser
+3 -24
View File
@@ -222,36 +222,12 @@ class Message():
ai_yaml_key: ClassVar[str] = 'ai' ai_yaml_key: ClassVar[str] = 'ai'
model_yaml_key: ClassVar[str] = 'model' model_yaml_key: ClassVar[str] = 'model'
def __post_init__(self) -> None:
# convert some types that are often set wrong
if self.tags is not None and not isinstance(self.tags, set):
self.tags = set(self.tags)
if self.file_path is not None and not isinstance(self.file_path, pathlib.Path):
self.file_path = pathlib.Path(self.file_path)
def __hash__(self) -> int: def __hash__(self) -> int:
""" """
The hash value is computed based on immutable members. The hash value is computed based on immutable members.
""" """
return hash((self.question, self.answer)) return hash((self.question, self.answer))
def equals(self, other: MessageInst, tags: bool = True, ai: bool = True,
model: bool = True, file_path: bool = True, verbose: bool = False) -> bool:
"""
Compare this message with another one, including the metadata.
Return True if everything is identical, False otherwise.
"""
equal: bool = ((not tags or (self.tags == other.tags))
and (not ai or (self.ai == other.ai)) # noqa: W503
and (not model or (self.model == other.model)) # noqa: W503
and (not file_path or (self.file_path == other.file_path)) # noqa: W503
and (self == other)) # noqa: W503
if not equal and verbose:
print("Messages not equal:")
print(self)
print(other)
return equal
@classmethod @classmethod
def from_dict(cls: Type[MessageInst], data: dict[str, Any]) -> MessageInst: def from_dict(cls: Type[MessageInst], data: dict[str, Any]) -> MessageInst:
""" """
@@ -442,6 +418,9 @@ class Message():
output.append(self.answer) output.append(self.answer)
return '\n'.join(output) return '\n'.join(output)
def __str__(self) -> str:
return self.to_str(True, True, False)
def to_file(self, file_path: Optional[pathlib.Path]=None) -> None: # noqa: 11 def to_file(self, file_path: Optional[pathlib.Path]=None) -> None: # noqa: 11
""" """
Write a Message to the given file. Type is determined based on the suffix. Write a Message to the given file. Type is determined based on the suffix.
+32 -43
View File
@@ -10,18 +10,7 @@ from chatmastermind.message import Message, Question, Answer, Tag, MessageFilter
from chatmastermind.chat import Chat, ChatDB, ChatError from chatmastermind.chat import Chat, ChatDB, ChatError
class TestChatBase(unittest.TestCase): class TestChat(unittest.TestCase):
def assert_messages_equal(self, msg1: list[Message], msg2: list[Message]) -> None:
"""
Compare messages using more than just Question and Answer.
"""
self.assertEqual(len(msg1), len(msg2))
for m1, m2 in zip(msg1, msg2):
# exclude the file_path, compare only Q, A and metadata
self.assertTrue(m1.equals(m2, file_path=False, verbose=True))
class TestChat(TestChatBase):
def setUp(self) -> None: def setUp(self) -> None:
self.chat = Chat([]) self.chat = Chat([])
self.message1 = Message(Question('Question 1'), self.message1 = Message(Question('Question 1'),
@@ -37,24 +26,24 @@ class TestChat(TestChatBase):
def test_unique_id(self) -> None: def test_unique_id(self) -> None:
# test with two identical messages # test with two identical messages
self.chat.msg_add([self.message1, self.message1]) self.chat.msg_add([self.message1, self.message1])
self.assert_messages_equal(self.chat.messages, [self.message1, self.message1]) self.assertSequenceEqual(self.chat.messages, [self.message1, self.message1])
self.chat.msg_unique_id() self.chat.msg_unique_id()
self.assert_messages_equal(self.chat.messages, [self.message1]) self.assertSequenceEqual(self.chat.messages, [self.message1])
# test with two different messages # test with two different messages
self.chat.msg_add([self.message2]) self.chat.msg_add([self.message2])
self.chat.msg_unique_id() self.chat.msg_unique_id()
self.assert_messages_equal(self.chat.messages, [self.message1, self.message2]) self.assertSequenceEqual(self.chat.messages, [self.message1, self.message2])
def test_unique_content(self) -> None: def test_unique_content(self) -> None:
# test with two identical messages # test with two identical messages
self.chat.msg_add([self.message1, self.message1]) self.chat.msg_add([self.message1, self.message1])
self.assert_messages_equal(self.chat.messages, [self.message1, self.message1]) self.assertSequenceEqual(self.chat.messages, [self.message1, self.message1])
self.chat.msg_unique_content() self.chat.msg_unique_content()
self.assert_messages_equal(self.chat.messages, [self.message1]) self.assertSequenceEqual(self.chat.messages, [self.message1])
# test with two different messages # test with two different messages
self.chat.msg_add([self.message2]) self.chat.msg_add([self.message2])
self.chat.msg_unique_content() self.chat.msg_unique_content()
self.assert_messages_equal(self.chat.messages, [self.message1, self.message2]) self.assertSequenceEqual(self.chat.messages, [self.message1, self.message2])
def test_filter(self) -> None: def test_filter(self) -> None:
self.chat.msg_add([self.message1, self.message2]) self.chat.msg_add([self.message1, self.message2])
@@ -161,7 +150,7 @@ Answer 2
self.assertEqual(mock_stdout.getvalue(), expected_output) self.assertEqual(mock_stdout.getvalue(), expected_output)
class TestChatDB(TestChatBase): class TestChatDB(unittest.TestCase):
def setUp(self) -> None: def setUp(self) -> None:
self.db_path = tempfile.TemporaryDirectory() self.db_path = tempfile.TemporaryDirectory()
self.cache_path = tempfile.TemporaryDirectory() self.cache_path = tempfile.TemporaryDirectory()
@@ -580,7 +569,7 @@ class TestChatDB(TestChatBase):
search_names = ['0001', '0002.yaml', self.message3.msg_id(), str(self.message3.file_path)] search_names = ['0001', '0002.yaml', self.message3.msg_id(), str(self.message3.file_path)]
expected_result = [self.message1, self.message2, self.message3] expected_result = [self.message1, self.message2, self.message3]
result = chat_db.msg_find(search_names, loc='all') result = chat_db.msg_find(search_names, loc='all')
self.assert_messages_equal(result, expected_result) self.assertSequenceEqual(result, expected_result)
def test_msg_latest(self) -> None: def test_msg_latest(self) -> None:
chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name), chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name),
@@ -606,47 +595,47 @@ class TestChatDB(TestChatBase):
chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name), chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name),
pathlib.Path(self.db_path.name)) pathlib.Path(self.db_path.name))
all_messages = [self.message1, self.message2, self.message3, self.message4] all_messages = [self.message1, self.message2, self.message3, self.message4]
self.assert_messages_equal(chat_db.msg_gather(loc='all'), all_messages) self.assertSequenceEqual(chat_db.msg_gather(loc='all'), all_messages)
self.assert_messages_equal(chat_db.msg_gather(loc='db'), all_messages) self.assertSequenceEqual(chat_db.msg_gather(loc='db'), all_messages)
self.assert_messages_equal(chat_db.msg_gather(loc='mem'), all_messages) self.assertSequenceEqual(chat_db.msg_gather(loc='mem'), all_messages)
self.assert_messages_equal(chat_db.msg_gather(loc='disk'), all_messages) self.assertSequenceEqual(chat_db.msg_gather(loc='disk'), all_messages)
self.assert_messages_equal(chat_db.msg_gather(loc='cache'), []) self.assertSequenceEqual(chat_db.msg_gather(loc='cache'), [])
# add a new message, but only to the internal list # add a new message, but only to the internal list
new_message = Message(Question("What?")) new_message = Message(Question("What?"))
all_messages_mem = all_messages + [new_message] all_messages_mem = all_messages + [new_message]
chat_db.msg_add([new_message]) chat_db.msg_add([new_message])
self.assert_messages_equal(chat_db.msg_gather(loc='mem'), all_messages_mem) self.assertSequenceEqual(chat_db.msg_gather(loc='mem'), all_messages_mem)
self.assert_messages_equal(chat_db.msg_gather(loc='all'), all_messages_mem) self.assertSequenceEqual(chat_db.msg_gather(loc='all'), all_messages_mem)
# the nr. of messages on disk did not change -> expect old result # the nr. of messages on disk did not change -> expect old result
self.assert_messages_equal(chat_db.msg_gather(loc='db'), all_messages) self.assertSequenceEqual(chat_db.msg_gather(loc='db'), all_messages)
self.assert_messages_equal(chat_db.msg_gather(loc='disk'), all_messages) self.assertSequenceEqual(chat_db.msg_gather(loc='disk'), all_messages)
self.assert_messages_equal(chat_db.msg_gather(loc='cache'), []) self.assertSequenceEqual(chat_db.msg_gather(loc='cache'), [])
# test with MessageFilter # test with MessageFilter
self.assert_messages_equal(chat_db.msg_gather(loc='all', mfilter=MessageFilter(tags_or={Tag('tag1')})), self.assertSequenceEqual(chat_db.msg_gather(loc='all', mfilter=MessageFilter(tags_or={Tag('tag1')})),
[self.message1]) [self.message1])
self.assert_messages_equal(chat_db.msg_gather(loc='disk', mfilter=MessageFilter(tags_or={Tag('tag2')})), self.assertSequenceEqual(chat_db.msg_gather(loc='disk', mfilter=MessageFilter(tags_or={Tag('tag2')})),
[self.message2]) [self.message2])
self.assert_messages_equal(chat_db.msg_gather(loc='cache', mfilter=MessageFilter(tags_or={Tag('tag3')})), self.assertSequenceEqual(chat_db.msg_gather(loc='cache', mfilter=MessageFilter(tags_or={Tag('tag3')})),
[]) [])
self.assert_messages_equal(chat_db.msg_gather(loc='mem', mfilter=MessageFilter(question_contains="What")), self.assertSequenceEqual(chat_db.msg_gather(loc='mem', mfilter=MessageFilter(question_contains="What")),
[new_message]) [new_message])
def test_msg_move_and_gather(self) -> None: def test_msg_move_and_gather(self) -> None:
chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name), chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name),
pathlib.Path(self.db_path.name)) pathlib.Path(self.db_path.name))
all_messages = [self.message1, self.message2, self.message3, self.message4] all_messages = [self.message1, self.message2, self.message3, self.message4]
self.assert_messages_equal(chat_db.msg_gather(loc='db'), all_messages) self.assertSequenceEqual(chat_db.msg_gather(loc='db'), all_messages)
self.assert_messages_equal(chat_db.msg_gather(loc='cache'), []) self.assertSequenceEqual(chat_db.msg_gather(loc='cache'), [])
# move first message to the cache # move first message to the cache
chat_db.cache_move(self.message1) chat_db.cache_move(self.message1)
self.assert_messages_equal(chat_db.msg_gather(loc='cache'), [self.message1]) self.assertSequenceEqual(chat_db.msg_gather(loc='cache'), [self.message1])
self.assertEqual(self.message1.file_path.parent, pathlib.Path(self.cache_path.name)) # type: ignore [union-attr] self.assertEqual(self.message1.file_path.parent, pathlib.Path(self.cache_path.name)) # type: ignore [union-attr]
self.assert_messages_equal(chat_db.msg_gather(loc='db'), [self.message2, self.message3, self.message4]) self.assertSequenceEqual(chat_db.msg_gather(loc='db'), [self.message2, self.message3, self.message4])
self.assert_messages_equal(chat_db.msg_gather(loc='all'), all_messages) self.assertSequenceEqual(chat_db.msg_gather(loc='all'), all_messages)
self.assert_messages_equal(chat_db.msg_gather(loc='disk'), all_messages) self.assertSequenceEqual(chat_db.msg_gather(loc='disk'), all_messages)
self.assert_messages_equal(chat_db.msg_gather(loc='mem'), all_messages) self.assertSequenceEqual(chat_db.msg_gather(loc='mem'), all_messages)
# now move first message back to the DB # now move first message back to the DB
chat_db.db_move(self.message1) chat_db.db_move(self.message1)
self.assert_messages_equal(chat_db.msg_gather(loc='cache'), []) self.assertSequenceEqual(chat_db.msg_gather(loc='cache'), [])
self.assertEqual(self.message1.file_path.parent, pathlib.Path(self.db_path.name)) # type: ignore [union-attr] self.assertEqual(self.message1.file_path.parent, pathlib.Path(self.db_path.name)) # type: ignore [union-attr]
self.assert_messages_equal(chat_db.msg_gather(loc='db'), all_messages) self.assertSequenceEqual(chat_db.msg_gather(loc='db'), all_messages)
+18 -195
View File
@@ -11,21 +11,10 @@ from chatmastermind.commands.question import create_message, question_cmd
from chatmastermind.tags import Tag from chatmastermind.tags import Tag
from chatmastermind.message import Message, Question, Answer from chatmastermind.message import Message, Question, Answer
from chatmastermind.chat import Chat, ChatDB from chatmastermind.chat import Chat, ChatDB
from chatmastermind.ai import AI, AIResponse, Tokens, AIError from chatmastermind.ai import AI, AIResponse, Tokens
class TestQuestionCmdBase(unittest.TestCase): class TestMessageCreate(unittest.TestCase):
def assert_messages_equal(self, msg1: list[Message], msg2: list[Message]) -> None:
"""
Compare messages using more than just Question and Answer.
"""
self.assertEqual(len(msg1), len(msg2))
for m1, m2 in zip(msg1, msg2):
# exclude the file_path, compare only Q, A and metadata
self.assertTrue(m1.equals(m2, file_path=False, verbose=True))
class TestMessageCreate(TestQuestionCmdBase):
""" """
Test if messages created by the 'question' command have Test if messages created by the 'question' command have
the correct format. the correct format.
@@ -212,7 +201,7 @@ It is embedded code
""")) """))
class TestQuestionCmd(TestQuestionCmdBase): class TestQuestionCmd(unittest.TestCase):
def setUp(self) -> None: def setUp(self) -> None:
# create DB and cache # create DB and cache
@@ -236,12 +225,8 @@ class TestQuestionCmd(TestQuestionCmdBase):
source_code=None, source_code=None,
create=None, create=None,
repeat=None, repeat=None,
process=None, process=None
overwrite=None
) )
# create a mock AI instance
self.ai = MagicMock(spec=AI)
self.ai.request.side_effect = self.mock_request
def input_message(self, args: argparse.Namespace) -> Message: def input_message(self, args: argparse.Namespace) -> Message:
""" """
@@ -266,7 +251,7 @@ class TestQuestionCmd(TestQuestionCmdBase):
Mock the 'ai.request()' function Mock the 'ai.request()' function
""" """
question.answer = Answer("Answer 0") question.answer = Answer("Answer 0")
question.tags = set(otags) if otags else None question.tags = otags
question.ai = 'FakeAI' question.ai = 'FakeAI'
question.model = 'FakeModel' question.model = 'FakeModel'
answers: list[Message] = [question] answers: list[Message] = [question]
@@ -285,9 +270,12 @@ class TestQuestionCmd(TestQuestionCmdBase):
@mock.patch('chatmastermind.commands.question.create_ai') @mock.patch('chatmastermind.commands.question.create_ai')
def test_ask_single_answer(self, mock_create_ai: MagicMock) -> None: def test_ask_single_answer(self, mock_create_ai: MagicMock) -> None:
""" """
Test single answer with no errors. Test single answer with no errors
""" """
mock_create_ai.return_value = self.ai # create a mock AI instance
ai = MagicMock(spec=AI)
ai.request.side_effect = self.mock_request
mock_create_ai.return_value = ai
expected_question = self.input_message(self.args) expected_question = self.input_message(self.args)
expected_responses = self.mock_request(expected_question, expected_responses = self.mock_request(expected_question,
Chat([]), Chat([]),
@@ -298,7 +286,7 @@ class TestQuestionCmd(TestQuestionCmdBase):
question_cmd(self.args, self.config) question_cmd(self.args, self.config)
# check for correct request call # check for correct request call
self.ai.request.assert_called_once_with(expected_question, ai.request.assert_called_once_with(expected_question,
ANY, ANY,
self.args.num_answers, self.args.num_answers,
self.args.output_tags) self.args.output_tags)
@@ -307,18 +295,21 @@ class TestQuestionCmd(TestQuestionCmdBase):
Path(self.db_dir.name)) Path(self.db_dir.name))
cached_msg = chat.msg_gather(loc='cache') cached_msg = chat.msg_gather(loc='cache')
self.assertEqual(len(self.message_list(self.cache_dir)), 1) self.assertEqual(len(self.message_list(self.cache_dir)), 1)
self.assert_messages_equal(cached_msg, expected_responses) self.assertSequenceEqual(cached_msg, expected_responses)
@mock.patch('chatmastermind.commands.question.ChatDB.from_dir') @mock.patch('chatmastermind.commands.question.ChatDB.from_dir')
@mock.patch('chatmastermind.commands.question.create_ai') @mock.patch('chatmastermind.commands.question.create_ai')
def test_ask_single_answer_mocked(self, mock_create_ai: MagicMock, mock_from_dir: MagicMock) -> None: def test_ask_single_answer_mocked(self, mock_create_ai: MagicMock, mock_from_dir: MagicMock) -> None:
""" """
Test single answer with no errors (mocked ChatDB version). Test single answer with no errors (mocked ChatDB version)
""" """
chat = MagicMock(spec=ChatDB) chat = MagicMock(spec=ChatDB)
mock_from_dir.return_value = chat mock_from_dir.return_value = chat
mock_create_ai.return_value = self.ai # create a mock AI instance
ai = MagicMock(spec=AI)
ai.request.side_effect = self.mock_request
mock_create_ai.return_value = ai
expected_question = self.input_message(self.args) expected_question = self.input_message(self.args)
expected_responses = self.mock_request(expected_question, expected_responses = self.mock_request(expected_question,
Chat([]), Chat([]),
@@ -329,7 +320,7 @@ class TestQuestionCmd(TestQuestionCmdBase):
question_cmd(self.args, self.config) question_cmd(self.args, self.config)
# check for correct request call # check for correct request call
self.ai.request.assert_called_once_with(expected_question, ai.request.assert_called_once_with(expected_question,
chat, chat,
self.args.num_answers, self.args.num_answers,
self.args.output_tags) self.args.output_tags)
@@ -343,171 +334,3 @@ class TestQuestionCmd(TestQuestionCmdBase):
# check that the messages have not been added to the internal message list # check that the messages have not been added to the internal message list
chat.cache_add.assert_not_called() chat.cache_add.assert_not_called()
@mock.patch('chatmastermind.commands.question.create_ai')
def test_ask_with_error(self, mock_create_ai: MagicMock) -> None:
"""
Provoke an error during the AI request and verify that the question
has been correctly stored in the cache.
"""
mock_create_ai.return_value = self.ai
expected_question = self.input_message(self.args)
self.ai.request.side_effect = AIError
# execute the command
with self.assertRaises(AIError):
question_cmd(self.args, self.config)
# check for correct request call
self.ai.request.assert_called_once_with(expected_question,
ANY,
self.args.num_answers,
self.args.output_tags)
# check for the expected message files
chat = ChatDB.from_dir(Path(self.cache_dir.name),
Path(self.db_dir.name))
cached_msg = chat.msg_gather(loc='cache')
self.assertEqual(len(self.message_list(self.cache_dir)), 1)
self.assert_messages_equal(cached_msg, [expected_question])
@mock.patch('chatmastermind.commands.question.create_ai')
def test_repeat_single_question(self, mock_create_ai: MagicMock) -> None:
"""
Repeat a single question.
"""
# 1. ask a question
mock_create_ai.return_value = self.ai
expected_question = self.input_message(self.args)
expected_responses = self.mock_request(expected_question,
Chat([]),
self.args.num_answers,
self.args.output_tags).messages
question_cmd(self.args, self.config)
chat = ChatDB.from_dir(Path(self.cache_dir.name),
Path(self.db_dir.name))
cached_msg = chat.msg_gather(loc='cache')
self.assertEqual(len(self.message_list(self.cache_dir)), 1)
self.assert_messages_equal(cached_msg, expected_responses)
# 2. repeat the last question (without overwriting)
# -> expect two identical messages (except for the file_path)
self.args.ask = None
self.args.repeat = []
self.args.overwrite = False
expected_responses += expected_responses
question_cmd(self.args, self.config)
cached_msg = chat.msg_gather(loc='cache')
self.assertEqual(len(self.message_list(self.cache_dir)), 2)
self.assert_messages_equal(cached_msg, expected_responses)
@mock.patch('chatmastermind.commands.question.create_ai')
def test_repeat_single_question_overwrite(self, mock_create_ai: MagicMock) -> None:
"""
Repeat a single question and overwrite the old one.
"""
# 1. ask a question
mock_create_ai.return_value = self.ai
expected_question = self.input_message(self.args)
expected_responses = self.mock_request(expected_question,
Chat([]),
self.args.num_answers,
self.args.output_tags).messages
question_cmd(self.args, self.config)
chat = ChatDB.from_dir(Path(self.cache_dir.name),
Path(self.db_dir.name))
cached_msg = chat.msg_gather(loc='cache')
assert cached_msg[0].file_path
cached_msg_file_id = cached_msg[0].file_path.stem
self.assertEqual(len(self.message_list(self.cache_dir)), 1)
self.assert_messages_equal(cached_msg, expected_responses)
# 2. repeat the last question (WITH overwriting)
# -> expect a single message afterwards
self.args.ask = None
self.args.repeat = []
self.args.overwrite = True
question_cmd(self.args, self.config)
cached_msg = chat.msg_gather(loc='cache')
self.assertEqual(len(self.message_list(self.cache_dir)), 1)
self.assert_messages_equal(cached_msg, expected_responses)
# also check that the file ID has not been changed
assert cached_msg[0].file_path
self.assertEqual(cached_msg_file_id, cached_msg[0].file_path.stem)
@mock.patch('chatmastermind.commands.question.create_ai')
def test_repeat_single_question_after_error(self, mock_create_ai: MagicMock) -> None:
"""
Repeat a single question after an error.
"""
# 1. ask a question and provoke an error
mock_create_ai.return_value = self.ai
expected_question = self.input_message(self.args)
self.ai.request.side_effect = AIError
with self.assertRaises(AIError):
question_cmd(self.args, self.config)
chat = ChatDB.from_dir(Path(self.cache_dir.name),
Path(self.db_dir.name))
cached_msg = chat.msg_gather(loc='cache')
assert cached_msg[0].file_path
cached_msg_file_id = cached_msg[0].file_path.stem
self.assertEqual(len(self.message_list(self.cache_dir)), 1)
self.assert_messages_equal(cached_msg, [expected_question])
# 2. repeat the last question (without overwriting)
# -> expect a single message because if the original has
# no answer, it should be overwritten by default
self.args.ask = None
self.args.repeat = []
self.args.overwrite = False
self.ai.request.side_effect = self.mock_request
expected_responses = self.mock_request(expected_question,
Chat([]),
self.args.num_answers,
self.args.output_tags).messages
question_cmd(self.args, self.config)
cached_msg = chat.msg_gather(loc='cache')
self.assertEqual(len(self.message_list(self.cache_dir)), 1)
self.assert_messages_equal(cached_msg, expected_responses)
# also check that the file ID has not been changed
assert cached_msg[0].file_path
self.assertEqual(cached_msg_file_id, cached_msg[0].file_path.stem)
@mock.patch('chatmastermind.commands.question.create_ai')
def test_repeat_single_question_new_args(self, mock_create_ai: MagicMock) -> None:
"""
Repeat a single question with new arguments.
"""
# 1. ask a question
mock_create_ai.return_value = self.ai
expected_question = self.input_message(self.args)
expected_responses = self.mock_request(expected_question,
Chat([]),
self.args.num_answers,
self.args.output_tags).messages
question_cmd(self.args, self.config)
chat = ChatDB.from_dir(Path(self.cache_dir.name),
Path(self.db_dir.name))
cached_msg = chat.msg_gather(loc='cache')
self.assertEqual(len(self.message_list(self.cache_dir)), 1)
self.assert_messages_equal(cached_msg, expected_responses)
# 2. repeat the last question with new arguments (without overwriting)
# -> expect two messages with identical question and answer, but different metadata
self.args.ask = None
self.args.repeat = []
self.args.overwrite = False
self.args.output_tags = ['newtag']
self.args.AI = 'newai'
self.args.model = 'newmodel'
new_expected_question = Message(question=Question(expected_question.question),
tags=set(self.args.output_tags),
ai=self.args.AI,
model=self.args.model)
expected_responses += self.mock_request(new_expected_question,
Chat([]),
self.args.num_answers,
set(self.args.output_tags)).messages
question_cmd(self.args, self.config)
cached_msg = chat.msg_gather(loc='cache')
self.assertEqual(len(self.message_list(self.cache_dir)), 2)
self.assert_messages_equal(cached_msg, expected_responses)