42 Commits

Author SHA1 Message Date
Oleksandr Kozachuk dca1bc621f Add Jinja2 as preparation for pipes. 2023-10-17 11:57:52 +02:00
Oleksandr Kozachuk 5f29f60168 Add .old/ to git ignore, I use that dir ofter for old files, I do not want to delete. 2023-10-17 11:53:49 +02:00
juk0de 3ea1f49027 cmm: added options '--tight' and '--no-paging' to the 'hist --print' cmd 2023-10-02 08:35:19 +02:00
juk0de 8f56399844 cmm: replaced options '--with-tags' and '--with-file' with '--with-metadata' 2023-10-01 10:11:16 +02:00
juk0de e4cb6eb22b README: updated 'hist' command description 2023-10-01 09:27:40 +02:00
juk0de e19c6bb1ea hist_cmd: added module 'test_hist_cmd.py' 2023-09-30 08:25:33 +02:00
juk0de 811b2e6830 hist_cmd: implemented '--convert' option 2023-09-29 18:53:12 +02:00
juk0de 2a8f01aee4 chat: 'msg_gather()' now supports globbing 2023-09-29 07:16:20 +02:00
juk0de efdb3cae2f question: moved around some code 2023-09-29 07:16:20 +02:00
juk0de aecfd1088d chat: added message file format as ChatDB class member 2023-09-29 07:16:20 +02:00
juk0de 140dbed809 message: added function 'rm_file()' and test 2023-09-29 07:16:20 +02:00
juk0de 01860ace2c test_question_cmd: modified tests to use '.msg' file suffix 2023-09-29 07:16:20 +02:00
juk0de df42bcee09 test_chat: added test for file_path collision detection 2023-09-29 07:16:20 +02:00
juk0de e34eab6519 test_chat: changed all tests to use the new '.msg' suffix 2023-09-29 07:16:20 +02:00
juk0de d07fd13e8e test_message: changed all tests to use the new '.msg' suffix 2023-09-29 07:16:20 +02:00
juk0de b8681e8274 message: fixed tag matching for YAML file format 2023-09-29 07:16:20 +02:00
juk0de d2be53aeab chat: switched to new message suffix and formats
- no longer using file suffix to choose the format
- added 'mformat' argument to 'write_xxx()' functions
- file suffix is now set by 'Message.to_file()' per default
2023-09-29 07:16:20 +02:00
juk0de 9ca9a23569 message: introduced file suffix '.msg'
- '.msg' suffix is always used for writing
- 'Message.to_file()' will set the file suffix if the given file_path has none
- added 'mformat' argument to 'Message.to_file()' for choosing the file format
- '.txt' and '.yaml' suffixes are only supported for reading
2023-09-29 07:16:20 +02:00
juk0de 6f3758e12e question_cmd: fixed '--create' option 2023-09-29 07:15:46 +02:00
juk0de dd836cd72d Merge pull request 'cmm question --repeat supports multiple questions, added tests and fixes' (#15) from repeat_multi into main
This PR primarily modifies the `cmm question --repeat` command to allow repeating multiple questions, instead of only the last one.

Additionally, this PR includes the following changes:

- In `ai_factory.py`, added optional parameters 'def_ai' and 'def_model' to the `create_ai` function which allows specifying a default AI and model.
- In `openai.py`, a potential bug was fixed where the 'tags' attribute was updated to ensure it is always a set, even when 'otags' is None.
- In `question.py`, a significant amount of new code was added to facilitate the 'repeat' functionality. This includes functions to create modified args based on an existing message (`create_msg_args`), to repeat a given list of messages (`repeat_messages`), and to invert the semantics of the INPUT tags for this command (`invert_input_tag_args`).
- In `main.py`, the 'nargs' parameter was changed from `+` to `*` in the 'or-tags', 'and-tags', and 'exclude-tags' arguments to accommodate the updated handling of tags in `question.py`.
- A new `test_common.py` file was added which includes a `FakeAI` class for testing purposes, and a `TestWithFakeAI` class which includes a number of methods for asserting various conditions about messages.

This PR also includes additional tests to verify the correct operation of the new 'repeat' functionality.
2023-09-26 18:04:27 +02:00
juk0de 601ebe731a test_question_cmd: added a new testcase and made the old cases more explicit (easier to read) 2023-09-24 08:53:37 +02:00
juk0de 87b25993be tests: moved 'FakeAI' and common functions to 'test_common.py' 2023-09-24 08:38:52 +02:00
juk0de a478408449 test_question_cmd: test fixes and cleanup 2023-09-23 08:53:26 +02:00
juk0de b83b396c7b question_cmd: fixed msg specific argument creation 2023-09-23 08:11:11 +02:00
juk0de 3c932aa88e openai: fixed assignment of output tags 2023-09-23 08:11:11 +02:00
juk0de b50caa345c test_question_cmd: introduced 'FakeAI' class 2023-09-23 08:11:11 +02:00
juk0de 80c5dcc801 question_cmd: input tag options without a tag (e. g. '-t') now select ALL tags 2023-09-23 08:11:11 +02:00
juk0de 33df84beaa ai_factory: added optional 'def_ai' and 'def_model' arguments to 'create_ai' 2023-09-22 13:43:31 +02:00
juk0de 0657a1bab8 question_cmd: fixed AI and model arguments when repeating messages 2023-09-22 13:43:31 +02:00
juk0de e9175aface test_question_cmd: added testcase for --repeat with multiple messages 2023-09-22 13:43:31 +02:00
juk0de 21f81f3569 question_cmd: implemented repetition of multiple messages 2023-09-22 13:43:31 +02:00
juk0de 4538624247 Merge pull request 'Implemented the 'question --repeat' command and other improvements' (#14) from repeat into main
Reviewed-on: #14
2023-09-21 07:25:47 +02:00
juk0de ac3c19739d README: updates and fixes 2023-09-20 10:18:06 +02:00
juk0de ed379ed535 print_cmd: added option to print latest message 2023-09-20 10:18:06 +02:00
juk0de c43bafe47a main: improved metavar names and descriptions 2023-09-20 10:18:06 +02:00
juk0de 7dd83428fb test_question_cmd: added more testcases for '--repeat' 2023-09-20 10:18:06 +02:00
juk0de 3ad4b96b8f test_question_cmd: added testclass for the 'question_cmd()' function 2023-09-20 10:17:59 +02:00
juk0de 561003aabe question_cmd: implemented repeating of the latest message 2023-09-20 10:17:59 +02:00
juk0de 59eb45a3ca chat: improved message equality checks 2023-09-20 10:17:59 +02:00
juk0de 29a20bd2d8 message: added 'equals()' function and improved robustness and debugging 2023-09-20 10:17:59 +02:00
juk0de 80a1457dd1 configuration: the cache folder can now be specified in the configuration file 2023-09-20 10:17:59 +02:00
juk0de f964c5471e Merge pull request 'Refactoring, fixes and new features for the 'chat.py' module' (#12) from chat_refactoring into main
Reviewed-on: #12
2023-09-18 14:23:51 +02:00
20 changed files with 1371 additions and 461 deletions
+1
View File
@@ -106,6 +106,7 @@ celerybeat.pid
.venv
env/
venv/
.old/
ENV/
env.bak/
venv.bak/
+71 -62
View File
@@ -46,79 +46,81 @@ cmm [global options] command [command options]
The `question` command is used to ask, create, and process questions.
```bash
cmm question [-t OTAGS]... [-k ATAGS]... [-x XTAGS]... [-o OUTTAGS]... [-A AI] [-M MODEL] [-n NUM] [-m MAX] [-T TEMP] (-a ASK | -c CREATE | -r REPEAT | -p PROCESS) [-O] [-s SOURCE]... [-S SOURCE]...
cmm question [-t OTAGS]... [-k ATAGS]... [-x XTAGS]... [-o OUTTAGS]... [-A AI_ID] [-M MODEL] [-n NUM] [-m MAX] [-T TEMP] (-a QUESTION | -c QUESTION | -r [MESSAGE ...] | -p [MESSAGE ...]) [-O] [-s FILE]... [-S FILE]...
```
* `-t, --or-tags OTAGS` : List of tags (one must match)
* `-k, --and-tags ATAGS` : List of tags (all must match)
* `-x, --exclude-tags XTAGS` : List of tags to exclude
* `-o, --output-tags OUTTAGS` : List of output tags (default: use input tags)
* `-A, --AI AI` : AI ID to use
* `-M, --model MODEL` : Model to use
* `-n, --num-answers NUM` : Number of answers to request
* `-m, --max-tokens MAX` : Max. number of tokens
* `-T, --temperature TEMP` : Temperature value
* `-a, --ask ASK` : Ask a question
* `-c, --create CREATE` : Create a question
* `-r, --repeat REPEAT` : Repeat a question
* `-p, --process PROCESS` : Process existing questions
* `-O, --overwrite` : Overwrite existing messages when repeating them
* `-s, --source-text SOURCE` : Add content of a file to the query
* `-S, --source-code SOURCE` : Add source code file content to the chat history
* `-t, --or-tags OTAGS`: List of tags (one must match)
* `-k, --and-tags ATAGS`: List of tags (all must match)
* `-x, --exclude-tags XTAGS`: List of tags to exclude
* `-o, --output-tags OUTTAGS`: List of output tags (default: use input tags)
* `-A, --AI AI_ID`: AI ID to use
* `-M, --model MODEL`: Model to use
* `-n, --num-answers NUM`: Number of answers to request
* `-m, --max-tokens MAX`: Max. number of tokens
* `-T, --temperature TEMP`: Temperature value
* `-a, --ask QUESTION`: Ask a question
* `-c, --create QUESTION`: Create a question
* `-r, --repeat [MESSAGE ...]`: Repeat a question
* `-p, --process [MESSAGE ...]`: Process existing questions
* `-O, --overwrite`: Overwrite existing messages when repeating them
* `-s, --source-text FILE`: Add content of a file to the query
* `-S, --source-code FILE`: Add source code file content to the chat history
#### Hist
The `hist` command is used to print the chat history.
The `hist` command is used to print and manage the chat history.
```bash
cmm hist [-t OTAGS]... [-k ATAGS]... [-x XTAGS]... [-w] [-W] [-S] [-A ANSWER] [-Q QUESTION]
cmm hist [--print | --convert FORMAT] [-t OTAGS]... [-k ATAGS]... [-x XTAGS]... [-w] [-W] [-S] [-A SUBSTRING] [-Q SUBSTRING]
```
* `-t, --or-tags OTAGS` : List of tags (one must match)
* `-k, --and-tags ATAGS` : List of tags (all must match)
* `-x, --exclude-tags XTAGS` : List of tags to exclude
* `-w, --with-tags` : Print chat history with tags
* `-W, --with-files` : Print chat history with filenames
* `-S, --source-code-only` : Print only source code
* `-A, --answer ANSWER` : Search for answer substring
* `-Q, --question QUESTION` : Search for question substring
* `-p, --print`: Print the DB chat history
* `-c, --convert FORMAT`: Convert all messages to the given format
* `-t, --or-tags OTAGS`: List of tags (one must match)
* `-k, --and-tags ATAGS`: List of tags (all must match)
* `-x, --exclude-tags XTAGS`: List of tags to exclude
* `-w, --with-metadata`: Print chat history with metadata (tags, filenames, AI, etc.)
* `-S, --source-code-only`: Only print embedded source code
* `-A, --answer SUBSTRING`: Filter for answer substring
* `-Q, --question SUBSTRING`: Filter for question substring
#### Tags
The `tags` command is used to manage tags.
```bash
cmm tags (-l | -p PREFIX | -c CONTENT)
cmm tags (-l | -p PREFIX | -c SUBSTRING)
```
* `-l, --list` : List all tags and their frequency
* `-p, --prefix PREFIX` : Filter tags by prefix
* `-c, --contain CONTENT` : Filter tags by contained substring
* `-l, --list`: List all tags and their frequency
* `-p, --prefix PREFIX`: Filter tags by prefix
* `-c, --contain SUBSTRING`: Filter tags by contained substring
#### Config
The `config` command is used to manage the configuration.
```bash
cmm config (-l | -m | -c CREATE)
cmm config (-l | -m | -c FILE)
```
* `-l, --list-models` : List all available models
* `-m, --print-model` : Print the currently configured model
* `-c, --create CREATE` : Create config with default settings in the given file
* `-l, --list-models`: List all available models
* `-m, --print-model`: Print the currently configured model
* `-c, --create FILE`: Create config with default settings in the given file
#### Print
The `print` command is used to print message files.
```bash
cmm print -f FILE [-q | -a | -S]
cmm print (-f FILE | -l) [-q | -a | -S]
```
* `-f, --file FILE` : File to print
* `-q, --question` : Print only question
* `-a, --answer` : Print only answer
* `-S, --only-source-code` : Print only source code
* `-f, --file FILE`: Print given file
* `-l, --latest`: Print latest message
* `-q, --question`: Only print the question
* `-a, --answer`: Only print the answer
* `-S, --only-source-code`: Only print embedded source code
### Examples
@@ -160,18 +162,27 @@ cmm print -f example.yaml
## Configuration
The configuration file (`.config.yaml`) should contain the following fields:
The default configuration filename is `.config.yaml` (it is searched in the current working directory).
Use the command `cmm config --create <FILENAME>` to create a default configuration:
- `openai`:
- `api_key`: Your OpenAI API key.
- `model`: The name of the OpenAI model to use (e.g. "text-davinci-002").
- `temperature`: The temperature value for the model.
- `max_tokens`: The maximum number of tokens for the model.
- `top_p`: The top P value for the model.
- `frequency_penalty`: The frequency penalty value.
- `presence_penalty`: The presence penalty value.
- `system`: The system message used to set the behavior of the AI.
- `db`: The directory where the question-answer pairs are stored in YAML files.
```
cache: .
db: ./db/
ais:
myopenai:
name: openai
model: gpt-3.5-turbo-16k
api_key: 0123456789
temperature: 1.0
max_tokens: 4000
top_p: 1.0
frequency_penalty: 0.0
presence_penalty: 0.0
system: You are an assistant
```
Each AI has its own section and the name of that section is called the 'AI ID' (in the example above it is `myopenai`).
The AI ID can be any string, as long as it's unique within the `ais` section. The AI ID is used for all commands that support the `AI` parameter and it's also stored within each message file.
## Autocompletion
@@ -186,33 +197,33 @@ After adding this line, restart your shell or run `source <your-shell-config-fil
## Contributing
### Enable commit hooks
```
```bash
pip install pre-commit
pre-commit install
```
### Execute tests before opening a PR
```
```bash
pytest
```
### Consider using `pyenv` / `pyenv-virtualenv`
Short installation instructions:
* install `pyenv`:
```
```bash
cd ~
git clone https://github.com/pyenv/pyenv .pyenv
cd ~/.pyenv && src/configure && make -C src
```
* make sure that `~/.pyenv/shims` and `~/.pyenv/bin` are the first entries in your `PATH`, e. g. by setting it in `~/.bashrc`
* make sure that `~/.pyenv/shims` and `~/.pyenv/bin` are the first entries in your `PATH`, e.g., by setting it in `~/.bashrc`
* add the following to your `~/.bashrc` (after setting `PATH`): `eval "$(pyenv init -)"`
* create a new terminal or source the changes (e. g. `source ~/.bashrc`)
* create a new terminal or source the changes (e.g., `source ~/.bashrc`)
* install `virtualenv`
```
```bash
git clone https://github.com/pyenv/pyenv-virtualenv.git $(pyenv root)/plugins/pyenv-virtualenv
```
* add the following to your `~/.bashrc` (after the commands above): `eval "$(pyenv virtualenv-init -)`
* create a new terminal or source the changes (e. g. `source ~/.bashrc`)
* go back to the `ChatMasterMind` repo and create a virtual environment with the latest `Python`, e. g. `3.11.4`:
```
* create a new terminal or source the changes (e.g., `source ~/.bashrc`)
* go back to the `ChatMasterMind` repo and create a virtual environment with the latest `Python`, e.g., `3.11.4`:
```bash
cd <CMM_REPO_PATH>
pyenv install 3.11.4
pyenv virtualenv 3.11.4 py311
@@ -223,5 +234,3 @@ pyenv activate py311
## License
This project is licensed under the terms of the WTFPL License.
+12 -6
View File
@@ -3,18 +3,20 @@ Creates different AI instances, based on the given configuration.
"""
import argparse
from typing import cast
from typing import cast, Optional
from .configuration import Config, AIConfig, OpenAIConfig
from .ai import AI, AIError
from .ais.openai import OpenAI
def create_ai(args: argparse.Namespace, config: Config) -> AI: # noqa: 11
def create_ai(args: argparse.Namespace, config: Config, # noqa: 11
def_ai: Optional[str] = None,
def_model: Optional[str] = None) -> AI:
"""
Creates an AI subclass instance from the given arguments
and configuration file. If AI has not been set in the
arguments, it searches for the ID 'default'. If that
is not found, it uses the first AI in the list.
Creates an AI subclass instance from the given arguments and configuration file.
If AI has not been set in the arguments, it searches for the ID 'default'. If
that is not found, it uses the first AI in the list. It's also possible to
specify a default AI and model using 'def_ai' and 'def_model'.
"""
ai_conf: AIConfig
if hasattr(args, 'AI') and args.AI:
@@ -22,6 +24,8 @@ def create_ai(args: argparse.Namespace, config: Config) -> AI: # noqa: 11
ai_conf = config.ais[args.AI]
except KeyError:
raise AIError(f"AI ID '{args.AI}' does not exist in this configuration")
elif def_ai:
ai_conf = config.ais[def_ai]
elif 'default' in config.ais:
ai_conf = config.ais['default']
else:
@@ -34,6 +38,8 @@ def create_ai(args: argparse.Namespace, config: Config) -> AI: # noqa: 11
ai = OpenAI(cast(OpenAIConfig, ai_conf))
if hasattr(args, 'model') and args.model:
ai.config.model = args.model
elif def_model:
ai.config.model = def_model
if hasattr(args, 'max_tokens') and args.max_tokens:
ai.config.max_tokens = args.max_tokens
if hasattr(args, 'temperature') and args.temperature:
+1 -1
View File
@@ -44,7 +44,7 @@ class OpenAI(AI):
frequency_penalty=self.config.frequency_penalty,
presence_penalty=self.config.presence_penalty)
question.answer = Answer(response['choices'][0]['message']['content'])
question.tags = otags
question.tags = set(otags) if otags is not None else None
question.ai = self.ID
question.model = self.config.model
answers: list[Message] = [question]
+47 -37
View File
@@ -6,9 +6,9 @@ from pathlib import Path
from pprint import PrettyPrinter
from pydoc import pager
from dataclasses import dataclass
from typing import TypeVar, Type, Optional, ClassVar, Any, Callable, Literal, Union
from typing import TypeVar, Type, Optional, Any, Callable, Literal, Union
from .configuration import default_config_file
from .message import Message, MessageFilter, MessageError, message_in
from .message import Message, MessageFilter, MessageError, MessageFormat, message_in, message_valid_formats
from .tags import Tag
ChatInst = TypeVar('ChatInst', bound='Chat')
@@ -17,6 +17,7 @@ ChatDBInst = TypeVar('ChatDBInst', bound='ChatDB')
db_next_file = '.next'
ignored_files = [db_next_file, default_config_file]
msg_location = Literal['mem', 'disk', 'cache', 'db', 'all']
msg_suffix = Message.file_suffix_write
class ChatError(Exception):
@@ -52,7 +53,7 @@ def read_dir(dir_path: Path,
for file_path in sorted(file_iter):
if (file_path.is_file()
and file_path.name not in ignored_files # noqa: W503
and file_path.suffix in Message.file_suffixes): # noqa: W503
and file_path.suffix in Message.file_suffixes_read): # noqa: W503
try:
message = Message.from_file(file_path, mfilter)
if message:
@@ -63,22 +64,20 @@ def read_dir(dir_path: Path,
def make_file_path(dir_path: Path,
file_suffix: str,
next_fid: Callable[[], int]) -> Path:
"""
Create a file_path for the given directory using the
given file_suffix and ID generator function.
Create a file_path for the given directory using the given ID generator function.
"""
file_path = dir_path / f"{next_fid():04d}{file_suffix}"
file_path = dir_path / f"{next_fid():04d}{msg_suffix}"
while file_path.exists():
file_path = dir_path / f"{next_fid():04d}{file_suffix}"
file_path = dir_path / f"{next_fid():04d}{msg_suffix}"
return file_path
def write_dir(dir_path: Path,
messages: list[Message],
file_suffix: str,
next_fid: Callable[[], int]) -> None:
next_fid: Callable[[], int],
mformat: MessageFormat = Message.default_format) -> None:
"""
Write all messages to the given directory. If a message has no file_path,
a new one will be created. If message.file_path exists, it will be modified
@@ -86,18 +85,17 @@ def write_dir(dir_path: Path,
Parameters:
* 'dir_path': destination directory
* 'messages': list of messages to write
* 'file_suffix': suffix for the message files ['.txt'|'.yaml']
* 'next_fid': callable that returns the next file ID
"""
for message in messages:
file_path = message.file_path
# message has no file_path: create one
if not file_path:
file_path = make_file_path(dir_path, file_suffix, next_fid)
file_path = make_file_path(dir_path, next_fid)
# file_path does not point to given directory: modify it
elif not file_path.parent.samefile(dir_path):
file_path = dir_path / file_path.name
message.to_file(file_path)
message.to_file(file_path, mformat=mformat)
def clear_dir(dir_path: Path,
@@ -109,7 +107,7 @@ def clear_dir(dir_path: Path,
for file_path in file_iter:
if (file_path.is_file()
and file_path.name not in ignored_files # noqa: W503
and file_path.suffix in Message.file_suffixes): # noqa: W503
and file_path.suffix in Message.file_suffixes_read): # noqa: W503
file_path.unlink(missing_ok=True)
@@ -146,7 +144,7 @@ class Chat:
Matching is True if:
* 'name' matches the full 'file_path'
* 'name' matches 'file_path.name' (i. e. including the suffix)
* 'name' matches 'file_path.stem' (i. e. without a suffix)
* 'name' matches 'file_path.stem' (i. e. without the suffix)
"""
return Path(name) == file_path or name == file_path.name or name == file_path.stem
@@ -257,14 +255,17 @@ class Chat:
return sum(m.tokens() for m in self.messages)
def print(self, source_code_only: bool = False,
with_tags: bool = False, with_files: bool = False,
paged: bool = True) -> None:
with_metadata: bool = False,
paged: bool = True,
tight: bool = False) -> None:
output: list[str] = []
for message in self.messages:
if source_code_only:
output.append(message.to_str(source_code_only=True))
continue
output.append(message.to_str(with_tags, with_files))
output.append(message.to_str(with_metadata))
if not tight:
output.append('\n' + ('-' * terminal_width()) + '\n')
if paged:
print_paged('\n'.join(output))
else:
@@ -281,15 +282,14 @@ class ChatDB(Chat):
persistently.
"""
default_file_suffix: ClassVar[str] = '.txt'
cache_path: Path
db_path: Path
# a MessageFilter that all messages must match (if given)
mfilter: Optional[MessageFilter] = None
file_suffix: str = default_file_suffix
# the glob pattern for all messages
glob: Optional[str] = None
# message format (for writing)
mformat: MessageFormat = Message.default_format
def __post_init__(self) -> None:
# contains the latest message ID
@@ -317,8 +317,7 @@ class ChatDB(Chat):
when reading them.
"""
messages = read_dir(db_path, glob, mfilter)
return cls(messages, cache_path, db_path, mfilter,
cls.default_file_suffix, glob)
return cls(messages, cache_path, db_path, mfilter, glob)
@classmethod
def from_messages(cls: Type[ChatDBInst],
@@ -345,7 +344,17 @@ class ChatDB(Chat):
with open(self.next_path, 'w') as f:
f.write(f'{fid}')
def msg_write(self, messages: Optional[list[Message]] = None) -> None:
def set_msg_format(self, mformat: MessageFormat) -> None:
"""
Set message format for writing messages.
"""
if mformat not in message_valid_formats:
raise ChatError(f"Message format '{mformat}' is not supported")
self.mformat = mformat
def msg_write(self,
messages: Optional[list[Message]] = None,
mformat: Optional[MessageFormat] = None) -> None:
"""
Write either the given messages or the internal ones to their CURRENT file_path.
If messages are given, they all must have a valid file_path. When writing the
@@ -356,7 +365,7 @@ class ChatDB(Chat):
raise ChatError("Can't write files without a valid file_path")
msgs = iter(messages if messages else self.messages)
while (m := next(msgs, None)):
m.to_file()
m.to_file(mformat=mformat if mformat else self.mformat)
def msg_update(self, messages: list[Message], write: bool = True) -> None:
"""
@@ -377,6 +386,7 @@ class ChatDB(Chat):
def msg_gather(self,
loc: msg_location,
require_file_path: bool = False,
glob: Optional[str] = None,
mfilter: Optional[MessageFilter] = None) -> list[Message]:
"""
Gather and return messages from the given locations:
@@ -395,9 +405,9 @@ class ChatDB(Chat):
else:
loc_messages += [m for m in self.messages if (mfilter is None or m.match(mfilter))]
if loc in ['cache', 'disk', 'all']:
loc_messages += read_dir(self.cache_path, mfilter=mfilter)
loc_messages += read_dir(self.cache_path, glob=glob, mfilter=mfilter)
if loc in ['db', 'disk', 'all']:
loc_messages += read_dir(self.db_path, mfilter=mfilter)
loc_messages += read_dir(self.db_path, glob=glob, mfilter=mfilter)
# remove_duplicates and sort the list
unique_messages: list[Message] = []
for m in loc_messages:
@@ -518,8 +528,8 @@ class ChatDB(Chat):
"""
write_dir(self.cache_path,
messages if messages else self.messages,
self.file_suffix,
self.get_next_fid)
self.get_next_fid,
self.mformat)
def cache_add(self, messages: list[Message], write: bool = True) -> None:
"""
@@ -531,11 +541,11 @@ class ChatDB(Chat):
if write:
write_dir(self.cache_path,
messages,
self.file_suffix,
self.get_next_fid)
self.get_next_fid,
self.mformat)
else:
for m in messages:
m.file_path = make_file_path(self.cache_path, self.default_file_suffix, self.get_next_fid)
m.file_path = make_file_path(self.cache_path, self.get_next_fid)
self.messages += messages
self.msg_sort()
@@ -585,8 +595,8 @@ class ChatDB(Chat):
"""
write_dir(self.db_path,
messages if messages else self.messages,
self.file_suffix,
self.get_next_fid)
self.get_next_fid,
self.mformat)
def db_add(self, messages: list[Message], write: bool = True) -> None:
"""
@@ -598,11 +608,11 @@ class ChatDB(Chat):
if write:
write_dir(self.db_path,
messages,
self.file_suffix,
self.get_next_fid)
self.get_next_fid,
self.mformat)
else:
for m in messages:
m.file_path = make_file_path(self.db_path, self.default_file_suffix, self.get_next_fid)
m.file_path = make_file_path(self.db_path, self.get_next_fid)
self.messages += messages
self.msg_sort()
+55 -6
View File
@@ -1,13 +1,51 @@
import sys
import argparse
from pathlib import Path
from ..configuration import Config
from ..chat import ChatDB
from ..message import MessageFilter
from ..message import MessageFilter, Message
def hist_cmd(args: argparse.Namespace, config: Config) -> None:
msg_suffix = Message.file_suffix_write # currently '.msg'
def convert_messages(args: argparse.Namespace, config: Config) -> None:
"""
Handler for the 'hist' command.
Convert messages to a new format. Also used to change old suffixes
('.txt', '.yaml') to the latest default message file suffix ('.msg').
"""
chat = ChatDB.from_dir(Path(config.cache),
Path(config.db))
# read all known message files
msgs = chat.msg_gather(loc='disk', glob='*.*')
# make a set of all message IDs
msg_ids = set([m.msg_id() for m in msgs])
# set requested format and write all messages
chat.set_msg_format(args.convert)
# delete the current suffix
# -> a new one will automatically be created
for m in msgs:
if m.file_path:
m.file_path = m.file_path.with_suffix('')
chat.msg_write(msgs)
# read all messages with the current default suffix
msgs = chat.msg_gather(loc='disk', glob=f'*{msg_suffix}')
# make sure we converted all of the original messages
for mid in msg_ids:
if not any(mid == m.msg_id() for m in msgs):
print(f"Message '{mid}' has not been found after conversion. Aborting.")
sys.exit(1)
# delete messages with old suffixes
msgs = chat.msg_gather(loc='disk', glob='*.*')
for m in msgs:
if m.file_path and m.file_path.suffix != msg_suffix:
m.rm_file()
print(f"Successfully converted {len(msg_ids)} messages.")
def print_chat(args: argparse.Namespace, config: Config) -> None:
"""
Print the DB chat history.
"""
mfilter = MessageFilter(tags_or=args.or_tags,
@@ -15,9 +53,20 @@ def hist_cmd(args: argparse.Namespace, config: Config) -> None:
tags_not=args.exclude_tags,
question_contains=args.question,
answer_contains=args.answer)
chat = ChatDB.from_dir(Path('.'),
chat = ChatDB.from_dir(Path(config.cache),
Path(config.db),
mfilter=mfilter)
chat.print(args.source_code_only,
args.with_tags,
args.with_files)
args.with_metadata,
paged=not args.no_paging,
tight=args.tight)
def hist_cmd(args: argparse.Namespace, config: Config) -> None:
"""
Handler for the 'hist' command.
"""
if args.print:
print_chat(args, config)
elif args.convert:
convert_messages(args, config)
+34 -16
View File
@@ -3,25 +3,43 @@ import argparse
from pathlib import Path
from ..configuration import Config
from ..message import Message, MessageError
from ..chat import ChatDB
def print_message(message: Message, args: argparse.Namespace) -> None:
"""
Print given message according to give arguments.
"""
if args.question:
print(message.question)
elif args.answer:
print(message.answer)
elif message.answer and args.only_source_code:
for code in message.answer.source_code():
print(code)
else:
print(message.to_str())
def print_cmd(args: argparse.Namespace, config: Config) -> None:
"""
Handler for the 'print' command.
"""
fname = Path(args.file)
try:
message = Message.from_file(fname)
if message:
if args.question:
print(message.question)
elif args.answer:
print(message.answer)
elif message.answer and args.only_source_code:
for code in message.answer.source_code():
print(code)
else:
print(message.to_str())
except MessageError:
print(f"File is not a valid message: {args.file}")
sys.exit(1)
# print given file
if args.file is not None:
fname = Path(args.file)
try:
message = Message.from_file(fname)
if message:
print_message(message, args)
except MessageError:
print(f"File is not a valid message: {args.file}")
sys.exit(1)
# print latest message
elif args.latest:
chat = ChatDB.from_dir(Path(config.cache), Path(config.db))
latest = chat.msg_latest(loc='disk')
if not latest:
print("No message found!")
sys.exit(1)
print_message(latest, args)
+125 -30
View File
@@ -1,6 +1,8 @@
import sys
import argparse
from pathlib import Path
from itertools import zip_longest
from copy import deepcopy
from ..configuration import Config
from ..chat import ChatDB
from ..message import Message, MessageFilter, MessageError, Question, source_code
@@ -8,6 +10,10 @@ from ..ai_factory import create_ai
from ..ai import AI, AIResponse
class QuestionCmdError(Exception):
pass
def add_file_as_text(question_parts: list[str], file: str) -> None:
"""
Add the given file as plain text to the question part list.
@@ -49,13 +55,41 @@ def add_file_as_code(question_parts: list[str], file: str) -> None:
question_parts.append(f"```\n{content}\n```")
def create_msg_args(msg: Message, args: argparse.Namespace) -> argparse.Namespace:
"""
Takes an existing message and CLI arguments, and returns modified args based
on the members of the given message. Used e.g. when repeating messages, where
it's necessary to determine the correct AI, module and output tags to use
(either from the existing message or the given args).
"""
msg_args = args
# if AI, model or output tags have not been specified,
# use those from the original message
if (args.AI is None
or args.model is None # noqa: W503
or args.output_tags is None): # noqa: W503
msg_args = deepcopy(args)
if args.AI is None and msg.ai is not None:
msg_args.AI = msg.ai
if args.model is None and msg.model is not None:
msg_args.model = msg.model
if args.output_tags is None and msg.tags is not None:
msg_args.output_tags = msg.tags
return msg_args
def create_message(chat: ChatDB, args: argparse.Namespace) -> Message:
"""
Creates a new message from the given arguments and writes it
Create a new message from the given arguments and write it
to the cache directory.
"""
question_parts = []
question_list = args.ask if args.ask is not None else []
if args.create is not None:
question_list = args.create
elif args.ask is not None:
question_list = args.ask
else:
raise QuestionCmdError("No question found")
text_files = args.source_text if args.source_text is not None else []
code_files = args.source_code if args.source_code is not None else []
@@ -70,21 +104,87 @@ def create_message(chat: ChatDB, args: argparse.Namespace) -> Message:
full_question = '\n\n'.join(question_parts)
message = Message(question=Question(full_question),
tags=args.output_tags, # FIXME
tags=args.output_tags,
ai=args.AI,
model=args.model)
chat.cache_add([message])
# only write the new message to the cache,
# don't add it to the internal list
chat.cache_write([message])
return message
def make_request(ai: AI, chat: ChatDB, message: Message, args: argparse.Namespace) -> None:
"""
Make an AI request with the given AI, chat history, message and arguments.
Write the response(s) to the cache directory, without appending it to the
given chat history. Then print the response(s).
"""
# print history and message question before making the request
ai.print()
chat.print(paged=False)
print(message.to_str())
response: AIResponse = ai.request(message,
chat,
args.num_answers,
args.output_tags)
# only write the response messages to the cache,
# don't add them to the internal list
chat.cache_write(response.messages)
for idx, msg in enumerate(response.messages):
print(f"=== ANSWER {idx+1} ===")
print(msg.answer)
if response.tokens:
print("===============")
print(response.tokens)
def repeat_messages(messages: list[Message], chat: ChatDB, args: argparse.Namespace, config: Config) -> None:
"""
Repeat the given messages using the given arguments.
"""
ai: AI
for msg in messages:
msg_args = create_msg_args(msg, args)
ai = create_ai(msg_args, config)
print(f"--------- Repeating message '{msg.msg_id()}': ---------")
# overwrite the latest message if requested or empty
# -> but not if it's in the DB!
if ((msg.answer is None or msg_args.overwrite is True)
and (not chat.msg_in_db(msg))): # noqa: W503
msg.clear_answer()
make_request(ai, chat, msg, msg_args)
# otherwise create a new one
else:
msg_args.ask = [msg.question]
message = create_message(chat, msg_args)
make_request(ai, chat, message, msg_args)
def invert_input_tag_args(args: argparse.Namespace) -> None:
"""
Changes the semantics of the INPUT tags for this command:
* not tags specified on the CLI -> no tags are selected
* empty tags specified on the CLI -> all tags are selected
"""
if args.or_tags is None:
args.or_tags = set()
elif len(args.or_tags) == 0:
args.or_tags = None
if args.and_tags is None:
args.and_tags = set()
elif len(args.and_tags) == 0:
args.and_tags = None
def question_cmd(args: argparse.Namespace, config: Config) -> None:
"""
Handler for the 'question' command.
"""
mfilter = MessageFilter(tags_or=args.or_tags if args.or_tags is not None else set(),
tags_and=args.and_tags if args.and_tags is not None else set(),
tags_not=args.exclude_tags if args.exclude_tags is not None else set())
chat = ChatDB.from_dir(cache_path=Path('.'),
invert_input_tag_args(args)
mfilter = MessageFilter(tags_or=args.or_tags,
tags_and=args.and_tags,
tags_not=args.exclude_tags)
chat = ChatDB.from_dir(cache_path=Path(config.cache),
db_path=Path(config.db),
mfilter=mfilter)
# if it's a new question, create and store it immediately
@@ -93,30 +193,25 @@ def question_cmd(args: argparse.Namespace, config: Config) -> None:
if args.create:
return
# create the correct AI instance
ai: AI = create_ai(args, config)
# === ASK ===
if args.ask:
ai.print()
chat.print(paged=False)
response: AIResponse = ai.request(message,
chat,
args.num_answers, # FIXME
args.output_tags) # FIXME
chat.msg_update([response.messages[0]])
chat.cache_add(response.messages[1:])
for idx, msg in enumerate(response.messages):
print(f"=== ANSWER {idx+1} ===")
print(msg.answer)
if response.tokens:
print("===============")
print(response.tokens)
ai: AI = create_ai(args, config)
make_request(ai, chat, message, args)
# === REPEAT ===
elif args.repeat is not None:
lmessage = chat.msg_latest()
assert lmessage
# TODO: repeat either the last question or the
# one(s) given in 'args.repeat' (overwrite
# existing ones if 'args.overwrite' is True)
pass
repeat_msgs: list[Message] = []
# repeat latest message
if len(args.repeat) == 0:
lmessage = chat.msg_latest(loc='cache')
if lmessage is None:
print("No message found to repeat!")
sys.exit(1)
repeat_msgs.append(lmessage)
# repeat given message(s)
else:
repeat_msgs = chat.msg_find(args.repeat, loc='disk')
repeat_messages(repeat_msgs, chat, args, config)
# === PROCESS ===
elif args.process is not None:
# TODO: process either all questions without an
# answer or the one(s) given in 'args.process'
+1 -1
View File
@@ -8,7 +8,7 @@ def tags_cmd(args: argparse.Namespace, config: Config) -> None:
"""
Handler for the 'tags' command.
"""
chat = ChatDB.from_dir(cache_path=Path('.'),
chat = ChatDB.from_dir(cache_path=Path(config.cache),
db_path=Path(config.db))
if args.list:
tags_freq = chat.msg_tags_frequency(args.prefix, args.contain)
+2
View File
@@ -116,6 +116,7 @@ class Config:
"""
# all members have default values, so we can easily create
# a default configuration
cache: str = '.'
db: str = './db/'
ais: dict[str, AIConfig] = field(default_factory=create_default_ai_configs)
@@ -132,6 +133,7 @@ class Config:
ai_conf = ai_config_instance(conf['name'], conf)
ais[ID] = ai_conf
return cls(
cache=str(source['cache']) if 'cache' in source else '.',
db=str(source['db']),
ais=ais
)
+32 -26
View File
@@ -34,23 +34,23 @@ def create_parser() -> argparse.ArgumentParser:
# a parent parser for all commands that support tag selection
tag_parser = argparse.ArgumentParser(add_help=False)
tag_arg = tag_parser.add_argument('-t', '--or-tags', nargs='+',
tag_arg = tag_parser.add_argument('-t', '--or-tags', nargs='*',
help='List of tags (one must match)', metavar='OTAGS')
tag_arg.completer = tags_completer # type: ignore
atag_arg = tag_parser.add_argument('-k', '--and-tags', nargs='+',
atag_arg = tag_parser.add_argument('-k', '--and-tags', nargs='*',
help='List of tags (all must match)', metavar='ATAGS')
atag_arg.completer = tags_completer # type: ignore
etag_arg = tag_parser.add_argument('-x', '--exclude-tags', nargs='+',
etag_arg = tag_parser.add_argument('-x', '--exclude-tags', nargs='*',
help='List of tags to exclude', metavar='XTAGS')
etag_arg.completer = tags_completer # type: ignore
otag_arg = tag_parser.add_argument('-o', '--output-tags', nargs='+',
help='List of output tags (default: use input tags)', metavar='OUTTAGS')
help='List of output tags (default: use input tags)', metavar='OUTAGS')
otag_arg.completer = tags_completer # type: ignore
# a parent parser for all commands that support AI configuration
ai_parser = argparse.ArgumentParser(add_help=False)
ai_parser.add_argument('-A', '--AI', help='AI ID to use')
ai_parser.add_argument('-M', '--model', help='Model to use')
ai_parser.add_argument('-A', '--AI', help='AI ID to use', metavar='AI_ID')
ai_parser.add_argument('-M', '--model', help='Model to use', metavar='MODEL')
ai_parser.add_argument('-n', '--num-answers', help='Number of answers to request', type=int, default=1)
ai_parser.add_argument('-m', '--max-tokens', help='Max. nr. of tokens', type=int)
ai_parser.add_argument('-T', '--temperature', help='Temperature value', type=float)
@@ -61,28 +61,32 @@ def create_parser() -> argparse.ArgumentParser:
aliases=['q'])
question_cmd_parser.set_defaults(func=question_cmd)
question_group = question_cmd_parser.add_mutually_exclusive_group(required=True)
question_group.add_argument('-a', '--ask', nargs='+', help='Ask a question')
question_group.add_argument('-c', '--create', nargs='+', help='Create a question')
question_group.add_argument('-r', '--repeat', nargs='*', help='Repeat a question')
question_group.add_argument('-p', '--process', nargs='*', help='Process existing questions')
question_group.add_argument('-a', '--ask', nargs='+', help='Ask a question', metavar='QUESTION')
question_group.add_argument('-c', '--create', nargs='+', help='Create a question', metavar='QUESTION')
question_group.add_argument('-r', '--repeat', nargs='*', help='Repeat a question', metavar='MESSAGE')
question_group.add_argument('-p', '--process', nargs='*', help='Process existing questions', metavar='MESSAGE')
question_cmd_parser.add_argument('-O', '--overwrite', help='Overwrite existing messages when repeating them',
action='store_true')
question_cmd_parser.add_argument('-s', '--source-text', nargs='+', help='Add content of a file to the query')
question_cmd_parser.add_argument('-S', '--source-code', nargs='+', help='Add source code file content to the chat history')
question_cmd_parser.add_argument('-s', '--source-text', nargs='+', help='Add content of a file to the query', metavar='FILE')
question_cmd_parser.add_argument('-S', '--source-code', nargs='+', help='Add source code file content to the chat history',
metavar='FILE')
# 'hist' command parser
hist_cmd_parser = cmdparser.add_parser('hist', parents=[tag_parser],
help="Print chat history.",
help="Print and manage chat history.",
aliases=['h'])
hist_cmd_parser.set_defaults(func=hist_cmd)
hist_cmd_parser.add_argument('-w', '--with-tags', help="Print chat history with tags.",
hist_group = hist_cmd_parser.add_mutually_exclusive_group(required=True)
hist_group.add_argument('-p', '--print', help='Print the DB chat history', action='store_true')
hist_group.add_argument('-c', '--convert', help='Convert all message files to the given format [txt|yaml]', metavar='FORMAT')
hist_cmd_parser.add_argument('-w', '--with-metadata', help="Print chat history with metadata (tags, filename, AI, etc.).",
action='store_true')
hist_cmd_parser.add_argument('-W', '--with-files', help="Print chat history with filenames.",
hist_cmd_parser.add_argument('-S', '--source-code-only', help='Only print embedded source code',
action='store_true')
hist_cmd_parser.add_argument('-S', '--source-code-only', help='Print only source code',
action='store_true')
hist_cmd_parser.add_argument('-A', '--answer', help='Search for answer substring')
hist_cmd_parser.add_argument('-Q', '--question', help='Search for question substring')
hist_cmd_parser.add_argument('-A', '--answer', help='Print only answers with given substring', metavar='SUBSTRING')
hist_cmd_parser.add_argument('-Q', '--question', help='Print only questions with given substring', metavar='SUBSTRING')
hist_cmd_parser.add_argument('-d', '--tight', help='Print without message separators', action='store_true')
hist_cmd_parser.add_argument('-P', '--no-paging', help='Print without paging', action='store_true')
# 'tags' command parser
tags_cmd_parser = cmdparser.add_parser('tags',
@@ -92,8 +96,8 @@ def create_parser() -> argparse.ArgumentParser:
tags_group = tags_cmd_parser.add_mutually_exclusive_group(required=True)
tags_group.add_argument('-l', '--list', help="List all tags and their frequency",
action='store_true')
tags_cmd_parser.add_argument('-p', '--prefix', help="Filter tags by prefix")
tags_cmd_parser.add_argument('-c', '--contain', help="Filter tags by contained substring")
tags_cmd_parser.add_argument('-p', '--prefix', help="Filter tags by prefix", metavar='PREFIX')
tags_cmd_parser.add_argument('-c', '--contain', help="Filter tags by contained substring", metavar='SUBSTRING')
# 'config' command parser
config_cmd_parser = cmdparser.add_parser('config',
@@ -106,18 +110,20 @@ def create_parser() -> argparse.ArgumentParser:
action='store_true')
config_group.add_argument('-m', '--print-model', help="Print the currently configured model",
action='store_true')
config_group.add_argument('-c', '--create', help="Create config with default settings in the given file")
config_group.add_argument('-c', '--create', help="Create config with default settings in the given file", metavar='FILE')
# 'print' command parser
print_cmd_parser = cmdparser.add_parser('print',
help="Print message files.",
aliases=['p'])
print_cmd_parser.set_defaults(func=print_cmd)
print_cmd_parser.add_argument('-f', '--file', help='File to print', required=True)
print_group = print_cmd_parser.add_mutually_exclusive_group(required=True)
print_group.add_argument('-f', '--file', help='Print given message file', metavar='FILE')
print_group.add_argument('-l', '--latest', help='Print latest message', action='store_true')
print_cmd_modes = print_cmd_parser.add_mutually_exclusive_group()
print_cmd_modes.add_argument('-q', '--question', help='Print only question', action='store_true')
print_cmd_modes.add_argument('-a', '--answer', help='Print only answer', action='store_true')
print_cmd_modes.add_argument('-S', '--only-source-code', help='Print only source code', action='store_true')
print_cmd_modes.add_argument('-q', '--question', help='Only print the question', action='store_true')
print_cmd_modes.add_argument('-a', '--answer', help='Only print the answer', action='store_true')
print_cmd_modes.add_argument('-S', '--only-source-code', help='Only print embedded source code', action='store_true')
argcomplete.autocomplete(parser)
return parser
+96 -55
View File
@@ -5,7 +5,8 @@ import pathlib
import yaml
import tempfile
import shutil
from typing import Type, TypeVar, ClassVar, Optional, Any, Union, Final, Literal, Iterable
from typing import Type, TypeVar, ClassVar, Optional, Any, Union, Final, Literal, Iterable, Tuple
from typing import get_args as typing_get_args
from dataclasses import dataclass, asdict, field
from .tags import Tag, TagLine, TagError, match_tags, rename_tags
@@ -15,6 +16,9 @@ MessageInst = TypeVar('MessageInst', bound='Message')
AILineInst = TypeVar('AILineInst', bound='AILine')
ModelLineInst = TypeVar('ModelLineInst', bound='ModelLine')
YamlDict = dict[str, Union[QuestionInst, AnswerInst, set[Tag]]]
MessageFormat = Literal['txt', 'yaml']
message_valid_formats: Final[Tuple[MessageFormat, ...]] = typing_get_args(MessageFormat)
message_default_format: Final[MessageFormat] = 'txt'
class MessageError(Exception):
@@ -92,7 +96,7 @@ class MessageFilter:
class AILine(str):
"""
A line that represents the AI name in a '.txt' file..
A line that represents the AI name in the 'txt' format.
"""
prefix: Final[str] = 'AI:'
@@ -112,7 +116,7 @@ class AILine(str):
class ModelLine(str):
"""
A line that represents the model name in a '.txt' file..
A line that represents the model name in the 'txt' format.
"""
prefix: Final[str] = 'MODEL:'
@@ -216,18 +220,44 @@ class Message():
model: Optional[str] = field(default=None, compare=False)
file_path: Optional[pathlib.Path] = field(default=None, compare=False)
# class variables
file_suffixes: ClassVar[list[str]] = ['.txt', '.yaml']
file_suffixes_read: ClassVar[list[str]] = ['.msg', '.txt', '.yaml']
file_suffix_write: ClassVar[str] = '.msg'
default_format: ClassVar[MessageFormat] = message_default_format
tags_yaml_key: ClassVar[str] = 'tags'
file_yaml_key: ClassVar[str] = 'file_path'
ai_yaml_key: ClassVar[str] = 'ai'
model_yaml_key: ClassVar[str] = 'model'
def __post_init__(self) -> None:
# convert some types that are often set wrong
if self.tags is not None and not isinstance(self.tags, set):
self.tags = set(self.tags)
if self.file_path is not None and not isinstance(self.file_path, pathlib.Path):
self.file_path = pathlib.Path(self.file_path)
def __hash__(self) -> int:
"""
The hash value is computed based on immutable members.
"""
return hash((self.question, self.answer))
def equals(self, other: MessageInst, tags: bool = True, ai: bool = True,
model: bool = True, file_path: bool = True, verbose: bool = False) -> bool:
"""
Compare this message with another one, including the metadata.
Return True if everything is identical, False otherwise.
"""
equal: bool = ((not tags or (self.tags == other.tags))
and (not ai or (self.ai == other.ai)) # noqa: W503
and (not model or (self.model == other.model)) # noqa: W503
and (not file_path or (self.file_path == other.file_path)) # noqa: W503
and (self == other)) # noqa: W503
if not equal and verbose:
print("Messages not equal:")
print(self)
print(other)
return equal
@classmethod
def from_dict(cls: Type[MessageInst], data: dict[str, Any]) -> MessageInst:
"""
@@ -252,24 +282,16 @@ class Message():
tags: set[Tag] = set()
if not file_path.exists():
raise MessageError(f"Message file '{file_path}' does not exist")
if file_path.suffix not in cls.file_suffixes:
if file_path.suffix not in cls.file_suffixes_read:
raise MessageError(f"File type '{file_path.suffix}' is not supported")
# for TXT, it's enough to read the TagLine
if file_path.suffix == '.txt':
with open(file_path, "r") as fd:
try:
tags = TagLine(fd.readline()).tags(prefix, contain)
except TagError:
pass # message without tags
else: # '.yaml'
try:
message = cls.from_file(file_path)
if message:
msg_tags = message.filter_tags(prefix=prefix, contain=contain)
except MessageError as e:
print(f"Error processing message in '{file_path}': {str(e)}")
if msg_tags:
tags = msg_tags
try:
message = cls.from_file(file_path)
if message:
msg_tags = message.filter_tags(prefix=prefix, contain=contain)
except MessageError as e:
print(f"Error processing message in '{file_path}': {str(e)}")
if msg_tags:
tags = msg_tags
return tags
@classmethod
@@ -304,15 +326,16 @@ class Message():
"""
if not file_path.exists():
raise MessageError(f"Message file '{file_path}' does not exist")
if file_path.suffix not in cls.file_suffixes:
if file_path.suffix not in cls.file_suffixes_read:
raise MessageError(f"File type '{file_path.suffix}' is not supported")
if file_path.suffix == '.txt':
# try TXT first
try:
message = cls.__from_file_txt(file_path,
mfilter.tags_or if mfilter else None,
mfilter.tags_and if mfilter else None,
mfilter.tags_not if mfilter else None)
else:
# then YAML
except MessageError:
message = cls.__from_file_yaml(file_path)
if message and (mfilter is None or message.match(mfilter)):
return message
@@ -349,10 +372,6 @@ class Message():
tags = TagLine(fd.readline()).tags()
except TagError:
fd.seek(pos)
if tags_or or tags_and or tags_not:
# match with an empty set if the file has no tags
if not match_tags(tags, tags_or, tags_and, tags_not):
return None
# AILine (Optional)
try:
pos = fd.tell()
@@ -367,17 +386,23 @@ class Message():
fd.seek(pos)
# Question and Answer
text = fd.read().strip().split('\n')
try:
question_idx = text.index(Question.txt_header) + 1
except ValueError:
raise MessageError(f"'{file_path}' does not contain a valid message")
try:
answer_idx = text.index(Answer.txt_header)
question = Question.from_list(text[question_idx:answer_idx])
answer = Answer.from_list(text[answer_idx + 1:])
except ValueError:
question = Question.from_list(text[question_idx:])
return cls(question, answer, tags, ai, model, file_path)
try:
question_idx = text.index(Question.txt_header) + 1
except ValueError:
raise MessageError(f"'{file_path}' does not contain a valid message")
try:
answer_idx = text.index(Answer.txt_header)
question = Question.from_list(text[question_idx:answer_idx])
answer = Answer.from_list(text[answer_idx + 1:])
except ValueError:
question = Question.from_list(text[question_idx:])
# match tags AFTER reading the whole file
# -> make sure it's a valid 'txt' file format
if tags_or or tags_and or tags_not:
# match with an empty set if the file has no tags
if not match_tags(tags, tags_or, tags_and, tags_not):
return None
return cls(question, answer, tags, ai, model, file_path)
@classmethod
def __from_file_yaml(cls: Type[MessageInst], file_path: pathlib.Path) -> MessageInst:
@@ -397,7 +422,7 @@ class Message():
except Exception:
raise MessageError(f"'{file_path}' does not contain a valid message")
def to_str(self, with_tags: bool = False, with_file: bool = False, source_code_only: bool = False) -> str:
def to_str(self, with_metadata: bool = False, source_code_only: bool = False) -> str:
"""
Return the current Message as a string.
"""
@@ -407,10 +432,11 @@ class Message():
if self.answer:
output.extend(self.answer.source_code(include_delims=True))
return '\n'.join(output) if len(output) > 0 else ''
if with_tags:
if with_metadata:
output.append(self.tags_str())
if with_file:
output.append('FILE: ' + str(self.file_path))
output.append('AI: ' + str(self.ai))
output.append('MODEL: ' + str(self.model))
output.append(Question.txt_header)
output.append(self.question)
if self.answer:
@@ -418,24 +444,29 @@ class Message():
output.append(self.answer)
return '\n'.join(output)
def __str__(self) -> str:
return self.to_str(True, True, False)
def to_file(self, file_path: Optional[pathlib.Path]=None) -> None: # noqa: 11
def to_file(self, file_path: Optional[pathlib.Path]=None, mformat: MessageFormat = message_default_format) -> None: # noqa: 11
"""
Write a Message to the given file. Type is determined based on the suffix.
Currently supported suffixes: ['.txt', '.yaml']
Write a Message to the given file. Supported message file formats are 'txt' and 'yaml'.
Suffix is always '.msg'.
"""
if file_path:
self.file_path = file_path
if not self.file_path:
raise MessageError("Got no valid path to write message")
if self.file_path.suffix not in self.file_suffixes:
raise MessageError(f"File type '{self.file_path.suffix}' is not supported")
if mformat not in message_valid_formats:
raise MessageError(f"File format '{mformat}' is not supported")
# check for valid suffix
# -> add one if it's empty
# -> refuse old or otherwise unsupported suffixes
if not self.file_path.suffix:
self.file_path = self.file_path.with_suffix(self.file_suffix_write)
elif self.file_path.suffix != self.file_suffix_write:
raise MessageError(f"File suffix '{self.file_path.suffix}' is not supported")
# TXT
if self.file_path.suffix == '.txt':
if mformat == 'txt':
return self.__to_file_txt(self.file_path)
elif self.file_path.suffix == '.yaml':
# YAML
elif mformat == 'yaml':
return self.__to_file_yaml(self.file_path)
def __to_file_txt(self, file_path: pathlib.Path) -> None:
@@ -447,8 +478,8 @@ class Message():
* Model [Optional]
* Question.txt_header
* Question
* Answer.txt_header
* Answer
* Answer.txt_header [Optional]
* Answer [Optional]
"""
with tempfile.NamedTemporaryFile(dir=file_path.parent, prefix=file_path.name, mode="w", delete=False) as temp_fd:
temp_file_path = pathlib.Path(temp_fd.name)
@@ -487,6 +518,13 @@ class Message():
yaml.dump(data, temp_fd, sort_keys=False)
shutil.move(temp_file_path, file_path)
def rm_file(self) -> None:
"""
Delete the message file. Ignore empty file_path and not existing files.
"""
if self.file_path is not None:
self.file_path.unlink(missing_ok=True)
def filter_tags(self, prefix: Optional[str] = None, contain: Optional[str] = None) -> set[Tag]:
"""
Filter tags based on their prefix (i. e. the tag starts with a given string)
@@ -540,6 +578,9 @@ class Message():
if self.tags:
self.tags = rename_tags(self.tags, tags_rename)
def clear_answer(self) -> None:
self.answer = None
def msg_id(self) -> str:
"""
Returns an ID that is unique throughout all messages in the same (DB) directory.
+1
View File
@@ -2,3 +2,4 @@ openai
PyYAML
argcomplete
pytest
Jinja2
+3 -6
View File
@@ -2,6 +2,8 @@ from setuptools import setup, find_packages
with open("README.md", "r", encoding="utf-8") as fh:
long_description = fh.read()
with open("requirements.txt", "r") as fh:
install_requirements = [line.strip() for line in fh]
setup(
name="ChatMastermind",
@@ -28,12 +30,7 @@ setup(
"Topic :: Utilities",
"Topic :: Text Processing",
],
install_requires=[
"openai",
"PyYAML",
"argcomplete",
"pytest",
],
install_requires=install_requirements,
python_requires=">=3.9",
test_suite="tests",
entry_points={
+176 -131
View File
@@ -10,40 +10,69 @@ from chatmastermind.message import Message, Question, Answer, Tag, MessageFilter
from chatmastermind.chat import Chat, ChatDB, ChatError
class TestChat(unittest.TestCase):
msg_suffix: str = Message.file_suffix_write
def msg_to_file_force_suffix(msg: Message) -> None:
"""
Force writing a message file with illegal suffixes.
"""
def_suffix = Message.file_suffix_write
assert msg.file_path
Message.file_suffix_write = msg.file_path.suffix
msg.to_file()
Message.file_suffix_write = def_suffix
class TestChatBase(unittest.TestCase):
def assert_messages_equal(self, msg1: list[Message], msg2: list[Message]) -> None:
"""
Compare messages using more than just Question and Answer.
"""
self.assertEqual(len(msg1), len(msg2))
for m1, m2 in zip(msg1, msg2):
# exclude the file_path, compare only Q, A and metadata
self.assertTrue(m1.equals(m2, file_path=False, verbose=True))
class TestChat(TestChatBase):
def setUp(self) -> None:
self.chat = Chat([])
self.message1 = Message(Question('Question 1'),
Answer('Answer 1'),
{Tag('atag1'), Tag('btag2')},
file_path=pathlib.Path('0001.txt'))
ai='FakeAI',
model='FakeModel',
file_path=pathlib.Path(f'0001{msg_suffix}'))
self.message2 = Message(Question('Question 2'),
Answer('Answer 2'),
{Tag('btag2')},
file_path=pathlib.Path('0002.txt'))
ai='FakeAI',
model='FakeModel',
file_path=pathlib.Path(f'0002{msg_suffix}'))
self.maxDiff = None
def test_unique_id(self) -> None:
# test with two identical messages
self.chat.msg_add([self.message1, self.message1])
self.assertSequenceEqual(self.chat.messages, [self.message1, self.message1])
self.assert_messages_equal(self.chat.messages, [self.message1, self.message1])
self.chat.msg_unique_id()
self.assertSequenceEqual(self.chat.messages, [self.message1])
self.assert_messages_equal(self.chat.messages, [self.message1])
# test with two different messages
self.chat.msg_add([self.message2])
self.chat.msg_unique_id()
self.assertSequenceEqual(self.chat.messages, [self.message1, self.message2])
self.assert_messages_equal(self.chat.messages, [self.message1, self.message2])
def test_unique_content(self) -> None:
# test with two identical messages
self.chat.msg_add([self.message1, self.message1])
self.assertSequenceEqual(self.chat.messages, [self.message1, self.message1])
self.assert_messages_equal(self.chat.messages, [self.message1, self.message1])
self.chat.msg_unique_content()
self.assertSequenceEqual(self.chat.messages, [self.message1])
self.assert_messages_equal(self.chat.messages, [self.message1])
# test with two different messages
self.chat.msg_add([self.message2])
self.chat.msg_unique_content()
self.assertSequenceEqual(self.chat.messages, [self.message1, self.message2])
self.assert_messages_equal(self.chat.messages, [self.message1, self.message2])
def test_filter(self) -> None:
self.chat.msg_add([self.message1, self.message2])
@@ -88,24 +117,24 @@ class TestChat(unittest.TestCase):
def test_find_remove_messages(self) -> None:
self.chat.msg_add([self.message1, self.message2])
msgs = self.chat.msg_find(['0001.txt'])
msgs = self.chat.msg_find(['0001'])
self.assertListEqual(msgs, [self.message1])
msgs = self.chat.msg_find(['0001.txt', '0002.txt'])
msgs = self.chat.msg_find(['0001', '0002'])
self.assertListEqual(msgs, [self.message1, self.message2])
# add new Message with full path
message3 = Message(Question('Question 2'),
Answer('Answer 2'),
{Tag('btag2')},
file_path=pathlib.Path('/foo/bla/0003.txt'))
file_path=pathlib.Path(f'/foo/bla/0003{msg_suffix}'))
self.chat.msg_add([message3])
# find new Message by full path
msgs = self.chat.msg_find(['/foo/bla/0003.txt'])
msgs = self.chat.msg_find([f'/foo/bla/0003{msg_suffix}'])
self.assertListEqual(msgs, [message3])
# find Message with full path only by filename
msgs = self.chat.msg_find(['0003.txt'])
msgs = self.chat.msg_find([f'0003{msg_suffix}'])
self.assertListEqual(msgs, [message3])
# remove last message
self.chat.msg_remove(['0003.txt'])
self.chat.msg_remove(['0003'])
self.assertListEqual(self.chat.messages, [self.message1, self.message2])
def test_latest_message(self) -> None:
@@ -118,7 +147,7 @@ class TestChat(unittest.TestCase):
@patch('sys.stdout', new_callable=StringIO)
def test_print(self, mock_stdout: StringIO) -> None:
self.chat.msg_add([self.message1, self.message2])
self.chat.print(paged=False)
self.chat.print(paged=False, tight=True)
expected_output = f"""{Question.txt_header}
Question 1
{Answer.txt_header}
@@ -131,17 +160,21 @@ Answer 2
self.assertEqual(mock_stdout.getvalue(), expected_output)
@patch('sys.stdout', new_callable=StringIO)
def test_print_with_tags_and_file(self, mock_stdout: StringIO) -> None:
def test_print_with_metadata(self, mock_stdout: StringIO) -> None:
self.chat.msg_add([self.message1, self.message2])
self.chat.print(paged=False, with_tags=True, with_files=True)
self.chat.print(paged=False, with_metadata=True, tight=True)
expected_output = f"""{TagLine.prefix} atag1 btag2
FILE: 0001.txt
FILE: 0001{msg_suffix}
AI: FakeAI
MODEL: FakeModel
{Question.txt_header}
Question 1
{Answer.txt_header}
Answer 1
{TagLine.prefix} btag2
FILE: 0002.txt
FILE: 0002{msg_suffix}
AI: FakeAI
MODEL: FakeModel
{Question.txt_header}
Question 2
{Answer.txt_header}
@@ -150,38 +183,34 @@ Answer 2
self.assertEqual(mock_stdout.getvalue(), expected_output)
class TestChatDB(unittest.TestCase):
class TestChatDB(TestChatBase):
def setUp(self) -> None:
self.db_path = tempfile.TemporaryDirectory()
self.cache_path = tempfile.TemporaryDirectory()
self.message1 = Message(Question('Question 1'),
Answer('Answer 1'),
{Tag('tag1')},
file_path=pathlib.Path('0001.txt'))
{Tag('tag1')})
self.message2 = Message(Question('Question 2'),
Answer('Answer 2'),
{Tag('tag2')},
file_path=pathlib.Path('0002.yaml'))
{Tag('tag2')})
self.message3 = Message(Question('Question 3'),
Answer('Answer 3'),
{Tag('tag3')},
file_path=pathlib.Path('0003.txt'))
{Tag('tag3')})
self.message4 = Message(Question('Question 4'),
Answer('Answer 4'),
{Tag('tag4')},
file_path=pathlib.Path('0004.yaml'))
{Tag('tag4')})
self.message1.to_file(pathlib.Path(self.db_path.name, '0001.txt'))
self.message2.to_file(pathlib.Path(self.db_path.name, '0002.yaml'))
self.message3.to_file(pathlib.Path(self.db_path.name, '0003.txt'))
self.message4.to_file(pathlib.Path(self.db_path.name, '0004.yaml'))
self.message1.to_file(pathlib.Path(self.db_path.name, '0001'), mformat='txt')
self.message2.to_file(pathlib.Path(self.db_path.name, '0002'), mformat='yaml')
self.message3.to_file(pathlib.Path(self.db_path.name, '0003'), mformat='txt')
self.message4.to_file(pathlib.Path(self.db_path.name, '0004'), mformat='yaml')
# make the next FID match the current state
next_fname = pathlib.Path(self.db_path.name) / '.next'
with open(next_fname, 'w') as f:
f.write('4')
# add some "trash" in order to test if it's correctly handled / ignored
self.trash_files = ['.config.yaml', 'foo.yaml', 'bla.txt']
self.trash_files = ['.config.yaml', 'foo.yaml', 'bla.txt', 'fubar.msg']
for file in self.trash_files:
with open(pathlib.Path(self.db_path.name) / file, 'w') as f:
f.write('test trash')
@@ -196,7 +225,7 @@ class TestChatDB(unittest.TestCase):
List all Message files in the given TemporaryDirectory.
"""
# exclude '.next'
return [f for f in pathlib.Path(tmp_dir.name).glob('*.[ty]*') if f.name not in self.trash_files]
return [f for f in pathlib.Path(tmp_dir.name).glob('*.[tym]*') if f.name not in self.trash_files]
def tearDown(self) -> None:
self.db_path.cleanup()
@@ -207,13 +236,31 @@ class TestChatDB(unittest.TestCase):
duplicate_message = Message(Question('Question 4'),
Answer('Answer 4'),
{Tag('tag4')},
file_path=pathlib.Path('0004.txt'))
duplicate_message.to_file(pathlib.Path(self.db_path.name, '0004.txt'))
file_path=pathlib.Path(self.db_path.name, '0004.txt'))
msg_to_file_force_suffix(duplicate_message)
with self.assertRaises(ChatError) as cm:
ChatDB.from_dir(pathlib.Path(self.cache_path.name),
pathlib.Path(self.db_path.name))
self.assertEqual(str(cm.exception), "Validation failed")
def test_file_path_ID_exists(self) -> None:
"""
Tests if the CacheDB chooses another ID if a file path with
the given one exists.
"""
# create a new and empty CacheDB
db_path = tempfile.TemporaryDirectory()
cache_path = tempfile.TemporaryDirectory()
chat_db = ChatDB.from_dir(pathlib.Path(cache_path.name),
pathlib.Path(db_path.name))
# add a message file
message = Message(Question('What?'),
file_path=pathlib.Path(cache_path.name) / f'0001{msg_suffix}')
message.to_file()
message1 = Message(Question('Where?'))
chat_db.cache_write([message1])
self.assertEqual(message1.msg_id(), '0002')
def test_from_dir(self) -> None:
chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name),
pathlib.Path(self.db_path.name))
@@ -222,25 +269,23 @@ class TestChatDB(unittest.TestCase):
self.assertEqual(chat_db.db_path, pathlib.Path(self.db_path.name))
# check that the files are sorted
self.assertEqual(chat_db.messages[0].file_path,
pathlib.Path(self.db_path.name, '0001.txt'))
pathlib.Path(self.db_path.name, f'0001{msg_suffix}'))
self.assertEqual(chat_db.messages[1].file_path,
pathlib.Path(self.db_path.name, '0002.yaml'))
pathlib.Path(self.db_path.name, f'0002{msg_suffix}'))
self.assertEqual(chat_db.messages[2].file_path,
pathlib.Path(self.db_path.name, '0003.txt'))
pathlib.Path(self.db_path.name, f'0003{msg_suffix}'))
self.assertEqual(chat_db.messages[3].file_path,
pathlib.Path(self.db_path.name, '0004.yaml'))
pathlib.Path(self.db_path.name, f'0004{msg_suffix}'))
def test_from_dir_glob(self) -> None:
chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name),
pathlib.Path(self.db_path.name),
glob='*.txt')
self.assertEqual(len(chat_db.messages), 2)
glob='*1.*')
self.assertEqual(len(chat_db.messages), 1)
self.assertEqual(chat_db.cache_path, pathlib.Path(self.cache_path.name))
self.assertEqual(chat_db.db_path, pathlib.Path(self.db_path.name))
self.assertEqual(chat_db.messages[0].file_path,
pathlib.Path(self.db_path.name, '0001.txt'))
self.assertEqual(chat_db.messages[1].file_path,
pathlib.Path(self.db_path.name, '0003.txt'))
pathlib.Path(self.db_path.name, f'0001{msg_suffix}'))
def test_from_dir_filter_tags(self) -> None:
chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name),
@@ -250,7 +295,7 @@ class TestChatDB(unittest.TestCase):
self.assertEqual(chat_db.cache_path, pathlib.Path(self.cache_path.name))
self.assertEqual(chat_db.db_path, pathlib.Path(self.db_path.name))
self.assertEqual(chat_db.messages[0].file_path,
pathlib.Path(self.db_path.name, '0001.txt'))
pathlib.Path(self.db_path.name, f'0001{msg_suffix}'))
def test_from_dir_filter_tags_empty(self) -> None:
chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name),
@@ -268,7 +313,7 @@ class TestChatDB(unittest.TestCase):
self.assertEqual(chat_db.cache_path, pathlib.Path(self.cache_path.name))
self.assertEqual(chat_db.db_path, pathlib.Path(self.db_path.name))
self.assertEqual(chat_db.messages[0].file_path,
pathlib.Path(self.db_path.name, '0002.yaml'))
pathlib.Path(self.db_path.name, f'0002{msg_suffix}'))
self.assertEqual(chat_db.messages[0].answer, 'Answer 2')
def test_from_messages(self) -> None:
@@ -313,25 +358,25 @@ class TestChatDB(unittest.TestCase):
chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name),
pathlib.Path(self.db_path.name))
# check that Message.file_path is correct
self.assertEqual(chat_db.messages[0].file_path, pathlib.Path(self.db_path.name, '0001.txt'))
self.assertEqual(chat_db.messages[1].file_path, pathlib.Path(self.db_path.name, '0002.yaml'))
self.assertEqual(chat_db.messages[2].file_path, pathlib.Path(self.db_path.name, '0003.txt'))
self.assertEqual(chat_db.messages[3].file_path, pathlib.Path(self.db_path.name, '0004.yaml'))
self.assertEqual(chat_db.messages[0].file_path, pathlib.Path(self.db_path.name, f'0001{msg_suffix}'))
self.assertEqual(chat_db.messages[1].file_path, pathlib.Path(self.db_path.name, f'0002{msg_suffix}'))
self.assertEqual(chat_db.messages[2].file_path, pathlib.Path(self.db_path.name, f'0003{msg_suffix}'))
self.assertEqual(chat_db.messages[3].file_path, pathlib.Path(self.db_path.name, f'0004{msg_suffix}'))
# write the messages to the cache directory
chat_db.cache_write()
# check if the written files are in the cache directory
cache_dir_files = self.message_list(self.cache_path)
self.assertEqual(len(cache_dir_files), 4)
self.assertIn(pathlib.Path(self.cache_path.name, '0001.txt'), cache_dir_files)
self.assertIn(pathlib.Path(self.cache_path.name, '0002.yaml'), cache_dir_files)
self.assertIn(pathlib.Path(self.cache_path.name, '0003.txt'), cache_dir_files)
self.assertIn(pathlib.Path(self.cache_path.name, '0004.yaml'), cache_dir_files)
self.assertIn(pathlib.Path(self.cache_path.name, f'0001{msg_suffix}'), cache_dir_files)
self.assertIn(pathlib.Path(self.cache_path.name, f'0002{msg_suffix}'), cache_dir_files)
self.assertIn(pathlib.Path(self.cache_path.name, f'0003{msg_suffix}'), cache_dir_files)
self.assertIn(pathlib.Path(self.cache_path.name, f'0004{msg_suffix}'), cache_dir_files)
# check that Message.file_path has been correctly updated
self.assertEqual(chat_db.messages[0].file_path, pathlib.Path(self.cache_path.name, '0001.txt'))
self.assertEqual(chat_db.messages[1].file_path, pathlib.Path(self.cache_path.name, '0002.yaml'))
self.assertEqual(chat_db.messages[2].file_path, pathlib.Path(self.cache_path.name, '0003.txt'))
self.assertEqual(chat_db.messages[3].file_path, pathlib.Path(self.cache_path.name, '0004.yaml'))
self.assertEqual(chat_db.messages[0].file_path, pathlib.Path(self.cache_path.name, f'0001{msg_suffix}'))
self.assertEqual(chat_db.messages[1].file_path, pathlib.Path(self.cache_path.name, f'0002{msg_suffix}'))
self.assertEqual(chat_db.messages[2].file_path, pathlib.Path(self.cache_path.name, f'0003{msg_suffix}'))
self.assertEqual(chat_db.messages[3].file_path, pathlib.Path(self.cache_path.name, f'0004{msg_suffix}'))
# check the timestamp of the files in the DB directory
db_dir_files = self.message_list(self.db_path)
@@ -343,18 +388,18 @@ class TestChatDB(unittest.TestCase):
# check if the written files are in the DB directory
db_dir_files = self.message_list(self.db_path)
self.assertEqual(len(db_dir_files), 4)
self.assertIn(pathlib.Path(self.db_path.name, '0001.txt'), db_dir_files)
self.assertIn(pathlib.Path(self.db_path.name, '0002.yaml'), db_dir_files)
self.assertIn(pathlib.Path(self.db_path.name, '0003.txt'), db_dir_files)
self.assertIn(pathlib.Path(self.db_path.name, '0004.yaml'), db_dir_files)
self.assertIn(pathlib.Path(self.db_path.name, f'0001{msg_suffix}'), db_dir_files)
self.assertIn(pathlib.Path(self.db_path.name, f'0002{msg_suffix}'), db_dir_files)
self.assertIn(pathlib.Path(self.db_path.name, f'0003{msg_suffix}'), db_dir_files)
self.assertIn(pathlib.Path(self.db_path.name, f'0004{msg_suffix}'), db_dir_files)
# check if all files in the DB dir have actually been overwritten
for file in db_dir_files:
self.assertGreater(file.stat().st_mtime, old_timestamps[file])
# check that Message.file_path has been correctly updated (again)
self.assertEqual(chat_db.messages[0].file_path, pathlib.Path(self.db_path.name, '0001.txt'))
self.assertEqual(chat_db.messages[1].file_path, pathlib.Path(self.db_path.name, '0002.yaml'))
self.assertEqual(chat_db.messages[2].file_path, pathlib.Path(self.db_path.name, '0003.txt'))
self.assertEqual(chat_db.messages[3].file_path, pathlib.Path(self.db_path.name, '0004.yaml'))
self.assertEqual(chat_db.messages[0].file_path, pathlib.Path(self.db_path.name, f'0001{msg_suffix}'))
self.assertEqual(chat_db.messages[1].file_path, pathlib.Path(self.db_path.name, f'0002{msg_suffix}'))
self.assertEqual(chat_db.messages[2].file_path, pathlib.Path(self.db_path.name, f'0003{msg_suffix}'))
self.assertEqual(chat_db.messages[3].file_path, pathlib.Path(self.db_path.name, f'0004{msg_suffix}'))
def test_db_read(self) -> None:
# create a new ChatDB instance
@@ -369,65 +414,65 @@ class TestChatDB(unittest.TestCase):
new_message2 = Message(Question('Question 6'),
Answer('Answer 6'),
{Tag('tag6')})
new_message1.to_file(pathlib.Path(self.db_path.name, '0005.txt'))
new_message2.to_file(pathlib.Path(self.db_path.name, '0006.yaml'))
new_message1.to_file(pathlib.Path(self.db_path.name, f'0005{msg_suffix}'), mformat='txt')
new_message2.to_file(pathlib.Path(self.db_path.name, f'0006{msg_suffix}'), mformat='yaml')
# read and check them
chat_db.db_read()
self.assertEqual(len(chat_db.messages), 6)
self.assertEqual(chat_db.messages[4].file_path, pathlib.Path(self.db_path.name, '0005.txt'))
self.assertEqual(chat_db.messages[5].file_path, pathlib.Path(self.db_path.name, '0006.yaml'))
self.assertEqual(chat_db.messages[4].file_path, pathlib.Path(self.db_path.name, f'0005{msg_suffix}'))
self.assertEqual(chat_db.messages[5].file_path, pathlib.Path(self.db_path.name, f'0006{msg_suffix}'))
# create 2 new files in the cache directory
new_message3 = Message(Question('Question 7'),
Answer('Answer 5'),
Answer('Answer 7'),
{Tag('tag7')})
new_message4 = Message(Question('Question 8'),
Answer('Answer 6'),
Answer('Answer 8'),
{Tag('tag8')})
new_message3.to_file(pathlib.Path(self.cache_path.name, '0007.txt'))
new_message4.to_file(pathlib.Path(self.cache_path.name, '0008.yaml'))
new_message3.to_file(pathlib.Path(self.cache_path.name, f'0007{msg_suffix}'), mformat='txt')
new_message4.to_file(pathlib.Path(self.cache_path.name, f'0008{msg_suffix}'), mformat='yaml')
# read and check them
chat_db.cache_read()
self.assertEqual(len(chat_db.messages), 8)
# check that the new message have the cache dir path
self.assertEqual(chat_db.messages[6].file_path, pathlib.Path(self.cache_path.name, '0007.txt'))
self.assertEqual(chat_db.messages[7].file_path, pathlib.Path(self.cache_path.name, '0008.yaml'))
self.assertEqual(chat_db.messages[6].file_path, pathlib.Path(self.cache_path.name, f'0007{msg_suffix}'))
self.assertEqual(chat_db.messages[7].file_path, pathlib.Path(self.cache_path.name, f'0008{msg_suffix}'))
# an the old ones keep their path (since they have not been replaced)
self.assertEqual(chat_db.messages[4].file_path, pathlib.Path(self.db_path.name, '0005.txt'))
self.assertEqual(chat_db.messages[5].file_path, pathlib.Path(self.db_path.name, '0006.yaml'))
self.assertEqual(chat_db.messages[4].file_path, pathlib.Path(self.db_path.name, f'0005{msg_suffix}'))
self.assertEqual(chat_db.messages[5].file_path, pathlib.Path(self.db_path.name, f'0006{msg_suffix}'))
# now overwrite two messages in the DB directory
new_message1.question = Question('New Question 1')
new_message2.question = Question('New Question 2')
new_message1.to_file(pathlib.Path(self.db_path.name, '0005.txt'))
new_message2.to_file(pathlib.Path(self.db_path.name, '0006.yaml'))
new_message1.to_file(pathlib.Path(self.db_path.name, f'0005{msg_suffix}'), mformat='txt')
new_message2.to_file(pathlib.Path(self.db_path.name, f'0006{msg_suffix}'), mformat='yaml')
# read from the DB dir and check if the modified messages have been updated
chat_db.db_read()
self.assertEqual(len(chat_db.messages), 8)
self.assertEqual(chat_db.messages[4].question, 'New Question 1')
self.assertEqual(chat_db.messages[5].question, 'New Question 2')
self.assertEqual(chat_db.messages[4].file_path, pathlib.Path(self.db_path.name, '0005.txt'))
self.assertEqual(chat_db.messages[5].file_path, pathlib.Path(self.db_path.name, '0006.yaml'))
self.assertEqual(chat_db.messages[4].file_path, pathlib.Path(self.db_path.name, f'0005{msg_suffix}'))
self.assertEqual(chat_db.messages[5].file_path, pathlib.Path(self.db_path.name, f'0006{msg_suffix}'))
# now write the messages from the cache to the DB directory
new_message3.to_file(pathlib.Path(self.db_path.name, '0007.txt'))
new_message4.to_file(pathlib.Path(self.db_path.name, '0008.yaml'))
new_message3.to_file(pathlib.Path(self.db_path.name, f'0007{msg_suffix}'))
new_message4.to_file(pathlib.Path(self.db_path.name, f'0008{msg_suffix}'))
# read and check them
chat_db.db_read()
self.assertEqual(len(chat_db.messages), 8)
# check that they now have the DB path
self.assertEqual(chat_db.messages[6].file_path, pathlib.Path(self.db_path.name, '0007.txt'))
self.assertEqual(chat_db.messages[7].file_path, pathlib.Path(self.db_path.name, '0008.yaml'))
self.assertEqual(chat_db.messages[6].file_path, pathlib.Path(self.db_path.name, f'0007{msg_suffix}'))
self.assertEqual(chat_db.messages[7].file_path, pathlib.Path(self.db_path.name, f'0008{msg_suffix}'))
def test_cache_clear(self) -> None:
# create a new ChatDB instance
chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name),
pathlib.Path(self.db_path.name))
# check that Message.file_path is correct
self.assertEqual(chat_db.messages[0].file_path, pathlib.Path(self.db_path.name, '0001.txt'))
self.assertEqual(chat_db.messages[1].file_path, pathlib.Path(self.db_path.name, '0002.yaml'))
self.assertEqual(chat_db.messages[2].file_path, pathlib.Path(self.db_path.name, '0003.txt'))
self.assertEqual(chat_db.messages[3].file_path, pathlib.Path(self.db_path.name, '0004.yaml'))
self.assertEqual(chat_db.messages[0].file_path, pathlib.Path(self.db_path.name, f'0001{msg_suffix}'))
self.assertEqual(chat_db.messages[1].file_path, pathlib.Path(self.db_path.name, f'0002{msg_suffix}'))
self.assertEqual(chat_db.messages[2].file_path, pathlib.Path(self.db_path.name, f'0003{msg_suffix}'))
self.assertEqual(chat_db.messages[3].file_path, pathlib.Path(self.db_path.name, f'0004{msg_suffix}'))
# write the messages to the cache directory
chat_db.cache_write()
@@ -439,10 +484,10 @@ class TestChatDB(unittest.TestCase):
chat_db.db_write()
db_dir_files = self.message_list(self.db_path)
self.assertEqual(len(db_dir_files), 4)
self.assertIn(pathlib.Path(self.db_path.name, '0001.txt'), db_dir_files)
self.assertIn(pathlib.Path(self.db_path.name, '0002.yaml'), db_dir_files)
self.assertIn(pathlib.Path(self.db_path.name, '0003.txt'), db_dir_files)
self.assertIn(pathlib.Path(self.db_path.name, '0004.yaml'), db_dir_files)
self.assertIn(pathlib.Path(self.db_path.name, f'0001{msg_suffix}'), db_dir_files)
self.assertIn(pathlib.Path(self.db_path.name, f'0002{msg_suffix}'), db_dir_files)
self.assertIn(pathlib.Path(self.db_path.name, f'0003{msg_suffix}'), db_dir_files)
self.assertIn(pathlib.Path(self.db_path.name, f'0004{msg_suffix}'), db_dir_files)
# add a new message with empty file_path
message_empty = Message(question=Question("What the hell am I doing here?"),
@@ -450,7 +495,7 @@ class TestChatDB(unittest.TestCase):
# and one for the cache dir
message_cache = Message(question=Question("What the hell am I doing here?"),
answer=Answer("You're a creep!"),
file_path=pathlib.Path(self.cache_path.name, '0005.txt'))
file_path=pathlib.Path(self.cache_path.name, '0005'))
chat_db.msg_add([message_empty, message_cache])
# clear the cache and check the cache dir
@@ -512,11 +557,11 @@ class TestChatDB(unittest.TestCase):
chat_db.msg_write([message])
# write a message with a valid file_path
message.file_path = pathlib.Path(self.cache_path.name) / '123456.txt'
message.file_path = pathlib.Path(self.cache_path.name) / '123456'
chat_db.msg_write([message])
cache_dir_files = self.message_list(self.cache_path)
self.assertEqual(len(cache_dir_files), 1)
self.assertIn(pathlib.Path(self.cache_path.name, '123456.txt'), cache_dir_files)
self.assertIn(pathlib.Path(self.cache_path.name, f'123456{msg_suffix}'), cache_dir_files)
def test_msg_update(self) -> None:
# create a new ChatDB instance
@@ -552,24 +597,24 @@ class TestChatDB(unittest.TestCase):
# search for a DB file in memory
self.assertEqual(chat_db.msg_find([str(self.message1.file_path)], loc='mem'), [self.message1])
self.assertEqual(chat_db.msg_find([self.message1.msg_id()], loc='mem'), [self.message1])
self.assertEqual(chat_db.msg_find(['0001.txt'], loc='mem'), [self.message1])
self.assertEqual(chat_db.msg_find(['0001.msg'], loc='mem'), [self.message1])
self.assertEqual(chat_db.msg_find(['0001'], loc='mem'), [self.message1])
# and on disk
self.assertEqual(chat_db.msg_find([str(self.message2.file_path)], loc='db'), [self.message2])
self.assertEqual(chat_db.msg_find([self.message2.msg_id()], loc='db'), [self.message2])
self.assertEqual(chat_db.msg_find(['0002.yaml'], loc='db'), [self.message2])
self.assertEqual(chat_db.msg_find(['0002.msg'], loc='db'), [self.message2])
self.assertEqual(chat_db.msg_find(['0002'], loc='db'), [self.message2])
# now search the cache -> expect empty result
self.assertEqual(chat_db.msg_find([str(self.message3.file_path)], loc='cache'), [])
self.assertEqual(chat_db.msg_find([self.message3.msg_id()], loc='cache'), [])
self.assertEqual(chat_db.msg_find(['0003.txt'], loc='cache'), [])
self.assertEqual(chat_db.msg_find(['0003.msg'], loc='cache'), [])
self.assertEqual(chat_db.msg_find(['0003'], loc='cache'), [])
# search for multiple messages
# -> search one twice, expect result to be unique
search_names = ['0001', '0002.yaml', self.message3.msg_id(), str(self.message3.file_path)]
search_names = ['0001', '0002.msg', self.message3.msg_id(), str(self.message3.file_path)]
expected_result = [self.message1, self.message2, self.message3]
result = chat_db.msg_find(search_names, loc='all')
self.assertSequenceEqual(result, expected_result)
self.assert_messages_equal(result, expected_result)
def test_msg_latest(self) -> None:
chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name),
@@ -595,47 +640,47 @@ class TestChatDB(unittest.TestCase):
chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name),
pathlib.Path(self.db_path.name))
all_messages = [self.message1, self.message2, self.message3, self.message4]
self.assertSequenceEqual(chat_db.msg_gather(loc='all'), all_messages)
self.assertSequenceEqual(chat_db.msg_gather(loc='db'), all_messages)
self.assertSequenceEqual(chat_db.msg_gather(loc='mem'), all_messages)
self.assertSequenceEqual(chat_db.msg_gather(loc='disk'), all_messages)
self.assertSequenceEqual(chat_db.msg_gather(loc='cache'), [])
self.assert_messages_equal(chat_db.msg_gather(loc='all'), all_messages)
self.assert_messages_equal(chat_db.msg_gather(loc='db'), all_messages)
self.assert_messages_equal(chat_db.msg_gather(loc='mem'), all_messages)
self.assert_messages_equal(chat_db.msg_gather(loc='disk'), all_messages)
self.assert_messages_equal(chat_db.msg_gather(loc='cache'), [])
# add a new message, but only to the internal list
new_message = Message(Question("What?"))
all_messages_mem = all_messages + [new_message]
chat_db.msg_add([new_message])
self.assertSequenceEqual(chat_db.msg_gather(loc='mem'), all_messages_mem)
self.assertSequenceEqual(chat_db.msg_gather(loc='all'), all_messages_mem)
self.assert_messages_equal(chat_db.msg_gather(loc='mem'), all_messages_mem)
self.assert_messages_equal(chat_db.msg_gather(loc='all'), all_messages_mem)
# the nr. of messages on disk did not change -> expect old result
self.assertSequenceEqual(chat_db.msg_gather(loc='db'), all_messages)
self.assertSequenceEqual(chat_db.msg_gather(loc='disk'), all_messages)
self.assertSequenceEqual(chat_db.msg_gather(loc='cache'), [])
self.assert_messages_equal(chat_db.msg_gather(loc='db'), all_messages)
self.assert_messages_equal(chat_db.msg_gather(loc='disk'), all_messages)
self.assert_messages_equal(chat_db.msg_gather(loc='cache'), [])
# test with MessageFilter
self.assertSequenceEqual(chat_db.msg_gather(loc='all', mfilter=MessageFilter(tags_or={Tag('tag1')})),
[self.message1])
self.assertSequenceEqual(chat_db.msg_gather(loc='disk', mfilter=MessageFilter(tags_or={Tag('tag2')})),
[self.message2])
self.assertSequenceEqual(chat_db.msg_gather(loc='cache', mfilter=MessageFilter(tags_or={Tag('tag3')})),
[])
self.assertSequenceEqual(chat_db.msg_gather(loc='mem', mfilter=MessageFilter(question_contains="What")),
[new_message])
self.assert_messages_equal(chat_db.msg_gather(loc='all', mfilter=MessageFilter(tags_or={Tag('tag1')})),
[self.message1])
self.assert_messages_equal(chat_db.msg_gather(loc='disk', mfilter=MessageFilter(tags_or={Tag('tag2')})),
[self.message2])
self.assert_messages_equal(chat_db.msg_gather(loc='cache', mfilter=MessageFilter(tags_or={Tag('tag3')})),
[])
self.assert_messages_equal(chat_db.msg_gather(loc='mem', mfilter=MessageFilter(question_contains="What")),
[new_message])
def test_msg_move_and_gather(self) -> None:
chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name),
pathlib.Path(self.db_path.name))
all_messages = [self.message1, self.message2, self.message3, self.message4]
self.assertSequenceEqual(chat_db.msg_gather(loc='db'), all_messages)
self.assertSequenceEqual(chat_db.msg_gather(loc='cache'), [])
self.assert_messages_equal(chat_db.msg_gather(loc='db'), all_messages)
self.assert_messages_equal(chat_db.msg_gather(loc='cache'), [])
# move first message to the cache
chat_db.cache_move(self.message1)
self.assertSequenceEqual(chat_db.msg_gather(loc='cache'), [self.message1])
self.assert_messages_equal(chat_db.msg_gather(loc='cache'), [self.message1])
self.assertEqual(self.message1.file_path.parent, pathlib.Path(self.cache_path.name)) # type: ignore [union-attr]
self.assertSequenceEqual(chat_db.msg_gather(loc='db'), [self.message2, self.message3, self.message4])
self.assertSequenceEqual(chat_db.msg_gather(loc='all'), all_messages)
self.assertSequenceEqual(chat_db.msg_gather(loc='disk'), all_messages)
self.assertSequenceEqual(chat_db.msg_gather(loc='mem'), all_messages)
self.assert_messages_equal(chat_db.msg_gather(loc='db'), [self.message2, self.message3, self.message4])
self.assert_messages_equal(chat_db.msg_gather(loc='all'), all_messages)
self.assert_messages_equal(chat_db.msg_gather(loc='disk'), all_messages)
self.assert_messages_equal(chat_db.msg_gather(loc='mem'), all_messages)
# now move first message back to the DB
chat_db.db_move(self.message1)
self.assertSequenceEqual(chat_db.msg_gather(loc='cache'), [])
self.assert_messages_equal(chat_db.msg_gather(loc='cache'), [])
self.assertEqual(self.message1.file_path.parent, pathlib.Path(self.db_path.name)) # type: ignore [union-attr]
self.assertSequenceEqual(chat_db.msg_gather(loc='db'), all_messages)
self.assert_messages_equal(chat_db.msg_gather(loc='db'), all_messages)
+100
View File
@@ -0,0 +1,100 @@
import unittest
import argparse
from typing import Union, Optional
from chatmastermind.configuration import Config, AIConfig
from chatmastermind.tags import Tag
from chatmastermind.message import Message, Answer
from chatmastermind.chat import Chat
from chatmastermind.ai import AI, AIResponse, Tokens, AIError
class FakeAI(AI):
"""
A mocked version of the 'AI' class.
"""
ID: str
name: str
config: AIConfig
def models(self) -> list[str]:
raise NotImplementedError
def tokens(self, data: Union[Message, Chat]) -> int:
return 123
def print(self) -> None:
pass
def print_models(self) -> None:
pass
def __init__(self, ID: str, model: str, error: bool = False):
self.ID = ID
self.model = model
self.error = error
def request(self,
question: Message,
chat: Chat,
num_answers: int = 1,
otags: Optional[set[Tag]] = None) -> AIResponse:
"""
Mock the 'ai.request()' function by either returning fake
answers or raising an exception.
"""
if self.error:
raise AIError
question.answer = Answer("Answer 0")
question.tags = set(otags) if otags is not None else None
question.ai = self.ID
question.model = self.model
answers: list[Message] = [question]
for n in range(1, num_answers):
answers.append(Message(question=question.question,
answer=Answer(f"Answer {n}"),
tags=otags,
ai=self.ID,
model=self.model))
return AIResponse(answers, Tokens(10, 10, 20))
class TestWithFakeAI(unittest.TestCase):
"""
Base class for all tests that need to use the FakeAI.
"""
def assert_msgs_equal_except_file_path(self, msg1: list[Message], msg2: list[Message]) -> None:
"""
Compare messages using Question, Answer and all metadata excecot for the file_path.
"""
self.assertEqual(len(msg1), len(msg2))
for m1, m2 in zip(msg1, msg2):
# exclude the file_path, compare only Q, A and metadata
self.assertTrue(m1.equals(m2, file_path=False, verbose=True))
def assert_msgs_all_equal(self, msg1: list[Message], msg2: list[Message]) -> None:
"""
Compare messages using Question, Answer and ALL metadata.
"""
self.assertEqual(len(msg1), len(msg2))
for m1, m2 in zip(msg1, msg2):
self.assertTrue(m1.equals(m2, verbose=True))
def assert_msgs_content_equal(self, msg1: list[Message], msg2: list[Message]) -> None:
"""
Compare messages using only Question and Answer.
"""
self.assertEqual(len(msg1), len(msg2))
for m1, m2 in zip(msg1, msg2):
self.assertEqual(m1, m2)
def mock_create_ai(self, args: argparse.Namespace, config: Config) -> AI:
"""
Mocked 'create_ai' that returns a 'FakeAI' instance.
"""
return FakeAI(args.AI, args.model)
def mock_create_ai_with_error(self, args: argparse.Namespace, config: Config) -> AI:
"""
Mocked 'create_ai' that returns a 'FakeAI' instance.
"""
return FakeAI(args.AI, args.model, error=True)
+7
View File
@@ -57,6 +57,7 @@ class TestConfig(unittest.TestCase):
def test_from_dict_should_create_config_from_dict(self) -> None:
source_dict = {
'cache': '.',
'db': './test_db/',
'ais': {
'myopenai': {
@@ -73,6 +74,7 @@ class TestConfig(unittest.TestCase):
}
}
config = Config.from_dict(source_dict)
self.assertEqual(config.cache, '.')
self.assertEqual(config.db, './test_db/')
self.assertEqual(len(config.ais), 1)
self.assertEqual(config.ais['myopenai'].name, 'openai')
@@ -89,6 +91,7 @@ class TestConfig(unittest.TestCase):
def test_from_file_should_load_config_from_file(self) -> None:
source_dict = {
'cache': './test_cache/',
'db': './test_db/',
'ais': {
'default': {
@@ -108,6 +111,7 @@ class TestConfig(unittest.TestCase):
yaml.dump(source_dict, f)
config = Config.from_file(self.test_file.name)
self.assertIsInstance(config, Config)
self.assertEqual(config.cache, './test_cache/')
self.assertEqual(config.db, './test_db/')
self.assertEqual(len(config.ais), 1)
self.assertIsInstance(config.ais['default'], AIConfig)
@@ -115,6 +119,7 @@ class TestConfig(unittest.TestCase):
def test_to_file_should_save_config_to_file(self) -> None:
config = Config(
cache='./test_cache/',
db='./test_db/',
ais={
'myopenai': OpenAIConfig(
@@ -133,12 +138,14 @@ class TestConfig(unittest.TestCase):
config.to_file(Path(self.test_file.name))
with open(self.test_file.name, 'r') as f:
saved_config = yaml.load(f, Loader=yaml.FullLoader)
self.assertEqual(saved_config['cache'], './test_cache/')
self.assertEqual(saved_config['db'], './test_db/')
self.assertEqual(len(saved_config['ais']), 1)
self.assertEqual(saved_config['ais']['myopenai']['system'], 'Custom system')
def test_from_file_error_unknown_ai(self) -> None:
source_dict = {
'cache': './test_cache/',
'db': './test_db/',
'ais': {
'default': {
+62
View File
@@ -0,0 +1,62 @@
import unittest
import argparse
import tempfile
import yaml
from pathlib import Path
from chatmastermind.message import Message, Question
from chatmastermind.chat import ChatDB, ChatError
from chatmastermind.configuration import Config
from chatmastermind.commands.hist import convert_messages
msg_suffix = Message.file_suffix_write
class TestConvertMessages(unittest.TestCase):
def setUp(self) -> None:
self.db_dir = tempfile.TemporaryDirectory()
self.cache_dir = tempfile.TemporaryDirectory()
self.db_path = Path(self.db_dir.name)
self.cache_path = Path(self.cache_dir.name)
self.args = argparse.Namespace()
self.config = Config()
self.config.cache = self.cache_dir.name
self.config.db = self.db_dir.name
# Prepare some messages
self.chat = ChatDB.from_dir(Path(self.cache_path),
Path(self.db_path))
self.messages = [Message(Question(f'Question {i}')) for i in range(0, 6)]
self.chat.db_write(self.messages[0:2])
self.chat.cache_write(self.messages[2:])
# Change some of the suffixes
assert self.messages[0].file_path
assert self.messages[1].file_path
self.messages[0].file_path.rename(self.messages[0].file_path.with_suffix('.txt'))
self.messages[1].file_path.rename(self.messages[1].file_path.with_suffix('.yaml'))
def tearDown(self) -> None:
self.db_dir.cleanup()
self.cache_dir.cleanup()
def test_convert_messages(self) -> None:
self.args.convert = 'yaml'
convert_messages(self.args, self.config)
msgs = self.chat.msg_gather(loc='disk', glob='*.*')
# Check if the number of messages is the same as before
self.assertEqual(len(msgs), len(self.messages))
# Check if all messages have the requested suffix
for msg in msgs:
assert msg.file_path
self.assertEqual(msg.file_path.suffix, msg_suffix)
# Check if the message IDs are correctly maintained
for m_new, m_old in zip(msgs, self.messages):
self.assertEqual(m_new.msg_id(), m_old.msg_id())
# check if all messages have the new format
for m in msgs:
with open(str(m.file_path), "r") as fd:
yaml.load(fd, Loader=yaml.FullLoader)
def test_convert_messages_wrong_format(self) -> None:
self.args.convert = 'foo'
with self.assertRaises(ChatError):
convert_messages(self.args, self.config)
+135 -72
View File
@@ -1,11 +1,16 @@
import unittest
import pathlib
import tempfile
import itertools
from typing import cast
from chatmastermind.message import source_code, Message, MessageError, Question, Answer, AILine, ModelLine, MessageFilter, message_in
from chatmastermind.message import source_code, Message, MessageError, Question, Answer, AILine, ModelLine,\
MessageFilter, message_in, message_valid_formats
from chatmastermind.tags import Tag, TagLine
msg_suffix: str = Message.file_suffix_write
class SourceCodeTestCase(unittest.TestCase):
def test_source_code_with_include_delims(self) -> None:
text = """
@@ -101,7 +106,7 @@ class AnswerTestCase(unittest.TestCase):
class MessageToFileTxtTestCase(unittest.TestCase):
def setUp(self) -> None:
self.file = tempfile.NamedTemporaryFile(delete=False, suffix='.txt')
self.file = tempfile.NamedTemporaryFile(delete=False, suffix=msg_suffix)
self.file_path = pathlib.Path(self.file.name)
self.message_complete = Message(Question('This is a question.'),
Answer('This is an answer.'),
@@ -117,7 +122,7 @@ class MessageToFileTxtTestCase(unittest.TestCase):
self.file_path.unlink()
def test_to_file_txt_complete(self) -> None:
self.message_complete.to_file(self.file_path)
self.message_complete.to_file(self.file_path, mformat='txt')
with open(self.file_path, "r") as fd:
content = fd.read()
@@ -132,7 +137,7 @@ This is an answer.
self.assertEqual(content, expected_content)
def test_to_file_txt_min(self) -> None:
self.message_min.to_file(self.file_path)
self.message_min.to_file(self.file_path, mformat='txt')
with open(self.file_path, "r") as fd:
content = fd.read()
@@ -141,11 +146,17 @@ This is a question.
"""
self.assertEqual(content, expected_content)
def test_to_file_unsupported_file_type(self) -> None:
def test_to_file_unsupported_file_suffix(self) -> None:
unsupported_file_path = pathlib.Path("example.doc")
with self.assertRaises(MessageError) as cm:
self.message_complete.to_file(unsupported_file_path)
self.assertEqual(str(cm.exception), "File type '.doc' is not supported")
self.assertEqual(str(cm.exception), "File suffix '.doc' is not supported")
def test_to_file_unsupported_file_format(self) -> None:
unsupported_file_format = pathlib.Path(f"example{msg_suffix}")
with self.assertRaises(MessageError) as cm:
self.message_complete.to_file(unsupported_file_format, mformat='doc') # type: ignore [arg-type]
self.assertEqual(str(cm.exception), "File format 'doc' is not supported")
def test_to_file_no_file_path(self) -> None:
"""
@@ -159,10 +170,24 @@ This is a question.
# reset the internal file_path
self.message_complete.file_path = self.file_path
def test_to_file_txt_auto_suffix(self) -> None:
"""
Test if suffix is auto-generated if omitted.
"""
file_path_no_suffix = self.file_path.with_suffix('')
# test with file_path member
self.message_min.file_path = file_path_no_suffix
self.message_min.to_file(mformat='txt')
self.assertEqual(self.message_min.file_path.suffix, msg_suffix)
# test with explicit file_path
self.message_min.file_path = file_path_no_suffix
self.message_min.to_file(file_path=file_path_no_suffix, mformat='txt')
self.assertEqual(self.message_min.file_path.suffix, msg_suffix)
class MessageToFileYamlTestCase(unittest.TestCase):
def setUp(self) -> None:
self.file = tempfile.NamedTemporaryFile(delete=False, suffix='.yaml')
self.file = tempfile.NamedTemporaryFile(delete=False, suffix=msg_suffix)
self.file_path = pathlib.Path(self.file.name)
self.message_complete = Message(Question('This is a question.'),
Answer('This is an answer.'),
@@ -184,7 +209,7 @@ class MessageToFileYamlTestCase(unittest.TestCase):
self.file_path.unlink()
def test_to_file_yaml_complete(self) -> None:
self.message_complete.to_file(self.file_path)
self.message_complete.to_file(self.file_path, mformat='yaml')
with open(self.file_path, "r") as fd:
content = fd.read()
@@ -199,7 +224,7 @@ class MessageToFileYamlTestCase(unittest.TestCase):
self.assertEqual(content, expected_content)
def test_to_file_yaml_multiline(self) -> None:
self.message_multiline.to_file(self.file_path)
self.message_multiline.to_file(self.file_path, mformat='yaml')
with open(self.file_path, "r") as fd:
content = fd.read()
@@ -218,17 +243,31 @@ class MessageToFileYamlTestCase(unittest.TestCase):
self.assertEqual(content, expected_content)
def test_to_file_yaml_min(self) -> None:
self.message_min.to_file(self.file_path)
self.message_min.to_file(self.file_path, mformat='yaml')
with open(self.file_path, "r") as fd:
content = fd.read()
expected_content = f"{Question.yaml_key}: This is a question.\n"
self.assertEqual(content, expected_content)
def test_to_file_yaml_auto_suffix(self) -> None:
"""
Test if suffix is auto-generated if omitted.
"""
file_path_no_suffix = self.file_path.with_suffix('')
# test with file_path member
self.message_min.file_path = file_path_no_suffix
self.message_min.to_file(mformat='yaml')
self.assertEqual(self.message_min.file_path.suffix, msg_suffix)
# test with explicit file_path
self.message_min.file_path = file_path_no_suffix
self.message_min.to_file(file_path=file_path_no_suffix, mformat='yaml')
self.assertEqual(self.message_min.file_path.suffix, msg_suffix)
class MessageFromFileTxtTestCase(unittest.TestCase):
def setUp(self) -> None:
self.file = tempfile.NamedTemporaryFile(delete=False, suffix='.txt')
self.file = tempfile.NamedTemporaryFile(delete=False, suffix=msg_suffix)
self.file_path = pathlib.Path(self.file.name)
with open(self.file_path, "w") as fd:
fd.write(f"""{TagLine.prefix} tag1 tag2
@@ -239,7 +278,7 @@ This is a question.
{Answer.txt_header}
This is an answer.
""")
self.file_min = tempfile.NamedTemporaryFile(delete=False, suffix='.txt')
self.file_min = tempfile.NamedTemporaryFile(delete=False, suffix=msg_suffix)
self.file_path_min = pathlib.Path(self.file_min.name)
with open(self.file_path_min, "w") as fd:
fd.write(f"""{Question.txt_header}
@@ -259,13 +298,13 @@ This is a question.
message = Message.from_file(self.file_path)
self.assertIsNotNone(message)
self.assertIsInstance(message, Message)
if message: # mypy bug
self.assertEqual(message.question, 'This is a question.')
self.assertEqual(message.answer, 'This is an answer.')
self.assertSetEqual(cast(set[Tag], message.tags), {Tag('tag1'), Tag('tag2')})
self.assertEqual(message.ai, 'ChatGPT')
self.assertEqual(message.model, 'gpt-3.5-turbo')
self.assertEqual(message.file_path, self.file_path)
assert message
self.assertEqual(message.question, 'This is a question.')
self.assertEqual(message.answer, 'This is an answer.')
self.assertSetEqual(cast(set[Tag], message.tags), {Tag('tag1'), Tag('tag2')})
self.assertEqual(message.ai, 'ChatGPT')
self.assertEqual(message.model, 'gpt-3.5-turbo')
self.assertEqual(message.file_path, self.file_path)
def test_from_file_txt_min(self) -> None:
"""
@@ -274,21 +313,21 @@ This is a question.
message = Message.from_file(self.file_path_min)
self.assertIsNotNone(message)
self.assertIsInstance(message, Message)
if message: # mypy bug
self.assertEqual(message.question, 'This is a question.')
self.assertEqual(message.file_path, self.file_path_min)
self.assertIsNone(message.answer)
assert message
self.assertEqual(message.question, 'This is a question.')
self.assertEqual(message.file_path, self.file_path_min)
self.assertIsNone(message.answer)
def test_from_file_txt_tags_match(self) -> None:
message = Message.from_file(self.file_path,
MessageFilter(tags_or={Tag('tag1')}))
self.assertIsNotNone(message)
self.assertIsInstance(message, Message)
if message: # mypy bug
self.assertEqual(message.question, 'This is a question.')
self.assertEqual(message.answer, 'This is an answer.')
self.assertSetEqual(cast(set[Tag], message.tags), {Tag('tag1'), Tag('tag2')})
self.assertEqual(message.file_path, self.file_path)
assert message
self.assertEqual(message.question, 'This is a question.')
self.assertEqual(message.answer, 'This is an answer.')
self.assertSetEqual(cast(set[Tag], message.tags), {Tag('tag1'), Tag('tag2')})
self.assertEqual(message.file_path, self.file_path)
def test_from_file_txt_tags_dont_match(self) -> None:
message = Message.from_file(self.file_path,
@@ -311,13 +350,13 @@ This is a question.
MessageFilter(tags_not={Tag('tag1')}))
self.assertIsNotNone(message)
self.assertIsInstance(message, Message)
if message: # mypy bug
self.assertEqual(message.question, 'This is a question.')
self.assertSetEqual(cast(set[Tag], message.tags), set())
self.assertEqual(message.file_path, self.file_path_min)
assert message
self.assertEqual(message.question, 'This is a question.')
self.assertSetEqual(cast(set[Tag], message.tags), set())
self.assertEqual(message.file_path, self.file_path_min)
def test_from_file_not_exists(self) -> None:
file_not_exists = pathlib.Path("example.txt")
file_not_exists = pathlib.Path(f"example{msg_suffix}")
with self.assertRaises(MessageError) as cm:
Message.from_file(file_not_exists)
self.assertEqual(str(cm.exception), f"Message file '{file_not_exists}' does not exist")
@@ -396,7 +435,7 @@ This is a question.
class MessageFromFileYamlTestCase(unittest.TestCase):
def setUp(self) -> None:
self.file = tempfile.NamedTemporaryFile(delete=False, suffix='.yaml')
self.file = tempfile.NamedTemporaryFile(delete=False, suffix=msg_suffix)
self.file_path = pathlib.Path(self.file.name)
with open(self.file_path, "w") as fd:
fd.write(f"""
@@ -410,7 +449,7 @@ class MessageFromFileYamlTestCase(unittest.TestCase):
- tag1
- tag2
""")
self.file_min = tempfile.NamedTemporaryFile(delete=False, suffix='.yaml')
self.file_min = tempfile.NamedTemporaryFile(delete=False, suffix=msg_suffix)
self.file_path_min = pathlib.Path(self.file_min.name)
with open(self.file_path_min, "w") as fd:
fd.write(f"""
@@ -431,13 +470,13 @@ class MessageFromFileYamlTestCase(unittest.TestCase):
message = Message.from_file(self.file_path)
self.assertIsInstance(message, Message)
self.assertIsNotNone(message)
if message: # mypy bug
self.assertEqual(message.question, 'This is a question.')
self.assertEqual(message.answer, 'This is an answer.')
self.assertSetEqual(cast(set[Tag], message.tags), {Tag('tag1'), Tag('tag2')})
self.assertEqual(message.ai, 'ChatGPT')
self.assertEqual(message.model, 'gpt-3.5-turbo')
self.assertEqual(message.file_path, self.file_path)
assert message
self.assertEqual(message.question, 'This is a question.')
self.assertEqual(message.answer, 'This is an answer.')
self.assertSetEqual(cast(set[Tag], message.tags), {Tag('tag1'), Tag('tag2')})
self.assertEqual(message.ai, 'ChatGPT')
self.assertEqual(message.model, 'gpt-3.5-turbo')
self.assertEqual(message.file_path, self.file_path)
def test_from_file_yaml_min(self) -> None:
"""
@@ -446,14 +485,14 @@ class MessageFromFileYamlTestCase(unittest.TestCase):
message = Message.from_file(self.file_path_min)
self.assertIsInstance(message, Message)
self.assertIsNotNone(message)
if message: # mypy bug
self.assertEqual(message.question, 'This is a question.')
self.assertSetEqual(cast(set[Tag], message.tags), set())
self.assertEqual(message.file_path, self.file_path_min)
self.assertIsNone(message.answer)
assert message
self.assertEqual(message.question, 'This is a question.')
self.assertSetEqual(cast(set[Tag], message.tags), set())
self.assertEqual(message.file_path, self.file_path_min)
self.assertIsNone(message.answer)
def test_from_file_not_exists(self) -> None:
file_not_exists = pathlib.Path("example.yaml")
file_not_exists = pathlib.Path(f"example{msg_suffix}")
with self.assertRaises(MessageError) as cm:
Message.from_file(file_not_exists)
self.assertEqual(str(cm.exception), f"Message file '{file_not_exists}' does not exist")
@@ -463,11 +502,11 @@ class MessageFromFileYamlTestCase(unittest.TestCase):
MessageFilter(tags_or={Tag('tag1')}))
self.assertIsNotNone(message)
self.assertIsInstance(message, Message)
if message: # mypy bug
self.assertEqual(message.question, 'This is a question.')
self.assertEqual(message.answer, 'This is an answer.')
self.assertSetEqual(cast(set[Tag], message.tags), {Tag('tag1'), Tag('tag2')})
self.assertEqual(message.file_path, self.file_path)
assert message
self.assertEqual(message.question, 'This is a question.')
self.assertEqual(message.answer, 'This is an answer.')
self.assertSetEqual(cast(set[Tag], message.tags), {Tag('tag1'), Tag('tag2')})
self.assertEqual(message.file_path, self.file_path)
def test_from_file_yaml_tags_dont_match(self) -> None:
message = Message.from_file(self.file_path,
@@ -484,10 +523,10 @@ class MessageFromFileYamlTestCase(unittest.TestCase):
MessageFilter(tags_not={Tag('tag1')}))
self.assertIsNotNone(message)
self.assertIsInstance(message, Message)
if message: # mypy bug
self.assertEqual(message.question, 'This is a question.')
self.assertSetEqual(cast(set[Tag], message.tags), set())
self.assertEqual(message.file_path, self.file_path_min)
assert message
self.assertEqual(message.question, 'This is a question.')
self.assertSetEqual(cast(set[Tag], message.tags), set())
self.assertEqual(message.file_path, self.file_path_min)
def test_from_file_yaml_question_match(self) -> None:
message = Message.from_file(self.file_path,
@@ -563,7 +602,7 @@ class MessageFromFileYamlTestCase(unittest.TestCase):
class TagsFromFileTestCase(unittest.TestCase):
def setUp(self) -> None:
self.file_txt = tempfile.NamedTemporaryFile(delete=False, suffix='.txt')
self.file_txt = tempfile.NamedTemporaryFile(delete=False, suffix=msg_suffix)
self.file_path_txt = pathlib.Path(self.file_txt.name)
with open(self.file_path_txt, "w") as fd:
fd.write(f"""{TagLine.prefix} tag1 tag2 ptag3
@@ -572,7 +611,7 @@ This is a question.
{Answer.txt_header}
This is an answer.
""")
self.file_txt_no_tags = tempfile.NamedTemporaryFile(delete=False, suffix='.txt')
self.file_txt_no_tags = tempfile.NamedTemporaryFile(delete=False, suffix=msg_suffix)
self.file_path_txt_no_tags = pathlib.Path(self.file_txt_no_tags.name)
with open(self.file_path_txt_no_tags, "w") as fd:
fd.write(f"""{Question.txt_header}
@@ -580,7 +619,7 @@ This is a question.
{Answer.txt_header}
This is an answer.
""")
self.file_txt_tags_empty = tempfile.NamedTemporaryFile(delete=False, suffix='.txt')
self.file_txt_tags_empty = tempfile.NamedTemporaryFile(delete=False, suffix=msg_suffix)
self.file_path_txt_tags_empty = pathlib.Path(self.file_txt_tags_empty.name)
with open(self.file_path_txt_tags_empty, "w") as fd:
fd.write(f"""TAGS:
@@ -589,7 +628,7 @@ This is a question.
{Answer.txt_header}
This is an answer.
""")
self.file_yaml = tempfile.NamedTemporaryFile(delete=False, suffix='.yaml')
self.file_yaml = tempfile.NamedTemporaryFile(delete=False, suffix=msg_suffix)
self.file_path_yaml = pathlib.Path(self.file_yaml.name)
with open(self.file_path_yaml, "w") as fd:
fd.write(f"""
@@ -602,7 +641,7 @@ This is an answer.
- tag2
- ptag3
""")
self.file_yaml_no_tags = tempfile.NamedTemporaryFile(delete=False, suffix='.yaml')
self.file_yaml_no_tags = tempfile.NamedTemporaryFile(delete=False, suffix=msg_suffix)
self.file_path_yaml_no_tags = pathlib.Path(self.file_yaml_no_tags.name)
with open(self.file_path_yaml_no_tags, "w") as fd:
fd.write(f"""
@@ -679,24 +718,25 @@ class TagsFromDirTestCase(unittest.TestCase):
{Tag('ctag5'), Tag('ctag6')}
]
self.files = [
pathlib.Path(self.temp_dir.name, 'file1.txt'),
pathlib.Path(self.temp_dir.name, 'file2.yaml'),
pathlib.Path(self.temp_dir.name, 'file3.txt')
pathlib.Path(self.temp_dir.name, f'file1{msg_suffix}'),
pathlib.Path(self.temp_dir.name, f'file2{msg_suffix}'),
pathlib.Path(self.temp_dir.name, f'file3{msg_suffix}')
]
self.files_no_tags = [
pathlib.Path(self.temp_dir_no_tags.name, 'file4.txt'),
pathlib.Path(self.temp_dir_no_tags.name, 'file5.yaml'),
pathlib.Path(self.temp_dir_no_tags.name, 'file6.txt')
pathlib.Path(self.temp_dir_no_tags.name, f'file4{msg_suffix}'),
pathlib.Path(self.temp_dir_no_tags.name, f'file5{msg_suffix}'),
pathlib.Path(self.temp_dir_no_tags.name, f'file6{msg_suffix}')
]
mformats = itertools.cycle(message_valid_formats)
for file, tags in zip(self.files, self.tag_sets):
message = Message(Question('This is a question.'),
Answer('This is an answer.'),
tags)
message.to_file(file)
message.to_file(file, next(mformats))
for file in self.files_no_tags:
message = Message(Question('This is a question.'),
Answer('This is an answer.'))
message.to_file(file)
message.to_file(file, next(mformats))
def tearDown(self) -> None:
self.temp_dir.cleanup()
@@ -719,7 +759,7 @@ class TagsFromDirTestCase(unittest.TestCase):
class MessageIDTestCase(unittest.TestCase):
def setUp(self) -> None:
self.file = tempfile.NamedTemporaryFile(delete=False, suffix='.txt')
self.file = tempfile.NamedTemporaryFile(delete=False, suffix=msg_suffix)
self.file_path = pathlib.Path(self.file.name)
self.message = Message(Question('This is a question.'),
file_path=self.file_path)
@@ -816,6 +856,8 @@ class MessageToStrTestCase(unittest.TestCase):
def setUp(self) -> None:
self.message = Message(Question('This is a question.'),
Answer('This is an answer.'),
ai=('FakeAI'),
model=('FakeModel'),
tags={Tag('atag1'), Tag('btag2')},
file_path=pathlib.Path('/tmp/foo/bla'))
@@ -829,8 +871,29 @@ This is an answer."""
def test_to_str_with_tags_and_file(self) -> None:
expected_output = f"""{TagLine.prefix} atag1 btag2
FILE: /tmp/foo/bla
AI: FakeAI
MODEL: FakeModel
{Question.txt_header}
This is a question.
{Answer.txt_header}
This is an answer."""
self.assertEqual(self.message.to_str(with_tags=True, with_file=True), expected_output)
self.assertEqual(self.message.to_str(with_metadata=True), expected_output)
class MessageRmFileTestCase(unittest.TestCase):
def setUp(self) -> None:
self.file = tempfile.NamedTemporaryFile(delete=False, suffix=msg_suffix)
self.file_path = pathlib.Path(self.file.name)
self.message = Message(Question('This is a question.'),
file_path=self.file_path)
self.message.to_file()
def tearDown(self) -> None:
self.file.close()
self.file_path.unlink(missing_ok=True)
def test_rm_file(self) -> None:
assert self.message.file_path
self.assertTrue(self.message.file_path.exists())
self.message.rm_file()
self.assertFalse(self.message.file_path.exists())
+410 -12
View File
@@ -1,25 +1,33 @@
import os
import unittest
import argparse
import tempfile
from copy import copy
from pathlib import Path
from unittest.mock import MagicMock
from chatmastermind.commands.question import create_message
from unittest import mock
from unittest.mock import MagicMock, call
from chatmastermind.configuration import Config
from chatmastermind.commands.question import create_message, question_cmd
from chatmastermind.tags import Tag
from chatmastermind.message import Message, Question, Answer
from chatmastermind.chat import ChatDB
from chatmastermind.chat import Chat, ChatDB
from chatmastermind.ai import AIError
from .test_common import TestWithFakeAI
class TestMessageCreate(unittest.TestCase):
msg_suffix = Message.file_suffix_write
class TestMessageCreate(TestWithFakeAI):
"""
Test if messages created by the 'question' command have
the correct format.
"""
def setUp(self) -> None:
# create ChatDB structure
self.db_path = tempfile.TemporaryDirectory()
self.cache_path = tempfile.TemporaryDirectory()
self.chat = ChatDB.from_dir(cache_path=Path(self.cache_path.name),
db_path=Path(self.db_path.name))
self.db_dir = tempfile.TemporaryDirectory()
self.cache_dir = tempfile.TemporaryDirectory()
self.chat = ChatDB.from_dir(cache_path=Path(self.cache_dir.name),
db_path=Path(self.db_dir.name))
# create some messages
self.message_text = Message(Question("What is this?"),
Answer("It is pure text"))
@@ -33,6 +41,8 @@ class TestMessageCreate(unittest.TestCase):
self.args.AI = None
self.args.model = None
self.args.output_tags = None
self.args.ask = None
self.args.create = None
# File 1 : no source code block, only text
self.source_file1 = tempfile.NamedTemporaryFile(delete=False)
self.source_file1_content = """This is just text.
@@ -74,17 +84,18 @@ Aaaand again some text."""
os.remove(self.source_file1.name)
os.remove(self.source_file2.name)
os.remove(self.source_file3.name)
os.remove(self.source_file4.name)
def message_list(self, tmp_dir: tempfile.TemporaryDirectory) -> list[Path]:
# exclude '.next'
return list(Path(tmp_dir.name).glob('*.[ty]*'))
return list(Path(tmp_dir.name).glob(f'*{msg_suffix}'))
def test_message_file_created(self) -> None:
self.args.ask = ["What is this?"]
cache_dir_files = self.message_list(self.cache_path)
cache_dir_files = self.message_list(self.cache_dir)
self.assertEqual(len(cache_dir_files), 0)
create_message(self.chat, self.args)
cache_dir_files = self.message_list(self.cache_path)
cache_dir_files = self.message_list(self.cache_dir)
self.assertEqual(len(cache_dir_files), 1)
message = Message.from_file(cache_dir_files[0])
self.assertIsInstance(message, Message)
@@ -193,3 +204,390 @@ This is embedded source code.
It is embedded code
```
"""))
class TestCreateOption(TestMessageCreate):
def test_message_file_created(self) -> None:
self.args.create = ["How does question --create work?"]
self.args.ask = None
cache_dir_files = self.message_list(self.cache_dir)
self.assertEqual(len(cache_dir_files), 0)
create_message(self.chat, self.args)
cache_dir_files = self.message_list(self.cache_dir)
self.assertEqual(len(cache_dir_files), 1)
message = Message.from_file(cache_dir_files[0])
self.assertIsInstance(message, Message)
self.assertEqual(message.question, Question("How does question --create work?")) # type: ignore [union-attr]
class TestQuestionCmd(TestWithFakeAI):
def setUp(self) -> None:
# create DB and cache
self.db_dir = tempfile.TemporaryDirectory()
self.cache_dir = tempfile.TemporaryDirectory()
# create configuration
self.config = Config()
self.config.cache = self.cache_dir.name
self.config.db = self.db_dir.name
# create a mock argparse.Namespace
self.args = argparse.Namespace(
ask=['What is the meaning of life?'],
num_answers=1,
output_tags=['science'],
AI='FakeAI',
model='FakeModel',
or_tags=None,
and_tags=None,
exclude_tags=None,
source_text=None,
source_code=None,
create=None,
repeat=None,
process=None,
overwrite=None
)
def message_list(self, tmp_dir: tempfile.TemporaryDirectory) -> list[Path]:
# exclude '.next'
return sorted([f for f in Path(tmp_dir.name).glob(f'*{msg_suffix}')])
class TestQuestionCmdAsk(TestQuestionCmd):
@mock.patch('chatmastermind.commands.question.create_ai')
def test_ask_single_answer(self, mock_create_ai: MagicMock) -> None:
"""
Test single answer with no errors.
"""
mock_create_ai.side_effect = self.mock_create_ai
expected_question = Message(Question(self.args.ask[0]),
tags=set(self.args.output_tags),
ai=self.args.AI,
model=self.args.model,
file_path=Path('<NOT COMPARED>'))
fake_ai = self.mock_create_ai(self.args, self.config)
expected_responses = fake_ai.request(expected_question,
Chat([]),
self.args.num_answers,
self.args.output_tags).messages
# execute the command
question_cmd(self.args, self.config)
# check for the expected message files
chat = ChatDB.from_dir(Path(self.cache_dir.name),
Path(self.db_dir.name))
cached_msg = chat.msg_gather(loc='cache')
self.assertEqual(len(self.message_list(self.cache_dir)), 1)
self.assert_msgs_equal_except_file_path(cached_msg, expected_responses)
@mock.patch('chatmastermind.commands.question.ChatDB.from_dir')
@mock.patch('chatmastermind.commands.question.create_ai')
def test_ask_single_answer_mocked(self, mock_create_ai: MagicMock, mock_from_dir: MagicMock) -> None:
"""
Test single answer with no errors (mocked ChatDB version).
"""
chat = MagicMock(spec=ChatDB)
mock_from_dir.return_value = chat
mock_create_ai.side_effect = self.mock_create_ai
expected_question = Message(Question(self.args.ask[0]),
tags=set(self.args.output_tags),
ai=self.args.AI,
model=self.args.model,
file_path=Path('<NOT COMPARED>'))
fake_ai = self.mock_create_ai(self.args, self.config)
expected_responses = fake_ai.request(expected_question,
Chat([]),
self.args.num_answers,
self.args.output_tags).messages
# execute the command
question_cmd(self.args, self.config)
# check for the correct ChatDB calls:
# - initial question has been written (prior to the actual request)
# - responses have been written (after the request)
chat.cache_write.assert_has_calls([call([expected_question]),
call(expected_responses)],
any_order=False)
# check that the messages have not been added to the internal message list
chat.cache_add.assert_not_called()
@mock.patch('chatmastermind.commands.question.create_ai')
def test_ask_with_error(self, mock_create_ai: MagicMock) -> None:
"""
Provoke an error during the AI request and verify that the question
has been correctly stored in the cache.
"""
mock_create_ai.side_effect = self.mock_create_ai_with_error
expected_question = Message(Question(self.args.ask[0]),
tags=set(self.args.output_tags),
ai=self.args.AI,
model=self.args.model,
file_path=Path('<NOT COMPARED>'))
# execute the command
with self.assertRaises(AIError):
question_cmd(self.args, self.config)
# check for the expected message files
chat = ChatDB.from_dir(Path(self.cache_dir.name),
Path(self.db_dir.name))
cached_msg = chat.msg_gather(loc='cache')
self.assertEqual(len(self.message_list(self.cache_dir)), 1)
self.assert_msgs_equal_except_file_path(cached_msg, [expected_question])
class TestQuestionCmdRepeat(TestQuestionCmd):
@mock.patch('chatmastermind.commands.question.create_ai')
def test_repeat_single_question(self, mock_create_ai: MagicMock) -> None:
"""
Repeat a single question.
"""
mock_create_ai.side_effect = self.mock_create_ai
chat = ChatDB.from_dir(Path(self.cache_dir.name),
Path(self.db_dir.name))
# create a message
message = Message(Question(self.args.ask[0]),
Answer('Old Answer'),
tags=set(self.args.output_tags),
ai=self.args.AI,
model=self.args.model,
file_path=Path(self.cache_dir.name) / f'0001{msg_suffix}')
chat.msg_write([message])
# repeat the last question (without overwriting)
# -> expect two identical messages (except for the file_path)
self.args.ask = None
self.args.repeat = []
self.args.overwrite = False
expected_response = Message(Question(message.question),
Answer('Answer 0'),
ai=message.ai,
model=message.model,
tags=message.tags,
file_path=Path('<NOT COMPARED>'))
# we expect the original message + the one with the new response
expected_responses = [message] + [expected_response]
question_cmd(self.args, self.config)
cached_msg = chat.msg_gather(loc='cache')
print(self.message_list(self.cache_dir))
self.assertEqual(len(self.message_list(self.cache_dir)), 2)
self.assert_msgs_equal_except_file_path(cached_msg, expected_responses)
@mock.patch('chatmastermind.commands.question.create_ai')
def test_repeat_single_question_overwrite(self, mock_create_ai: MagicMock) -> None:
"""
Repeat a single question and overwrite the old one.
"""
mock_create_ai.side_effect = self.mock_create_ai
chat = ChatDB.from_dir(Path(self.cache_dir.name),
Path(self.db_dir.name))
# create a message
message = Message(Question(self.args.ask[0]),
Answer('Old Answer'),
tags=set(self.args.output_tags),
ai=self.args.AI,
model=self.args.model,
file_path=Path(self.cache_dir.name) / f'0001{msg_suffix}')
chat.msg_write([message])
cached_msg = chat.msg_gather(loc='cache')
assert cached_msg[0].file_path
cached_msg_file_id = cached_msg[0].file_path.stem
# repeat the last question (WITH overwriting)
# -> expect a single message afterwards (with a new answer)
self.args.ask = None
self.args.repeat = []
self.args.overwrite = True
expected_response = Message(Question(message.question),
Answer('Answer 0'),
ai=message.ai,
model=message.model,
tags=message.tags,
file_path=Path('<NOT COMPARED>'))
question_cmd(self.args, self.config)
cached_msg = chat.msg_gather(loc='cache')
self.assertEqual(len(self.message_list(self.cache_dir)), 1)
self.assert_msgs_equal_except_file_path(cached_msg, [expected_response])
# also check that the file ID has not been changed
assert cached_msg[0].file_path
self.assertEqual(cached_msg_file_id, cached_msg[0].file_path.stem)
@mock.patch('chatmastermind.commands.question.create_ai')
def test_repeat_single_question_after_error(self, mock_create_ai: MagicMock) -> None:
"""
Repeat a single question after an error.
"""
mock_create_ai.side_effect = self.mock_create_ai
chat = ChatDB.from_dir(Path(self.cache_dir.name),
Path(self.db_dir.name))
# create a question WITHOUT an answer
# -> just like after an error, which is tested above
message = Message(Question(self.args.ask[0]),
tags=set(self.args.output_tags),
ai=self.args.AI,
model=self.args.model,
file_path=Path(self.cache_dir.name) / f'0001{msg_suffix}')
chat.msg_write([message])
cached_msg = chat.msg_gather(loc='cache')
assert cached_msg[0].file_path
cached_msg_file_id = cached_msg[0].file_path.stem
# repeat the last question (without overwriting)
# -> expect a single message because if the original has
# no answer, it should be overwritten by default
self.args.ask = None
self.args.repeat = []
self.args.overwrite = False
expected_response = Message(Question(message.question),
Answer('Answer 0'),
ai=message.ai,
model=message.model,
tags=message.tags,
file_path=Path('<NOT COMPARED>'))
question_cmd(self.args, self.config)
cached_msg = chat.msg_gather(loc='cache')
self.assertEqual(len(self.message_list(self.cache_dir)), 1)
self.assert_msgs_equal_except_file_path(cached_msg, [expected_response])
# also check that the file ID has not been changed
assert cached_msg[0].file_path
self.assertEqual(cached_msg_file_id, cached_msg[0].file_path.stem)
@mock.patch('chatmastermind.commands.question.create_ai')
def test_repeat_single_question_new_args(self, mock_create_ai: MagicMock) -> None:
"""
Repeat a single question with new arguments.
"""
mock_create_ai.side_effect = self.mock_create_ai
chat = ChatDB.from_dir(Path(self.cache_dir.name),
Path(self.db_dir.name))
# create a message
message = Message(Question(self.args.ask[0]),
Answer('Old Answer'),
tags=set(self.args.output_tags),
ai=self.args.AI,
model=self.args.model,
file_path=Path(self.cache_dir.name) / f'0001{msg_suffix}')
chat.msg_write([message])
cached_msg = chat.msg_gather(loc='cache')
assert cached_msg[0].file_path
# repeat the last question with new arguments (without overwriting)
# -> expect two messages with identical question but different metadata and new answer
self.args.ask = None
self.args.repeat = []
self.args.overwrite = False
self.args.output_tags = ['newtag']
self.args.AI = 'newai'
self.args.model = 'newmodel'
new_expected_response = Message(Question(message.question),
Answer('Answer 0'),
ai='newai',
model='newmodel',
tags={Tag('newtag')},
file_path=Path('<NOT COMPARED>'))
question_cmd(self.args, self.config)
cached_msg = chat.msg_gather(loc='cache')
self.assertEqual(len(self.message_list(self.cache_dir)), 2)
self.assert_msgs_equal_except_file_path(cached_msg, [message] + [new_expected_response])
@mock.patch('chatmastermind.commands.question.create_ai')
def test_repeat_single_question_new_args_overwrite(self, mock_create_ai: MagicMock) -> None:
"""
Repeat a single question with new arguments, overwriting the old one.
"""
mock_create_ai.side_effect = self.mock_create_ai
chat = ChatDB.from_dir(Path(self.cache_dir.name),
Path(self.db_dir.name))
# create a message
message = Message(Question(self.args.ask[0]),
Answer('Old Answer'),
tags=set(self.args.output_tags),
ai=self.args.AI,
model=self.args.model,
file_path=Path(self.cache_dir.name) / f'0001{msg_suffix}')
chat.msg_write([message])
cached_msg = chat.msg_gather(loc='cache')
assert cached_msg[0].file_path
# repeat the last question with new arguments
self.args.ask = None
self.args.repeat = []
self.args.overwrite = True
self.args.output_tags = ['newtag']
self.args.AI = 'newai'
self.args.model = 'newmodel'
new_expected_response = Message(Question(message.question),
Answer('Answer 0'),
ai='newai',
model='newmodel',
tags={Tag('newtag')},
file_path=Path('<NOT COMPARED>'))
question_cmd(self.args, self.config)
cached_msg = chat.msg_gather(loc='cache')
self.assertEqual(len(self.message_list(self.cache_dir)), 1)
self.assert_msgs_equal_except_file_path(cached_msg, [new_expected_response])
@mock.patch('chatmastermind.commands.question.create_ai')
def test_repeat_multiple_questions(self, mock_create_ai: MagicMock) -> None:
"""
Repeat multiple questions.
"""
mock_create_ai.side_effect = self.mock_create_ai
chat = ChatDB.from_dir(Path(self.cache_dir.name),
Path(self.db_dir.name))
# 1. === create three questions ===
# cached message without an answer
message1 = Message(Question(self.args.ask[0]),
tags=self.args.output_tags,
ai=self.args.AI,
model=self.args.model,
file_path=Path(self.cache_dir.name) / f'0001{msg_suffix}')
# cached message with an answer
message2 = Message(Question(self.args.ask[0]),
Answer('Old Answer'),
tags=self.args.output_tags,
ai=self.args.AI,
model=self.args.model,
file_path=Path(self.cache_dir.name) / f'0002{msg_suffix}')
# DB message without an answer
message3 = Message(Question(self.args.ask[0]),
tags=self.args.output_tags,
ai=self.args.AI,
model=self.args.model,
file_path=Path(self.db_dir.name) / f'0003{msg_suffix}')
chat.msg_write([message1, message2, message3])
questions = [message1, message2, message3]
expected_responses: list[Message] = []
fake_ai = self.mock_create_ai(self.args, self.config)
for question in questions:
# since the message's answer is modified, we use a copy
# -> the original is used for comparison below
expected_responses += fake_ai.request(copy(question),
Chat([]),
self.args.num_answers,
set(self.args.output_tags)).messages
# 2. === repeat all three questions (without overwriting) ===
self.args.ask = None
self.args.repeat = ['0001', '0002', '0003']
self.args.overwrite = False
question_cmd(self.args, self.config)
# two new files should be in the cache directory
# * the repeated cached message with answer
# * the repeated DB message
# -> the cached message without answer should be overwritten
self.assertEqual(len(self.message_list(self.cache_dir)), 4)
self.assertEqual(len(self.message_list(self.db_dir)), 1)
expected_cache_messages = [expected_responses[0], message2, expected_responses[1], expected_responses[2]]
cached_msg = chat.msg_gather(loc='cache')
self.assert_msgs_equal_except_file_path(cached_msg, expected_cache_messages)
# check that the DB message has not been modified at all
db_msg = chat.msg_gather(loc='db')
self.assert_msgs_all_equal(db_msg, [message3])