19 Commits

Author SHA1 Message Date
juk0de ac3c19739d README: updates and fixes 2023-09-20 10:18:06 +02:00
juk0de ed379ed535 print_cmd: added option to print latest message 2023-09-20 10:18:06 +02:00
juk0de c43bafe47a main: improved metavar names and descriptions 2023-09-20 10:18:06 +02:00
juk0de 7dd83428fb test_question_cmd: added more testcases for '--repeat' 2023-09-20 10:18:06 +02:00
juk0de 3ad4b96b8f test_question_cmd: added testclass for the 'question_cmd()' function 2023-09-20 10:17:59 +02:00
juk0de 561003aabe question_cmd: implemented repeating of the latest message 2023-09-20 10:17:59 +02:00
juk0de 59eb45a3ca chat: improved message equality checks 2023-09-20 10:17:59 +02:00
juk0de 29a20bd2d8 message: added 'equals()' function and improved robustness and debugging 2023-09-20 10:17:59 +02:00
juk0de 80a1457dd1 configuration: the cache folder can now be specified in the configuration file 2023-09-20 10:17:59 +02:00
juk0de 25fffb6fea chat: db_read() and cache_read() now also support globbing and filtering 2023-09-17 10:59:29 +02:00
juk0de cf572e1882 chat: added functions db_move() and chat_move() (and tests) 2023-09-17 10:59:29 +02:00
juk0de 2fb7410b43 chat: added functions msg_in_cache() and msg_in_db(), also tests 2023-09-17 10:59:29 +02:00
juk0de 33ae27f00e chat: msg_remove() now supports multiple locations 2023-09-17 10:59:29 +02:00
juk0de f6a6e6036b chat: added validation during initialization 2023-09-17 10:59:29 +02:00
juk0de 525cdb92a1 message / chat: 'msg_id()' now returns 'file_path.stem' (removed suffix) 2023-09-17 10:59:29 +02:00
juk0de fc82f85b7c chat: added new functions: msg_unique_id(), msg_unique_content() and tests 2023-09-17 10:59:24 +02:00
juk0de d90845b58b chat: added new functions to ChatDB: msg_gather(), msg_find(), msg_remove() 2023-09-17 10:58:26 +02:00
juk0de 98777295d6 refactor: renamed (almost) all Chat/ChatDB functions 2023-09-17 10:58:26 +02:00
juk0de f6109949c8 chat: ChatDB now correctly ignores files that contain no valid messages 2023-09-17 10:58:10 +02:00
13 changed files with 1090 additions and 348 deletions
+51 -43
View File
@@ -46,32 +46,32 @@ cmm [global options] command [command options]
The `question` command is used to ask, create, and process questions. The `question` command is used to ask, create, and process questions.
```bash ```bash
cmm question [-t OTAGS]... [-k ATAGS]... [-x XTAGS]... [-o OUTTAGS]... [-A AI] [-M MODEL] [-n NUM] [-m MAX] [-T TEMP] (-a ASK | -c CREATE | -r REPEAT | -p PROCESS) [-O] [-s SOURCE]... [-S SOURCE]... cmm question [-t OTAGS]... [-k ATAGS]... [-x XTAGS]... [-o OUTTAGS]... [-A AI_ID] [-M MODEL] [-n NUM] [-m MAX] [-T TEMP] (-a QUESTION | -c QUESTION | -r [MESSAGE ...] | -p [MESSAGE ...]) [-O] [-s FILE]... [-S FILE]...
``` ```
* `-t, --or-tags OTAGS`: List of tags (one must match) * `-t, --or-tags OTAGS`: List of tags (one must match)
* `-k, --and-tags ATAGS`: List of tags (all must match) * `-k, --and-tags ATAGS`: List of tags (all must match)
* `-x, --exclude-tags XTAGS`: List of tags to exclude * `-x, --exclude-tags XTAGS`: List of tags to exclude
* `-o, --output-tags OUTTAGS`: List of output tags (default: use input tags) * `-o, --output-tags OUTTAGS`: List of output tags (default: use input tags)
* `-A, --AI AI` : AI ID to use * `-A, --AI AI_ID`: AI ID to use
* `-M, --model MODEL`: Model to use * `-M, --model MODEL`: Model to use
* `-n, --num-answers NUM`: Number of answers to request * `-n, --num-answers NUM`: Number of answers to request
* `-m, --max-tokens MAX`: Max. number of tokens * `-m, --max-tokens MAX`: Max. number of tokens
* `-T, --temperature TEMP`: Temperature value * `-T, --temperature TEMP`: Temperature value
* `-a, --ask ASK` : Ask a question * `-a, --ask QUESTION`: Ask a question
* `-c, --create CREATE` : Create a question * `-c, --create QUESTION`: Create a question
* `-r, --repeat REPEAT` : Repeat a question * `-r, --repeat [MESSAGE ...]`: Repeat a question
* `-p, --process PROCESS` : Process existing questions * `-p, --process [MESSAGE ...]`: Process existing questions
* `-O, --overwrite`: Overwrite existing messages when repeating them * `-O, --overwrite`: Overwrite existing messages when repeating them
* `-s, --source-text SOURCE` : Add content of a file to the query * `-s, --source-text FILE`: Add content of a file to the query
* `-S, --source-code SOURCE` : Add source code file content to the chat history * `-S, --source-code FILE`: Add source code file content to the chat history
#### Hist #### Hist
The `hist` command is used to print the chat history. The `hist` command is used to print the chat history.
```bash ```bash
cmm hist [-t OTAGS]... [-k ATAGS]... [-x XTAGS]... [-w] [-W] [-S] [-A ANSWER] [-Q QUESTION] cmm hist [-t OTAGS]... [-k ATAGS]... [-x XTAGS]... [-w] [-W] [-S] [-A SUBSTRING] [-Q SUBSTRING]
``` ```
* `-t, --or-tags OTAGS`: List of tags (one must match) * `-t, --or-tags OTAGS`: List of tags (one must match)
@@ -79,46 +79,47 @@ cmm hist [-t OTAGS]... [-k ATAGS]... [-x XTAGS]... [-w] [-W] [-S] [-A ANSWER] [-
* `-x, --exclude-tags XTAGS`: List of tags to exclude * `-x, --exclude-tags XTAGS`: List of tags to exclude
* `-w, --with-tags`: Print chat history with tags * `-w, --with-tags`: Print chat history with tags
* `-W, --with-files`: Print chat history with filenames * `-W, --with-files`: Print chat history with filenames
* `-S, --source-code-only` : Print only source code * `-S, --source-code-only`: Only print embedded source code
* `-A, --answer ANSWER` : Search for answer substring * `-A, --answer SUBSTRING`: Search for answer substring
* `-Q, --question QUESTION` : Search for question substring * `-Q, --question SUBSTRING`: Search for question substring
#### Tags #### Tags
The `tags` command is used to manage tags. The `tags` command is used to manage tags.
```bash ```bash
cmm tags (-l | -p PREFIX | -c CONTENT) cmm tags (-l | -p PREFIX | -c SUBSTRING)
``` ```
* `-l, --list`: List all tags and their frequency * `-l, --list`: List all tags and their frequency
* `-p, --prefix PREFIX`: Filter tags by prefix * `-p, --prefix PREFIX`: Filter tags by prefix
* `-c, --contain CONTENT` : Filter tags by contained substring * `-c, --contain SUBSTRING`: Filter tags by contained substring
#### Config #### Config
The `config` command is used to manage the configuration. The `config` command is used to manage the configuration.
```bash ```bash
cmm config (-l | -m | -c CREATE) cmm config (-l | -m | -c FILE)
``` ```
* `-l, --list-models`: List all available models * `-l, --list-models`: List all available models
* `-m, --print-model`: Print the currently configured model * `-m, --print-model`: Print the currently configured model
* `-c, --create CREATE` : Create config with default settings in the given file * `-c, --create FILE`: Create config with default settings in the given file
#### Print #### Print
The `print` command is used to print message files. The `print` command is used to print message files.
```bash ```bash
cmm print -f FILE [-q | -a | -S] cmm print (-f FILE | -l) [-q | -a | -S]
``` ```
* `-f, --file FILE` : File to print * `-f, --file FILE`: Print given file
* `-q, --question` : Print only question * `-l, --latest`: Print latest message
* `-a, --answer` : Print only answer * `-q, --question`: Only print the question
* `-S, --only-source-code` : Print only source code * `-a, --answer`: Only print the answer
* `-S, --only-source-code`: Only print embedded source code
### Examples ### Examples
@@ -160,18 +161,27 @@ cmm print -f example.yaml
## Configuration ## Configuration
The configuration file (`.config.yaml`) should contain the following fields: The default configuration filename is `.config.yaml` (it is searched in the current working directory).
Use the command `cmm config --create <FILENAME>` to create a default configuration:
- `openai`: ```
- `api_key`: Your OpenAI API key. cache: .
- `model`: The name of the OpenAI model to use (e.g. "text-davinci-002"). db: ./db/
- `temperature`: The temperature value for the model. ais:
- `max_tokens`: The maximum number of tokens for the model. myopenai:
- `top_p`: The top P value for the model. name: openai
- `frequency_penalty`: The frequency penalty value. model: gpt-3.5-turbo-16k
- `presence_penalty`: The presence penalty value. api_key: 0123456789
- `system`: The system message used to set the behavior of the AI. temperature: 1.0
- `db`: The directory where the question-answer pairs are stored in YAML files. max_tokens: 4000
top_p: 1.0
frequency_penalty: 0.0
presence_penalty: 0.0
system: You are an assistant
```
Each AI has its own section and the name of that section is called the 'AI ID' (in the example above it is `myopenai`).
The AI ID can be any string, as long as it's unique within the `ais` section. The AI ID is used for all commands that support the `AI` parameter and it's also stored within each message file.
## Autocompletion ## Autocompletion
@@ -186,33 +196,33 @@ After adding this line, restart your shell or run `source <your-shell-config-fil
## Contributing ## Contributing
### Enable commit hooks ### Enable commit hooks
``` ```bash
pip install pre-commit pip install pre-commit
pre-commit install pre-commit install
``` ```
### Execute tests before opening a PR ### Execute tests before opening a PR
``` ```bash
pytest pytest
``` ```
### Consider using `pyenv` / `pyenv-virtualenv` ### Consider using `pyenv` / `pyenv-virtualenv`
Short installation instructions: Short installation instructions:
* install `pyenv`: * install `pyenv`:
``` ```bash
cd ~ cd ~
git clone https://github.com/pyenv/pyenv .pyenv git clone https://github.com/pyenv/pyenv .pyenv
cd ~/.pyenv && src/configure && make -C src cd ~/.pyenv && src/configure && make -C src
``` ```
* make sure that `~/.pyenv/shims` and `~/.pyenv/bin` are the first entries in your `PATH`, e. g. by setting it in `~/.bashrc` * make sure that `~/.pyenv/shims` and `~/.pyenv/bin` are the first entries in your `PATH`, e.g., by setting it in `~/.bashrc`
* add the following to your `~/.bashrc` (after setting `PATH`): `eval "$(pyenv init -)"` * add the following to your `~/.bashrc` (after setting `PATH`): `eval "$(pyenv init -)"`
* create a new terminal or source the changes (e. g. `source ~/.bashrc`) * create a new terminal or source the changes (e.g., `source ~/.bashrc`)
* install `virtualenv` * install `virtualenv`
``` ```bash
git clone https://github.com/pyenv/pyenv-virtualenv.git $(pyenv root)/plugins/pyenv-virtualenv git clone https://github.com/pyenv/pyenv-virtualenv.git $(pyenv root)/plugins/pyenv-virtualenv
``` ```
* add the following to your `~/.bashrc` (after the commands above): `eval "$(pyenv virtualenv-init -)` * add the following to your `~/.bashrc` (after the commands above): `eval "$(pyenv virtualenv-init -)`
* create a new terminal or source the changes (e. g. `source ~/.bashrc`) * create a new terminal or source the changes (e.g., `source ~/.bashrc`)
* go back to the `ChatMasterMind` repo and create a virtual environment with the latest `Python`, e. g. `3.11.4`: * go back to the `ChatMasterMind` repo and create a virtual environment with the latest `Python`, e.g., `3.11.4`:
``` ```bash
cd <CMM_REPO_PATH> cd <CMM_REPO_PATH>
pyenv install 3.11.4 pyenv install 3.11.4
pyenv virtualenv 3.11.4 py311 pyenv virtualenv 3.11.4 py311
@@ -223,5 +233,3 @@ pyenv activate py311
## License ## License
This project is licensed under the terms of the WTFPL License. This project is licensed under the terms of the WTFPL License.
+307 -128
View File
@@ -6,7 +6,7 @@ from pathlib import Path
from pprint import PrettyPrinter from pprint import PrettyPrinter
from pydoc import pager from pydoc import pager
from dataclasses import dataclass from dataclasses import dataclass
from typing import TypeVar, Type, Optional, ClassVar, Any, Callable, Literal from typing import TypeVar, Type, Optional, ClassVar, Any, Callable, Literal, Union
from .configuration import default_config_file from .configuration import default_config_file
from .message import Message, MessageFilter, MessageError, message_in from .message import Message, MessageFilter, MessageError, message_in
from .tags import Tag from .tags import Tag
@@ -16,6 +16,7 @@ ChatDBInst = TypeVar('ChatDBInst', bound='ChatDB')
db_next_file = '.next' db_next_file = '.next'
ignored_files = [db_next_file, default_config_file] ignored_files = [db_next_file, default_config_file]
msg_location = Literal['mem', 'disk', 'cache', 'db', 'all']
class ChatError(Exception): class ChatError(Exception):
@@ -57,7 +58,7 @@ def read_dir(dir_path: Path,
if message: if message:
messages.append(message) messages.append(message)
except MessageError as e: except MessageError as e:
print(f"Error processing message in '{file_path}': {str(e)}") print(f"WARNING: Skipping message in '{file_path}': {str(e)}")
return messages return messages
@@ -106,7 +107,9 @@ def clear_dir(dir_path: Path,
""" """
file_iter = dir_path.glob(glob) if glob else dir_path.iterdir() file_iter = dir_path.glob(glob) if glob else dir_path.iterdir()
for file_path in file_iter: for file_path in file_iter:
if file_path.is_file() and file_path.suffix in Message.file_suffixes: if (file_path.is_file()
and file_path.name not in ignored_files # noqa: W503
and file_path.suffix in Message.file_suffixes): # noqa: W503
file_path.unlink(missing_ok=True) file_path.unlink(missing_ok=True)
@@ -118,14 +121,43 @@ class Chat:
messages: list[Message] messages: list[Message]
def filter(self, mfilter: MessageFilter) -> None: def __post_init__(self) -> None:
self.validate()
def validate(self) -> None:
"""
Validate this Chat instance.
"""
def msg_paths(stem: str) -> list[str]:
return [str(fp) for fp in file_paths if fp.stem == stem]
file_paths: set[Path] = {m.file_path for m in self.messages if m.file_path is not None}
file_stems = [m.file_path.stem for m in self.messages if m.file_path is not None]
error = False
for fp in file_paths:
if file_stems.count(fp.stem) > 1:
print(f"ERROR: Found multiple copies of message '{fp.stem}': {msg_paths(fp.stem)}")
error = True
if error:
raise ChatError("Validation failed")
def msg_name_matches(self, file_path: Path, name: str) -> bool:
"""
Return True if the given name matches the given file_path.
Matching is True if:
* 'name' matches the full 'file_path'
* 'name' matches 'file_path.name' (i. e. including the suffix)
* 'name' matches 'file_path.stem' (i. e. without a suffix)
"""
return Path(name) == file_path or name == file_path.name or name == file_path.stem
def msg_filter(self, mfilter: MessageFilter) -> None:
""" """
Use 'Message.match(mfilter) to remove all messages that Use 'Message.match(mfilter) to remove all messages that
don't fulfill the filter requirements. don't fulfill the filter requirements.
""" """
self.messages = [m for m in self.messages if m.match(mfilter)] self.messages = [m for m in self.messages if m.match(mfilter)]
def sort(self, reverse: bool = False) -> None: def msg_sort(self, reverse: bool = False) -> None:
""" """
Sort the messages according to 'Message.msg_id()'. Sort the messages according to 'Message.msg_id()'.
""" """
@@ -135,51 +167,71 @@ class Chat:
except MessageError: except MessageError:
pass pass
def clear(self) -> None: def msg_unique_id(self) -> None:
"""
Remove duplicates from the internal messages, based on the msg_id (i. e. file_path).
Messages without a file_path are kept.
"""
old_msgs = self.messages.copy()
self.messages = []
for m in old_msgs:
if not message_in(m, self.messages):
self.messages.append(m)
self.msg_sort()
def msg_unique_content(self) -> None:
"""
Remove duplicates from the internal messages, based on the content (i. e. question + answer).
"""
self.messages = list(set(self.messages))
self.msg_sort()
def msg_clear(self) -> None:
""" """
Delete all messages. Delete all messages.
""" """
self.messages = [] self.messages = []
def add_messages(self, messages: list[Message]) -> None: def msg_add(self, messages: list[Message]) -> None:
""" """
Add new messages and sort them if possible. Add new messages and sort them if possible.
""" """
self.messages += messages self.messages += messages
self.sort() self.msg_sort()
def latest_message(self, mfilter: Optional[MessageFilter] = None) -> Optional[Message]: def msg_latest(self, mfilter: Optional[MessageFilter] = None) -> Optional[Message]:
""" """
Return the last added message (according to the file ID) that matches the given filter. Return the last added message (according to the file ID) that matches the given filter.
When containing messages without a valid file_path, it returns the latest message in When containing messages without a valid file_path, it returns the latest message in
the internal list. the internal list.
""" """
if len(self.messages) > 0: if len(self.messages) > 0:
self.sort() self.msg_sort()
for m in reversed(self.messages): for m in reversed(self.messages):
if mfilter is None or m.match(mfilter): if mfilter is None or m.match(mfilter):
return m return m
return None return None
def find_messages(self, msg_names: list[str]) -> list[Message]: def msg_find(self, msg_names: list[str]) -> list[Message]:
""" """
Search and return the messages with the given names. Names can either be filenames Search and return the messages with the given names. Names can either be filenames
(incl. suffixes) or full paths. Messages that can't be found are ignored (i. e. the (with or without suffix), full paths or Message.msg_id(). Messages that can't be
caller should check the result if he requires all messages). found are ignored (i. e. the caller should check the result if they require all
messages).
""" """
return [m for m in self.messages return [m for m in self.messages
if any((m.file_path and (m.file_path == Path(mn) or m.file_path.name == mn)) for mn in msg_names)] if any((m.file_path and self.msg_name_matches(m.file_path, mn)) for mn in msg_names)]
def remove_messages(self, msg_names: list[str]) -> None: def msg_remove(self, msg_names: list[str]) -> None:
""" """
Remove the messages with the given names. Names can either be filenames Remove the messages with the given names. Names can either be filenames
(incl. the suffix) or full paths. (with or without suffix), full paths or Message.msg_id().
""" """
self.messages = [m for m in self.messages self.messages = [m for m in self.messages
if not any((m.file_path and (m.file_path == Path(mn) or m.file_path.name == mn)) for mn in msg_names)] if not any((m.file_path and self.msg_name_matches(m.file_path, mn)) for mn in msg_names)]
self.sort() self.msg_sort()
def tags(self, prefix: Optional[str] = None, contain: Optional[str] = None) -> set[Tag]: def msg_tags(self, prefix: Optional[str] = None, contain: Optional[str] = None) -> set[Tag]:
""" """
Get the tags of all messages, optionally filtered by prefix or substring. Get the tags of all messages, optionally filtered by prefix or substring.
""" """
@@ -188,7 +240,7 @@ class Chat:
tags |= m.filter_tags(prefix, contain) tags |= m.filter_tags(prefix, contain)
return set(sorted(tags)) return set(sorted(tags))
def tags_frequency(self, prefix: Optional[str] = None, contain: Optional[str] = None) -> dict[Tag, int]: def msg_tags_frequency(self, prefix: Optional[str] = None, contain: Optional[str] = None) -> dict[Tag, int]:
""" """
Get the frequency of all tags of all messages, optionally filtered by prefix or substring. Get the frequency of all tags of all messages, optionally filtered by prefix or substring.
""" """
@@ -245,6 +297,7 @@ class ChatDB(Chat):
# make all paths absolute # make all paths absolute
self.cache_path = self.cache_path.absolute() self.cache_path = self.cache_path.absolute()
self.db_path = self.db_path.absolute() self.db_path = self.db_path.absolute()
self.validate()
@classmethod @classmethod
def from_dir(cls: Type[ChatDBInst], def from_dir(cls: Type[ChatDBInst],
@@ -292,84 +345,185 @@ class ChatDB(Chat):
with open(self.next_path, 'w') as f: with open(self.next_path, 'w') as f:
f.write(f'{fid}') f.write(f'{fid}')
def read_db(self) -> None: def msg_write(self, messages: Optional[list[Message]] = None) -> None:
""" """
Reads new messages from the DB directory. New ones are added to the internal list, Write either the given messages or the internal ones to their CURRENT file_path.
existing ones are replaced. A message is determined as 'existing' if a message with If messages are given, they all must have a valid file_path. When writing the
internal messages, the ones with a valid file_path are written, the others
are ignored.
"""
if messages and any(m.file_path is None for m in messages):
raise ChatError("Can't write files without a valid file_path")
msgs = iter(messages if messages else self.messages)
while (m := next(msgs, None)):
m.to_file()
def msg_update(self, messages: list[Message], write: bool = True) -> None:
"""
Update EXISTING messages. A message is determined as 'existing' if a message with
the same base filename (i. e. 'file_path.name') is already in the list. the same base filename (i. e. 'file_path.name') is already in the list.
Only accepts existing messages.
""" """
new_messages = read_dir(self.db_path, self.glob, self.mfilter) if any(not message_in(m, self.messages) for m in messages):
raise ChatError("Can't update messages that are not in the internal list")
# remove old versions and add new ones
self.messages = [m for m in self.messages if not message_in(m, messages)]
self.messages += messages
self.msg_sort()
# write the UPDATED messages if requested
if write:
self.msg_write(messages)
def msg_gather(self,
loc: msg_location,
require_file_path: bool = False,
mfilter: Optional[MessageFilter] = None) -> list[Message]:
"""
Gather and return messages from the given locations:
* 'mem' : messages currently in memory
* 'disk' : messages on disk (cache + DB directory), but not in memory
* 'cache': messages in the cache directory
* 'db' : messages in the DB directory
* 'all' : all messages ('mem' + 'disk')
If 'require_file_path' is True, return only files with a valid file_path.
"""
loc_messages: list[Message] = []
if loc in ['mem', 'all']:
if require_file_path:
loc_messages += [m for m in self.messages if (m.file_path is not None and (mfilter is None or m.match(mfilter)))]
else:
loc_messages += [m for m in self.messages if (mfilter is None or m.match(mfilter))]
if loc in ['cache', 'disk', 'all']:
loc_messages += read_dir(self.cache_path, mfilter=mfilter)
if loc in ['db', 'disk', 'all']:
loc_messages += read_dir(self.db_path, mfilter=mfilter)
# remove_duplicates and sort the list
unique_messages: list[Message] = []
for m in loc_messages:
if not message_in(m, unique_messages):
unique_messages.append(m)
try:
unique_messages.sort(key=lambda m: m.msg_id())
# messages in 'mem' can have an empty file_path
except MessageError:
pass
return unique_messages
def msg_find(self,
msg_names: list[str],
loc: msg_location = 'mem',
) -> list[Message]:
"""
Search and return the messages with the given names. Names can either be filenames
(with or without suffix), full paths or Message.msg_id(). Messages that can't be
found are ignored (i. e. the caller should check the result if they require all
messages).
Searches one of the following locations:
* 'mem' : messages currently in memory
* 'disk' : messages on disk (cache + DB directory), but not in memory
* 'cache': messages in the cache directory
* 'db' : messages in the DB directory
* 'all' : all messages ('mem' + 'disk')
"""
loc_messages = self.msg_gather(loc, require_file_path=True)
return [m for m in loc_messages
if any((m.file_path and self.msg_name_matches(m.file_path, mn)) for mn in msg_names)]
def msg_remove(self, msg_names: list[str], loc: msg_location = 'mem') -> None:
"""
Remove the messages with the given names. Names can either be filenames
(with or without suffix), full paths or Message.msg_id(). Also deletes the
files of all given messages with a valid file_path.
Delete files from one of the following locations:
* 'mem' : messages currently in memory
* 'disk' : messages on disk (cache + DB directory), but not in memory
* 'cache': messages in the cache directory
* 'db' : messages in the DB directory
* 'all' : all messages ('mem' + 'disk')
"""
if loc != 'mem':
# delete the message files first
rm_messages = self.msg_find(msg_names, loc=loc)
for m in rm_messages:
if (m.file_path):
m.file_path.unlink()
# then remove them from the internal list
super().msg_remove(msg_names)
def msg_latest(self,
mfilter: Optional[MessageFilter] = None,
loc: msg_location = 'mem') -> Optional[Message]:
"""
Return the last added message (according to the file ID) that matches the given filter.
Only consider messages with a valid file_path (except if loc is 'mem').
Searches one of the following locations:
* 'mem' : messages currently in memory
* 'disk' : messages on disk (cache + DB directory), but not in memory
* 'cache': messages in the cache directory
* 'db' : messages in the DB directory
* 'all' : all messages ('mem' + 'disk')
"""
# only consider messages with a valid file_path so they can be sorted
loc_messages = self.msg_gather(loc, require_file_path=True)
loc_messages.sort(key=lambda m: m.msg_id(), reverse=True)
for m in loc_messages:
if mfilter is None or m.match(mfilter):
return m
return None
def msg_in_cache(self, message: Union[Message, str]) -> bool:
"""
Return true if the given Message (or filename or Message.msg_id())
is located in the cache directory. False otherwise.
"""
if isinstance(message, Message):
return (message.file_path is not None
and message.file_path.parent.samefile(self.cache_path) # noqa: W503
and message.file_path.exists()) # noqa: W503
else:
return len(self.msg_find([message], loc='cache')) > 0
def msg_in_db(self, message: Union[Message, str]) -> bool:
"""
Return true if the given Message (or filename or Message.msg_id())
is located in the DB directory. False otherwise.
"""
if isinstance(message, Message):
return (message.file_path is not None
and message.file_path.parent.samefile(self.db_path) # noqa: W503
and message.file_path.exists()) # noqa: W503
else:
return len(self.msg_find([message], loc='db')) > 0
def cache_read(self, glob: Optional[str] = None, mfilter: Optional[MessageFilter] = None) -> None:
"""
Read messages from the cache directory. New ones are added to the internal list,
existing ones are replaced. A message is determined as 'existing' if a message
with the same base filename (i. e. 'file_path.name') is already in the list.
"""
new_messages = read_dir(self.cache_path, glob, mfilter)
# remove all messages from self.messages that are in the new list # remove all messages from self.messages that are in the new list
self.messages = [m for m in self.messages if not message_in(m, new_messages)] self.messages = [m for m in self.messages if not message_in(m, new_messages)]
# copy the messages from the temporary list to self.messages and sort them # copy the messages from the temporary list to self.messages and sort them
self.messages += new_messages self.messages += new_messages
self.sort() self.msg_sort()
def read_cache(self) -> None: def cache_write(self, messages: Optional[list[Message]] = None) -> None:
"""
Reads new messages from the cache directory. New ones are added to the internal list,
existing ones are replaced. A message is determined as 'existing' if a message with
the same base filename (i. e. 'file_path.name') is already in the list.
"""
new_messages = read_dir(self.cache_path, self.glob, self.mfilter)
# remove all messages from self.messages that are in the new list
self.messages = [m for m in self.messages if not message_in(m, new_messages)]
# copy the messages from the temporary list to self.messages and sort them
self.messages += new_messages
self.sort()
def write_db(self, messages: Optional[list[Message]] = None) -> None:
"""
Write messages to the DB directory. If a message has no file_path, a new one
will be created. If message.file_path exists, it will be modified to point
to the DB directory.
"""
write_dir(self.db_path,
messages if messages else self.messages,
self.file_suffix,
self.get_next_fid)
def write_cache(self, messages: Optional[list[Message]] = None) -> None:
""" """
Write messages to the cache directory. If a message has no file_path, a new one Write messages to the cache directory. If a message has no file_path, a new one
will be created. If message.file_path exists, it will be modified to point to will be created. If message.file_path exists, it will be modified to point to
the cache directory. the cache directory.
Does NOT add the messages to the internal list (use 'cache_add()' for that)!
""" """
write_dir(self.cache_path, write_dir(self.cache_path,
messages if messages else self.messages, messages if messages else self.messages,
self.file_suffix, self.file_suffix,
self.get_next_fid) self.get_next_fid)
def clear_cache(self) -> None: def cache_add(self, messages: list[Message], write: bool = True) -> None:
""" """
Deletes all Message files from the cache dir and removes those messages from Add NEW messages and set the file_path to the cache directory.
the internal list.
"""
clear_dir(self.cache_path, self.glob)
# only keep messages from DB dir (or those that have not yet been written)
self.messages = [m for m in self.messages if not m.file_path or m.file_path.parent.samefile(self.db_path)]
def add_to_db(self, messages: list[Message], write: bool = True) -> None:
"""
Add the given new messages and set the file_path to the DB directory.
Only accepts messages without a file_path.
"""
if any(m.file_path is not None for m in messages):
raise ChatError("Can't add new messages with existing file_path")
if write:
write_dir(self.db_path,
messages,
self.file_suffix,
self.get_next_fid)
else:
for m in messages:
m.file_path = make_file_path(self.db_path, self.default_file_suffix, self.get_next_fid)
self.messages += messages
self.sort()
def add_to_cache(self, messages: list[Message], write: bool = True) -> None:
"""
Add the given new messages and set the file_path to the cache directory.
Only accepts messages without a file_path. Only accepts messages without a file_path.
""" """
if any(m.file_path is not None for m in messages): if any(m.file_path is not None for m in messages):
@@ -383,62 +537,87 @@ class ChatDB(Chat):
for m in messages: for m in messages:
m.file_path = make_file_path(self.cache_path, self.default_file_suffix, self.get_next_fid) m.file_path = make_file_path(self.cache_path, self.default_file_suffix, self.get_next_fid)
self.messages += messages self.messages += messages
self.sort() self.msg_sort()
def write_messages(self, messages: Optional[list[Message]] = None) -> None: def cache_clear(self, glob: Optional[str] = None) -> None:
""" """
Write either the given messages or the internal ones to their current file_path. Delete all message files from the cache dir and remove them from the internal list.
If messages are given, they all must have a valid file_path. When writing the
internal messages, the ones with a valid file_path are written, the others
are ignored.
""" """
if messages and any(m.file_path is None for m in messages): clear_dir(self.cache_path, glob)
raise ChatError("Can't write files without a valid file_path") # only keep messages from DB dir (or those that have not yet been written)
msgs = iter(messages if messages else self.messages) self.messages = [m for m in self.messages if not m.file_path or m.file_path.parent.samefile(self.db_path)]
while (m := next(msgs, None)):
m.to_file()
def update_messages(self, messages: list[Message], write: bool = True) -> None: def cache_move(self, message: Message) -> None:
""" """
Update existing messages. A message is determined as 'existing' if a message with Moves the given messages to the cache directory.
the same base filename (i. e. 'file_path.name') is already in the list. Only accepts
existing messages.
""" """
if any(not message_in(m, self.messages) for m in messages): # remember the old path (if any)
raise ChatError("Can't update messages that are not in the internal list") old_path: Optional[Path] = None
# remove old versions and add new ones if message.file_path:
self.messages = [m for m in self.messages if not message_in(m, messages)] old_path = message.file_path
self.messages += messages # write message to the new destination
self.sort() self.cache_write([message])
# write the UPDATED messages if requested # remove the old one (if any)
if old_path:
self.msg_remove([str(old_path)], loc='db')
# (re)add it to the internal list
self.msg_add([message])
def db_read(self, glob: Optional[str] = None, mfilter: Optional[MessageFilter] = None) -> None:
"""
Read messages from the DB directory. New ones are added to the internal list,
existing ones are replaced. A message is determined as 'existing' if a message
with the same base filename (i. e. 'file_path.name') is already in the list.
"""
new_messages = read_dir(self.db_path, self.glob, self.mfilter)
# remove all messages from self.messages that are in the new list
self.messages = [m for m in self.messages if not message_in(m, new_messages)]
# copy the messages from the temporary list to self.messages and sort them
self.messages += new_messages
self.msg_sort()
def db_write(self, messages: Optional[list[Message]] = None) -> None:
"""
Write messages to the DB directory. If a message has no file_path, a new one
will be created. If message.file_path exists, it will be modified to point
to the DB directory.
Does NOT add the messages to the internal list (use 'db_add()' for that)!
"""
write_dir(self.db_path,
messages if messages else self.messages,
self.file_suffix,
self.get_next_fid)
def db_add(self, messages: list[Message], write: bool = True) -> None:
"""
Add NEW messages and set the file_path to the DB directory.
Only accepts messages without a file_path.
"""
if any(m.file_path is not None for m in messages):
raise ChatError("Can't add new messages with existing file_path")
if write: if write:
self.write_messages(messages) write_dir(self.db_path,
messages,
self.file_suffix,
self.get_next_fid)
else:
for m in messages:
m.file_path = make_file_path(self.db_path, self.default_file_suffix, self.get_next_fid)
self.messages += messages
self.msg_sort()
def latest_message(self, def db_move(self, message: Message) -> None:
mfilter: Optional[MessageFilter] = None,
source: Literal['mem', 'disk', 'cache', 'db', 'all'] = 'mem') -> Optional[Message]:
""" """
Return the last added message (according to the file ID) that matches the given filter. Moves the given messages to the db directory.
Only consider messages with a valid file_path (except if source is 'mem').
Searches one of the following sources:
* 'mem' : only search messages currently in memory
* 'disk' : search messages on disk (cache + DB directory), but not in memory
* 'cache': only search messages in the cache directory
* 'db' : only search messages in the DB directory
* 'all' : search all messages ('mem' + 'disk')
""" """
source_messages: list[Message] = [] # remember the old path (if any)
if source == 'mem': old_path: Optional[Path] = None
return super().latest_message(mfilter) if message.file_path:
if source in ['cache', 'disk', 'all']: old_path = message.file_path
source_messages += read_dir(self.cache_path, mfilter=mfilter) # write message to the new destination
if source in ['db', 'disk', 'all']: self.db_write([message])
source_messages += read_dir(self.db_path, mfilter=mfilter) # remove the old one (if any)
if source in ['all']: if old_path:
# only consider messages with a valid file_path so they can be sorted self.msg_remove([str(old_path)], loc='cache')
source_messages += [m for m in self.messages if (m.file_path is not None and (mfilter is None or m.match(mfilter)))] # (re)add it to the internal list
source_messages.sort(key=lambda m: m.msg_id(), reverse=True) self.msg_add([message])
for m in source_messages:
if mfilter is None or m.match(mfilter):
return m
return None
+1 -1
View File
@@ -15,7 +15,7 @@ def hist_cmd(args: argparse.Namespace, config: Config) -> None:
tags_not=args.exclude_tags, tags_not=args.exclude_tags,
question_contains=args.question, question_contains=args.question,
answer_contains=args.answer) answer_contains=args.answer)
chat = ChatDB.from_dir(Path('.'), chat = ChatDB.from_dir(Path(config.cache),
Path(config.db), Path(config.db),
mfilter=mfilter) mfilter=mfilter)
chat.print(args.source_code_only, chat.print(args.source_code_only,
+24 -6
View File
@@ -3,16 +3,13 @@ import argparse
from pathlib import Path from pathlib import Path
from ..configuration import Config from ..configuration import Config
from ..message import Message, MessageError from ..message import Message, MessageError
from ..chat import ChatDB
def print_cmd(args: argparse.Namespace, config: Config) -> None: def print_message(message: Message, args: argparse.Namespace) -> None:
""" """
Handler for the 'print' command. Print given message according to give arguments.
""" """
fname = Path(args.file)
try:
message = Message.from_file(fname)
if message:
if args.question: if args.question:
print(message.question) print(message.question)
elif args.answer: elif args.answer:
@@ -22,6 +19,27 @@ def print_cmd(args: argparse.Namespace, config: Config) -> None:
print(code) print(code)
else: else:
print(message.to_str()) print(message.to_str())
def print_cmd(args: argparse.Namespace, config: Config) -> None:
"""
Handler for the 'print' command.
"""
# print given file
if args.file is not None:
fname = Path(args.file)
try:
message = Message.from_file(fname)
if message:
print_message(message, args)
except MessageError: except MessageError:
print(f"File is not a valid message: {args.file}") print(f"File is not a valid message: {args.file}")
sys.exit(1) sys.exit(1)
# print latest message
elif args.latest:
chat = ChatDB.from_dir(Path(config.cache), Path(config.db))
latest = chat.msg_latest(loc='disk')
if not latest:
print("No message found!")
sys.exit(1)
print_message(latest, args)
+54 -24
View File
@@ -1,3 +1,4 @@
import sys
import argparse import argparse
from pathlib import Path from pathlib import Path
from itertools import zip_longest from itertools import zip_longest
@@ -51,7 +52,8 @@ def add_file_as_code(question_parts: list[str], file: str) -> None:
def create_message(chat: ChatDB, args: argparse.Namespace) -> Message: def create_message(chat: ChatDB, args: argparse.Namespace) -> Message:
""" """
Creates (and writes) a new message from the given arguments. Create a new message from the given arguments and write it
to the cache directory.
""" """
question_parts = [] question_parts = []
question_list = args.ask if args.ask is not None else [] question_list = args.ask if args.ask is not None else []
@@ -69,13 +71,40 @@ def create_message(chat: ChatDB, args: argparse.Namespace) -> Message:
full_question = '\n\n'.join(question_parts) full_question = '\n\n'.join(question_parts)
message = Message(question=Question(full_question), message = Message(question=Question(full_question),
tags=args.output_tags, # FIXME tags=args.output_tags,
ai=args.AI, ai=args.AI,
model=args.model) model=args.model)
chat.add_to_cache([message]) # only write the new message to the cache,
# don't add it to the internal list
chat.cache_write([message])
return message return message
def make_request(ai: AI, chat: ChatDB, message: Message, args: argparse.Namespace) -> None:
"""
Make an AI request with the given AI, chat history, message and arguments.
Write the response(s) to the cache directory, without appending it to the
given chat history. Then print the response(s).
"""
# print history and message question before making the request
ai.print()
chat.print(paged=False)
print(message.to_str())
response: AIResponse = ai.request(message,
chat,
args.num_answers,
args.output_tags)
# only write the response messages to the cache,
# don't add them to the internal list
chat.cache_write(response.messages)
for idx, msg in enumerate(response.messages):
print(f"=== ANSWER {idx+1} ===")
print(msg.answer)
if response.tokens:
print("===============")
print(response.tokens)
def question_cmd(args: argparse.Namespace, config: Config) -> None: def question_cmd(args: argparse.Namespace, config: Config) -> None:
""" """
Handler for the 'question' command. Handler for the 'question' command.
@@ -83,7 +112,7 @@ def question_cmd(args: argparse.Namespace, config: Config) -> None:
mfilter = MessageFilter(tags_or=args.or_tags if args.or_tags is not None else set(), mfilter = MessageFilter(tags_or=args.or_tags if args.or_tags is not None else set(),
tags_and=args.and_tags if args.and_tags is not None else set(), tags_and=args.and_tags if args.and_tags is not None else set(),
tags_not=args.exclude_tags if args.exclude_tags is not None else set()) tags_not=args.exclude_tags if args.exclude_tags is not None else set())
chat = ChatDB.from_dir(cache_path=Path('.'), chat = ChatDB.from_dir(cache_path=Path(config.cache),
db_path=Path(config.db), db_path=Path(config.db),
mfilter=mfilter) mfilter=mfilter)
# if it's a new question, create and store it immediately # if it's a new question, create and store it immediately
@@ -94,28 +123,29 @@ def question_cmd(args: argparse.Namespace, config: Config) -> None:
# create the correct AI instance # create the correct AI instance
ai: AI = create_ai(args, config) ai: AI = create_ai(args, config)
# === ASK ===
if args.ask: if args.ask:
ai.print() make_request(ai, chat, message, args)
chat.print(paged=False) # === REPEAT ===
response: AIResponse = ai.request(message,
chat,
args.num_answers, # FIXME
args.output_tags) # FIXME
chat.update_messages([response.messages[0]])
chat.add_to_cache(response.messages[1:])
for idx, msg in enumerate(response.messages):
print(f"=== ANSWER {idx+1} ===")
print(msg.answer)
if response.tokens:
print("===============")
print(response.tokens)
elif args.repeat is not None: elif args.repeat is not None:
lmessage = chat.latest_message() lmessage = chat.msg_latest(loc='cache')
assert lmessage if lmessage is None:
# TODO: repeat either the last question or the print("No message found to repeat!")
# one(s) given in 'args.repeat' (overwrite sys.exit(1)
# existing ones if 'args.overwrite' is True) else:
pass print(f"Repeating message '{lmessage.msg_id()}':")
# overwrite the latest message if requested or empty
if lmessage.answer is None or args.overwrite is True:
lmessage.clear_answer()
make_request(ai, chat, lmessage, args)
# otherwise create a new one
else:
args.ask = [lmessage.question]
message = create_message(chat, args)
make_request(ai, chat, message, args)
# === PROCESS ===
elif args.process is not None: elif args.process is not None:
# TODO: process either all questions without an # TODO: process either all questions without an
# answer or the one(s) given in 'args.process' # answer or the one(s) given in 'args.process'
+2 -2
View File
@@ -8,10 +8,10 @@ def tags_cmd(args: argparse.Namespace, config: Config) -> None:
""" """
Handler for the 'tags' command. Handler for the 'tags' command.
""" """
chat = ChatDB.from_dir(cache_path=Path('.'), chat = ChatDB.from_dir(cache_path=Path(config.cache),
db_path=Path(config.db)) db_path=Path(config.db))
if args.list: if args.list:
tags_freq = chat.tags_frequency(args.prefix, args.contain) tags_freq = chat.msg_tags_frequency(args.prefix, args.contain)
for tag, freq in tags_freq.items(): for tag, freq in tags_freq.items():
print(f"- {tag}: {freq}") print(f"- {tag}: {freq}")
# TODO: add renaming # TODO: add renaming
+2
View File
@@ -116,6 +116,7 @@ class Config:
""" """
# all members have default values, so we can easily create # all members have default values, so we can easily create
# a default configuration # a default configuration
cache: str = '.'
db: str = './db/' db: str = './db/'
ais: dict[str, AIConfig] = field(default_factory=create_default_ai_configs) ais: dict[str, AIConfig] = field(default_factory=create_default_ai_configs)
@@ -132,6 +133,7 @@ class Config:
ai_conf = ai_config_instance(conf['name'], conf) ai_conf = ai_config_instance(conf['name'], conf)
ais[ID] = ai_conf ais[ID] = ai_conf
return cls( return cls(
cache=str(source['cache']) if 'cache' in source else '.',
db=str(source['db']), db=str(source['db']),
ais=ais ais=ais
) )
+22 -19
View File
@@ -44,13 +44,13 @@ def create_parser() -> argparse.ArgumentParser:
help='List of tags to exclude', metavar='XTAGS') help='List of tags to exclude', metavar='XTAGS')
etag_arg.completer = tags_completer # type: ignore etag_arg.completer = tags_completer # type: ignore
otag_arg = tag_parser.add_argument('-o', '--output-tags', nargs='+', otag_arg = tag_parser.add_argument('-o', '--output-tags', nargs='+',
help='List of output tags (default: use input tags)', metavar='OUTTAGS') help='List of output tags (default: use input tags)', metavar='OUTAGS')
otag_arg.completer = tags_completer # type: ignore otag_arg.completer = tags_completer # type: ignore
# a parent parser for all commands that support AI configuration # a parent parser for all commands that support AI configuration
ai_parser = argparse.ArgumentParser(add_help=False) ai_parser = argparse.ArgumentParser(add_help=False)
ai_parser.add_argument('-A', '--AI', help='AI ID to use') ai_parser.add_argument('-A', '--AI', help='AI ID to use', metavar='AI_ID')
ai_parser.add_argument('-M', '--model', help='Model to use') ai_parser.add_argument('-M', '--model', help='Model to use', metavar='MODEL')
ai_parser.add_argument('-n', '--num-answers', help='Number of answers to request', type=int, default=1) ai_parser.add_argument('-n', '--num-answers', help='Number of answers to request', type=int, default=1)
ai_parser.add_argument('-m', '--max-tokens', help='Max. nr. of tokens', type=int) ai_parser.add_argument('-m', '--max-tokens', help='Max. nr. of tokens', type=int)
ai_parser.add_argument('-T', '--temperature', help='Temperature value', type=float) ai_parser.add_argument('-T', '--temperature', help='Temperature value', type=float)
@@ -61,14 +61,15 @@ def create_parser() -> argparse.ArgumentParser:
aliases=['q']) aliases=['q'])
question_cmd_parser.set_defaults(func=question_cmd) question_cmd_parser.set_defaults(func=question_cmd)
question_group = question_cmd_parser.add_mutually_exclusive_group(required=True) question_group = question_cmd_parser.add_mutually_exclusive_group(required=True)
question_group.add_argument('-a', '--ask', nargs='+', help='Ask a question') question_group.add_argument('-a', '--ask', nargs='+', help='Ask a question', metavar='QUESTION')
question_group.add_argument('-c', '--create', nargs='+', help='Create a question') question_group.add_argument('-c', '--create', nargs='+', help='Create a question', metavar='QUESTION')
question_group.add_argument('-r', '--repeat', nargs='*', help='Repeat a question') question_group.add_argument('-r', '--repeat', nargs='*', help='Repeat a question', metavar='MESSAGE')
question_group.add_argument('-p', '--process', nargs='*', help='Process existing questions') question_group.add_argument('-p', '--process', nargs='*', help='Process existing questions', metavar='MESSAGE')
question_cmd_parser.add_argument('-O', '--overwrite', help='Overwrite existing messages when repeating them', question_cmd_parser.add_argument('-O', '--overwrite', help='Overwrite existing messages when repeating them',
action='store_true') action='store_true')
question_cmd_parser.add_argument('-s', '--source-text', nargs='+', help='Add content of a file to the query') question_cmd_parser.add_argument('-s', '--source-text', nargs='+', help='Add content of a file to the query', metavar='FILE')
question_cmd_parser.add_argument('-S', '--source-code', nargs='+', help='Add source code file content to the chat history') question_cmd_parser.add_argument('-S', '--source-code', nargs='+', help='Add source code file content to the chat history',
metavar='FILE')
# 'hist' command parser # 'hist' command parser
hist_cmd_parser = cmdparser.add_parser('hist', parents=[tag_parser], hist_cmd_parser = cmdparser.add_parser('hist', parents=[tag_parser],
@@ -79,10 +80,10 @@ def create_parser() -> argparse.ArgumentParser:
action='store_true') action='store_true')
hist_cmd_parser.add_argument('-W', '--with-files', help="Print chat history with filenames.", hist_cmd_parser.add_argument('-W', '--with-files', help="Print chat history with filenames.",
action='store_true') action='store_true')
hist_cmd_parser.add_argument('-S', '--source-code-only', help='Print only source code', hist_cmd_parser.add_argument('-S', '--source-code-only', help='Only print embedded source code',
action='store_true') action='store_true')
hist_cmd_parser.add_argument('-A', '--answer', help='Search for answer substring') hist_cmd_parser.add_argument('-A', '--answer', help='Search for answer substring', metavar='SUBSTRING')
hist_cmd_parser.add_argument('-Q', '--question', help='Search for question substring') hist_cmd_parser.add_argument('-Q', '--question', help='Search for question substring', metavar='SUBSTRING')
# 'tags' command parser # 'tags' command parser
tags_cmd_parser = cmdparser.add_parser('tags', tags_cmd_parser = cmdparser.add_parser('tags',
@@ -92,8 +93,8 @@ def create_parser() -> argparse.ArgumentParser:
tags_group = tags_cmd_parser.add_mutually_exclusive_group(required=True) tags_group = tags_cmd_parser.add_mutually_exclusive_group(required=True)
tags_group.add_argument('-l', '--list', help="List all tags and their frequency", tags_group.add_argument('-l', '--list', help="List all tags and their frequency",
action='store_true') action='store_true')
tags_cmd_parser.add_argument('-p', '--prefix', help="Filter tags by prefix") tags_cmd_parser.add_argument('-p', '--prefix', help="Filter tags by prefix", metavar='PREFIX')
tags_cmd_parser.add_argument('-c', '--contain', help="Filter tags by contained substring") tags_cmd_parser.add_argument('-c', '--contain', help="Filter tags by contained substring", metavar='SUBSTRING')
# 'config' command parser # 'config' command parser
config_cmd_parser = cmdparser.add_parser('config', config_cmd_parser = cmdparser.add_parser('config',
@@ -106,18 +107,20 @@ def create_parser() -> argparse.ArgumentParser:
action='store_true') action='store_true')
config_group.add_argument('-m', '--print-model', help="Print the currently configured model", config_group.add_argument('-m', '--print-model', help="Print the currently configured model",
action='store_true') action='store_true')
config_group.add_argument('-c', '--create', help="Create config with default settings in the given file") config_group.add_argument('-c', '--create', help="Create config with default settings in the given file", metavar='FILE')
# 'print' command parser # 'print' command parser
print_cmd_parser = cmdparser.add_parser('print', print_cmd_parser = cmdparser.add_parser('print',
help="Print message files.", help="Print message files.",
aliases=['p']) aliases=['p'])
print_cmd_parser.set_defaults(func=print_cmd) print_cmd_parser.set_defaults(func=print_cmd)
print_cmd_parser.add_argument('-f', '--file', help='File to print', required=True) print_group = print_cmd_parser.add_mutually_exclusive_group(required=True)
print_group.add_argument('-f', '--file', help='Print given message file', metavar='FILE')
print_group.add_argument('-l', '--latest', help='Print latest message', action='store_true')
print_cmd_modes = print_cmd_parser.add_mutually_exclusive_group() print_cmd_modes = print_cmd_parser.add_mutually_exclusive_group()
print_cmd_modes.add_argument('-q', '--question', help='Print only question', action='store_true') print_cmd_modes.add_argument('-q', '--question', help='Only print the question', action='store_true')
print_cmd_modes.add_argument('-a', '--answer', help='Print only answer', action='store_true') print_cmd_modes.add_argument('-a', '--answer', help='Only print the answer', action='store_true')
print_cmd_modes.add_argument('-S', '--only-source-code', help='Print only source code', action='store_true') print_cmd_modes.add_argument('-S', '--only-source-code', help='Only print embedded source code', action='store_true')
argcomplete.autocomplete(parser) argcomplete.autocomplete(parser)
return parser return parser
+34 -6
View File
@@ -222,12 +222,36 @@ class Message():
ai_yaml_key: ClassVar[str] = 'ai' ai_yaml_key: ClassVar[str] = 'ai'
model_yaml_key: ClassVar[str] = 'model' model_yaml_key: ClassVar[str] = 'model'
def __post_init__(self) -> None:
# convert some types that are often set wrong
if self.tags is not None and not isinstance(self.tags, set):
self.tags = set(self.tags)
if self.file_path is not None and not isinstance(self.file_path, pathlib.Path):
self.file_path = pathlib.Path(self.file_path)
def __hash__(self) -> int: def __hash__(self) -> int:
""" """
The hash value is computed based on immutable members. The hash value is computed based on immutable members.
""" """
return hash((self.question, self.answer)) return hash((self.question, self.answer))
def equals(self, other: MessageInst, tags: bool = True, ai: bool = True,
model: bool = True, file_path: bool = True, verbose: bool = False) -> bool:
"""
Compare this message with another one, including the metadata.
Return True if everything is identical, False otherwise.
"""
equal: bool = ((not tags or (self.tags == other.tags))
and (not ai or (self.ai == other.ai)) # noqa: W503
and (not model or (self.model == other.model)) # noqa: W503
and (not file_path or (self.file_path == other.file_path)) # noqa: W503
and (self == other)) # noqa: W503
if not equal and verbose:
print("Messages not equal:")
print(self)
print(other)
return equal
@classmethod @classmethod
def from_dict(cls: Type[MessageInst], data: dict[str, Any]) -> MessageInst: def from_dict(cls: Type[MessageInst], data: dict[str, Any]) -> MessageInst:
""" """
@@ -370,7 +394,7 @@ class Message():
try: try:
question_idx = text.index(Question.txt_header) + 1 question_idx = text.index(Question.txt_header) + 1
except ValueError: except ValueError:
raise MessageError(f"Question header '{Question.txt_header}' not found in '{file_path}'") raise MessageError(f"'{file_path}' does not contain a valid message")
try: try:
answer_idx = text.index(Answer.txt_header) answer_idx = text.index(Answer.txt_header)
question = Question.from_list(text[question_idx:answer_idx]) question = Question.from_list(text[question_idx:answer_idx])
@@ -390,9 +414,12 @@ class Message():
* Message.model_yaml_key: str [Optional] * Message.model_yaml_key: str [Optional]
""" """
with open(file_path, "r") as fd: with open(file_path, "r") as fd:
try:
data = yaml.load(fd, Loader=yaml.FullLoader) data = yaml.load(fd, Loader=yaml.FullLoader)
data[cls.file_yaml_key] = file_path data[cls.file_yaml_key] = file_path
return cls.from_dict(data) return cls.from_dict(data)
except Exception:
raise MessageError(f"'{file_path}' does not contain a valid message")
def to_str(self, with_tags: bool = False, with_file: bool = False, source_code_only: bool = False) -> str: def to_str(self, with_tags: bool = False, with_file: bool = False, source_code_only: bool = False) -> str:
""" """
@@ -415,9 +442,6 @@ class Message():
output.append(self.answer) output.append(self.answer)
return '\n'.join(output) return '\n'.join(output)
def __str__(self) -> str:
return self.to_str(True, True, False)
def to_file(self, file_path: Optional[pathlib.Path]=None) -> None: # noqa: 11 def to_file(self, file_path: Optional[pathlib.Path]=None) -> None: # noqa: 11
""" """
Write a Message to the given file. Type is determined based on the suffix. Write a Message to the given file. Type is determined based on the suffix.
@@ -537,13 +561,17 @@ class Message():
if self.tags: if self.tags:
self.tags = rename_tags(self.tags, tags_rename) self.tags = rename_tags(self.tags, tags_rename)
def clear_answer(self) -> None:
self.answer = None
def msg_id(self) -> str: def msg_id(self) -> str:
""" """
Returns an ID that is unique throughout all messages in the same (DB) directory. Returns an ID that is unique throughout all messages in the same (DB) directory.
Currently this is the file name. The ID is also used for sorting messages. Currently this is the file name without suffix. The ID is also used for sorting
messages.
""" """
if self.file_path: if self.file_path:
return self.file_path.name return self.file_path.stem
else: else:
raise MessageError("Can't create file ID without a file path") raise MessageError("Can't create file ID without a file path")
+225 -76
View File
@@ -2,6 +2,7 @@ import unittest
import pathlib import pathlib
import tempfile import tempfile
import time import time
import yaml
from io import StringIO from io import StringIO
from unittest.mock import patch from unittest.mock import patch
from chatmastermind.tags import TagLine from chatmastermind.tags import TagLine
@@ -9,7 +10,18 @@ from chatmastermind.message import Message, Question, Answer, Tag, MessageFilter
from chatmastermind.chat import Chat, ChatDB, ChatError from chatmastermind.chat import Chat, ChatDB, ChatError
class TestChat(unittest.TestCase): class TestChatBase(unittest.TestCase):
def assert_messages_equal(self, msg1: list[Message], msg2: list[Message]) -> None:
"""
Compare messages using more than just Question and Answer.
"""
self.assertEqual(len(msg1), len(msg2))
for m1, m2 in zip(msg1, msg2):
# exclude the file_path, compare only Q, A and metadata
self.assertTrue(m1.equals(m2, file_path=False, verbose=True))
class TestChat(TestChatBase):
def setUp(self) -> None: def setUp(self) -> None:
self.chat = Chat([]) self.chat = Chat([])
self.message1 = Message(Question('Question 1'), self.message1 = Message(Question('Question 1'),
@@ -20,80 +32,103 @@ class TestChat(unittest.TestCase):
Answer('Answer 2'), Answer('Answer 2'),
{Tag('btag2')}, {Tag('btag2')},
file_path=pathlib.Path('0002.txt')) file_path=pathlib.Path('0002.txt'))
self.maxDiff = None
def test_unique_id(self) -> None:
# test with two identical messages
self.chat.msg_add([self.message1, self.message1])
self.assert_messages_equal(self.chat.messages, [self.message1, self.message1])
self.chat.msg_unique_id()
self.assert_messages_equal(self.chat.messages, [self.message1])
# test with two different messages
self.chat.msg_add([self.message2])
self.chat.msg_unique_id()
self.assert_messages_equal(self.chat.messages, [self.message1, self.message2])
def test_unique_content(self) -> None:
# test with two identical messages
self.chat.msg_add([self.message1, self.message1])
self.assert_messages_equal(self.chat.messages, [self.message1, self.message1])
self.chat.msg_unique_content()
self.assert_messages_equal(self.chat.messages, [self.message1])
# test with two different messages
self.chat.msg_add([self.message2])
self.chat.msg_unique_content()
self.assert_messages_equal(self.chat.messages, [self.message1, self.message2])
def test_filter(self) -> None: def test_filter(self) -> None:
self.chat.add_messages([self.message1, self.message2]) self.chat.msg_add([self.message1, self.message2])
self.chat.filter(MessageFilter(answer_contains='Answer 1')) self.chat.msg_filter(MessageFilter(answer_contains='Answer 1'))
self.assertEqual(len(self.chat.messages), 1) self.assertEqual(len(self.chat.messages), 1)
self.assertEqual(self.chat.messages[0].question, 'Question 1') self.assertEqual(self.chat.messages[0].question, 'Question 1')
def test_sort(self) -> None: def test_sort(self) -> None:
self.chat.add_messages([self.message2, self.message1]) self.chat.msg_add([self.message2, self.message1])
self.chat.sort() self.chat.msg_sort()
self.assertEqual(self.chat.messages[0].question, 'Question 1') self.assertEqual(self.chat.messages[0].question, 'Question 1')
self.assertEqual(self.chat.messages[1].question, 'Question 2') self.assertEqual(self.chat.messages[1].question, 'Question 2')
self.chat.sort(reverse=True) self.chat.msg_sort(reverse=True)
self.assertEqual(self.chat.messages[0].question, 'Question 2') self.assertEqual(self.chat.messages[0].question, 'Question 2')
self.assertEqual(self.chat.messages[1].question, 'Question 1') self.assertEqual(self.chat.messages[1].question, 'Question 1')
def test_clear(self) -> None: def test_clear(self) -> None:
self.chat.add_messages([self.message1]) self.chat.msg_add([self.message1])
self.chat.clear() self.chat.msg_clear()
self.assertEqual(len(self.chat.messages), 0) self.assertEqual(len(self.chat.messages), 0)
def test_add_messages(self) -> None: def test_add_messages(self) -> None:
self.chat.add_messages([self.message1, self.message2]) self.chat.msg_add([self.message1, self.message2])
self.assertEqual(len(self.chat.messages), 2) self.assertEqual(len(self.chat.messages), 2)
self.assertEqual(self.chat.messages[0].question, 'Question 1') self.assertEqual(self.chat.messages[0].question, 'Question 1')
self.assertEqual(self.chat.messages[1].question, 'Question 2') self.assertEqual(self.chat.messages[1].question, 'Question 2')
def test_tags(self) -> None: def test_tags(self) -> None:
self.chat.add_messages([self.message1, self.message2]) self.chat.msg_add([self.message1, self.message2])
tags_all = self.chat.tags() tags_all = self.chat.msg_tags()
self.assertSetEqual(tags_all, {Tag('atag1'), Tag('btag2')}) self.assertSetEqual(tags_all, {Tag('atag1'), Tag('btag2')})
tags_pref = self.chat.tags(prefix='a') tags_pref = self.chat.msg_tags(prefix='a')
self.assertSetEqual(tags_pref, {Tag('atag1')}) self.assertSetEqual(tags_pref, {Tag('atag1')})
tags_cont = self.chat.tags(contain='2') tags_cont = self.chat.msg_tags(contain='2')
self.assertSetEqual(tags_cont, {Tag('btag2')}) self.assertSetEqual(tags_cont, {Tag('btag2')})
def test_tags_frequency(self) -> None: def test_tags_frequency(self) -> None:
self.chat.add_messages([self.message1, self.message2]) self.chat.msg_add([self.message1, self.message2])
tags_freq = self.chat.tags_frequency() tags_freq = self.chat.msg_tags_frequency()
self.assertDictEqual(tags_freq, {'atag1': 1, 'btag2': 2}) self.assertDictEqual(tags_freq, {'atag1': 1, 'btag2': 2})
def test_find_remove_messages(self) -> None: def test_find_remove_messages(self) -> None:
self.chat.add_messages([self.message1, self.message2]) self.chat.msg_add([self.message1, self.message2])
msgs = self.chat.find_messages(['0001.txt']) msgs = self.chat.msg_find(['0001.txt'])
self.assertListEqual(msgs, [self.message1]) self.assertListEqual(msgs, [self.message1])
msgs = self.chat.find_messages(['0001.txt', '0002.txt']) msgs = self.chat.msg_find(['0001.txt', '0002.txt'])
self.assertListEqual(msgs, [self.message1, self.message2]) self.assertListEqual(msgs, [self.message1, self.message2])
# add new Message with full path # add new Message with full path
message3 = Message(Question('Question 2'), message3 = Message(Question('Question 2'),
Answer('Answer 2'), Answer('Answer 2'),
{Tag('btag2')}, {Tag('btag2')},
file_path=pathlib.Path('/foo/bla/0003.txt')) file_path=pathlib.Path('/foo/bla/0003.txt'))
self.chat.add_messages([message3]) self.chat.msg_add([message3])
# find new Message by full path # find new Message by full path
msgs = self.chat.find_messages(['/foo/bla/0003.txt']) msgs = self.chat.msg_find(['/foo/bla/0003.txt'])
self.assertListEqual(msgs, [message3]) self.assertListEqual(msgs, [message3])
# find Message with full path only by filename # find Message with full path only by filename
msgs = self.chat.find_messages(['0003.txt']) msgs = self.chat.msg_find(['0003.txt'])
self.assertListEqual(msgs, [message3]) self.assertListEqual(msgs, [message3])
# remove last message # remove last message
self.chat.remove_messages(['0003.txt']) self.chat.msg_remove(['0003.txt'])
self.assertListEqual(self.chat.messages, [self.message1, self.message2]) self.assertListEqual(self.chat.messages, [self.message1, self.message2])
def test_latest_message(self) -> None: def test_latest_message(self) -> None:
self.assertIsNone(self.chat.latest_message()) self.assertIsNone(self.chat.msg_latest())
self.chat.add_messages([self.message1]) self.chat.msg_add([self.message1])
self.assertEqual(self.chat.latest_message(), self.message1) self.assertEqual(self.chat.msg_latest(), self.message1)
self.chat.add_messages([self.message2]) self.chat.msg_add([self.message2])
self.assertEqual(self.chat.latest_message(), self.message2) self.assertEqual(self.chat.msg_latest(), self.message2)
@patch('sys.stdout', new_callable=StringIO) @patch('sys.stdout', new_callable=StringIO)
def test_print(self, mock_stdout: StringIO) -> None: def test_print(self, mock_stdout: StringIO) -> None:
self.chat.add_messages([self.message1, self.message2]) self.chat.msg_add([self.message1, self.message2])
self.chat.print(paged=False) self.chat.print(paged=False)
expected_output = f"""{Question.txt_header} expected_output = f"""{Question.txt_header}
Question 1 Question 1
@@ -108,7 +143,7 @@ Answer 2
@patch('sys.stdout', new_callable=StringIO) @patch('sys.stdout', new_callable=StringIO)
def test_print_with_tags_and_file(self, mock_stdout: StringIO) -> None: def test_print_with_tags_and_file(self, mock_stdout: StringIO) -> None:
self.chat.add_messages([self.message1, self.message2]) self.chat.msg_add([self.message1, self.message2])
self.chat.print(paged=False, with_tags=True, with_files=True) self.chat.print(paged=False, with_tags=True, with_files=True)
expected_output = f"""{TagLine.prefix} atag1 btag2 expected_output = f"""{TagLine.prefix} atag1 btag2
FILE: 0001.txt FILE: 0001.txt
@@ -126,7 +161,7 @@ Answer 2
self.assertEqual(mock_stdout.getvalue(), expected_output) self.assertEqual(mock_stdout.getvalue(), expected_output)
class TestChatDB(unittest.TestCase): class TestChatDB(TestChatBase):
def setUp(self) -> None: def setUp(self) -> None:
self.db_path = tempfile.TemporaryDirectory() self.db_path = tempfile.TemporaryDirectory()
self.cache_path = tempfile.TemporaryDirectory() self.cache_path = tempfile.TemporaryDirectory()
@@ -156,20 +191,41 @@ class TestChatDB(unittest.TestCase):
next_fname = pathlib.Path(self.db_path.name) / '.next' next_fname = pathlib.Path(self.db_path.name) / '.next'
with open(next_fname, 'w') as f: with open(next_fname, 'w') as f:
f.write('4') f.write('4')
# add some "trash" in order to test if it's correctly handled / ignored
self.trash_files = ['.config.yaml', 'foo.yaml', 'bla.txt']
for file in self.trash_files:
with open(pathlib.Path(self.db_path.name) / file, 'w') as f:
f.write('test trash')
# also create a file with actual yaml content
with open(pathlib.Path(self.db_path.name) / 'content.yaml', 'w') as f:
yaml.dump({'key': 'value'}, f)
self.trash_files.append('content.yaml')
self.maxDiff = None
def message_list(self, tmp_dir: tempfile.TemporaryDirectory) -> list[pathlib.Path]: def message_list(self, tmp_dir: tempfile.TemporaryDirectory) -> list[pathlib.Path]:
""" """
List all Message files in the given TemporaryDirectory. List all Message files in the given TemporaryDirectory.
""" """
# exclude '.next' # exclude '.next'
return list(pathlib.Path(tmp_dir.name).glob('*.[ty]*')) return [f for f in pathlib.Path(tmp_dir.name).glob('*.[ty]*') if f.name not in self.trash_files]
def tearDown(self) -> None: def tearDown(self) -> None:
self.db_path.cleanup() self.db_path.cleanup()
self.cache_path.cleanup() self.cache_path.cleanup()
pass pass
def test_chat_db_from_dir(self) -> None: def test_validate(self) -> None:
duplicate_message = Message(Question('Question 4'),
Answer('Answer 4'),
{Tag('tag4')},
file_path=pathlib.Path('0004.txt'))
duplicate_message.to_file(pathlib.Path(self.db_path.name, '0004.txt'))
with self.assertRaises(ChatError) as cm:
ChatDB.from_dir(pathlib.Path(self.cache_path.name),
pathlib.Path(self.db_path.name))
self.assertEqual(str(cm.exception), "Validation failed")
def test_from_dir(self) -> None:
chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name), chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name),
pathlib.Path(self.db_path.name)) pathlib.Path(self.db_path.name))
self.assertEqual(len(chat_db.messages), 4) self.assertEqual(len(chat_db.messages), 4)
@@ -185,7 +241,7 @@ class TestChatDB(unittest.TestCase):
self.assertEqual(chat_db.messages[3].file_path, self.assertEqual(chat_db.messages[3].file_path,
pathlib.Path(self.db_path.name, '0004.yaml')) pathlib.Path(self.db_path.name, '0004.yaml'))
def test_chat_db_from_dir_glob(self) -> None: def test_from_dir_glob(self) -> None:
chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name), chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name),
pathlib.Path(self.db_path.name), pathlib.Path(self.db_path.name),
glob='*.txt') glob='*.txt')
@@ -197,7 +253,7 @@ class TestChatDB(unittest.TestCase):
self.assertEqual(chat_db.messages[1].file_path, self.assertEqual(chat_db.messages[1].file_path,
pathlib.Path(self.db_path.name, '0003.txt')) pathlib.Path(self.db_path.name, '0003.txt'))
def test_chat_db_from_dir_filter_tags(self) -> None: def test_from_dir_filter_tags(self) -> None:
chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name), chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name),
pathlib.Path(self.db_path.name), pathlib.Path(self.db_path.name),
mfilter=MessageFilter(tags_or={Tag('tag1')})) mfilter=MessageFilter(tags_or={Tag('tag1')}))
@@ -207,7 +263,7 @@ class TestChatDB(unittest.TestCase):
self.assertEqual(chat_db.messages[0].file_path, self.assertEqual(chat_db.messages[0].file_path,
pathlib.Path(self.db_path.name, '0001.txt')) pathlib.Path(self.db_path.name, '0001.txt'))
def test_chat_db_from_dir_filter_tags_empty(self) -> None: def test_from_dir_filter_tags_empty(self) -> None:
chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name), chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name),
pathlib.Path(self.db_path.name), pathlib.Path(self.db_path.name),
mfilter=MessageFilter(tags_or=set(), mfilter=MessageFilter(tags_or=set(),
@@ -215,7 +271,7 @@ class TestChatDB(unittest.TestCase):
tags_not=set())) tags_not=set()))
self.assertEqual(len(chat_db.messages), 0) self.assertEqual(len(chat_db.messages), 0)
def test_chat_db_from_dir_filter_answer(self) -> None: def test_from_dir_filter_answer(self) -> None:
chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name), chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name),
pathlib.Path(self.db_path.name), pathlib.Path(self.db_path.name),
mfilter=MessageFilter(answer_contains='Answer 2')) mfilter=MessageFilter(answer_contains='Answer 2'))
@@ -226,7 +282,7 @@ class TestChatDB(unittest.TestCase):
pathlib.Path(self.db_path.name, '0002.yaml')) pathlib.Path(self.db_path.name, '0002.yaml'))
self.assertEqual(chat_db.messages[0].answer, 'Answer 2') self.assertEqual(chat_db.messages[0].answer, 'Answer 2')
def test_chat_db_from_messages(self) -> None: def test_from_messages(self) -> None:
chat_db = ChatDB.from_messages(pathlib.Path(self.cache_path.name), chat_db = ChatDB.from_messages(pathlib.Path(self.cache_path.name),
pathlib.Path(self.db_path.name), pathlib.Path(self.db_path.name),
messages=[self.message1, self.message2, messages=[self.message1, self.message2,
@@ -235,7 +291,7 @@ class TestChatDB(unittest.TestCase):
self.assertEqual(chat_db.cache_path, pathlib.Path(self.cache_path.name)) self.assertEqual(chat_db.cache_path, pathlib.Path(self.cache_path.name))
self.assertEqual(chat_db.db_path, pathlib.Path(self.db_path.name)) self.assertEqual(chat_db.db_path, pathlib.Path(self.db_path.name))
def test_chat_db_fids(self) -> None: def test_fids(self) -> None:
chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name), chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name),
pathlib.Path(self.db_path.name)) pathlib.Path(self.db_path.name))
self.assertEqual(chat_db.get_next_fid(), 5) self.assertEqual(chat_db.get_next_fid(), 5)
@@ -244,7 +300,26 @@ class TestChatDB(unittest.TestCase):
with open(chat_db.next_path, 'r') as f: with open(chat_db.next_path, 'r') as f:
self.assertEqual(f.read(), '7') self.assertEqual(f.read(), '7')
def test_chat_db_write(self) -> None: def test_msg_in_db_or_cache(self) -> None:
# create a new ChatDB instance
chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name),
pathlib.Path(self.db_path.name))
self.assertTrue(chat_db.msg_in_db(self.message1))
self.assertTrue(chat_db.msg_in_db(str(self.message1.file_path)))
self.assertTrue(chat_db.msg_in_db(self.message1.msg_id()))
self.assertFalse(chat_db.msg_in_cache(self.message1))
self.assertFalse(chat_db.msg_in_cache(str(self.message1.file_path)))
self.assertFalse(chat_db.msg_in_cache(self.message1.msg_id()))
# add new message to the cache dir
cache_message = Message(question=Question("Question 1"),
answer=Answer("Answer 1"))
chat_db.cache_add([cache_message])
self.assertTrue(chat_db.msg_in_cache(cache_message))
self.assertTrue(chat_db.msg_in_cache(cache_message.msg_id()))
self.assertFalse(chat_db.msg_in_db(cache_message))
self.assertFalse(chat_db.msg_in_db(str(cache_message.file_path)))
def test_db_write(self) -> None:
# create a new ChatDB instance # create a new ChatDB instance
chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name), chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name),
pathlib.Path(self.db_path.name)) pathlib.Path(self.db_path.name))
@@ -255,7 +330,7 @@ class TestChatDB(unittest.TestCase):
self.assertEqual(chat_db.messages[3].file_path, pathlib.Path(self.db_path.name, '0004.yaml')) self.assertEqual(chat_db.messages[3].file_path, pathlib.Path(self.db_path.name, '0004.yaml'))
# write the messages to the cache directory # write the messages to the cache directory
chat_db.write_cache() chat_db.cache_write()
# check if the written files are in the cache directory # check if the written files are in the cache directory
cache_dir_files = self.message_list(self.cache_path) cache_dir_files = self.message_list(self.cache_path)
self.assertEqual(len(cache_dir_files), 4) self.assertEqual(len(cache_dir_files), 4)
@@ -275,7 +350,7 @@ class TestChatDB(unittest.TestCase):
old_timestamps = {file: file.stat().st_mtime for file in db_dir_files} old_timestamps = {file: file.stat().st_mtime for file in db_dir_files}
# overwrite the messages in the db directory # overwrite the messages in the db directory
time.sleep(0.05) time.sleep(0.05)
chat_db.write_db() chat_db.db_write()
# check if the written files are in the DB directory # check if the written files are in the DB directory
db_dir_files = self.message_list(self.db_path) db_dir_files = self.message_list(self.db_path)
self.assertEqual(len(db_dir_files), 4) self.assertEqual(len(db_dir_files), 4)
@@ -292,7 +367,7 @@ class TestChatDB(unittest.TestCase):
self.assertEqual(chat_db.messages[2].file_path, pathlib.Path(self.db_path.name, '0003.txt')) self.assertEqual(chat_db.messages[2].file_path, pathlib.Path(self.db_path.name, '0003.txt'))
self.assertEqual(chat_db.messages[3].file_path, pathlib.Path(self.db_path.name, '0004.yaml')) self.assertEqual(chat_db.messages[3].file_path, pathlib.Path(self.db_path.name, '0004.yaml'))
def test_chat_db_read(self) -> None: def test_db_read(self) -> None:
# create a new ChatDB instance # create a new ChatDB instance
chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name), chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name),
pathlib.Path(self.db_path.name)) pathlib.Path(self.db_path.name))
@@ -308,7 +383,7 @@ class TestChatDB(unittest.TestCase):
new_message1.to_file(pathlib.Path(self.db_path.name, '0005.txt')) new_message1.to_file(pathlib.Path(self.db_path.name, '0005.txt'))
new_message2.to_file(pathlib.Path(self.db_path.name, '0006.yaml')) new_message2.to_file(pathlib.Path(self.db_path.name, '0006.yaml'))
# read and check them # read and check them
chat_db.read_db() chat_db.db_read()
self.assertEqual(len(chat_db.messages), 6) self.assertEqual(len(chat_db.messages), 6)
self.assertEqual(chat_db.messages[4].file_path, pathlib.Path(self.db_path.name, '0005.txt')) self.assertEqual(chat_db.messages[4].file_path, pathlib.Path(self.db_path.name, '0005.txt'))
self.assertEqual(chat_db.messages[5].file_path, pathlib.Path(self.db_path.name, '0006.yaml')) self.assertEqual(chat_db.messages[5].file_path, pathlib.Path(self.db_path.name, '0006.yaml'))
@@ -323,7 +398,7 @@ class TestChatDB(unittest.TestCase):
new_message3.to_file(pathlib.Path(self.cache_path.name, '0007.txt')) new_message3.to_file(pathlib.Path(self.cache_path.name, '0007.txt'))
new_message4.to_file(pathlib.Path(self.cache_path.name, '0008.yaml')) new_message4.to_file(pathlib.Path(self.cache_path.name, '0008.yaml'))
# read and check them # read and check them
chat_db.read_cache() chat_db.cache_read()
self.assertEqual(len(chat_db.messages), 8) self.assertEqual(len(chat_db.messages), 8)
# check that the new message have the cache dir path # check that the new message have the cache dir path
self.assertEqual(chat_db.messages[6].file_path, pathlib.Path(self.cache_path.name, '0007.txt')) self.assertEqual(chat_db.messages[6].file_path, pathlib.Path(self.cache_path.name, '0007.txt'))
@@ -338,7 +413,7 @@ class TestChatDB(unittest.TestCase):
new_message1.to_file(pathlib.Path(self.db_path.name, '0005.txt')) new_message1.to_file(pathlib.Path(self.db_path.name, '0005.txt'))
new_message2.to_file(pathlib.Path(self.db_path.name, '0006.yaml')) new_message2.to_file(pathlib.Path(self.db_path.name, '0006.yaml'))
# read from the DB dir and check if the modified messages have been updated # read from the DB dir and check if the modified messages have been updated
chat_db.read_db() chat_db.db_read()
self.assertEqual(len(chat_db.messages), 8) self.assertEqual(len(chat_db.messages), 8)
self.assertEqual(chat_db.messages[4].question, 'New Question 1') self.assertEqual(chat_db.messages[4].question, 'New Question 1')
self.assertEqual(chat_db.messages[5].question, 'New Question 2') self.assertEqual(chat_db.messages[5].question, 'New Question 2')
@@ -349,13 +424,13 @@ class TestChatDB(unittest.TestCase):
new_message3.to_file(pathlib.Path(self.db_path.name, '0007.txt')) new_message3.to_file(pathlib.Path(self.db_path.name, '0007.txt'))
new_message4.to_file(pathlib.Path(self.db_path.name, '0008.yaml')) new_message4.to_file(pathlib.Path(self.db_path.name, '0008.yaml'))
# read and check them # read and check them
chat_db.read_db() chat_db.db_read()
self.assertEqual(len(chat_db.messages), 8) self.assertEqual(len(chat_db.messages), 8)
# check that they now have the DB path # check that they now have the DB path
self.assertEqual(chat_db.messages[6].file_path, pathlib.Path(self.db_path.name, '0007.txt')) self.assertEqual(chat_db.messages[6].file_path, pathlib.Path(self.db_path.name, '0007.txt'))
self.assertEqual(chat_db.messages[7].file_path, pathlib.Path(self.db_path.name, '0008.yaml')) self.assertEqual(chat_db.messages[7].file_path, pathlib.Path(self.db_path.name, '0008.yaml'))
def test_chat_db_clear(self) -> None: def test_cache_clear(self) -> None:
# create a new ChatDB instance # create a new ChatDB instance
chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name), chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name),
pathlib.Path(self.db_path.name)) pathlib.Path(self.db_path.name))
@@ -366,13 +441,13 @@ class TestChatDB(unittest.TestCase):
self.assertEqual(chat_db.messages[3].file_path, pathlib.Path(self.db_path.name, '0004.yaml')) self.assertEqual(chat_db.messages[3].file_path, pathlib.Path(self.db_path.name, '0004.yaml'))
# write the messages to the cache directory # write the messages to the cache directory
chat_db.write_cache() chat_db.cache_write()
# check if the written files are in the cache directory # check if the written files are in the cache directory
cache_dir_files = self.message_list(self.cache_path) cache_dir_files = self.message_list(self.cache_path)
self.assertEqual(len(cache_dir_files), 4) self.assertEqual(len(cache_dir_files), 4)
# now rewrite them to the DB dir and check for modified paths # now rewrite them to the DB dir and check for modified paths
chat_db.write_db() chat_db.db_write()
db_dir_files = self.message_list(self.db_path) db_dir_files = self.message_list(self.db_path)
self.assertEqual(len(db_dir_files), 4) self.assertEqual(len(db_dir_files), 4)
self.assertIn(pathlib.Path(self.db_path.name, '0001.txt'), db_dir_files) self.assertIn(pathlib.Path(self.db_path.name, '0001.txt'), db_dir_files)
@@ -387,10 +462,10 @@ class TestChatDB(unittest.TestCase):
message_cache = Message(question=Question("What the hell am I doing here?"), message_cache = Message(question=Question("What the hell am I doing here?"),
answer=Answer("You're a creep!"), answer=Answer("You're a creep!"),
file_path=pathlib.Path(self.cache_path.name, '0005.txt')) file_path=pathlib.Path(self.cache_path.name, '0005.txt'))
chat_db.add_messages([message_empty, message_cache]) chat_db.msg_add([message_empty, message_cache])
# clear the cache and check the cache dir # clear the cache and check the cache dir
chat_db.clear_cache() chat_db.cache_clear()
cache_dir_files = self.message_list(self.cache_path) cache_dir_files = self.message_list(self.cache_path)
self.assertEqual(len(cache_dir_files), 0) self.assertEqual(len(cache_dir_files), 0)
# make sure that the DB messages (and the new message) are still there # make sure that the DB messages (and the new message) are still there
@@ -400,7 +475,7 @@ class TestChatDB(unittest.TestCase):
# but not the message with the cache dir path # but not the message with the cache dir path
self.assertFalse(any(m.file_path == message_cache.file_path for m in chat_db.messages)) self.assertFalse(any(m.file_path == message_cache.file_path for m in chat_db.messages))
def test_chat_db_add(self) -> None: def test_add(self) -> None:
# create a new ChatDB instance # create a new ChatDB instance
chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name), chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name),
pathlib.Path(self.db_path.name)) pathlib.Path(self.db_path.name))
@@ -411,7 +486,7 @@ class TestChatDB(unittest.TestCase):
# add new messages to the cache dir # add new messages to the cache dir
message1 = Message(question=Question("Question 1"), message1 = Message(question=Question("Question 1"),
answer=Answer("Answer 1")) answer=Answer("Answer 1"))
chat_db.add_to_cache([message1]) chat_db.cache_add([message1])
# check if the file_path has been correctly set # check if the file_path has been correctly set
self.assertIsNotNone(message1.file_path) self.assertIsNotNone(message1.file_path)
self.assertEqual(message1.file_path.parent, pathlib.Path(self.cache_path.name)) # type: ignore [union-attr] self.assertEqual(message1.file_path.parent, pathlib.Path(self.cache_path.name)) # type: ignore [union-attr]
@@ -421,7 +496,7 @@ class TestChatDB(unittest.TestCase):
# add new messages to the DB dir # add new messages to the DB dir
message2 = Message(question=Question("Question 2"), message2 = Message(question=Question("Question 2"),
answer=Answer("Answer 2")) answer=Answer("Answer 2"))
chat_db.add_to_db([message2]) chat_db.db_add([message2])
# check if the file_path has been correctly set # check if the file_path has been correctly set
self.assertIsNotNone(message2.file_path) self.assertIsNotNone(message2.file_path)
self.assertEqual(message2.file_path.parent, pathlib.Path(self.db_path.name)) # type: ignore [union-attr] self.assertEqual(message2.file_path.parent, pathlib.Path(self.db_path.name)) # type: ignore [union-attr]
@@ -429,9 +504,9 @@ class TestChatDB(unittest.TestCase):
self.assertEqual(len(db_dir_files), 5) self.assertEqual(len(db_dir_files), 5)
with self.assertRaises(ChatError): with self.assertRaises(ChatError):
chat_db.add_to_cache([Message(Question("?"), file_path=pathlib.Path("foo"))]) chat_db.cache_add([Message(Question("?"), file_path=pathlib.Path("foo"))])
def test_chat_db_write_messages(self) -> None: def test_msg_write(self) -> None:
# create a new ChatDB instance # create a new ChatDB instance
chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name), chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name),
pathlib.Path(self.db_path.name)) pathlib.Path(self.db_path.name))
@@ -445,16 +520,16 @@ class TestChatDB(unittest.TestCase):
message = Message(question=Question("Question 1"), message = Message(question=Question("Question 1"),
answer=Answer("Answer 1")) answer=Answer("Answer 1"))
with self.assertRaises(ChatError): with self.assertRaises(ChatError):
chat_db.write_messages([message]) chat_db.msg_write([message])
# write a message with a valid file_path # write a message with a valid file_path
message.file_path = pathlib.Path(self.cache_path.name) / '123456.txt' message.file_path = pathlib.Path(self.cache_path.name) / '123456.txt'
chat_db.write_messages([message]) chat_db.msg_write([message])
cache_dir_files = self.message_list(self.cache_path) cache_dir_files = self.message_list(self.cache_path)
self.assertEqual(len(cache_dir_files), 1) self.assertEqual(len(cache_dir_files), 1)
self.assertIn(pathlib.Path(self.cache_path.name, '123456.txt'), cache_dir_files) self.assertIn(pathlib.Path(self.cache_path.name, '123456.txt'), cache_dir_files)
def test_chat_db_update_messages(self) -> None: def test_msg_update(self) -> None:
# create a new ChatDB instance # create a new ChatDB instance
chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name), chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name),
pathlib.Path(self.db_path.name)) pathlib.Path(self.db_path.name))
@@ -467,37 +542,111 @@ class TestChatDB(unittest.TestCase):
message = chat_db.messages[0] message = chat_db.messages[0]
message.answer = Answer("New answer") message.answer = Answer("New answer")
# update message without writing # update message without writing
chat_db.update_messages([message], write=False) chat_db.msg_update([message], write=False)
self.assertEqual(chat_db.messages[0].answer, Answer("New answer")) self.assertEqual(chat_db.messages[0].answer, Answer("New answer"))
# re-read the message and check for old content # re-read the message and check for old content
chat_db.read_db() chat_db.db_read()
self.assertEqual(chat_db.messages[0].answer, Answer("Answer 1")) self.assertEqual(chat_db.messages[0].answer, Answer("Answer 1"))
# now check with writing (message should be overwritten) # now check with writing (message should be overwritten)
chat_db.update_messages([message], write=True) chat_db.msg_update([message], write=True)
chat_db.read_db() chat_db.db_read()
self.assertEqual(chat_db.messages[0].answer, Answer("New answer")) self.assertEqual(chat_db.messages[0].answer, Answer("New answer"))
# test without file_path -> expect error # test without file_path -> expect error
message1 = Message(question=Question("Question 1"), message1 = Message(question=Question("Question 1"),
answer=Answer("Answer 1")) answer=Answer("Answer 1"))
with self.assertRaises(ChatError): with self.assertRaises(ChatError):
chat_db.update_messages([message1]) chat_db.msg_update([message1])
def test_chat_db_latest_message(self) -> None: def test_msg_find(self) -> None:
chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name), chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name),
pathlib.Path(self.db_path.name)) pathlib.Path(self.db_path.name))
self.assertEqual(chat_db.latest_message(source='mem'), self.message4) # search for a DB file in memory
self.assertEqual(chat_db.latest_message(source='db'), self.message4) self.assertEqual(chat_db.msg_find([str(self.message1.file_path)], loc='mem'), [self.message1])
self.assertEqual(chat_db.latest_message(source='disk'), self.message4) self.assertEqual(chat_db.msg_find([self.message1.msg_id()], loc='mem'), [self.message1])
self.assertEqual(chat_db.latest_message(source='all'), self.message4) self.assertEqual(chat_db.msg_find(['0001.txt'], loc='mem'), [self.message1])
self.assertEqual(chat_db.msg_find(['0001'], loc='mem'), [self.message1])
# and on disk
self.assertEqual(chat_db.msg_find([str(self.message2.file_path)], loc='db'), [self.message2])
self.assertEqual(chat_db.msg_find([self.message2.msg_id()], loc='db'), [self.message2])
self.assertEqual(chat_db.msg_find(['0002.yaml'], loc='db'), [self.message2])
self.assertEqual(chat_db.msg_find(['0002'], loc='db'), [self.message2])
# now search the cache -> expect empty result
self.assertEqual(chat_db.msg_find([str(self.message3.file_path)], loc='cache'), [])
self.assertEqual(chat_db.msg_find([self.message3.msg_id()], loc='cache'), [])
self.assertEqual(chat_db.msg_find(['0003.txt'], loc='cache'), [])
self.assertEqual(chat_db.msg_find(['0003'], loc='cache'), [])
# search for multiple messages
# -> search one twice, expect result to be unique
search_names = ['0001', '0002.yaml', self.message3.msg_id(), str(self.message3.file_path)]
expected_result = [self.message1, self.message2, self.message3]
result = chat_db.msg_find(search_names, loc='all')
self.assert_messages_equal(result, expected_result)
def test_msg_latest(self) -> None:
chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name),
pathlib.Path(self.db_path.name))
self.assertEqual(chat_db.msg_latest(loc='mem'), self.message4)
self.assertEqual(chat_db.msg_latest(loc='db'), self.message4)
self.assertEqual(chat_db.msg_latest(loc='disk'), self.message4)
self.assertEqual(chat_db.msg_latest(loc='all'), self.message4)
# the cache is currently empty: # the cache is currently empty:
self.assertIsNone(chat_db.latest_message(source='cache')) self.assertIsNone(chat_db.msg_latest(loc='cache'))
# add new messages to the cache dir # add new messages to the cache dir
new_message = Message(question=Question("New Question"), new_message = Message(question=Question("New Question"),
answer=Answer("New Answer")) answer=Answer("New Answer"))
chat_db.add_to_cache([new_message]) chat_db.cache_add([new_message])
self.assertEqual(chat_db.latest_message(source='cache'), new_message) self.assertEqual(chat_db.msg_latest(loc='cache'), new_message)
self.assertEqual(chat_db.latest_message(source='mem'), new_message) self.assertEqual(chat_db.msg_latest(loc='mem'), new_message)
self.assertEqual(chat_db.latest_message(source='disk'), new_message) self.assertEqual(chat_db.msg_latest(loc='disk'), new_message)
self.assertEqual(chat_db.latest_message(source='all'), new_message) self.assertEqual(chat_db.msg_latest(loc='all'), new_message)
# the DB does not contain the new message # the DB does not contain the new message
self.assertEqual(chat_db.latest_message(source='db'), self.message4) self.assertEqual(chat_db.msg_latest(loc='db'), self.message4)
def test_msg_gather(self) -> None:
chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name),
pathlib.Path(self.db_path.name))
all_messages = [self.message1, self.message2, self.message3, self.message4]
self.assert_messages_equal(chat_db.msg_gather(loc='all'), all_messages)
self.assert_messages_equal(chat_db.msg_gather(loc='db'), all_messages)
self.assert_messages_equal(chat_db.msg_gather(loc='mem'), all_messages)
self.assert_messages_equal(chat_db.msg_gather(loc='disk'), all_messages)
self.assert_messages_equal(chat_db.msg_gather(loc='cache'), [])
# add a new message, but only to the internal list
new_message = Message(Question("What?"))
all_messages_mem = all_messages + [new_message]
chat_db.msg_add([new_message])
self.assert_messages_equal(chat_db.msg_gather(loc='mem'), all_messages_mem)
self.assert_messages_equal(chat_db.msg_gather(loc='all'), all_messages_mem)
# the nr. of messages on disk did not change -> expect old result
self.assert_messages_equal(chat_db.msg_gather(loc='db'), all_messages)
self.assert_messages_equal(chat_db.msg_gather(loc='disk'), all_messages)
self.assert_messages_equal(chat_db.msg_gather(loc='cache'), [])
# test with MessageFilter
self.assert_messages_equal(chat_db.msg_gather(loc='all', mfilter=MessageFilter(tags_or={Tag('tag1')})),
[self.message1])
self.assert_messages_equal(chat_db.msg_gather(loc='disk', mfilter=MessageFilter(tags_or={Tag('tag2')})),
[self.message2])
self.assert_messages_equal(chat_db.msg_gather(loc='cache', mfilter=MessageFilter(tags_or={Tag('tag3')})),
[])
self.assert_messages_equal(chat_db.msg_gather(loc='mem', mfilter=MessageFilter(question_contains="What")),
[new_message])
def test_msg_move_and_gather(self) -> None:
chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name),
pathlib.Path(self.db_path.name))
all_messages = [self.message1, self.message2, self.message3, self.message4]
self.assert_messages_equal(chat_db.msg_gather(loc='db'), all_messages)
self.assert_messages_equal(chat_db.msg_gather(loc='cache'), [])
# move first message to the cache
chat_db.cache_move(self.message1)
self.assert_messages_equal(chat_db.msg_gather(loc='cache'), [self.message1])
self.assertEqual(self.message1.file_path.parent, pathlib.Path(self.cache_path.name)) # type: ignore [union-attr]
self.assert_messages_equal(chat_db.msg_gather(loc='db'), [self.message2, self.message3, self.message4])
self.assert_messages_equal(chat_db.msg_gather(loc='all'), all_messages)
self.assert_messages_equal(chat_db.msg_gather(loc='disk'), all_messages)
self.assert_messages_equal(chat_db.msg_gather(loc='mem'), all_messages)
# now move first message back to the DB
chat_db.db_move(self.message1)
self.assert_messages_equal(chat_db.msg_gather(loc='cache'), [])
self.assertEqual(self.message1.file_path.parent, pathlib.Path(self.db_path.name)) # type: ignore [union-attr]
self.assert_messages_equal(chat_db.msg_gather(loc='db'), all_messages)
+7
View File
@@ -57,6 +57,7 @@ class TestConfig(unittest.TestCase):
def test_from_dict_should_create_config_from_dict(self) -> None: def test_from_dict_should_create_config_from_dict(self) -> None:
source_dict = { source_dict = {
'cache': '.',
'db': './test_db/', 'db': './test_db/',
'ais': { 'ais': {
'myopenai': { 'myopenai': {
@@ -73,6 +74,7 @@ class TestConfig(unittest.TestCase):
} }
} }
config = Config.from_dict(source_dict) config = Config.from_dict(source_dict)
self.assertEqual(config.cache, '.')
self.assertEqual(config.db, './test_db/') self.assertEqual(config.db, './test_db/')
self.assertEqual(len(config.ais), 1) self.assertEqual(len(config.ais), 1)
self.assertEqual(config.ais['myopenai'].name, 'openai') self.assertEqual(config.ais['myopenai'].name, 'openai')
@@ -89,6 +91,7 @@ class TestConfig(unittest.TestCase):
def test_from_file_should_load_config_from_file(self) -> None: def test_from_file_should_load_config_from_file(self) -> None:
source_dict = { source_dict = {
'cache': './test_cache/',
'db': './test_db/', 'db': './test_db/',
'ais': { 'ais': {
'default': { 'default': {
@@ -108,6 +111,7 @@ class TestConfig(unittest.TestCase):
yaml.dump(source_dict, f) yaml.dump(source_dict, f)
config = Config.from_file(self.test_file.name) config = Config.from_file(self.test_file.name)
self.assertIsInstance(config, Config) self.assertIsInstance(config, Config)
self.assertEqual(config.cache, './test_cache/')
self.assertEqual(config.db, './test_db/') self.assertEqual(config.db, './test_db/')
self.assertEqual(len(config.ais), 1) self.assertEqual(len(config.ais), 1)
self.assertIsInstance(config.ais['default'], AIConfig) self.assertIsInstance(config.ais['default'], AIConfig)
@@ -115,6 +119,7 @@ class TestConfig(unittest.TestCase):
def test_to_file_should_save_config_to_file(self) -> None: def test_to_file_should_save_config_to_file(self) -> None:
config = Config( config = Config(
cache='./test_cache/',
db='./test_db/', db='./test_db/',
ais={ ais={
'myopenai': OpenAIConfig( 'myopenai': OpenAIConfig(
@@ -133,12 +138,14 @@ class TestConfig(unittest.TestCase):
config.to_file(Path(self.test_file.name)) config.to_file(Path(self.test_file.name))
with open(self.test_file.name, 'r') as f: with open(self.test_file.name, 'r') as f:
saved_config = yaml.load(f, Loader=yaml.FullLoader) saved_config = yaml.load(f, Loader=yaml.FullLoader)
self.assertEqual(saved_config['cache'], './test_cache/')
self.assertEqual(saved_config['db'], './test_db/') self.assertEqual(saved_config['db'], './test_db/')
self.assertEqual(len(saved_config['ais']), 1) self.assertEqual(len(saved_config['ais']), 1)
self.assertEqual(saved_config['ais']['myopenai']['system'], 'Custom system') self.assertEqual(saved_config['ais']['myopenai']['system'], 'Custom system')
def test_from_file_error_unknown_ai(self) -> None: def test_from_file_error_unknown_ai(self) -> None:
source_dict = { source_dict = {
'cache': './test_cache/',
'db': './test_db/', 'db': './test_db/',
'ais': { 'ais': {
'default': { 'default': {
+1 -1
View File
@@ -730,7 +730,7 @@ class MessageIDTestCase(unittest.TestCase):
self.file_path.unlink() self.file_path.unlink()
def test_msg_id_txt(self) -> None: def test_msg_id_txt(self) -> None:
self.assertEqual(self.message.msg_id(), self.file_path.name) self.assertEqual(self.message.msg_id(), self.file_path.stem)
def test_msg_id_txt_exception(self) -> None: def test_msg_id_txt_exception(self) -> None:
with self.assertRaises(MessageError): with self.assertRaises(MessageError):
+329 -11
View File
@@ -3,29 +3,45 @@ import unittest
import argparse import argparse
import tempfile import tempfile
from pathlib import Path from pathlib import Path
from unittest.mock import MagicMock from unittest import mock
from chatmastermind.commands.question import create_message from unittest.mock import MagicMock, call, ANY
from typing import Optional
from chatmastermind.configuration import Config
from chatmastermind.commands.question import create_message, question_cmd
from chatmastermind.tags import Tag
from chatmastermind.message import Message, Question, Answer from chatmastermind.message import Message, Question, Answer
from chatmastermind.chat import ChatDB from chatmastermind.chat import Chat, ChatDB
from chatmastermind.ai import AI, AIResponse, Tokens, AIError
class TestMessageCreate(unittest.TestCase): class TestQuestionCmdBase(unittest.TestCase):
def assert_messages_equal(self, msg1: list[Message], msg2: list[Message]) -> None:
"""
Compare messages using more than just Question and Answer.
"""
self.assertEqual(len(msg1), len(msg2))
for m1, m2 in zip(msg1, msg2):
# exclude the file_path, compare only Q, A and metadata
self.assertTrue(m1.equals(m2, file_path=False, verbose=True))
class TestMessageCreate(TestQuestionCmdBase):
""" """
Test if messages created by the 'question' command have Test if messages created by the 'question' command have
the correct format. the correct format.
""" """
def setUp(self) -> None: def setUp(self) -> None:
# create ChatDB structure # create ChatDB structure
self.db_path = tempfile.TemporaryDirectory() self.db_dir = tempfile.TemporaryDirectory()
self.cache_path = tempfile.TemporaryDirectory() self.cache_dir = tempfile.TemporaryDirectory()
self.chat = ChatDB.from_dir(cache_path=Path(self.cache_path.name), self.chat = ChatDB.from_dir(cache_path=Path(self.cache_dir.name),
db_path=Path(self.db_path.name)) db_path=Path(self.db_dir.name))
# create some messages # create some messages
self.message_text = Message(Question("What is this?"), self.message_text = Message(Question("What is this?"),
Answer("It is pure text")) Answer("It is pure text"))
self.message_code = Message(Question("What is this?"), self.message_code = Message(Question("What is this?"),
Answer("Text\n```\nIt is embedded code\n```\ntext")) Answer("Text\n```\nIt is embedded code\n```\ntext"))
self.chat.add_to_db([self.message_text, self.message_code]) self.chat.db_add([self.message_text, self.message_code])
# create arguments mock # create arguments mock
self.args = MagicMock(spec=argparse.Namespace) self.args = MagicMock(spec=argparse.Namespace)
self.args.source_text = None self.args.source_text = None
@@ -74,6 +90,7 @@ Aaaand again some text."""
os.remove(self.source_file1.name) os.remove(self.source_file1.name)
os.remove(self.source_file2.name) os.remove(self.source_file2.name)
os.remove(self.source_file3.name) os.remove(self.source_file3.name)
os.remove(self.source_file4.name)
def message_list(self, tmp_dir: tempfile.TemporaryDirectory) -> list[Path]: def message_list(self, tmp_dir: tempfile.TemporaryDirectory) -> list[Path]:
# exclude '.next' # exclude '.next'
@@ -81,10 +98,10 @@ Aaaand again some text."""
def test_message_file_created(self) -> None: def test_message_file_created(self) -> None:
self.args.ask = ["What is this?"] self.args.ask = ["What is this?"]
cache_dir_files = self.message_list(self.cache_path) cache_dir_files = self.message_list(self.cache_dir)
self.assertEqual(len(cache_dir_files), 0) self.assertEqual(len(cache_dir_files), 0)
create_message(self.chat, self.args) create_message(self.chat, self.args)
cache_dir_files = self.message_list(self.cache_path) cache_dir_files = self.message_list(self.cache_dir)
self.assertEqual(len(cache_dir_files), 1) self.assertEqual(len(cache_dir_files), 1)
message = Message.from_file(cache_dir_files[0]) message = Message.from_file(cache_dir_files[0])
self.assertIsInstance(message, Message) self.assertIsInstance(message, Message)
@@ -193,3 +210,304 @@ This is embedded source code.
It is embedded code It is embedded code
``` ```
""")) """))
class TestQuestionCmd(TestQuestionCmdBase):
def setUp(self) -> None:
# create DB and cache
self.db_dir = tempfile.TemporaryDirectory()
self.cache_dir = tempfile.TemporaryDirectory()
# create configuration
self.config = Config()
self.config.cache = self.cache_dir.name
self.config.db = self.db_dir.name
# create a mock argparse.Namespace
self.args = argparse.Namespace(
ask=['What is the meaning of life?'],
num_answers=1,
output_tags=['science'],
AI='openai',
model='gpt-3.5-turbo',
or_tags=None,
and_tags=None,
exclude_tags=None,
source_text=None,
source_code=None,
create=None,
repeat=None,
process=None,
overwrite=None
)
# create a mock AI instance
self.ai = MagicMock(spec=AI)
self.ai.request.side_effect = self.mock_request
def input_message(self, args: argparse.Namespace) -> Message:
"""
Create the expected input message for a question using the
given arguments.
"""
# NOTE: we only use the first question from the "ask" list
# -> message creation using "question.create_message()" is
# tested above
# the answer is always empty for the input message
return Message(Question(args.ask[0]),
tags=args.output_tags,
ai=args.AI,
model=args.model)
def mock_request(self,
question: Message,
chat: Chat,
num_answers: int = 1,
otags: Optional[set[Tag]] = None) -> AIResponse:
"""
Mock the 'ai.request()' function
"""
question.answer = Answer("Answer 0")
question.tags = set(otags) if otags else None
question.ai = 'FakeAI'
question.model = 'FakeModel'
answers: list[Message] = [question]
for n in range(1, num_answers):
answers.append(Message(question=question.question,
answer=Answer(f"Answer {n}"),
tags=otags,
ai='FakeAI',
model='FakeModel'))
return AIResponse(answers, Tokens(10, 10, 20))
def message_list(self, tmp_dir: tempfile.TemporaryDirectory) -> list[Path]:
# exclude '.next'
return sorted([f for f in Path(tmp_dir.name).glob('*.[ty]*')])
@mock.patch('chatmastermind.commands.question.create_ai')
def test_ask_single_answer(self, mock_create_ai: MagicMock) -> None:
"""
Test single answer with no errors.
"""
mock_create_ai.return_value = self.ai
expected_question = self.input_message(self.args)
expected_responses = self.mock_request(expected_question,
Chat([]),
self.args.num_answers,
self.args.output_tags).messages
# execute the command
question_cmd(self.args, self.config)
# check for correct request call
self.ai.request.assert_called_once_with(expected_question,
ANY,
self.args.num_answers,
self.args.output_tags)
# check for the expected message files
chat = ChatDB.from_dir(Path(self.cache_dir.name),
Path(self.db_dir.name))
cached_msg = chat.msg_gather(loc='cache')
self.assertEqual(len(self.message_list(self.cache_dir)), 1)
self.assert_messages_equal(cached_msg, expected_responses)
@mock.patch('chatmastermind.commands.question.ChatDB.from_dir')
@mock.patch('chatmastermind.commands.question.create_ai')
def test_ask_single_answer_mocked(self, mock_create_ai: MagicMock, mock_from_dir: MagicMock) -> None:
"""
Test single answer with no errors (mocked ChatDB version).
"""
chat = MagicMock(spec=ChatDB)
mock_from_dir.return_value = chat
mock_create_ai.return_value = self.ai
expected_question = self.input_message(self.args)
expected_responses = self.mock_request(expected_question,
Chat([]),
self.args.num_answers,
self.args.output_tags).messages
# execute the command
question_cmd(self.args, self.config)
# check for correct request call
self.ai.request.assert_called_once_with(expected_question,
chat,
self.args.num_answers,
self.args.output_tags)
# check for the correct ChatDB calls:
# - initial question has been written (prior to the actual request)
# - responses have been written (after the request)
chat.cache_write.assert_has_calls([call([expected_question]),
call(expected_responses)],
any_order=False)
# check that the messages have not been added to the internal message list
chat.cache_add.assert_not_called()
@mock.patch('chatmastermind.commands.question.create_ai')
def test_ask_with_error(self, mock_create_ai: MagicMock) -> None:
"""
Provoke an error during the AI request and verify that the question
has been correctly stored in the cache.
"""
mock_create_ai.return_value = self.ai
expected_question = self.input_message(self.args)
self.ai.request.side_effect = AIError
# execute the command
with self.assertRaises(AIError):
question_cmd(self.args, self.config)
# check for correct request call
self.ai.request.assert_called_once_with(expected_question,
ANY,
self.args.num_answers,
self.args.output_tags)
# check for the expected message files
chat = ChatDB.from_dir(Path(self.cache_dir.name),
Path(self.db_dir.name))
cached_msg = chat.msg_gather(loc='cache')
self.assertEqual(len(self.message_list(self.cache_dir)), 1)
self.assert_messages_equal(cached_msg, [expected_question])
@mock.patch('chatmastermind.commands.question.create_ai')
def test_repeat_single_question(self, mock_create_ai: MagicMock) -> None:
"""
Repeat a single question.
"""
# 1. ask a question
mock_create_ai.return_value = self.ai
expected_question = self.input_message(self.args)
expected_responses = self.mock_request(expected_question,
Chat([]),
self.args.num_answers,
self.args.output_tags).messages
question_cmd(self.args, self.config)
chat = ChatDB.from_dir(Path(self.cache_dir.name),
Path(self.db_dir.name))
cached_msg = chat.msg_gather(loc='cache')
self.assertEqual(len(self.message_list(self.cache_dir)), 1)
self.assert_messages_equal(cached_msg, expected_responses)
# 2. repeat the last question (without overwriting)
# -> expect two identical messages (except for the file_path)
self.args.ask = None
self.args.repeat = []
self.args.overwrite = False
expected_responses += expected_responses
question_cmd(self.args, self.config)
cached_msg = chat.msg_gather(loc='cache')
self.assertEqual(len(self.message_list(self.cache_dir)), 2)
self.assert_messages_equal(cached_msg, expected_responses)
@mock.patch('chatmastermind.commands.question.create_ai')
def test_repeat_single_question_overwrite(self, mock_create_ai: MagicMock) -> None:
"""
Repeat a single question and overwrite the old one.
"""
# 1. ask a question
mock_create_ai.return_value = self.ai
expected_question = self.input_message(self.args)
expected_responses = self.mock_request(expected_question,
Chat([]),
self.args.num_answers,
self.args.output_tags).messages
question_cmd(self.args, self.config)
chat = ChatDB.from_dir(Path(self.cache_dir.name),
Path(self.db_dir.name))
cached_msg = chat.msg_gather(loc='cache')
assert cached_msg[0].file_path
cached_msg_file_id = cached_msg[0].file_path.stem
self.assertEqual(len(self.message_list(self.cache_dir)), 1)
self.assert_messages_equal(cached_msg, expected_responses)
# 2. repeat the last question (WITH overwriting)
# -> expect a single message afterwards
self.args.ask = None
self.args.repeat = []
self.args.overwrite = True
question_cmd(self.args, self.config)
cached_msg = chat.msg_gather(loc='cache')
self.assertEqual(len(self.message_list(self.cache_dir)), 1)
self.assert_messages_equal(cached_msg, expected_responses)
# also check that the file ID has not been changed
assert cached_msg[0].file_path
self.assertEqual(cached_msg_file_id, cached_msg[0].file_path.stem)
@mock.patch('chatmastermind.commands.question.create_ai')
def test_repeat_single_question_after_error(self, mock_create_ai: MagicMock) -> None:
"""
Repeat a single question after an error.
"""
# 1. ask a question and provoke an error
mock_create_ai.return_value = self.ai
expected_question = self.input_message(self.args)
self.ai.request.side_effect = AIError
with self.assertRaises(AIError):
question_cmd(self.args, self.config)
chat = ChatDB.from_dir(Path(self.cache_dir.name),
Path(self.db_dir.name))
cached_msg = chat.msg_gather(loc='cache')
assert cached_msg[0].file_path
cached_msg_file_id = cached_msg[0].file_path.stem
self.assertEqual(len(self.message_list(self.cache_dir)), 1)
self.assert_messages_equal(cached_msg, [expected_question])
# 2. repeat the last question (without overwriting)
# -> expect a single message because if the original has
# no answer, it should be overwritten by default
self.args.ask = None
self.args.repeat = []
self.args.overwrite = False
self.ai.request.side_effect = self.mock_request
expected_responses = self.mock_request(expected_question,
Chat([]),
self.args.num_answers,
self.args.output_tags).messages
question_cmd(self.args, self.config)
cached_msg = chat.msg_gather(loc='cache')
self.assertEqual(len(self.message_list(self.cache_dir)), 1)
self.assert_messages_equal(cached_msg, expected_responses)
# also check that the file ID has not been changed
assert cached_msg[0].file_path
self.assertEqual(cached_msg_file_id, cached_msg[0].file_path.stem)
@mock.patch('chatmastermind.commands.question.create_ai')
def test_repeat_single_question_new_args(self, mock_create_ai: MagicMock) -> None:
"""
Repeat a single question with new arguments.
"""
# 1. ask a question
mock_create_ai.return_value = self.ai
expected_question = self.input_message(self.args)
expected_responses = self.mock_request(expected_question,
Chat([]),
self.args.num_answers,
self.args.output_tags).messages
question_cmd(self.args, self.config)
chat = ChatDB.from_dir(Path(self.cache_dir.name),
Path(self.db_dir.name))
cached_msg = chat.msg_gather(loc='cache')
self.assertEqual(len(self.message_list(self.cache_dir)), 1)
self.assert_messages_equal(cached_msg, expected_responses)
# 2. repeat the last question with new arguments (without overwriting)
# -> expect two messages with identical question and answer, but different metadata
self.args.ask = None
self.args.repeat = []
self.args.overwrite = False
self.args.output_tags = ['newtag']
self.args.AI = 'newai'
self.args.model = 'newmodel'
new_expected_question = Message(question=Question(expected_question.question),
tags=set(self.args.output_tags),
ai=self.args.AI,
model=self.args.model)
expected_responses += self.mock_request(new_expected_question,
Chat([]),
self.args.num_answers,
set(self.args.output_tags)).messages
question_cmd(self.args, self.config)
cached_msg = chat.msg_gather(loc='cache')
self.assertEqual(len(self.message_list(self.cache_dir)), 2)
self.assert_messages_equal(cached_msg, expected_responses)