6 Commits

32 changed files with 1160 additions and 3024 deletions
-1
View File
@@ -106,7 +106,6 @@ celerybeat.pid
.venv .venv
env/ env/
venv/ venv/
.old/
ENV/ ENV/
env.bak/ env.bak/
venv.bak/ venv.bak/
+64 -105
View File
@@ -37,97 +37,63 @@ cmm [global options] command [command options]
### Global Options ### Global Options
- `-C`, `--config`: Config file name (defaults to `.config.yaml`). - `-c`, `--config`: Config file name (defaults to `.config.yaml`).
### Commands
- `ask`: Ask a question.
- `hist`: Print chat history.
- `tag`: Manage tags.
- `config`: Manage configuration.
- `print`: Print files.
### Command Options ### Command Options
#### Question #### `ask` Command Options
The `question` command is used to ask, create, and process questions. - `-q`, `--question`: Question to ask (required).
- `-m`, `--max-tokens`: Max tokens to use.
- `-T`, `--temperature`: Temperature to use.
- `-M`, `--model`: Model to use.
- `-n`, `--number`: Number of answers to produce (default is 3).
- `-s`, `--source`: Add content of a file to the query.
- `-S`, `--only-source-code`: Add pure source code to the chat history.
- `-t`, `--tags`: List of tag names.
- `-e`, `--extags`: List of tag names to exclude.
- `-o`, `--output-tags`: List of output tag names (default is the input tags).
- `-a`, `--match-all-tags`: All given tags must match when selecting chat history entries.
```bash #### `hist` Command Options
cmm question [-t OTAGS]... [-k ATAGS]... [-x XTAGS]... [-o OUTTAGS]... [-A AI_ID] [-M MODEL] [-n NUM] [-m MAX] [-T TEMP] (-a QUESTION | -c QUESTION | -r [MESSAGE ...] | -p [MESSAGE ...]) [-O] [-s FILE]... [-S FILE]...
```
* `-t, --or-tags OTAGS`: List of tags (one must match) - `-d`, `--dump`: Print chat history as Python structure.
* `-k, --and-tags ATAGS`: List of tags (all must match) - `-w`, `--with-tags`: Print chat history with tags.
* `-x, --exclude-tags XTAGS`: List of tags to exclude - `-W`, `--with-files`: Print chat history with filenames.
* `-o, --output-tags OUTTAGS`: List of output tags (default: use input tags) - `-S`, `--only-source-code`: Print only source code.
* `-A, --AI AI_ID`: AI ID to use - `-t`, `--tags`: List of tag names.
* `-M, --model MODEL`: Model to use - `-e`, `--extags`: List of tag names to exclude.
* `-n, --num-answers NUM`: Number of answers to request - `-a`, `--match-all-tags`: All given tags must match when selecting chat history entries.
* `-m, --max-tokens MAX`: Max. number of tokens
* `-T, --temperature TEMP`: Temperature value
* `-a, --ask QUESTION`: Ask a question
* `-c, --create QUESTION`: Create a question
* `-r, --repeat [MESSAGE ...]`: Repeat a question
* `-p, --process [MESSAGE ...]`: Process existing questions
* `-O, --overwrite`: Overwrite existing messages when repeating them
* `-s, --source-text FILE`: Add content of a file to the query
* `-S, --source-code FILE`: Add source code file content to the chat history
#### Hist #### `tag` Command Options
The `hist` command is used to print and manage the chat history. - `-l`, `--list`: List all tags and their frequency.
```bash #### `config` Command Options
cmm hist [--print | --convert FORMAT] [-t OTAGS]... [-k ATAGS]... [-x XTAGS]... [-w] [-W] [-S] [-A SUBSTRING] [-Q SUBSTRING]
```
* `-p, --print`: Print the DB chat history - `-l`, `--list-models`: List all available models.
* `-c, --convert FORMAT`: Convert all messages to the given format - `-m`, `--print-model`: Print the currently configured model.
* `-t, --or-tags OTAGS`: List of tags (one must match) - `-M`, `--model`: Set model in the config file.
* `-k, --and-tags ATAGS`: List of tags (all must match)
* `-x, --exclude-tags XTAGS`: List of tags to exclude
* `-w, --with-metadata`: Print chat history with metadata (tags, filenames, AI, etc.)
* `-S, --source-code-only`: Only print embedded source code
* `-A, --answer SUBSTRING`: Filter for answer substring
* `-Q, --question SUBSTRING`: Filter for question substring
#### Tags #### `print` Command Options
The `tags` command is used to manage tags. - `-f`, `--file`: File to print (required).
- `-S`, `--only-source-code`: Print only source code.
```bash
cmm tags (-l | -p PREFIX | -c SUBSTRING)
```
* `-l, --list`: List all tags and their frequency
* `-p, --prefix PREFIX`: Filter tags by prefix
* `-c, --contain SUBSTRING`: Filter tags by contained substring
#### Config
The `config` command is used to manage the configuration.
```bash
cmm config (-l | -m | -c FILE)
```
* `-l, --list-models`: List all available models
* `-m, --print-model`: Print the currently configured model
* `-c, --create FILE`: Create config with default settings in the given file
#### Print
The `print` command is used to print message files.
```bash
cmm print (-f FILE | -l) [-q | -a | -S]
```
* `-f, --file FILE`: Print given file
* `-l, --latest`: Print latest message
* `-q, --question`: Only print the question
* `-a, --answer`: Only print the answer
* `-S, --only-source-code`: Only print embedded source code
### Examples ### Examples
1. Ask a question: 1. Ask a question:
```bash ```bash
cmm question -a "What is the meaning of life?" -t philosophy -x religion cmm ask -q "What is the meaning of life?" -t philosophy -e religion
``` ```
2. Display the chat history: 2. Display the chat history:
@@ -139,19 +105,19 @@ cmm hist
3. Filter chat history by tags: 3. Filter chat history by tags:
```bash ```bash
cmm hist --or-tags tag1 tag2 cmm hist -t tag1 tag2
``` ```
4. Exclude chat history by tags: 4. Exclude chat history by tags:
```bash ```bash
cmm hist --exclude-tags tag3 tag4 cmm hist -e tag3 tag4
``` ```
5. List all tags and their frequency: 5. List all tags and their frequency:
```bash ```bash
cmm tags -l cmm tag -l
``` ```
6. Print the contents of a file: 6. Print the contents of a file:
@@ -162,27 +128,18 @@ cmm print -f example.yaml
## Configuration ## Configuration
The default configuration filename is `.config.yaml` (it is searched in the current working directory). The configuration file (`.config.yaml`) should contain the following fields:
Use the command `cmm config --create <FILENAME>` to create a default configuration:
``` - `openai`:
cache: . - `api_key`: Your OpenAI API key.
db: ./db/ - `model`: The name of the OpenAI model to use (e.g. "text-davinci-002").
ais: - `temperature`: The temperature value for the model.
myopenai: - `max_tokens`: The maximum number of tokens for the model.
name: openai - `top_p`: The top P value for the model.
model: gpt-3.5-turbo-16k - `frequency_penalty`: The frequency penalty value.
api_key: 0123456789 - `presence_penalty`: The presence penalty value.
temperature: 1.0 - `system`: The system message used to set the behavior of the AI.
max_tokens: 4000 - `db`: The directory where the question-answer pairs are stored in YAML files.
top_p: 1.0
frequency_penalty: 0.0
presence_penalty: 0.0
system: You are an assistant
```
Each AI has its own section and the name of that section is called the 'AI ID' (in the example above it is `myopenai`).
The AI ID can be any string, as long as it's unique within the `ais` section. The AI ID is used for all commands that support the `AI` parameter and it's also stored within each message file.
## Autocompletion ## Autocompletion
@@ -197,33 +154,33 @@ After adding this line, restart your shell or run `source <your-shell-config-fil
## Contributing ## Contributing
### Enable commit hooks ### Enable commit hooks
```bash ```
pip install pre-commit pip install pre-commit
pre-commit install pre-commit install
``` ```
### Execute tests before opening a PR ### Execute tests before opening a PR
```bash ```
pytest pytest
``` ```
### Consider using `pyenv` / `pyenv-virtualenv` ### Consider using `pyenv` / `pyenv-virtualenv`
Short installation instructions: Short installation instructions:
* install `pyenv`: * install `pyenv`:
```bash ```
cd ~ cd ~
git clone https://github.com/pyenv/pyenv .pyenv git clone https://github.com/pyenv/pyenv .pyenv
cd ~/.pyenv && src/configure && make -C src cd ~/.pyenv && src/configure && make -C src
``` ```
* make sure that `~/.pyenv/shims` and `~/.pyenv/bin` are the first entries in your `PATH`, e.g., by setting it in `~/.bashrc` * make sure that `~/.pyenv/shims` and `~/.pyenv/bin` are the first entries in your `PATH`, e. g. by setting it in `~/.bashrc`
* add the following to your `~/.bashrc` (after setting `PATH`): `eval "$(pyenv init -)"` * add the following to your `~/.bashrc` (after setting `PATH`): `eval "$(pyenv init -)"`
* create a new terminal or source the changes (e.g., `source ~/.bashrc`) * create a new terminal or source the changes (e. g. `source ~/.bashrc`)
* install `virtualenv` * install `virtualenv`
```bash ```
git clone https://github.com/pyenv/pyenv-virtualenv.git $(pyenv root)/plugins/pyenv-virtualenv git clone https://github.com/pyenv/pyenv-virtualenv.git $(pyenv root)/plugins/pyenv-virtualenv
``` ```
* add the following to your `~/.bashrc` (after the commands above): `eval "$(pyenv virtualenv-init -)` * add the following to your `~/.bashrc` (after the commands above): `eval "$(pyenv virtualenv-init -)`
* create a new terminal or source the changes (e.g., `source ~/.bashrc`) * create a new terminal or source the changes (e. g. `source ~/.bashrc`)
* go back to the `ChatMasterMind` repo and create a virtual environment with the latest `Python`, e.g., `3.11.4`: * go back to the `ChatMasterMind` repo and create a virtual environment with the latest `Python`, e. g. `3.11.4`:
```bash ```
cd <CMM_REPO_PATH> cd <CMM_REPO_PATH>
pyenv install 3.11.4 pyenv install 3.11.4
pyenv virtualenv 3.11.4 py311 pyenv virtualenv 3.11.4 py311
@@ -234,3 +191,5 @@ pyenv activate py311
## License ## License
This project is licensed under the terms of the WTFPL License. This project is licensed under the terms of the WTFPL License.
+8 -24
View File
@@ -1,7 +1,7 @@
from dataclasses import dataclass from dataclasses import dataclass
from abc import abstractmethod
from typing import Protocol, Optional, Union from typing import Protocol, Optional, Union
from .configuration import AIConfig from .configuration import AIConfig
from .tags import Tag
from .message import Message from .message import Message
from .chat import Chat from .chat import Chat
@@ -33,38 +33,28 @@ class AI(Protocol):
The base class for AI clients. The base class for AI clients.
""" """
ID: str
name: str name: str
config: AIConfig config: AIConfig
@abstractmethod
def request(self, def request(self,
question: Message, question: Message,
chat: Chat, context: Chat,
num_answers: int = 1, num_answers: int = 1) -> AIResponse:
otags: Optional[set[Tag]] = None) -> AIResponse:
""" """
Make an AI request. Parameters: Make an AI request, asking the given question with the given
* question: the question to ask context (i. e. chat history). The nr. of requested answers
* chat: the chat history to be added as context corresponds to the nr. of messages in the 'AIResponse'.
* num_answers: nr. of requested answers (corresponds
to the nr. of messages in the 'AIResponse')
* otags: the output tags, i. e. the tags that all
returned messages should contain
""" """
raise NotImplementedError raise NotImplementedError
@abstractmethod
def models(self) -> list[str]: def models(self) -> list[str]:
""" """
Return all models supported by this AI. Return all models supported by this AI.
""" """
raise NotImplementedError raise NotImplementedError
def print_models(self) -> None:
"""
Print all models supported by this AI.
"""
raise NotImplementedError
def tokens(self, data: Union[Message, Chat]) -> int: def tokens(self, data: Union[Message, Chat]) -> int:
""" """
Computes the nr. of AI language tokens for the given message Computes the nr. of AI language tokens for the given message
@@ -72,9 +62,3 @@ class AI(Protocol):
and is not implemented for all AIs. and is not implemented for all AIs.
""" """
raise NotImplementedError raise NotImplementedError
def print(self) -> None:
"""
Print some info about the current AI, like system message.
"""
pass
-49
View File
@@ -1,49 +0,0 @@
"""
Creates different AI instances, based on the given configuration.
"""
import argparse
from typing import cast, Optional
from .configuration import Config, AIConfig, OpenAIConfig
from .ai import AI, AIError
from .ais.openai import OpenAI
def create_ai(args: argparse.Namespace, config: Config, # noqa: 11
def_ai: Optional[str] = None,
def_model: Optional[str] = None) -> AI:
"""
Creates an AI subclass instance from the given arguments and configuration file.
If AI has not been set in the arguments, it searches for the ID 'default'. If
that is not found, it uses the first AI in the list. It's also possible to
specify a default AI and model using 'def_ai' and 'def_model'.
"""
ai_conf: AIConfig
if hasattr(args, 'AI') and args.AI:
try:
ai_conf = config.ais[args.AI]
except KeyError:
raise AIError(f"AI ID '{args.AI}' does not exist in this configuration")
elif def_ai:
ai_conf = config.ais[def_ai]
elif 'default' in config.ais:
ai_conf = config.ais['default']
else:
try:
ai_conf = next(iter(config.ais.values()))
except StopIteration:
raise AIError("No AI found in this configuration")
if ai_conf.name == 'openai':
ai = OpenAI(cast(OpenAIConfig, ai_conf))
if hasattr(args, 'model') and args.model:
ai.config.model = args.model
elif def_model:
ai.config.model = def_model
if hasattr(args, 'max_tokens') and args.max_tokens:
ai.config.max_tokens = args.max_tokens
if hasattr(args, 'temperature') and args.temperature:
ai.config.temperature = args.temperature
return ai
else:
raise AIError(f"AI '{args.AI}' is not supported")
View File
+10 -31
View File
@@ -2,12 +2,12 @@
Implements the OpenAI client classes and functions. Implements the OpenAI client classes and functions.
""" """
import openai import openai
from typing import Optional, Union from typing import Optional
from ..tags import Tag from ..tags import Tag
from ..message import Message, Answer from ..message import Message, Answer
from ..chat import Chat from ..chat import Chat
from ..ai import AI, AIResponse, Tokens from ..ai import AI, AIResponse, Tokens
from ..configuration import OpenAIConfig from ..config import OpenAIConfig
ChatType = list[dict[str, str]] ChatType = list[dict[str, str]]
@@ -17,11 +17,7 @@ class OpenAI(AI):
The OpenAI AI client. The OpenAI AI client.
""" """
def __init__(self, config: OpenAIConfig) -> None: config: OpenAIConfig
self.ID = config.ID
self.name = config.name
self.config = config
openai.api_key = config.api_key
def request(self, def request(self,
question: Message, question: Message,
@@ -43,31 +39,22 @@ class OpenAI(AI):
n=num_answers, n=num_answers,
frequency_penalty=self.config.frequency_penalty, frequency_penalty=self.config.frequency_penalty,
presence_penalty=self.config.presence_penalty) presence_penalty=self.config.presence_penalty)
question.answer = Answer(response['choices'][0]['message']['content']) answers: list[Message] = []
question.tags = set(otags) if otags is not None else None for choice in response['choices']: # type: ignore
question.ai = self.ID
question.model = self.config.model
answers: list[Message] = [question]
for choice in response['choices'][1:]: # type: ignore
answers.append(Message(question=question.question, answers.append(Message(question=question.question,
answer=Answer(choice['message']['content']), answer=Answer(choice['message']['content']),
tags=otags, tags=otags,
ai=self.ID, ai=self.name,
model=self.config.model)) model=self.config.model))
return AIResponse(answers, Tokens(response['usage']['prompt_tokens'], return AIResponse(answers, Tokens(response['usage']['prompt'],
response['usage']['completion_tokens'], response['usage']['completion'],
response['usage']['total_tokens'])) response['usage']['total']))
def models(self) -> list[str]: def models(self) -> list[str]:
""" """
Return all models supported by this AI. Return all models supported by this AI.
""" """
ret = [] raise NotImplementedError
for engine in sorted(openai.Engine.list()['data'], key=lambda x: x['id']):
if engine['ready']:
ret.append(engine['id'])
ret.sort()
return ret
def print_models(self) -> None: def print_models(self) -> None:
""" """
@@ -101,11 +88,3 @@ class OpenAI(AI):
if question: if question:
append('user', question.question) append('user', question.question)
return oai_chat return oai_chat
def tokens(self, data: Union[Message, Chat]) -> int:
raise NotImplementedError
def print(self) -> None:
print(f"MODEL: {self.config.model}")
print("=== SYSTEM ===")
print(self.config.system)
+45
View File
@@ -0,0 +1,45 @@
import openai
from .utils import ChatType
from .configuration import Config
def openai_api_key(api_key: str) -> None:
openai.api_key = api_key
def print_models() -> None:
"""
Print all models supported by the current AI.
"""
not_ready = []
for engine in sorted(openai.Engine.list()['data'], key=lambda x: x['id']):
if engine['ready']:
print(engine['id'])
else:
not_ready.append(engine['id'])
if len(not_ready) > 0:
print('\nNot ready: ' + ', '.join(not_ready))
def ai(chat: ChatType,
config: Config,
number: int
) -> tuple[list[str], dict[str, int]]:
"""
Make AI request with the given chat history and configuration.
Return AI response and tokens used.
"""
response = openai.ChatCompletion.create(
model=config.openai.model,
messages=chat,
temperature=config.openai.temperature,
max_tokens=config.openai.max_tokens,
top_p=config.openai.top_p,
n=number,
frequency_penalty=config.openai.frequency_penalty,
presence_penalty=config.openai.presence_penalty)
result = []
for choice in response['choices']: # type: ignore
result.append(choice['message']['content'].strip())
return result, dict(response['usage']) # type: ignore
+115 -393
View File
@@ -2,23 +2,17 @@
Module implementing various chat classes and functions for managing a chat history. Module implementing various chat classes and functions for managing a chat history.
""" """
import shutil import shutil
from pathlib import Path import pathlib
from pprint import PrettyPrinter from pprint import PrettyPrinter
from pydoc import pager from pydoc import pager
from dataclasses import dataclass from dataclasses import dataclass
from typing import TypeVar, Type, Optional, Any, Callable, Literal, Union from typing import TypeVar, Type, Optional, ClassVar, Any, Callable
from .configuration import default_config_file from .message import Question, Answer, Message, MessageFilter, MessageError, source_code, message_in
from .message import Message, MessageFilter, MessageError, MessageFormat, message_in, message_valid_formats
from .tags import Tag from .tags import Tag
ChatInst = TypeVar('ChatInst', bound='Chat') ChatInst = TypeVar('ChatInst', bound='Chat')
ChatDBInst = TypeVar('ChatDBInst', bound='ChatDB') ChatDBInst = TypeVar('ChatDBInst', bound='ChatDB')
db_next_file = '.next'
ignored_files = [db_next_file, default_config_file]
msg_location = Literal['mem', 'disk', 'cache', 'db', 'all']
msg_suffix = Message.file_suffix_write
class ChatError(Exception): class ChatError(Exception):
pass pass
@@ -36,7 +30,7 @@ def print_paged(text: str) -> None:
pager(text) pager(text)
def read_dir(dir_path: Path, def read_dir(dir_path: pathlib.Path,
glob: Optional[str] = None, glob: Optional[str] = None,
mfilter: Optional[MessageFilter] = None) -> list[Message]: mfilter: Optional[MessageFilter] = None) -> list[Message]:
""" """
@@ -51,33 +45,30 @@ def read_dir(dir_path: Path,
messages: list[Message] = [] messages: list[Message] = []
file_iter = dir_path.glob(glob) if glob else dir_path.iterdir() file_iter = dir_path.glob(glob) if glob else dir_path.iterdir()
for file_path in sorted(file_iter): for file_path in sorted(file_iter):
if (file_path.is_file() if file_path.is_file() and file_path.suffix in Message.file_suffixes:
and file_path.name not in ignored_files # noqa: W503
and file_path.suffix in Message.file_suffixes_read): # noqa: W503
try: try:
message = Message.from_file(file_path, mfilter) message = Message.from_file(file_path, mfilter)
if message: if message:
messages.append(message) messages.append(message)
except MessageError as e: except MessageError as e:
print(f"WARNING: Skipping message in '{file_path}': {str(e)}") print(f"Error processing message in '{file_path}': {str(e)}")
return messages return messages
def make_file_path(dir_path: Path, def make_file_path(dir_path: pathlib.Path,
next_fid: Callable[[], int]) -> Path: file_suffix: str,
next_fid: Callable[[], int]) -> pathlib.Path:
""" """
Create a file_path for the given directory using the given ID generator function. Create a file_path for the given directory using the
given file_suffix and ID generator function.
""" """
file_path = dir_path / f"{next_fid():04d}{msg_suffix}" return dir_path / f"{next_fid():04d}{file_suffix}"
while file_path.exists():
file_path = dir_path / f"{next_fid():04d}{msg_suffix}"
return file_path
def write_dir(dir_path: Path, def write_dir(dir_path: pathlib.Path,
messages: list[Message], messages: list[Message],
next_fid: Callable[[], int], file_suffix: str,
mformat: MessageFormat = Message.default_format) -> None: next_fid: Callable[[], int]) -> None:
""" """
Write all messages to the given directory. If a message has no file_path, Write all messages to the given directory. If a message has no file_path,
a new one will be created. If message.file_path exists, it will be modified a new one will be created. If message.file_path exists, it will be modified
@@ -85,29 +76,28 @@ def write_dir(dir_path: Path,
Parameters: Parameters:
* 'dir_path': destination directory * 'dir_path': destination directory
* 'messages': list of messages to write * 'messages': list of messages to write
* 'file_suffix': suffix for the message files ['.txt'|'.yaml']
* 'next_fid': callable that returns the next file ID * 'next_fid': callable that returns the next file ID
""" """
for message in messages: for message in messages:
file_path = message.file_path file_path = message.file_path
# message has no file_path: create one # message has no file_path: create one
if not file_path: if not file_path:
file_path = make_file_path(dir_path, next_fid) file_path = make_file_path(dir_path, file_suffix, next_fid)
# file_path does not point to given directory: modify it # file_path does not point to given directory: modify it
elif not file_path.parent.samefile(dir_path): elif not file_path.parent.samefile(dir_path):
file_path = dir_path / file_path.name file_path = dir_path / file_path.name
message.to_file(file_path, mformat=mformat) message.to_file(file_path)
def clear_dir(dir_path: Path, def clear_dir(dir_path: pathlib.Path,
glob: Optional[str] = None) -> None: glob: Optional[str] = None) -> None:
""" """
Deletes all Message files in the given directory. Deletes all Message files in the given directory.
""" """
file_iter = dir_path.glob(glob) if glob else dir_path.iterdir() file_iter = dir_path.glob(glob) if glob else dir_path.iterdir()
for file_path in file_iter: for file_path in file_iter:
if (file_path.is_file() if file_path.is_file() and file_path.suffix in Message.file_suffixes:
and file_path.name not in ignored_files # noqa: W503
and file_path.suffix in Message.file_suffixes_read): # noqa: W503
file_path.unlink(missing_ok=True) file_path.unlink(missing_ok=True)
@@ -119,43 +109,14 @@ class Chat:
messages: list[Message] messages: list[Message]
def __post_init__(self) -> None: def filter(self, mfilter: MessageFilter) -> None:
self.validate()
def validate(self) -> None:
"""
Validate this Chat instance.
"""
def msg_paths(stem: str) -> list[str]:
return [str(fp) for fp in file_paths if fp.stem == stem]
file_paths: set[Path] = {m.file_path for m in self.messages if m.file_path is not None}
file_stems = [m.file_path.stem for m in self.messages if m.file_path is not None]
error = False
for fp in file_paths:
if file_stems.count(fp.stem) > 1:
print(f"ERROR: Found multiple copies of message '{fp.stem}': {msg_paths(fp.stem)}")
error = True
if error:
raise ChatError("Validation failed")
def msg_name_matches(self, file_path: Path, name: str) -> bool:
"""
Return True if the given name matches the given file_path.
Matching is True if:
* 'name' matches the full 'file_path'
* 'name' matches 'file_path.name' (i. e. including the suffix)
* 'name' matches 'file_path.stem' (i. e. without the suffix)
"""
return Path(name) == file_path or name == file_path.name or name == file_path.stem
def msg_filter(self, mfilter: MessageFilter) -> None:
""" """
Use 'Message.match(mfilter) to remove all messages that Use 'Message.match(mfilter) to remove all messages that
don't fulfill the filter requirements. don't fulfill the filter requirements.
""" """
self.messages = [m for m in self.messages if m.match(mfilter)] self.messages = [m for m in self.messages if m.match(mfilter)]
def msg_sort(self, reverse: bool = False) -> None: def sort(self, reverse: bool = False) -> None:
""" """
Sort the messages according to 'Message.msg_id()'. Sort the messages according to 'Message.msg_id()'.
""" """
@@ -165,71 +126,20 @@ class Chat:
except MessageError: except MessageError:
pass pass
def msg_unique_id(self) -> None: def clear(self) -> None:
"""
Remove duplicates from the internal messages, based on the msg_id (i. e. file_path).
Messages without a file_path are kept.
"""
old_msgs = self.messages.copy()
self.messages = []
for m in old_msgs:
if not message_in(m, self.messages):
self.messages.append(m)
self.msg_sort()
def msg_unique_content(self) -> None:
"""
Remove duplicates from the internal messages, based on the content (i. e. question + answer).
"""
self.messages = list(set(self.messages))
self.msg_sort()
def msg_clear(self) -> None:
""" """
Delete all messages. Delete all messages.
""" """
self.messages = [] self.messages = []
def msg_add(self, messages: list[Message]) -> None: def add_messages(self, messages: list[Message]) -> None:
""" """
Add new messages and sort them if possible. Add new messages and sort them if possible.
""" """
self.messages += messages self.messages += messages
self.msg_sort() self.sort()
def msg_latest(self, mfilter: Optional[MessageFilter] = None) -> Optional[Message]: def tags(self, prefix: Optional[str] = None, contain: Optional[str] = None) -> set[Tag]:
"""
Return the last added message (according to the file ID) that matches the given filter.
When containing messages without a valid file_path, it returns the latest message in
the internal list.
"""
if len(self.messages) > 0:
self.msg_sort()
for m in reversed(self.messages):
if mfilter is None or m.match(mfilter):
return m
return None
def msg_find(self, msg_names: list[str]) -> list[Message]:
"""
Search and return the messages with the given names. Names can either be filenames
(with or without suffix), full paths or Message.msg_id(). Messages that can't be
found are ignored (i. e. the caller should check the result if they require all
messages).
"""
return [m for m in self.messages
if any((m.file_path and self.msg_name_matches(m.file_path, mn)) for mn in msg_names)]
def msg_remove(self, msg_names: list[str]) -> None:
"""
Remove the messages with the given names. Names can either be filenames
(with or without suffix), full paths or Message.msg_id().
"""
self.messages = [m for m in self.messages
if not any((m.file_path and self.msg_name_matches(m.file_path, mn)) for mn in msg_names)]
self.msg_sort()
def msg_tags(self, prefix: Optional[str] = None, contain: Optional[str] = None) -> set[Tag]:
""" """
Get the tags of all messages, optionally filtered by prefix or substring. Get the tags of all messages, optionally filtered by prefix or substring.
""" """
@@ -238,7 +148,7 @@ class Chat:
tags |= m.filter_tags(prefix, contain) tags |= m.filter_tags(prefix, contain)
return set(sorted(tags)) return set(sorted(tags))
def msg_tags_frequency(self, prefix: Optional[str] = None, contain: Optional[str] = None) -> dict[Tag, int]: def tags_frequency(self, prefix: Optional[str] = None, contain: Optional[str] = None) -> dict[Tag, int]:
""" """
Get the frequency of all tags of all messages, optionally filtered by prefix or substring. Get the frequency of all tags of all messages, optionally filtered by prefix or substring.
""" """
@@ -255,17 +165,23 @@ class Chat:
return sum(m.tokens() for m in self.messages) return sum(m.tokens() for m in self.messages)
def print(self, source_code_only: bool = False, def print(self, source_code_only: bool = False,
with_metadata: bool = False, with_tags: bool = False, with_files: bool = False,
paged: bool = True, paged: bool = True) -> None:
tight: bool = False) -> None:
output: list[str] = [] output: list[str] = []
for message in self.messages: for message in self.messages:
if source_code_only: if source_code_only:
output.append(message.to_str(source_code_only=True)) output.extend(source_code(message.question, include_delims=True))
continue continue
output.append(message.to_str(with_metadata)) output.append('-' * terminal_width())
if not tight: if with_tags:
output.append('\n' + ('-' * terminal_width()) + '\n') output.append(message.tags_str())
if with_files:
output.append('FILE: ' + str(message.file_path))
output.append(Question.txt_header)
output.append(message.question)
if message.answer:
output.append(Answer.txt_header)
output.append(message.answer)
if paged: if paged:
print_paged('\n'.join(output)) print_paged('\n'.join(output))
else: else:
@@ -282,27 +198,27 @@ class ChatDB(Chat):
persistently. persistently.
""" """
cache_path: Path default_file_suffix: ClassVar[str] = '.txt'
db_path: Path
cache_path: pathlib.Path
db_path: pathlib.Path
# a MessageFilter that all messages must match (if given) # a MessageFilter that all messages must match (if given)
mfilter: Optional[MessageFilter] = None mfilter: Optional[MessageFilter] = None
file_suffix: str = default_file_suffix
# the glob pattern for all messages # the glob pattern for all messages
glob: Optional[str] = None glob: Optional[str] = None
# message format (for writing)
mformat: MessageFormat = Message.default_format
def __post_init__(self) -> None: def __post_init__(self) -> None:
# contains the latest message ID # contains the latest message ID
self.next_path = self.db_path / db_next_file self.next_fname = self.db_path / '.next'
# make all paths absolute # make all paths absolute
self.cache_path = self.cache_path.absolute() self.cache_path = self.cache_path.absolute()
self.db_path = self.db_path.absolute() self.db_path = self.db_path.absolute()
self.validate()
@classmethod @classmethod
def from_dir(cls: Type[ChatDBInst], def from_dir(cls: Type[ChatDBInst],
cache_path: Path, cache_path: pathlib.Path,
db_path: Path, db_path: pathlib.Path,
glob: Optional[str] = None, glob: Optional[str] = None,
mfilter: Optional[MessageFilter] = None) -> ChatDBInst: mfilter: Optional[MessageFilter] = None) -> ChatDBInst:
""" """
@@ -317,12 +233,13 @@ class ChatDB(Chat):
when reading them. when reading them.
""" """
messages = read_dir(db_path, glob, mfilter) messages = read_dir(db_path, glob, mfilter)
return cls(messages, cache_path, db_path, mfilter, glob) return cls(messages, cache_path, db_path, mfilter,
cls.default_file_suffix, glob)
@classmethod @classmethod
def from_messages(cls: Type[ChatDBInst], def from_messages(cls: Type[ChatDBInst],
cache_path: Path, cache_path: pathlib.Path,
db_path: Path, db_path: pathlib.Path,
messages: list[Message], messages: list[Message],
mfilter: Optional[MessageFilter] = None) -> ChatDBInst: mfilter: Optional[MessageFilter] = None) -> ChatDBInst:
""" """
@@ -332,7 +249,7 @@ class ChatDB(Chat):
def get_next_fid(self) -> int: def get_next_fid(self) -> int:
try: try:
with open(self.next_path, 'r') as f: with open(self.next_fname, 'r') as f:
next_fid = int(f.read()) + 1 next_fid = int(f.read()) + 1
self.set_next_fid(next_fid) self.set_next_fid(next_fid)
return next_fid return next_fid
@@ -341,266 +258,69 @@ class ChatDB(Chat):
return 1 return 1
def set_next_fid(self, fid: int) -> None: def set_next_fid(self, fid: int) -> None:
with open(self.next_path, 'w') as f: with open(self.next_fname, 'w') as f:
f.write(f'{fid}') f.write(f'{fid}')
def set_msg_format(self, mformat: MessageFormat) -> None: def read_db(self) -> None:
""" """
Set message format for writing messages. Reads new messages from the DB directory. New ones are added to the internal list,
""" existing ones are replaced. A message is determined as 'existing' if a message with
if mformat not in message_valid_formats:
raise ChatError(f"Message format '{mformat}' is not supported")
self.mformat = mformat
def msg_write(self,
messages: Optional[list[Message]] = None,
mformat: Optional[MessageFormat] = None) -> None:
"""
Write either the given messages or the internal ones to their CURRENT file_path.
If messages are given, they all must have a valid file_path. When writing the
internal messages, the ones with a valid file_path are written, the others
are ignored.
"""
if messages and any(m.file_path is None for m in messages):
raise ChatError("Can't write files without a valid file_path")
msgs = iter(messages if messages else self.messages)
while (m := next(msgs, None)):
m.to_file(mformat=mformat if mformat else self.mformat)
def msg_update(self, messages: list[Message], write: bool = True) -> None:
"""
Update EXISTING messages. A message is determined as 'existing' if a message with
the same base filename (i. e. 'file_path.name') is already in the list. the same base filename (i. e. 'file_path.name') is already in the list.
Only accepts existing messages.
"""
if any(not message_in(m, self.messages) for m in messages):
raise ChatError("Can't update messages that are not in the internal list")
# remove old versions and add new ones
self.messages = [m for m in self.messages if not message_in(m, messages)]
self.messages += messages
self.msg_sort()
# write the UPDATED messages if requested
if write:
self.msg_write(messages)
def msg_gather(self,
loc: msg_location,
require_file_path: bool = False,
glob: Optional[str] = None,
mfilter: Optional[MessageFilter] = None) -> list[Message]:
"""
Gather and return messages from the given locations:
* 'mem' : messages currently in memory
* 'disk' : messages on disk (cache + DB directory), but not in memory
* 'cache': messages in the cache directory
* 'db' : messages in the DB directory
* 'all' : all messages ('mem' + 'disk')
If 'require_file_path' is True, return only files with a valid file_path.
"""
loc_messages: list[Message] = []
if loc in ['mem', 'all']:
if require_file_path:
loc_messages += [m for m in self.messages if (m.file_path is not None and (mfilter is None or m.match(mfilter)))]
else:
loc_messages += [m for m in self.messages if (mfilter is None or m.match(mfilter))]
if loc in ['cache', 'disk', 'all']:
loc_messages += read_dir(self.cache_path, glob=glob, mfilter=mfilter)
if loc in ['db', 'disk', 'all']:
loc_messages += read_dir(self.db_path, glob=glob, mfilter=mfilter)
# remove_duplicates and sort the list
unique_messages: list[Message] = []
for m in loc_messages:
if not message_in(m, unique_messages):
unique_messages.append(m)
try:
unique_messages.sort(key=lambda m: m.msg_id())
# messages in 'mem' can have an empty file_path
except MessageError:
pass
return unique_messages
def msg_find(self,
msg_names: list[str],
loc: msg_location = 'mem',
) -> list[Message]:
"""
Search and return the messages with the given names. Names can either be filenames
(with or without suffix), full paths or Message.msg_id(). Messages that can't be
found are ignored (i. e. the caller should check the result if they require all
messages).
Searches one of the following locations:
* 'mem' : messages currently in memory
* 'disk' : messages on disk (cache + DB directory), but not in memory
* 'cache': messages in the cache directory
* 'db' : messages in the DB directory
* 'all' : all messages ('mem' + 'disk')
"""
loc_messages = self.msg_gather(loc, require_file_path=True)
return [m for m in loc_messages
if any((m.file_path and self.msg_name_matches(m.file_path, mn)) for mn in msg_names)]
def msg_remove(self, msg_names: list[str], loc: msg_location = 'mem') -> None:
"""
Remove the messages with the given names. Names can either be filenames
(with or without suffix), full paths or Message.msg_id(). Also deletes the
files of all given messages with a valid file_path.
Delete files from one of the following locations:
* 'mem' : messages currently in memory
* 'disk' : messages on disk (cache + DB directory), but not in memory
* 'cache': messages in the cache directory
* 'db' : messages in the DB directory
* 'all' : all messages ('mem' + 'disk')
"""
if loc != 'mem':
# delete the message files first
rm_messages = self.msg_find(msg_names, loc=loc)
for m in rm_messages:
if (m.file_path):
m.file_path.unlink()
# then remove them from the internal list
super().msg_remove(msg_names)
def msg_latest(self,
mfilter: Optional[MessageFilter] = None,
loc: msg_location = 'mem') -> Optional[Message]:
"""
Return the last added message (according to the file ID) that matches the given filter.
Only consider messages with a valid file_path (except if loc is 'mem').
Searches one of the following locations:
* 'mem' : messages currently in memory
* 'disk' : messages on disk (cache + DB directory), but not in memory
* 'cache': messages in the cache directory
* 'db' : messages in the DB directory
* 'all' : all messages ('mem' + 'disk')
"""
# only consider messages with a valid file_path so they can be sorted
loc_messages = self.msg_gather(loc, require_file_path=True)
loc_messages.sort(key=lambda m: m.msg_id(), reverse=True)
for m in loc_messages:
if mfilter is None or m.match(mfilter):
return m
return None
def msg_in_cache(self, message: Union[Message, str]) -> bool:
"""
Return true if the given Message (or filename or Message.msg_id())
is located in the cache directory. False otherwise.
"""
if isinstance(message, Message):
return (message.file_path is not None
and message.file_path.parent.samefile(self.cache_path) # noqa: W503
and message.file_path.exists()) # noqa: W503
else:
return len(self.msg_find([message], loc='cache')) > 0
def msg_in_db(self, message: Union[Message, str]) -> bool:
"""
Return true if the given Message (or filename or Message.msg_id())
is located in the DB directory. False otherwise.
"""
if isinstance(message, Message):
return (message.file_path is not None
and message.file_path.parent.samefile(self.db_path) # noqa: W503
and message.file_path.exists()) # noqa: W503
else:
return len(self.msg_find([message], loc='db')) > 0
def cache_read(self, glob: Optional[str] = None, mfilter: Optional[MessageFilter] = None) -> None:
"""
Read messages from the cache directory. New ones are added to the internal list,
existing ones are replaced. A message is determined as 'existing' if a message
with the same base filename (i. e. 'file_path.name') is already in the list.
"""
new_messages = read_dir(self.cache_path, glob, mfilter)
# remove all messages from self.messages that are in the new list
self.messages = [m for m in self.messages if not message_in(m, new_messages)]
# copy the messages from the temporary list to self.messages and sort them
self.messages += new_messages
self.msg_sort()
def cache_write(self, messages: Optional[list[Message]] = None) -> None:
"""
Write messages to the cache directory. If a message has no file_path, a new one
will be created. If message.file_path exists, it will be modified to point to
the cache directory.
Does NOT add the messages to the internal list (use 'cache_add()' for that)!
"""
write_dir(self.cache_path,
messages if messages else self.messages,
self.get_next_fid,
self.mformat)
def cache_add(self, messages: list[Message], write: bool = True) -> None:
"""
Add NEW messages and set the file_path to the cache directory.
Only accepts messages without a file_path.
"""
if any(m.file_path is not None for m in messages):
raise ChatError("Can't add new messages with existing file_path")
if write:
write_dir(self.cache_path,
messages,
self.get_next_fid,
self.mformat)
else:
for m in messages:
m.file_path = make_file_path(self.cache_path, self.get_next_fid)
self.messages += messages
self.msg_sort()
def cache_clear(self, glob: Optional[str] = None) -> None:
"""
Delete all message files from the cache dir and remove them from the internal list.
"""
clear_dir(self.cache_path, glob)
# only keep messages from DB dir (or those that have not yet been written)
self.messages = [m for m in self.messages if not m.file_path or m.file_path.parent.samefile(self.db_path)]
def cache_move(self, message: Message) -> None:
"""
Moves the given messages to the cache directory.
"""
# remember the old path (if any)
old_path: Optional[Path] = None
if message.file_path:
old_path = message.file_path
# write message to the new destination
self.cache_write([message])
# remove the old one (if any)
if old_path:
self.msg_remove([str(old_path)], loc='db')
# (re)add it to the internal list
self.msg_add([message])
def db_read(self, glob: Optional[str] = None, mfilter: Optional[MessageFilter] = None) -> None:
"""
Read messages from the DB directory. New ones are added to the internal list,
existing ones are replaced. A message is determined as 'existing' if a message
with the same base filename (i. e. 'file_path.name') is already in the list.
""" """
new_messages = read_dir(self.db_path, self.glob, self.mfilter) new_messages = read_dir(self.db_path, self.glob, self.mfilter)
# remove all messages from self.messages that are in the new list # remove all messages from self.messages that are in the new list
self.messages = [m for m in self.messages if not message_in(m, new_messages)] self.messages = [m for m in self.messages if not message_in(m, new_messages)]
# copy the messages from the temporary list to self.messages and sort them # copy the messages from the temporary list to self.messages and sort them
self.messages += new_messages self.messages += new_messages
self.msg_sort() self.sort()
def db_write(self, messages: Optional[list[Message]] = None) -> None: def read_cache(self) -> None:
"""
Reads new messages from the cache directory. New ones are added to the internal list,
existing ones are replaced. A message is determined as 'existing' if a message with
the same base filename (i. e. 'file_path.name') is already in the list.
"""
new_messages = read_dir(self.cache_path, self.glob, self.mfilter)
# remove all messages from self.messages that are in the new list
self.messages = [m for m in self.messages if not message_in(m, new_messages)]
# copy the messages from the temporary list to self.messages and sort them
self.messages += new_messages
self.sort()
def write_db(self, messages: Optional[list[Message]] = None) -> None:
""" """
Write messages to the DB directory. If a message has no file_path, a new one Write messages to the DB directory. If a message has no file_path, a new one
will be created. If message.file_path exists, it will be modified to point will be created. If message.file_path exists, it will be modified to point
to the DB directory. to the DB directory.
Does NOT add the messages to the internal list (use 'db_add()' for that)!
""" """
write_dir(self.db_path, write_dir(self.db_path,
messages if messages else self.messages, messages if messages else self.messages,
self.get_next_fid, self.file_suffix,
self.mformat) self.get_next_fid)
def db_add(self, messages: list[Message], write: bool = True) -> None: def write_cache(self, messages: Optional[list[Message]] = None) -> None:
""" """
Add NEW messages and set the file_path to the DB directory. Write messages to the cache directory. If a message has no file_path, a new one
will be created. If message.file_path exists, it will be modified to point to
the cache directory.
"""
write_dir(self.cache_path,
messages if messages else self.messages,
self.file_suffix,
self.get_next_fid)
def clear_cache(self) -> None:
"""
Deletes all Message files from the cache dir and removes those messages from
the internal list.
"""
clear_dir(self.cache_path, self.glob)
# only keep messages from DB dir (or those that have not yet been written)
self.messages = [m for m in self.messages if not m.file_path or m.file_path.parent.samefile(self.db_path)]
def add_to_db(self, messages: list[Message], write: bool = True) -> None:
"""
Adds the given new messages and sets the file_path to the DB directory.
Only accepts messages without a file_path. Only accepts messages without a file_path.
""" """
if any(m.file_path is not None for m in messages): if any(m.file_path is not None for m in messages):
@@ -608,26 +328,28 @@ class ChatDB(Chat):
if write: if write:
write_dir(self.db_path, write_dir(self.db_path,
messages, messages,
self.get_next_fid, self.file_suffix,
self.mformat) self.get_next_fid)
else: else:
for m in messages: for m in messages:
m.file_path = make_file_path(self.db_path, self.get_next_fid) m.file_path = make_file_path(self.db_path, self.default_file_suffix, self.get_next_fid)
self.messages += messages self.messages += messages
self.msg_sort() self.sort()
def db_move(self, message: Message) -> None: def add_to_cache(self, messages: list[Message], write: bool = True) -> None:
""" """
Moves the given messages to the db directory. Adds the given new messages and sets the file_path to the cache directory.
Only accepts messages without a file_path.
""" """
# remember the old path (if any) if any(m.file_path is not None for m in messages):
old_path: Optional[Path] = None raise ChatError("Can't add new messages with existing file_path")
if message.file_path: if write:
old_path = message.file_path write_dir(self.cache_path,
# write message to the new destination messages,
self.db_write([message]) self.file_suffix,
# remove the old one (if any) self.get_next_fid)
if old_path: else:
self.msg_remove([str(old_path)], loc='cache') for m in messages:
# (re)add it to the internal list m.file_path = make_file_path(self.cache_path, self.default_file_suffix, self.get_next_fid)
self.msg_add([message]) self.messages += messages
self.sort()
-20
View File
@@ -1,20 +0,0 @@
import argparse
from pathlib import Path
from ..configuration import Config
from ..ai import AI
from ..ai_factory import create_ai
def config_cmd(args: argparse.Namespace) -> None:
"""
Handler for the 'config' command.
"""
if args.create:
Config.create_default(Path(args.create))
elif args.list_models or args.print_model:
config: Config = Config.from_file(args.config)
ai: AI = create_ai(args, config)
if args.list_models:
ai.print_models()
else:
print(ai.config.model)
-72
View File
@@ -1,72 +0,0 @@
import sys
import argparse
from pathlib import Path
from ..configuration import Config
from ..chat import ChatDB
from ..message import MessageFilter, Message
msg_suffix = Message.file_suffix_write # currently '.msg'
def convert_messages(args: argparse.Namespace, config: Config) -> None:
"""
Convert messages to a new format. Also used to change old suffixes
('.txt', '.yaml') to the latest default message file suffix ('.msg').
"""
chat = ChatDB.from_dir(Path(config.cache),
Path(config.db))
# read all known message files
msgs = chat.msg_gather(loc='disk', glob='*.*')
# make a set of all message IDs
msg_ids = set([m.msg_id() for m in msgs])
# set requested format and write all messages
chat.set_msg_format(args.convert)
# delete the current suffix
# -> a new one will automatically be created
for m in msgs:
if m.file_path:
m.file_path = m.file_path.with_suffix('')
chat.msg_write(msgs)
# read all messages with the current default suffix
msgs = chat.msg_gather(loc='disk', glob=f'*{msg_suffix}')
# make sure we converted all of the original messages
for mid in msg_ids:
if not any(mid == m.msg_id() for m in msgs):
print(f"Message '{mid}' has not been found after conversion. Aborting.")
sys.exit(1)
# delete messages with old suffixes
msgs = chat.msg_gather(loc='disk', glob='*.*')
for m in msgs:
if m.file_path and m.file_path.suffix != msg_suffix:
m.rm_file()
print(f"Successfully converted {len(msg_ids)} messages.")
def print_chat(args: argparse.Namespace, config: Config) -> None:
"""
Print the DB chat history.
"""
mfilter = MessageFilter(tags_or=args.or_tags,
tags_and=args.and_tags,
tags_not=args.exclude_tags,
question_contains=args.question,
answer_contains=args.answer)
chat = ChatDB.from_dir(Path(config.cache),
Path(config.db),
mfilter=mfilter)
chat.print(args.source_code_only,
args.with_metadata,
paged=not args.no_paging,
tight=args.tight)
def hist_cmd(args: argparse.Namespace, config: Config) -> None:
"""
Handler for the 'hist' command.
"""
if args.print:
print_chat(args, config)
elif args.convert:
convert_messages(args, config)
-45
View File
@@ -1,45 +0,0 @@
import sys
import argparse
from pathlib import Path
from ..configuration import Config
from ..message import Message, MessageError
from ..chat import ChatDB
def print_message(message: Message, args: argparse.Namespace) -> None:
"""
Print given message according to give arguments.
"""
if args.question:
print(message.question)
elif args.answer:
print(message.answer)
elif message.answer and args.only_source_code:
for code in message.answer.source_code():
print(code)
else:
print(message.to_str())
def print_cmd(args: argparse.Namespace, config: Config) -> None:
"""
Handler for the 'print' command.
"""
# print given file
if args.file is not None:
fname = Path(args.file)
try:
message = Message.from_file(fname)
if message:
print_message(message, args)
except MessageError:
print(f"File is not a valid message: {args.file}")
sys.exit(1)
# print latest message
elif args.latest:
chat = ChatDB.from_dir(Path(config.cache), Path(config.db))
latest = chat.msg_latest(loc='disk')
if not latest:
print("No message found!")
sys.exit(1)
print_message(latest, args)
-218
View File
@@ -1,218 +0,0 @@
import sys
import argparse
from pathlib import Path
from itertools import zip_longest
from copy import deepcopy
from ..configuration import Config
from ..chat import ChatDB
from ..message import Message, MessageFilter, MessageError, Question, source_code
from ..ai_factory import create_ai
from ..ai import AI, AIResponse
class QuestionCmdError(Exception):
pass
def add_file_as_text(question_parts: list[str], file: str) -> None:
"""
Add the given file as plain text to the question part list.
If the file is a Message, add the answer.
"""
file_path = Path(file)
content: str
try:
message = Message.from_file(file_path)
if message and message.answer:
content = message.answer
except MessageError:
with open(file) as r:
content = r.read().strip()
if len(content) > 0:
question_parts.append(content)
def add_file_as_code(question_parts: list[str], file: str) -> None:
"""
Add all source code from the given file. If no code segments can be extracted,
the whole content is added as source code segment. If the file is a Message,
extract the source code from the answer.
"""
file_path = Path(file)
content: str
try:
message = Message.from_file(file_path)
if message and message.answer:
content = message.answer
except MessageError:
with open(file) as r:
content = r.read().strip()
# extract and add source code
code_parts = source_code(content, include_delims=True)
if len(code_parts) > 0:
question_parts += code_parts
else:
question_parts.append(f"```\n{content}\n```")
def create_msg_args(msg: Message, args: argparse.Namespace) -> argparse.Namespace:
"""
Takes an existing message and CLI arguments, and returns modified args based
on the members of the given message. Used e.g. when repeating messages, where
it's necessary to determine the correct AI, module and output tags to use
(either from the existing message or the given args).
"""
msg_args = args
# if AI, model or output tags have not been specified,
# use those from the original message
if (args.AI is None
or args.model is None # noqa: W503
or args.output_tags is None): # noqa: W503
msg_args = deepcopy(args)
if args.AI is None and msg.ai is not None:
msg_args.AI = msg.ai
if args.model is None and msg.model is not None:
msg_args.model = msg.model
if args.output_tags is None and msg.tags is not None:
msg_args.output_tags = msg.tags
return msg_args
def create_message(chat: ChatDB, args: argparse.Namespace) -> Message:
"""
Create a new message from the given arguments and write it
to the cache directory.
"""
question_parts = []
if args.create is not None:
question_list = args.create
elif args.ask is not None:
question_list = args.ask
else:
raise QuestionCmdError("No question found")
text_files = args.source_text if args.source_text is not None else []
code_files = args.source_code if args.source_code is not None else []
for question, text_file, code_file in zip_longest(question_list, text_files, code_files, fillvalue=None):
if question is not None and len(question.strip()) > 0:
question_parts.append(question)
if text_file is not None and len(text_file) > 0:
add_file_as_text(question_parts, text_file)
if code_file is not None and len(code_file) > 0:
add_file_as_code(question_parts, code_file)
full_question = '\n\n'.join(question_parts)
message = Message(question=Question(full_question),
tags=args.output_tags,
ai=args.AI,
model=args.model)
# only write the new message to the cache,
# don't add it to the internal list
chat.cache_write([message])
return message
def make_request(ai: AI, chat: ChatDB, message: Message, args: argparse.Namespace) -> None:
"""
Make an AI request with the given AI, chat history, message and arguments.
Write the response(s) to the cache directory, without appending it to the
given chat history. Then print the response(s).
"""
# print history and message question before making the request
ai.print()
chat.print(paged=False)
print(message.to_str())
response: AIResponse = ai.request(message,
chat,
args.num_answers,
args.output_tags)
# only write the response messages to the cache,
# don't add them to the internal list
chat.cache_write(response.messages)
for idx, msg in enumerate(response.messages):
print(f"=== ANSWER {idx+1} ===")
print(msg.answer)
if response.tokens:
print("===============")
print(response.tokens)
def repeat_messages(messages: list[Message], chat: ChatDB, args: argparse.Namespace, config: Config) -> None:
"""
Repeat the given messages using the given arguments.
"""
ai: AI
for msg in messages:
msg_args = create_msg_args(msg, args)
ai = create_ai(msg_args, config)
print(f"--------- Repeating message '{msg.msg_id()}': ---------")
# overwrite the latest message if requested or empty
# -> but not if it's in the DB!
if ((msg.answer is None or msg_args.overwrite is True)
and (not chat.msg_in_db(msg))): # noqa: W503
msg.clear_answer()
make_request(ai, chat, msg, msg_args)
# otherwise create a new one
else:
msg_args.ask = [msg.question]
message = create_message(chat, msg_args)
make_request(ai, chat, message, msg_args)
def invert_input_tag_args(args: argparse.Namespace) -> None:
"""
Changes the semantics of the INPUT tags for this command:
* not tags specified on the CLI -> no tags are selected
* empty tags specified on the CLI -> all tags are selected
"""
if args.or_tags is None:
args.or_tags = set()
elif len(args.or_tags) == 0:
args.or_tags = None
if args.and_tags is None:
args.and_tags = set()
elif len(args.and_tags) == 0:
args.and_tags = None
def question_cmd(args: argparse.Namespace, config: Config) -> None:
"""
Handler for the 'question' command.
"""
invert_input_tag_args(args)
mfilter = MessageFilter(tags_or=args.or_tags,
tags_and=args.and_tags,
tags_not=args.exclude_tags)
chat = ChatDB.from_dir(cache_path=Path(config.cache),
db_path=Path(config.db),
mfilter=mfilter)
# if it's a new question, create and store it immediately
if args.ask or args.create:
message = create_message(chat, args)
if args.create:
return
# === ASK ===
if args.ask:
ai: AI = create_ai(args, config)
make_request(ai, chat, message, args)
# === REPEAT ===
elif args.repeat is not None:
repeat_msgs: list[Message] = []
# repeat latest message
if len(args.repeat) == 0:
lmessage = chat.msg_latest(loc='cache')
if lmessage is None:
print("No message found to repeat!")
sys.exit(1)
repeat_msgs.append(lmessage)
# repeat given message(s)
else:
repeat_msgs = chat.msg_find(args.repeat, loc='disk')
repeat_messages(repeat_msgs, chat, args, config)
# === PROCESS ===
elif args.process is not None:
# TODO: process either all questions without an
# answer or the one(s) given in 'args.process'
pass
-17
View File
@@ -1,17 +0,0 @@
import argparse
from pathlib import Path
from ..configuration import Config
from ..chat import ChatDB
def tags_cmd(args: argparse.Namespace, config: Config) -> None:
"""
Handler for the 'tags' command.
"""
chat = ChatDB.from_dir(cache_path=Path(config.cache),
db_path=Path(config.db))
if args.list:
tags_freq = chat.msg_tags_frequency(args.prefix, args.contain)
for tag, freq in tags_freq.items():
print(f"- {tag}: {freq}")
# TODO: add renaming
+23 -114
View File
@@ -1,52 +1,17 @@
import yaml import yaml
from pathlib import Path from typing import Type, TypeVar, Any
from typing import Type, TypeVar, Any, Optional, ClassVar from dataclasses import dataclass, asdict
from dataclasses import dataclass, asdict, field
ConfigInst = TypeVar('ConfigInst', bound='Config') ConfigInst = TypeVar('ConfigInst', bound='Config')
AIConfigInst = TypeVar('AIConfigInst', bound='AIConfig')
OpenAIConfigInst = TypeVar('OpenAIConfigInst', bound='OpenAIConfig') OpenAIConfigInst = TypeVar('OpenAIConfigInst', bound='OpenAIConfig')
supported_ais: list[str] = ['openai']
default_config_file = '.config.yaml'
class ConfigError(Exception):
pass
def str_presenter(dumper: yaml.Dumper, data: str) -> yaml.ScalarNode:
"""
Changes the YAML dump style to multiline syntax for multiline strings.
"""
if len(data.splitlines()) > 1:
return dumper.represent_scalar('tag:yaml.org,2002:str', data, style='|')
return dumper.represent_scalar('tag:yaml.org,2002:str', data)
yaml.add_representer(str, str_presenter)
@dataclass @dataclass
class AIConfig: class AIConfig:
""" """
The base class of all AI configurations. The base class of all AI configurations.
""" """
# the name of the AI the config class represents name: str
# -> it's a class variable and thus not part of the
# dataclass constructor
name: ClassVar[str]
# a user-defined ID for an AI configuration entry
ID: str
model: str = 'n/a'
# the name must not be changed
def __setattr__(self, name: str, value: Any) -> None:
if name == 'name':
raise AttributeError("'{name}' is not allowed to be changed")
else:
super().__setattr__(name, value)
@dataclass @dataclass
@@ -54,59 +19,29 @@ class OpenAIConfig(AIConfig):
""" """
The OpenAI section of the configuration file. The OpenAI section of the configuration file.
""" """
name: ClassVar[str] = 'openai' api_key: str
model: str
# all members have default values, so we can easily create temperature: float
# a default configuration max_tokens: int
ID: str = 'myopenai' top_p: float
api_key: str = '0123456789' frequency_penalty: float
model: str = 'gpt-3.5-turbo-16k' presence_penalty: float
temperature: float = 1.0
max_tokens: int = 4000
top_p: float = 1.0
frequency_penalty: float = 0.0
presence_penalty: float = 0.0
system: str = 'You are an assistant'
@classmethod @classmethod
def from_dict(cls: Type[OpenAIConfigInst], source: dict[str, Any]) -> OpenAIConfigInst: def from_dict(cls: Type[OpenAIConfigInst], source: dict[str, Any]) -> OpenAIConfigInst:
""" """
Create OpenAIConfig from a dict. Create OpenAIConfig from a dict.
""" """
res = cls( return cls(
name='OpenAI',
api_key=str(source['api_key']), api_key=str(source['api_key']),
model=str(source['model']), model=str(source['model']),
max_tokens=int(source['max_tokens']), max_tokens=int(source['max_tokens']),
temperature=float(source['temperature']), temperature=float(source['temperature']),
top_p=float(source['top_p']), top_p=float(source['top_p']),
frequency_penalty=float(source['frequency_penalty']), frequency_penalty=float(source['frequency_penalty']),
presence_penalty=float(source['presence_penalty']), presence_penalty=float(source['presence_penalty'])
system=str(source['system'])
) )
# overwrite default ID if provided
if 'ID' in source:
res.ID = source['ID']
return res
def ai_config_instance(name: str, conf_dict: Optional[dict[str, Any]] = None) -> AIConfig:
"""
Creates an AIConfig instance of the given name.
"""
if name.lower() == 'openai':
if conf_dict is None:
return OpenAIConfig()
else:
return OpenAIConfig.from_dict(conf_dict)
else:
raise ConfigError(f"Unknown AI '{name}'")
def create_default_ai_configs() -> dict[str, AIConfig]:
"""
Create a dict containing default configurations for all supported AIs.
"""
return {ai_config_instance(name).ID: ai_config_instance(name) for name in supported_ais}
@dataclass @dataclass
@@ -114,56 +49,30 @@ class Config:
""" """
The configuration file structure. The configuration file structure.
""" """
# all members have default values, so we can easily create system: str
# a default configuration db: str
cache: str = '.' openai: OpenAIConfig
db: str = './db/'
ais: dict[str, AIConfig] = field(default_factory=create_default_ai_configs)
@classmethod @classmethod
def from_dict(cls: Type[ConfigInst], source: dict[str, Any]) -> ConfigInst: def from_dict(cls: Type[ConfigInst], source: dict[str, Any]) -> ConfigInst:
""" """
Create Config from a dict (with the same format as the config file). Create Config from a dict.
""" """
# create the correct AI type instances
ais: dict[str, AIConfig] = {}
for ID, conf in source['ais'].items():
# add the AI ID to the config (for easy internal access)
conf['ID'] = ID
ai_conf = ai_config_instance(conf['name'], conf)
ais[ID] = ai_conf
return cls( return cls(
cache=str(source['cache']) if 'cache' in source else '.', system=str(source['system']),
db=str(source['db']), db=str(source['db']),
ais=ais openai=OpenAIConfig.from_dict(source['openai'])
) )
@classmethod
def create_default(self, file_path: Path) -> None:
"""
Creates a default Config in the given file.
"""
conf = Config()
conf.to_file(file_path)
@classmethod @classmethod
def from_file(cls: Type[ConfigInst], path: str) -> ConfigInst: def from_file(cls: Type[ConfigInst], path: str) -> ConfigInst:
with open(path, 'r') as f: with open(path, 'r') as f:
source = yaml.load(f, Loader=yaml.FullLoader) source = yaml.load(f, Loader=yaml.FullLoader)
return cls.from_dict(source) return cls.from_dict(source)
def to_file(self, file_path: Path) -> None: def to_file(self, path: str) -> None:
# remove the AI name from the config (for a cleaner format) with open(path, 'w') as f:
data = self.as_dict() yaml.dump(asdict(self), f, sort_keys=False)
for conf in data['ais'].values():
del (conf['ID'])
with open(file_path, 'w') as f:
yaml.dump(data, f, sort_keys=False)
def as_dict(self) -> dict[str, Any]: def as_dict(self) -> dict[str, Any]:
res = asdict(self) return asdict(self)
# add the AI name manually (as first element)
# (not done by 'asdict' because it's a class variable)
for ID, conf in res['ais'].items():
res['ais'][ID] = {**{'name': self.ais[ID].name}, **conf}
return res
+167 -63
View File
@@ -2,18 +2,21 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# vim: set fileencoding=utf-8 : # vim: set fileencoding=utf-8 :
import yaml
import sys import sys
import argcomplete import argcomplete
import argparse import argparse
from pathlib import Path from pathlib import Path
from .utils import terminal_width, print_tag_args, print_chat_hist, display_source_code, ChatType
from .storage import save_answers, create_chat_hist, read_file, dump_data
from .api_client import ai, openai_api_key, print_models
from .configuration import Config
from .chat import ChatDB
from .message import Message, MessageFilter
from itertools import zip_longest
from typing import Any from typing import Any
from .configuration import Config, default_config_file
from .message import Message default_config = '.config.yaml'
from .commands.question import question_cmd
from .commands.tags import tags_cmd
from .commands.config import config_cmd
from .commands.hist import hist_cmd
from .commands.print import print_cmd
def tags_completer(prefix: str, parsed_args: Any, **kwargs: Any) -> list[str]: def tags_completer(prefix: str, parsed_args: Any, **kwargs: Any) -> list[str]:
@@ -21,10 +24,128 @@ def tags_completer(prefix: str, parsed_args: Any, **kwargs: Any) -> list[str]:
return list(Message.tags_from_dir(Path(config.db), prefix=prefix)) return list(Message.tags_from_dir(Path(config.db), prefix=prefix))
def create_question_with_hist(args: argparse.Namespace,
config: Config,
) -> tuple[ChatType, str, list[str]]:
"""
Creates the "AI request", including the question and chat history as determined
by the specified tags.
"""
tags = args.tags or []
etags = args.etags or []
otags = args.output_tags or []
if not args.source_code_only:
print_tag_args(tags, etags, otags)
question_parts = []
question_list = args.question if args.question is not None else []
source_list = args.source if args.source is not None else []
for question, source in zip_longest(question_list, source_list, fillvalue=None):
if question is not None and source is not None:
with open(source) as r:
question_parts.append(f"{question}\n\n```\n{r.read().strip()}\n```")
elif question is not None:
question_parts.append(question)
elif source is not None:
with open(source) as r:
question_parts.append(f"```\n{r.read().strip()}\n```")
full_question = '\n\n'.join(question_parts)
chat = create_chat_hist(full_question, tags, etags, config,
match_all_tags=True if args.atags else False, # FIXME
with_tags=False,
with_file=False)
return chat, full_question, tags
def tags_cmd(args: argparse.Namespace, config: Config) -> None:
"""
Handler for the 'tags' command.
"""
chat = ChatDB.from_dir(cache_path=Path('.'),
db_path=Path(config.db))
if args.list:
tags_freq = chat.tags_frequency(args.prefix, args.contain)
for tag, freq in tags_freq.items():
print(f"- {tag}: {freq}")
# TODO: add renaming
def config_cmd(args: argparse.Namespace, config: Config) -> None:
"""
Handler for the 'config' command.
"""
if args.list_models:
print_models()
elif args.print_model:
print(config.openai.model)
elif args.model:
config.openai.model = args.model
config.to_file(args.config)
def ask_cmd(args: argparse.Namespace, config: Config) -> None:
"""
Handler for the 'ask' command.
"""
if args.max_tokens:
config.openai.max_tokens = args.max_tokens
if args.temperature:
config.openai.temperature = args.temperature
if args.model:
config.openai.model = args.model
chat, question, tags = create_question_with_hist(args, config)
print_chat_hist(chat, False, args.source_code_only)
otags = args.output_tags or []
answers, usage = ai(chat, config, args.number)
save_answers(question, answers, tags, otags, config)
print("-" * terminal_width())
print(f"Usage: {usage}")
def hist_cmd(args: argparse.Namespace, config: Config) -> None:
"""
Handler for the 'hist' command.
"""
mfilter = MessageFilter(tags_or=args.tags,
tags_and=args.atags,
tags_not=args.etags,
question_contains=args.question,
answer_contains=args.answer)
chat = ChatDB.from_dir(Path('.'),
Path(config.db),
mfilter=mfilter)
chat.print(args.source_code_only,
args.with_tags,
args.with_files)
def print_cmd(args: argparse.Namespace, config: Config) -> None:
"""
Handler for the 'print' command.
"""
fname = Path(args.file)
if fname.suffix == '.yaml':
with open(args.file, 'r') as f:
data = yaml.load(f, Loader=yaml.FullLoader)
elif fname.suffix == '.txt':
data = read_file(fname)
else:
print(f"Unknown file type: {args.file}")
sys.exit(1)
if args.source_code_only:
display_source_code(data['answer'])
else:
print(dump_data(data).strip())
def create_parser() -> argparse.ArgumentParser: def create_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
description="ChatMastermind is a Python application that automates conversation with AI") description="ChatMastermind is a Python application that automates conversation with AI")
parser.add_argument('-C', '--config', help='Config file name.', default=default_config_file) parser.add_argument('-c', '--config', help='Config file name.', default=default_config)
# subcommand-parser # subcommand-parser
cmdparser = parser.add_subparsers(dest='command', cmdparser = parser.add_subparsers(dest='command',
@@ -34,59 +155,48 @@ def create_parser() -> argparse.ArgumentParser:
# a parent parser for all commands that support tag selection # a parent parser for all commands that support tag selection
tag_parser = argparse.ArgumentParser(add_help=False) tag_parser = argparse.ArgumentParser(add_help=False)
tag_arg = tag_parser.add_argument('-t', '--or-tags', nargs='*', tag_arg = tag_parser.add_argument('-t', '--tags', nargs='+',
help='List of tags (one must match)', metavar='OTAGS') help='List of tag names (one must match)', metavar='TAGS')
tag_arg.completer = tags_completer # type: ignore tag_arg.completer = tags_completer # type: ignore
atag_arg = tag_parser.add_argument('-k', '--and-tags', nargs='*', atag_arg = tag_parser.add_argument('-a', '--atags', nargs='+',
help='List of tags (all must match)', metavar='ATAGS') help='List of tag names (all must match)', metavar='TAGS')
atag_arg.completer = tags_completer # type: ignore atag_arg.completer = tags_completer # type: ignore
etag_arg = tag_parser.add_argument('-x', '--exclude-tags', nargs='*', etag_arg = tag_parser.add_argument('-e', '--etags', nargs='+',
help='List of tags to exclude', metavar='XTAGS') help='List of tag names to exclude', metavar='ETAGS')
etag_arg.completer = tags_completer # type: ignore etag_arg.completer = tags_completer # type: ignore
otag_arg = tag_parser.add_argument('-o', '--output-tags', nargs='+', otag_arg = tag_parser.add_argument('-o', '--output-tags', nargs='+',
help='List of output tags (default: use input tags)', metavar='OUTAGS') help='List of output tag names, default is input', metavar='OTAGS')
otag_arg.completer = tags_completer # type: ignore otag_arg.completer = tags_completer # type: ignore
# a parent parser for all commands that support AI configuration # 'ask' command parser
ai_parser = argparse.ArgumentParser(add_help=False) ask_cmd_parser = cmdparser.add_parser('ask', parents=[tag_parser],
ai_parser.add_argument('-A', '--AI', help='AI ID to use', metavar='AI_ID') help="Ask a question.",
ai_parser.add_argument('-M', '--model', help='Model to use', metavar='MODEL') aliases=['a'])
ai_parser.add_argument('-n', '--num-answers', help='Number of answers to request', type=int, default=1) ask_cmd_parser.set_defaults(func=ask_cmd)
ai_parser.add_argument('-m', '--max-tokens', help='Max. nr. of tokens', type=int) ask_cmd_parser.add_argument('-q', '--question', nargs='+', help='Question to ask',
ai_parser.add_argument('-T', '--temperature', help='Temperature value', type=float) required=True)
ask_cmd_parser.add_argument('-m', '--max-tokens', help='Max tokens to use', type=int)
# 'question' command parser ask_cmd_parser.add_argument('-T', '--temperature', help='Temperature to use', type=float)
question_cmd_parser = cmdparser.add_parser('question', parents=[tag_parser, ai_parser], ask_cmd_parser.add_argument('-M', '--model', help='Model to use')
help="ask, create and process questions.", ask_cmd_parser.add_argument('-n', '--number', help='Number of answers to produce', type=int,
aliases=['q']) default=1)
question_cmd_parser.set_defaults(func=question_cmd) ask_cmd_parser.add_argument('-s', '--source', nargs='+', help='Source add content of a file to the query')
question_group = question_cmd_parser.add_mutually_exclusive_group(required=True) ask_cmd_parser.add_argument('-S', '--source-code-only', help='Add pure source code to the chat history',
question_group.add_argument('-a', '--ask', nargs='+', help='Ask a question', metavar='QUESTION')
question_group.add_argument('-c', '--create', nargs='+', help='Create a question', metavar='QUESTION')
question_group.add_argument('-r', '--repeat', nargs='*', help='Repeat a question', metavar='MESSAGE')
question_group.add_argument('-p', '--process', nargs='*', help='Process existing questions', metavar='MESSAGE')
question_cmd_parser.add_argument('-O', '--overwrite', help='Overwrite existing messages when repeating them',
action='store_true') action='store_true')
question_cmd_parser.add_argument('-s', '--source-text', nargs='+', help='Add content of a file to the query', metavar='FILE')
question_cmd_parser.add_argument('-S', '--source-code', nargs='+', help='Add source code file content to the chat history',
metavar='FILE')
# 'hist' command parser # 'hist' command parser
hist_cmd_parser = cmdparser.add_parser('hist', parents=[tag_parser], hist_cmd_parser = cmdparser.add_parser('hist', parents=[tag_parser],
help="Print and manage chat history.", help="Print chat history.",
aliases=['h']) aliases=['h'])
hist_cmd_parser.set_defaults(func=hist_cmd) hist_cmd_parser.set_defaults(func=hist_cmd)
hist_group = hist_cmd_parser.add_mutually_exclusive_group(required=True) hist_cmd_parser.add_argument('-w', '--with-tags', help="Print chat history with tags.",
hist_group.add_argument('-p', '--print', help='Print the DB chat history', action='store_true')
hist_group.add_argument('-c', '--convert', help='Convert all message files to the given format [txt|yaml]', metavar='FORMAT')
hist_cmd_parser.add_argument('-w', '--with-metadata', help="Print chat history with metadata (tags, filename, AI, etc.).",
action='store_true') action='store_true')
hist_cmd_parser.add_argument('-S', '--source-code-only', help='Only print embedded source code', hist_cmd_parser.add_argument('-W', '--with-files', help="Print chat history with filenames.",
action='store_true') action='store_true')
hist_cmd_parser.add_argument('-A', '--answer', help='Print only answers with given substring', metavar='SUBSTRING') hist_cmd_parser.add_argument('-S', '--source-code-only', help='Print only source code',
hist_cmd_parser.add_argument('-Q', '--question', help='Print only questions with given substring', metavar='SUBSTRING') action='store_true')
hist_cmd_parser.add_argument('-d', '--tight', help='Print without message separators', action='store_true') hist_cmd_parser.add_argument('-A', '--answer', help='Search for answer substring')
hist_cmd_parser.add_argument('-P', '--no-paging', help='Print without paging', action='store_true') hist_cmd_parser.add_argument('-Q', '--question', help='Search for question substring')
# 'tags' command parser # 'tags' command parser
tags_cmd_parser = cmdparser.add_parser('tags', tags_cmd_parser = cmdparser.add_parser('tags',
@@ -96,34 +206,29 @@ def create_parser() -> argparse.ArgumentParser:
tags_group = tags_cmd_parser.add_mutually_exclusive_group(required=True) tags_group = tags_cmd_parser.add_mutually_exclusive_group(required=True)
tags_group.add_argument('-l', '--list', help="List all tags and their frequency", tags_group.add_argument('-l', '--list', help="List all tags and their frequency",
action='store_true') action='store_true')
tags_cmd_parser.add_argument('-p', '--prefix', help="Filter tags by prefix", metavar='PREFIX') tags_cmd_parser.add_argument('-p', '--prefix', help="Filter tags by prefix")
tags_cmd_parser.add_argument('-c', '--contain', help="Filter tags by contained substring", metavar='SUBSTRING') tags_cmd_parser.add_argument('-c', '--contain', help="Filter tags by contained substring")
# 'config' command parser # 'config' command parser
config_cmd_parser = cmdparser.add_parser('config', config_cmd_parser = cmdparser.add_parser('config',
help="Manage configuration", help="Manage configuration",
aliases=['c']) aliases=['c'])
config_cmd_parser.set_defaults(func=config_cmd) config_cmd_parser.set_defaults(func=config_cmd)
config_cmd_parser.add_argument('-A', '--AI', help='AI ID to use')
config_group = config_cmd_parser.add_mutually_exclusive_group(required=True) config_group = config_cmd_parser.add_mutually_exclusive_group(required=True)
config_group.add_argument('-l', '--list-models', help="List all available models", config_group.add_argument('-l', '--list-models', help="List all available models",
action='store_true') action='store_true')
config_group.add_argument('-m', '--print-model', help="Print the currently configured model", config_group.add_argument('-m', '--print-model', help="Print the currently configured model",
action='store_true') action='store_true')
config_group.add_argument('-c', '--create', help="Create config with default settings in the given file", metavar='FILE') config_group.add_argument('-M', '--model', help="Set model in the config file")
# 'print' command parser # 'print' command parser
print_cmd_parser = cmdparser.add_parser('print', print_cmd_parser = cmdparser.add_parser('print',
help="Print message files.", help="Print files.",
aliases=['p']) aliases=['p'])
print_cmd_parser.set_defaults(func=print_cmd) print_cmd_parser.set_defaults(func=print_cmd)
print_group = print_cmd_parser.add_mutually_exclusive_group(required=True) print_cmd_parser.add_argument('-f', '--file', help='File to print', required=True)
print_group.add_argument('-f', '--file', help='Print given message file', metavar='FILE') print_cmd_parser.add_argument('-S', '--source-code-only', help='Print only source code',
print_group.add_argument('-l', '--latest', help='Print latest message', action='store_true') action='store_true')
print_cmd_modes = print_cmd_parser.add_mutually_exclusive_group()
print_cmd_modes.add_argument('-q', '--question', help='Only print the question', action='store_true')
print_cmd_modes.add_argument('-a', '--answer', help='Only print the answer', action='store_true')
print_cmd_modes.add_argument('-S', '--only-source-code', help='Only print embedded source code', action='store_true')
argcomplete.autocomplete(parser) argcomplete.autocomplete(parser)
return parser return parser
@@ -133,11 +238,10 @@ def main() -> int:
parser = create_parser() parser = create_parser()
args = parser.parse_args() args = parser.parse_args()
command = parser.parse_args() command = parser.parse_args()
if command.func == config_cmd:
command.func(command)
else:
config = Config.from_file(args.config) config = Config.from_file(args.config)
openai_api_key(config.openai.api_key)
command.func(command, config) command.func(command, config)
return 0 return 0
+43 -118
View File
@@ -3,10 +3,7 @@ Module implementing message related functions and classes.
""" """
import pathlib import pathlib
import yaml import yaml
import tempfile from typing import Type, TypeVar, ClassVar, Optional, Any, Union, Final, Literal, Iterable
import shutil
from typing import Type, TypeVar, ClassVar, Optional, Any, Union, Final, Literal, Iterable, Tuple
from typing import get_args as typing_get_args
from dataclasses import dataclass, asdict, field from dataclasses import dataclass, asdict, field
from .tags import Tag, TagLine, TagError, match_tags, rename_tags from .tags import Tag, TagLine, TagError, match_tags, rename_tags
@@ -16,9 +13,6 @@ MessageInst = TypeVar('MessageInst', bound='Message')
AILineInst = TypeVar('AILineInst', bound='AILine') AILineInst = TypeVar('AILineInst', bound='AILine')
ModelLineInst = TypeVar('ModelLineInst', bound='ModelLine') ModelLineInst = TypeVar('ModelLineInst', bound='ModelLine')
YamlDict = dict[str, Union[QuestionInst, AnswerInst, set[Tag]]] YamlDict = dict[str, Union[QuestionInst, AnswerInst, set[Tag]]]
MessageFormat = Literal['txt', 'yaml']
message_valid_formats: Final[Tuple[MessageFormat, ...]] = typing_get_args(MessageFormat)
message_default_format: Final[MessageFormat] = 'txt'
class MessageError(Exception): class MessageError(Exception):
@@ -96,7 +90,7 @@ class MessageFilter:
class AILine(str): class AILine(str):
""" """
A line that represents the AI name in the 'txt' format. A line that represents the AI name in a '.txt' file..
""" """
prefix: Final[str] = 'AI:' prefix: Final[str] = 'AI:'
@@ -116,7 +110,7 @@ class AILine(str):
class ModelLine(str): class ModelLine(str):
""" """
A line that represents the model name in the 'txt' format. A line that represents the model name in a '.txt' file..
""" """
prefix: Final[str] = 'MODEL:' prefix: Final[str] = 'MODEL:'
@@ -220,44 +214,18 @@ class Message():
model: Optional[str] = field(default=None, compare=False) model: Optional[str] = field(default=None, compare=False)
file_path: Optional[pathlib.Path] = field(default=None, compare=False) file_path: Optional[pathlib.Path] = field(default=None, compare=False)
# class variables # class variables
file_suffixes_read: ClassVar[list[str]] = ['.msg', '.txt', '.yaml'] file_suffixes: ClassVar[list[str]] = ['.txt', '.yaml']
file_suffix_write: ClassVar[str] = '.msg'
default_format: ClassVar[MessageFormat] = message_default_format
tags_yaml_key: ClassVar[str] = 'tags' tags_yaml_key: ClassVar[str] = 'tags'
file_yaml_key: ClassVar[str] = 'file_path' file_yaml_key: ClassVar[str] = 'file_path'
ai_yaml_key: ClassVar[str] = 'ai' ai_yaml_key: ClassVar[str] = 'ai'
model_yaml_key: ClassVar[str] = 'model' model_yaml_key: ClassVar[str] = 'model'
def __post_init__(self) -> None:
# convert some types that are often set wrong
if self.tags is not None and not isinstance(self.tags, set):
self.tags = set(self.tags)
if self.file_path is not None and not isinstance(self.file_path, pathlib.Path):
self.file_path = pathlib.Path(self.file_path)
def __hash__(self) -> int: def __hash__(self) -> int:
""" """
The hash value is computed based on immutable members. The hash value is computed based on immutable members.
""" """
return hash((self.question, self.answer)) return hash((self.question, self.answer))
def equals(self, other: MessageInst, tags: bool = True, ai: bool = True,
model: bool = True, file_path: bool = True, verbose: bool = False) -> bool:
"""
Compare this message with another one, including the metadata.
Return True if everything is identical, False otherwise.
"""
equal: bool = ((not tags or (self.tags == other.tags))
and (not ai or (self.ai == other.ai)) # noqa: W503
and (not model or (self.model == other.model)) # noqa: W503
and (not file_path or (self.file_path == other.file_path)) # noqa: W503
and (self == other)) # noqa: W503
if not equal and verbose:
print("Messages not equal:")
print(self)
print(other)
return equal
@classmethod @classmethod
def from_dict(cls: Type[MessageInst], data: dict[str, Any]) -> MessageInst: def from_dict(cls: Type[MessageInst], data: dict[str, Any]) -> MessageInst:
""" """
@@ -282,8 +250,16 @@ class Message():
tags: set[Tag] = set() tags: set[Tag] = set()
if not file_path.exists(): if not file_path.exists():
raise MessageError(f"Message file '{file_path}' does not exist") raise MessageError(f"Message file '{file_path}' does not exist")
if file_path.suffix not in cls.file_suffixes_read: if file_path.suffix not in cls.file_suffixes:
raise MessageError(f"File type '{file_path.suffix}' is not supported") raise MessageError(f"File type '{file_path.suffix}' is not supported")
# for TXT, it's enough to read the TagLine
if file_path.suffix == '.txt':
with open(file_path, "r") as fd:
try:
tags = TagLine(fd.readline()).tags(prefix, contain)
except TagError:
pass # message without tags
else: # '.yaml'
try: try:
message = cls.from_file(file_path) message = cls.from_file(file_path)
if message: if message:
@@ -326,18 +302,17 @@ class Message():
""" """
if not file_path.exists(): if not file_path.exists():
raise MessageError(f"Message file '{file_path}' does not exist") raise MessageError(f"Message file '{file_path}' does not exist")
if file_path.suffix not in cls.file_suffixes_read: if file_path.suffix not in cls.file_suffixes:
raise MessageError(f"File type '{file_path.suffix}' is not supported") raise MessageError(f"File type '{file_path.suffix}' is not supported")
# try TXT first
try: if file_path.suffix == '.txt':
message = cls.__from_file_txt(file_path, message = cls.__from_file_txt(file_path,
mfilter.tags_or if mfilter else None, mfilter.tags_or if mfilter else None,
mfilter.tags_and if mfilter else None, mfilter.tags_and if mfilter else None,
mfilter.tags_not if mfilter else None) mfilter.tags_not if mfilter else None)
# then YAML else:
except MessageError:
message = cls.__from_file_yaml(file_path) message = cls.__from_file_yaml(file_path)
if message and (mfilter is None or message.match(mfilter)): if message and (not mfilter or (mfilter and message.match(mfilter))):
return message return message
else: else:
return None return None
@@ -372,6 +347,10 @@ class Message():
tags = TagLine(fd.readline()).tags() tags = TagLine(fd.readline()).tags()
except TagError: except TagError:
fd.seek(pos) fd.seek(pos)
if tags_or or tags_and or tags_not:
# match with an empty set if the file has no tags
if not match_tags(tags, tags_or, tags_and, tags_not):
return None
# AILine (Optional) # AILine (Optional)
try: try:
pos = fd.tell() pos = fd.tell()
@@ -389,19 +368,13 @@ class Message():
try: try:
question_idx = text.index(Question.txt_header) + 1 question_idx = text.index(Question.txt_header) + 1
except ValueError: except ValueError:
raise MessageError(f"'{file_path}' does not contain a valid message") raise MessageError(f"Question header '{Question.txt_header}' not found in '{file_path}'")
try: try:
answer_idx = text.index(Answer.txt_header) answer_idx = text.index(Answer.txt_header)
question = Question.from_list(text[question_idx:answer_idx]) question = Question.from_list(text[question_idx:answer_idx])
answer = Answer.from_list(text[answer_idx + 1:]) answer = Answer.from_list(text[answer_idx + 1:])
except ValueError: except ValueError:
question = Question.from_list(text[question_idx:]) question = Question.from_list(text[question_idx:])
# match tags AFTER reading the whole file
# -> make sure it's a valid 'txt' file format
if tags_or or tags_and or tags_not:
# match with an empty set if the file has no tags
if not match_tags(tags, tags_or, tags_and, tags_not):
return None
return cls(question, answer, tags, ai, model, file_path) return cls(question, answer, tags, ai, model, file_path)
@classmethod @classmethod
@@ -415,58 +388,25 @@ class Message():
* Message.model_yaml_key: str [Optional] * Message.model_yaml_key: str [Optional]
""" """
with open(file_path, "r") as fd: with open(file_path, "r") as fd:
try:
data = yaml.load(fd, Loader=yaml.FullLoader) data = yaml.load(fd, Loader=yaml.FullLoader)
data[cls.file_yaml_key] = file_path data[cls.file_yaml_key] = file_path
return cls.from_dict(data) return cls.from_dict(data)
except Exception:
raise MessageError(f"'{file_path}' does not contain a valid message")
def to_str(self, with_metadata: bool = False, source_code_only: bool = False) -> str: def to_file(self, file_path: Optional[pathlib.Path]=None) -> None: # noqa: 11
""" """
Return the current Message as a string. Write a Message to the given file. Type is determined based on the suffix.
""" Currently supported suffixes: ['.txt', '.yaml']
output: list[str] = []
if source_code_only:
# use the source code from answer only
if self.answer:
output.extend(self.answer.source_code(include_delims=True))
return '\n'.join(output) if len(output) > 0 else ''
if with_metadata:
output.append(self.tags_str())
output.append('FILE: ' + str(self.file_path))
output.append('AI: ' + str(self.ai))
output.append('MODEL: ' + str(self.model))
output.append(Question.txt_header)
output.append(self.question)
if self.answer:
output.append(Answer.txt_header)
output.append(self.answer)
return '\n'.join(output)
def to_file(self, file_path: Optional[pathlib.Path]=None, mformat: MessageFormat = message_default_format) -> None: # noqa: 11
"""
Write a Message to the given file. Supported message file formats are 'txt' and 'yaml'.
Suffix is always '.msg'.
""" """
if file_path: if file_path:
self.file_path = file_path self.file_path = file_path
if not self.file_path: if not self.file_path:
raise MessageError("Got no valid path to write message") raise MessageError("Got no valid path to write message")
if mformat not in message_valid_formats: if self.file_path.suffix not in self.file_suffixes:
raise MessageError(f"File format '{mformat}' is not supported") raise MessageError(f"File type '{self.file_path.suffix}' is not supported")
# check for valid suffix
# -> add one if it's empty
# -> refuse old or otherwise unsupported suffixes
if not self.file_path.suffix:
self.file_path = self.file_path.with_suffix(self.file_suffix_write)
elif self.file_path.suffix != self.file_suffix_write:
raise MessageError(f"File suffix '{self.file_path.suffix}' is not supported")
# TXT # TXT
if mformat == 'txt': if self.file_path.suffix == '.txt':
return self.__to_file_txt(self.file_path) return self.__to_file_txt(self.file_path)
# YAML elif self.file_path.suffix == '.yaml':
elif mformat == 'yaml':
return self.__to_file_yaml(self.file_path) return self.__to_file_yaml(self.file_path)
def __to_file_txt(self, file_path: pathlib.Path) -> None: def __to_file_txt(self, file_path: pathlib.Path) -> None:
@@ -478,21 +418,19 @@ class Message():
* Model [Optional] * Model [Optional]
* Question.txt_header * Question.txt_header
* Question * Question
* Answer.txt_header [Optional] * Answer.txt_header
* Answer [Optional] * Answer
""" """
with tempfile.NamedTemporaryFile(dir=file_path.parent, prefix=file_path.name, mode="w", delete=False) as temp_fd: with open(file_path, "w") as fd:
temp_file_path = pathlib.Path(temp_fd.name)
if self.tags: if self.tags:
temp_fd.write(f'{TagLine.from_set(self.tags)}\n') fd.write(f'{TagLine.from_set(self.tags)}\n')
if self.ai: if self.ai:
temp_fd.write(f'{AILine.from_ai(self.ai)}\n') fd.write(f'{AILine.from_ai(self.ai)}\n')
if self.model: if self.model:
temp_fd.write(f'{ModelLine.from_model(self.model)}\n') fd.write(f'{ModelLine.from_model(self.model)}\n')
temp_fd.write(f'{Question.txt_header}\n{self.question}\n') fd.write(f'{Question.txt_header}\n{self.question}\n')
if self.answer: if self.answer:
temp_fd.write(f'{Answer.txt_header}\n{self.answer}\n') fd.write(f'{Answer.txt_header}\n{self.answer}\n')
shutil.move(temp_file_path, file_path)
def __to_file_yaml(self, file_path: pathlib.Path) -> None: def __to_file_yaml(self, file_path: pathlib.Path) -> None:
""" """
@@ -504,8 +442,7 @@ class Message():
* Message.ai_yaml_key: str [Optional] * Message.ai_yaml_key: str [Optional]
* Message.model_yaml_key: str [Optional] * Message.model_yaml_key: str [Optional]
""" """
with tempfile.NamedTemporaryFile(dir=file_path.parent, prefix=file_path.name, mode="w", delete=False) as temp_fd: with open(file_path, "w") as fd:
temp_file_path = pathlib.Path(temp_fd.name)
data: YamlDict = {Question.yaml_key: str(self.question)} data: YamlDict = {Question.yaml_key: str(self.question)}
if self.answer: if self.answer:
data[Answer.yaml_key] = str(self.answer) data[Answer.yaml_key] = str(self.answer)
@@ -515,15 +452,7 @@ class Message():
data[self.model_yaml_key] = self.model data[self.model_yaml_key] = self.model
if self.tags: if self.tags:
data[self.tags_yaml_key] = sorted([str(tag) for tag in self.tags]) data[self.tags_yaml_key] = sorted([str(tag) for tag in self.tags])
yaml.dump(data, temp_fd, sort_keys=False) yaml.dump(data, fd, sort_keys=False)
shutil.move(temp_file_path, file_path)
def rm_file(self) -> None:
"""
Delete the message file. Ignore empty file_path and not existing files.
"""
if self.file_path is not None:
self.file_path.unlink(missing_ok=True)
def filter_tags(self, prefix: Optional[str] = None, contain: Optional[str] = None) -> set[Tag]: def filter_tags(self, prefix: Optional[str] = None, contain: Optional[str] = None) -> set[Tag]:
""" """
@@ -555,7 +484,7 @@ class Message():
Return True if all attributes match, else False. Return True if all attributes match, else False.
""" """
mytags = self.tags or set() mytags = self.tags or set()
if (((mfilter.tags_or is not None or mfilter.tags_and is not None or mfilter.tags_not is not None) if (((mfilter.tags_or or mfilter.tags_and or mfilter.tags_not)
and not match_tags(mytags, mfilter.tags_or, mfilter.tags_and, mfilter.tags_not)) # noqa: W503 and not match_tags(mytags, mfilter.tags_or, mfilter.tags_and, mfilter.tags_not)) # noqa: W503
or (mfilter.ai and (not self.ai or mfilter.ai != self.ai)) # noqa: W503 or (mfilter.ai and (not self.ai or mfilter.ai != self.ai)) # noqa: W503
or (mfilter.model and (not self.model or mfilter.model != self.model)) # noqa: W503 or (mfilter.model and (not self.model or mfilter.model != self.model)) # noqa: W503
@@ -578,17 +507,13 @@ class Message():
if self.tags: if self.tags:
self.tags = rename_tags(self.tags, tags_rename) self.tags = rename_tags(self.tags, tags_rename)
def clear_answer(self) -> None:
self.answer = None
def msg_id(self) -> str: def msg_id(self) -> str:
""" """
Returns an ID that is unique throughout all messages in the same (DB) directory. Returns an ID that is unique throughout all messages in the same (DB) directory.
Currently this is the file name without suffix. The ID is also used for sorting Currently this is the file name. The ID is also used for sorting messages.
messages.
""" """
if self.file_path: if self.file_path:
return self.file_path.stem return self.file_path.name
else: else:
raise MessageError("Can't create file ID without a file path") raise MessageError("Can't create file ID without a file path")
+121
View File
@@ -0,0 +1,121 @@
import yaml
import io
import pathlib
from .utils import terminal_width, append_message, message_to_chat, ChatType
from .configuration import Config
from typing import Any, Optional
def read_file(fname: pathlib.Path, tags_only: bool = False) -> dict[str, Any]:
with open(fname, "r") as fd:
tagline = fd.readline().strip().split(':', maxsplit=1)[1].strip()
# also support tags separated by ',' (old format)
separator = ',' if ',' in tagline else ' '
tags = [t.strip() for t in tagline.split(separator)]
if tags_only:
return {"tags": tags}
text = fd.read().strip().split('\n')
question_idx = text.index("=== QUESTION ===") + 1
answer_idx = text.index("==== ANSWER ====")
question = "\n".join(text[question_idx:answer_idx]).strip()
answer = "\n".join(text[answer_idx + 1:]).strip()
return {"question": question, "answer": answer, "tags": tags,
"file": fname.name}
def dump_data(data: dict[str, Any]) -> str:
with io.StringIO() as fd:
fd.write(f'TAGS: {" ".join(data["tags"])}\n')
fd.write(f'=== QUESTION ===\n{data["question"]}\n')
fd.write(f'==== ANSWER ====\n{data["answer"]}\n')
return fd.getvalue()
def write_file(fname: str, data: dict[str, Any]) -> None:
with open(fname, "w") as fd:
fd.write(f'TAGS: {" ".join(data["tags"])}\n')
fd.write(f'=== QUESTION ===\n{data["question"]}\n')
fd.write(f'==== ANSWER ====\n{data["answer"]}\n')
def save_answers(question: str,
answers: list[str],
tags: list[str],
otags: Optional[list[str]],
config: Config
) -> None:
wtags = otags or tags
num, inum = 0, 0
next_fname = pathlib.Path(str(config.db)) / '.next'
try:
with open(next_fname, 'r') as f:
num = int(f.read())
except Exception:
pass
for answer in answers:
num += 1
inum += 1
title = f'-- ANSWER {inum} '
title_end = '-' * (terminal_width() - len(title))
print(f'{title}{title_end}')
print(answer)
write_file(f"{num:04d}.txt", {"question": question, "answer": answer, "tags": wtags})
with open(next_fname, 'w') as f:
f.write(f'{num}')
def create_chat_hist(question: Optional[str],
tags: Optional[list[str]],
extags: Optional[list[str]],
config: Config,
match_all_tags: bool = False,
with_tags: bool = False,
with_file: bool = False
) -> ChatType:
chat: ChatType = []
append_message(chat, 'system', str(config.system).strip())
for file in sorted(pathlib.Path(str(config.db)).iterdir()):
if file.suffix == '.yaml':
with open(file, 'r') as f:
data = yaml.load(f, Loader=yaml.FullLoader)
data['file'] = file.name
elif file.suffix == '.txt':
data = read_file(file)
else:
continue
data_tags = set(data.get('tags', []))
tags_match: bool
if match_all_tags:
tags_match = not tags or set(tags).issubset(data_tags)
else:
tags_match = not tags or bool(data_tags.intersection(tags))
extags_do_not_match = \
not extags or not data_tags.intersection(extags)
if tags_match and extags_do_not_match:
message_to_chat(data, chat, with_tags, with_file)
if question:
append_message(chat, 'user', question)
return chat
def get_tags(config: Config, prefix: Optional[str]) -> list[str]:
result = []
for file in sorted(pathlib.Path(str(config.db)).iterdir()):
if file.suffix == '.yaml':
with open(file, 'r') as f:
data = yaml.load(f, Loader=yaml.FullLoader)
elif file.suffix == '.txt':
data = read_file(file, tags_only=True)
else:
continue
for tag in data.get('tags', []):
if prefix and len(prefix) > 0:
if tag.startswith(prefix):
result.append(tag)
else:
result.append(tag)
return result
def get_tags_unique(config: Config, prefix: Optional[str]) -> list[str]:
return list(set(get_tags(config, prefix)))
+80
View File
@@ -0,0 +1,80 @@
import shutil
from pprint import PrettyPrinter
from typing import Any
ChatType = list[dict[str, str]]
def terminal_width() -> int:
return shutil.get_terminal_size().columns
def pp(*args: Any, **kwargs: Any) -> None:
return PrettyPrinter(width=terminal_width()).pprint(*args, **kwargs)
def print_tag_args(tags: list[str], extags: list[str], otags: list[str]) -> None:
"""
Prints the tags specified in the given args.
"""
printed_messages = []
if tags:
printed_messages.append(f"Tags: {' '.join(tags)}")
if extags:
printed_messages.append(f"Excluding tags: {' '.join(extags)}")
if otags:
printed_messages.append(f"Output tags: {' '.join(otags)}")
if printed_messages:
print("\n".join(printed_messages))
print()
def append_message(chat: ChatType,
role: str,
content: str
) -> None:
chat.append({'role': role, 'content': content.replace("''", "'")})
def message_to_chat(message: dict[str, str],
chat: ChatType,
with_tags: bool = False,
with_file: bool = False
) -> None:
append_message(chat, 'user', message['question'])
append_message(chat, 'assistant', message['answer'])
if with_tags:
tags = " ".join(message['tags'])
append_message(chat, 'tags', tags)
if with_file:
append_message(chat, 'file', message['file'])
def display_source_code(content: str) -> None:
try:
content_start = content.index('```')
content_end = content.rindex('```')
if content_start + 3 < content_end:
print(content[content_start + 3:content_end].strip())
except ValueError:
pass
def print_chat_hist(chat: ChatType, dump: bool = False, source_code: bool = False) -> None:
if dump:
pp(chat)
return
for message in chat:
text_too_long = len(message['content']) > terminal_width() - len(message['role']) - 2
if source_code:
display_source_code(message['content'])
continue
if message['role'] == 'user':
print('-' * terminal_width())
if text_too_long:
print(f"{message['role'].upper()}:")
print(message['content'])
else:
print(f"{message['role'].upper()}: {message['content']}")
-56
View File
@@ -1,56 +0,0 @@
<?php
$secret_key = '123';
// check for POST request
if ($_SERVER['REQUEST_METHOD'] != 'POST') {
error_log('FAILED - not POST - '. $_SERVER['REQUEST_METHOD']);
exit();
}
// get content type
$content_type = isset($_SERVER['CONTENT_TYPE']) ? strtolower(trim($_SERVER['CONTENT_TYPE'])) : '';
if ($content_type != 'application/json') {
error_log('FAILED - not application/json - '. $content_type);
exit();
}
// get payload
$payload = trim(file_get_contents("php://input"));
if (empty($payload)) {
error_log('FAILED - no payload');
exit();
}
// get header signature
$header_signature = isset($_SERVER['HTTP_X_GITEA_SIGNATURE']) ? $_SERVER['HTTP_X_GITEA_SIGNATURE'] : '';
if (empty($header_signature)) {
error_log('FAILED - header signature missing');
exit();
}
// calculate payload signature
$payload_signature = hash_hmac('sha256', $payload, $secret_key, false);
// check payload signature against header signature
if ($header_signature !== $payload_signature) {
error_log('FAILED - payload signature');
exit();
}
// convert json to array
$decoded = json_decode($payload, true);
// check for json decode errors
if (json_last_error() !== JSON_ERROR_NONE) {
error_log('FAILED - json decode - '. json_last_error());
exit();
}
// success, do something
$output = shell_exec('/home/kaizen/repos/ChatMastermind/hooks/push_hook.sh');
echo "$output";
?>
-8
View File
@@ -1,8 +0,0 @@
#!/usr/bin/bash
. /home/kaizen/.bashrc
set -e
cd /home/kaizen/repos/ChatMastermind
git pull
pre-commit run -a
pytest
-1
View File
@@ -2,4 +2,3 @@ openai
PyYAML PyYAML
argcomplete argcomplete
pytest pytest
Jinja2
+7 -4
View File
@@ -2,8 +2,6 @@ from setuptools import setup, find_packages
with open("README.md", "r", encoding="utf-8") as fh: with open("README.md", "r", encoding="utf-8") as fh:
long_description = fh.read() long_description = fh.read()
with open("requirements.txt", "r") as fh:
install_requirements = [line.strip() for line in fh]
setup( setup(
name="ChatMastermind", name="ChatMastermind",
@@ -14,7 +12,7 @@ setup(
long_description=long_description, long_description=long_description,
long_description_content_type="text/markdown", long_description_content_type="text/markdown",
url="https://github.com/ok2/ChatMastermind", url="https://github.com/ok2/ChatMastermind",
packages=find_packages() + ["chatmastermind.ais", "chatmastermind.commands"], packages=find_packages(),
classifiers=[ classifiers=[
"Development Status :: 3 - Alpha", "Development Status :: 3 - Alpha",
"Environment :: Console", "Environment :: Console",
@@ -30,7 +28,12 @@ setup(
"Topic :: Utilities", "Topic :: Utilities",
"Topic :: Text Processing", "Topic :: Text Processing",
], ],
install_requires=install_requirements, install_requires=[
"openai",
"PyYAML",
"argcomplete",
"pytest"
],
python_requires=">=3.9", python_requires=">=3.9",
test_suite="tests", test_suite="tests",
entry_points={ entry_points={
-48
View File
@@ -1,48 +0,0 @@
import argparse
import unittest
from unittest.mock import MagicMock
from chatmastermind.ai_factory import create_ai
from chatmastermind.configuration import Config
from chatmastermind.ai import AIError
from chatmastermind.ais.openai import OpenAI
class TestCreateAI(unittest.TestCase):
def setUp(self) -> None:
self.args = MagicMock(spec=argparse.Namespace)
self.args.AI = 'myopenai'
self.args.model = None
self.args.max_tokens = None
self.args.temperature = None
def test_create_ai_from_args(self) -> None:
# Create an AI with the default configuration
config = Config()
self.args.AI = 'myopenai'
ai = create_ai(self.args, config)
self.assertIsInstance(ai, OpenAI)
def test_create_ai_from_default(self) -> None:
self.args.AI = None
# Create an AI with the default configuration
config = Config()
ai = create_ai(self.args, config)
self.assertIsInstance(ai, OpenAI)
def test_create_empty_ai_error(self) -> None:
self.args.AI = None
# Create Config with empty AIs
config = Config()
config.ais = {}
# Call create_ai function and assert that it raises AIError
with self.assertRaises(AIError):
create_ai(self.args, config)
def test_create_unsupported_ai_error(self) -> None:
# Mock argparse.Namespace with ai='invalid_ai'
self.args.AI = 'invalid_ai'
# Create default Config
config = Config()
# Call create_ai function and assert that it raises AIError
with self.assertRaises(AIError):
create_ai(self.args, config)
-81
View File
@@ -1,81 +0,0 @@
import unittest
from unittest import mock
from chatmastermind.ais.openai import OpenAI
from chatmastermind.message import Message, Question, Answer
from chatmastermind.chat import Chat
from chatmastermind.ai import AIResponse, Tokens
from chatmastermind.configuration import OpenAIConfig
class OpenAITest(unittest.TestCase):
@mock.patch('openai.ChatCompletion.create')
def test_request(self, mock_create: mock.MagicMock) -> None:
# Create a test instance of OpenAI
config = OpenAIConfig()
openai = OpenAI(config)
# Set up the mock response from openai.ChatCompletion.create
mock_response = {
'choices': [
{
'message': {
'content': 'Answer 1'
}
},
{
'message': {
'content': 'Answer 2'
}
}
],
'usage': {
'prompt_tokens': 10,
'completion_tokens': 20,
'total_tokens': 30
}
}
mock_create.return_value = mock_response
# Create test data
question = Message(Question('Question'))
chat = Chat([
Message(Question('Question 1'), answer=Answer('Answer 1')),
Message(Question('Question 2'), answer=Answer('Answer 2')),
# add message without an answer -> expect to be skipped
Message(Question('Question 3'))
])
# Make the request
response = openai.request(question, chat, num_answers=2)
# Assert the AIResponse
self.assertIsInstance(response, AIResponse)
self.assertEqual(len(response.messages), 2)
self.assertEqual(response.messages[0].answer, 'Answer 1')
self.assertEqual(response.messages[1].answer, 'Answer 2')
self.assertIsNotNone(response.tokens)
self.assertIsInstance(response.tokens, Tokens)
assert response.tokens
self.assertEqual(response.tokens.prompt, 10)
self.assertEqual(response.tokens.completion, 20)
self.assertEqual(response.tokens.total, 30)
# Assert the mock call to openai.ChatCompletion.create
mock_create.assert_called_once_with(
model=f'{config.model}',
messages=[
{'role': 'system', 'content': f'{config.system}'},
{'role': 'user', 'content': 'Question 1'},
{'role': 'assistant', 'content': 'Answer 1'},
{'role': 'user', 'content': 'Question 2'},
{'role': 'assistant', 'content': 'Answer 2'},
{'role': 'user', 'content': 'Question'}
],
temperature=config.temperature,
max_tokens=config.max_tokens,
top_p=config.top_p,
n=2,
frequency_penalty=config.frequency_penalty,
presence_penalty=config.presence_penalty
)
+127 -427
View File
@@ -1,157 +1,77 @@
import unittest
import pathlib import pathlib
import tempfile import tempfile
import time import time
import yaml
from io import StringIO from io import StringIO
from unittest.mock import patch from unittest.mock import patch
from chatmastermind.tags import TagLine from chatmastermind.tags import TagLine
from chatmastermind.message import Message, Question, Answer, Tag, MessageFilter from chatmastermind.message import Message, Question, Answer, Tag, MessageFilter
from chatmastermind.chat import Chat, ChatDB, ChatError from chatmastermind.chat import Chat, ChatDB, terminal_width
from .test_main import CmmTestCase
msg_suffix: str = Message.file_suffix_write class TestChat(CmmTestCase):
def msg_to_file_force_suffix(msg: Message) -> None:
"""
Force writing a message file with illegal suffixes.
"""
def_suffix = Message.file_suffix_write
assert msg.file_path
Message.file_suffix_write = msg.file_path.suffix
msg.to_file()
Message.file_suffix_write = def_suffix
class TestChatBase(unittest.TestCase):
def assert_messages_equal(self, msg1: list[Message], msg2: list[Message]) -> None:
"""
Compare messages using more than just Question and Answer.
"""
self.assertEqual(len(msg1), len(msg2))
for m1, m2 in zip(msg1, msg2):
# exclude the file_path, compare only Q, A and metadata
self.assertTrue(m1.equals(m2, file_path=False, verbose=True))
class TestChat(TestChatBase):
def setUp(self) -> None: def setUp(self) -> None:
self.chat = Chat([]) self.chat = Chat([])
self.message1 = Message(Question('Question 1'), self.message1 = Message(Question('Question 1'),
Answer('Answer 1'), Answer('Answer 1'),
{Tag('atag1'), Tag('btag2')}, {Tag('atag1'), Tag('btag2')},
ai='FakeAI', file_path=pathlib.Path('0001.txt'))
model='FakeModel',
file_path=pathlib.Path(f'0001{msg_suffix}'))
self.message2 = Message(Question('Question 2'), self.message2 = Message(Question('Question 2'),
Answer('Answer 2'), Answer('Answer 2'),
{Tag('btag2')}, {Tag('btag2')},
ai='FakeAI', file_path=pathlib.Path('0002.txt'))
model='FakeModel',
file_path=pathlib.Path(f'0002{msg_suffix}'))
self.maxDiff = None
def test_unique_id(self) -> None:
# test with two identical messages
self.chat.msg_add([self.message1, self.message1])
self.assert_messages_equal(self.chat.messages, [self.message1, self.message1])
self.chat.msg_unique_id()
self.assert_messages_equal(self.chat.messages, [self.message1])
# test with two different messages
self.chat.msg_add([self.message2])
self.chat.msg_unique_id()
self.assert_messages_equal(self.chat.messages, [self.message1, self.message2])
def test_unique_content(self) -> None:
# test with two identical messages
self.chat.msg_add([self.message1, self.message1])
self.assert_messages_equal(self.chat.messages, [self.message1, self.message1])
self.chat.msg_unique_content()
self.assert_messages_equal(self.chat.messages, [self.message1])
# test with two different messages
self.chat.msg_add([self.message2])
self.chat.msg_unique_content()
self.assert_messages_equal(self.chat.messages, [self.message1, self.message2])
def test_filter(self) -> None: def test_filter(self) -> None:
self.chat.msg_add([self.message1, self.message2]) self.chat.add_messages([self.message1, self.message2])
self.chat.msg_filter(MessageFilter(answer_contains='Answer 1')) self.chat.filter(MessageFilter(answer_contains='Answer 1'))
self.assertEqual(len(self.chat.messages), 1) self.assertEqual(len(self.chat.messages), 1)
self.assertEqual(self.chat.messages[0].question, 'Question 1') self.assertEqual(self.chat.messages[0].question, 'Question 1')
def test_sort(self) -> None: def test_sort(self) -> None:
self.chat.msg_add([self.message2, self.message1]) self.chat.add_messages([self.message2, self.message1])
self.chat.msg_sort() self.chat.sort()
self.assertEqual(self.chat.messages[0].question, 'Question 1') self.assertEqual(self.chat.messages[0].question, 'Question 1')
self.assertEqual(self.chat.messages[1].question, 'Question 2') self.assertEqual(self.chat.messages[1].question, 'Question 2')
self.chat.msg_sort(reverse=True) self.chat.sort(reverse=True)
self.assertEqual(self.chat.messages[0].question, 'Question 2') self.assertEqual(self.chat.messages[0].question, 'Question 2')
self.assertEqual(self.chat.messages[1].question, 'Question 1') self.assertEqual(self.chat.messages[1].question, 'Question 1')
def test_clear(self) -> None: def test_clear(self) -> None:
self.chat.msg_add([self.message1]) self.chat.add_messages([self.message1])
self.chat.msg_clear() self.chat.clear()
self.assertEqual(len(self.chat.messages), 0) self.assertEqual(len(self.chat.messages), 0)
def test_add_messages(self) -> None: def test_add_messages(self) -> None:
self.chat.msg_add([self.message1, self.message2]) self.chat.add_messages([self.message1, self.message2])
self.assertEqual(len(self.chat.messages), 2) self.assertEqual(len(self.chat.messages), 2)
self.assertEqual(self.chat.messages[0].question, 'Question 1') self.assertEqual(self.chat.messages[0].question, 'Question 1')
self.assertEqual(self.chat.messages[1].question, 'Question 2') self.assertEqual(self.chat.messages[1].question, 'Question 2')
def test_tags(self) -> None: def test_tags(self) -> None:
self.chat.msg_add([self.message1, self.message2]) self.chat.add_messages([self.message1, self.message2])
tags_all = self.chat.msg_tags() tags_all = self.chat.tags()
self.assertSetEqual(tags_all, {Tag('atag1'), Tag('btag2')}) self.assertSetEqual(tags_all, {Tag('atag1'), Tag('btag2')})
tags_pref = self.chat.msg_tags(prefix='a') tags_pref = self.chat.tags(prefix='a')
self.assertSetEqual(tags_pref, {Tag('atag1')}) self.assertSetEqual(tags_pref, {Tag('atag1')})
tags_cont = self.chat.msg_tags(contain='2') tags_cont = self.chat.tags(contain='2')
self.assertSetEqual(tags_cont, {Tag('btag2')}) self.assertSetEqual(tags_cont, {Tag('btag2')})
def test_tags_frequency(self) -> None: def test_tags_frequency(self) -> None:
self.chat.msg_add([self.message1, self.message2]) self.chat.add_messages([self.message1, self.message2])
tags_freq = self.chat.msg_tags_frequency() tags_freq = self.chat.tags_frequency()
self.assertDictEqual(tags_freq, {'atag1': 1, 'btag2': 2}) self.assertDictEqual(tags_freq, {'atag1': 1, 'btag2': 2})
def test_find_remove_messages(self) -> None:
self.chat.msg_add([self.message1, self.message2])
msgs = self.chat.msg_find(['0001'])
self.assertListEqual(msgs, [self.message1])
msgs = self.chat.msg_find(['0001', '0002'])
self.assertListEqual(msgs, [self.message1, self.message2])
# add new Message with full path
message3 = Message(Question('Question 2'),
Answer('Answer 2'),
{Tag('btag2')},
file_path=pathlib.Path(f'/foo/bla/0003{msg_suffix}'))
self.chat.msg_add([message3])
# find new Message by full path
msgs = self.chat.msg_find([f'/foo/bla/0003{msg_suffix}'])
self.assertListEqual(msgs, [message3])
# find Message with full path only by filename
msgs = self.chat.msg_find([f'0003{msg_suffix}'])
self.assertListEqual(msgs, [message3])
# remove last message
self.chat.msg_remove(['0003'])
self.assertListEqual(self.chat.messages, [self.message1, self.message2])
def test_latest_message(self) -> None:
self.assertIsNone(self.chat.msg_latest())
self.chat.msg_add([self.message1])
self.assertEqual(self.chat.msg_latest(), self.message1)
self.chat.msg_add([self.message2])
self.assertEqual(self.chat.msg_latest(), self.message2)
@patch('sys.stdout', new_callable=StringIO) @patch('sys.stdout', new_callable=StringIO)
def test_print(self, mock_stdout: StringIO) -> None: def test_print(self, mock_stdout: StringIO) -> None:
self.chat.msg_add([self.message1, self.message2]) self.chat.add_messages([self.message1, self.message2])
self.chat.print(paged=False, tight=True) self.chat.print(paged=False)
expected_output = f"""{Question.txt_header} expected_output = f"""{'-'*terminal_width()}
{Question.txt_header}
Question 1 Question 1
{Answer.txt_header} {Answer.txt_header}
Answer 1 Answer 1
{'-'*terminal_width()}
{Question.txt_header} {Question.txt_header}
Question 2 Question 2
{Answer.txt_header} {Answer.txt_header}
@@ -160,21 +80,19 @@ Answer 2
self.assertEqual(mock_stdout.getvalue(), expected_output) self.assertEqual(mock_stdout.getvalue(), expected_output)
@patch('sys.stdout', new_callable=StringIO) @patch('sys.stdout', new_callable=StringIO)
def test_print_with_metadata(self, mock_stdout: StringIO) -> None: def test_print_with_tags_and_file(self, mock_stdout: StringIO) -> None:
self.chat.msg_add([self.message1, self.message2]) self.chat.add_messages([self.message1, self.message2])
self.chat.print(paged=False, with_metadata=True, tight=True) self.chat.print(paged=False, with_tags=True, with_files=True)
expected_output = f"""{TagLine.prefix} atag1 btag2 expected_output = f"""{'-'*terminal_width()}
FILE: 0001{msg_suffix} {TagLine.prefix} atag1 btag2
AI: FakeAI FILE: 0001.txt
MODEL: FakeModel
{Question.txt_header} {Question.txt_header}
Question 1 Question 1
{Answer.txt_header} {Answer.txt_header}
Answer 1 Answer 1
{'-'*terminal_width()}
{TagLine.prefix} btag2 {TagLine.prefix} btag2
FILE: 0002{msg_suffix} FILE: 0002.txt
AI: FakeAI
MODEL: FakeModel
{Question.txt_header} {Question.txt_header}
Question 2 Question 2
{Answer.txt_header} {Answer.txt_header}
@@ -183,85 +101,50 @@ Answer 2
self.assertEqual(mock_stdout.getvalue(), expected_output) self.assertEqual(mock_stdout.getvalue(), expected_output)
class TestChatDB(TestChatBase): class TestChatDB(CmmTestCase):
def setUp(self) -> None: def setUp(self) -> None:
self.db_path = tempfile.TemporaryDirectory() self.db_path = tempfile.TemporaryDirectory()
self.cache_path = tempfile.TemporaryDirectory() self.cache_path = tempfile.TemporaryDirectory()
self.message1 = Message(Question('Question 1'), self.message1 = Message(Question('Question 1'),
Answer('Answer 1'), Answer('Answer 1'),
{Tag('tag1')}) {Tag('tag1')},
file_path=pathlib.Path('0001.txt'))
self.message2 = Message(Question('Question 2'), self.message2 = Message(Question('Question 2'),
Answer('Answer 2'), Answer('Answer 2'),
{Tag('tag2')}) {Tag('tag2')},
file_path=pathlib.Path('0002.yaml'))
self.message3 = Message(Question('Question 3'), self.message3 = Message(Question('Question 3'),
Answer('Answer 3'), Answer('Answer 3'),
{Tag('tag3')}) {Tag('tag3')},
file_path=pathlib.Path('0003.txt'))
self.message4 = Message(Question('Question 4'), self.message4 = Message(Question('Question 4'),
Answer('Answer 4'), Answer('Answer 4'),
{Tag('tag4')}) {Tag('tag4')},
file_path=pathlib.Path('0004.yaml'))
self.message1.to_file(pathlib.Path(self.db_path.name, '0001'), mformat='txt') self.message1.to_file(pathlib.Path(self.db_path.name, '0001.txt'))
self.message2.to_file(pathlib.Path(self.db_path.name, '0002'), mformat='yaml') self.message2.to_file(pathlib.Path(self.db_path.name, '0002.yaml'))
self.message3.to_file(pathlib.Path(self.db_path.name, '0003'), mformat='txt') self.message3.to_file(pathlib.Path(self.db_path.name, '0003.txt'))
self.message4.to_file(pathlib.Path(self.db_path.name, '0004'), mformat='yaml') self.message4.to_file(pathlib.Path(self.db_path.name, '0004.yaml'))
# make the next FID match the current state # make the next FID match the current state
next_fname = pathlib.Path(self.db_path.name) / '.next' next_fname = pathlib.Path(self.db_path.name) / '.next'
with open(next_fname, 'w') as f: with open(next_fname, 'w') as f:
f.write('4') f.write('4')
# add some "trash" in order to test if it's correctly handled / ignored
self.trash_files = ['.config.yaml', 'foo.yaml', 'bla.txt', 'fubar.msg']
for file in self.trash_files:
with open(pathlib.Path(self.db_path.name) / file, 'w') as f:
f.write('test trash')
# also create a file with actual yaml content
with open(pathlib.Path(self.db_path.name) / 'content.yaml', 'w') as f:
yaml.dump({'key': 'value'}, f)
self.trash_files.append('content.yaml')
self.maxDiff = None
def message_list(self, tmp_dir: tempfile.TemporaryDirectory) -> list[pathlib.Path]: def message_list(self, tmp_dir: tempfile.TemporaryDirectory) -> list[pathlib.Path]:
""" """
List all Message files in the given TemporaryDirectory. List all Message files in the given TemporaryDirectory.
""" """
# exclude '.next' # exclude '.next'
return [f for f in pathlib.Path(tmp_dir.name).glob('*.[tym]*') if f.name not in self.trash_files] return list(pathlib.Path(tmp_dir.name).glob('*.[ty]*'))
def tearDown(self) -> None: def tearDown(self) -> None:
self.db_path.cleanup() self.db_path.cleanup()
self.cache_path.cleanup() self.cache_path.cleanup()
pass pass
def test_validate(self) -> None: def test_chat_db_from_dir(self) -> None:
duplicate_message = Message(Question('Question 4'),
Answer('Answer 4'),
{Tag('tag4')},
file_path=pathlib.Path(self.db_path.name, '0004.txt'))
msg_to_file_force_suffix(duplicate_message)
with self.assertRaises(ChatError) as cm:
ChatDB.from_dir(pathlib.Path(self.cache_path.name),
pathlib.Path(self.db_path.name))
self.assertEqual(str(cm.exception), "Validation failed")
def test_file_path_ID_exists(self) -> None:
"""
Tests if the CacheDB chooses another ID if a file path with
the given one exists.
"""
# create a new and empty CacheDB
db_path = tempfile.TemporaryDirectory()
cache_path = tempfile.TemporaryDirectory()
chat_db = ChatDB.from_dir(pathlib.Path(cache_path.name),
pathlib.Path(db_path.name))
# add a message file
message = Message(Question('What?'),
file_path=pathlib.Path(cache_path.name) / f'0001{msg_suffix}')
message.to_file()
message1 = Message(Question('Where?'))
chat_db.cache_write([message1])
self.assertEqual(message1.msg_id(), '0002')
def test_from_dir(self) -> None:
chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name), chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name),
pathlib.Path(self.db_path.name)) pathlib.Path(self.db_path.name))
self.assertEqual(len(chat_db.messages), 4) self.assertEqual(len(chat_db.messages), 4)
@@ -269,43 +152,27 @@ class TestChatDB(TestChatBase):
self.assertEqual(chat_db.db_path, pathlib.Path(self.db_path.name)) self.assertEqual(chat_db.db_path, pathlib.Path(self.db_path.name))
# check that the files are sorted # check that the files are sorted
self.assertEqual(chat_db.messages[0].file_path, self.assertEqual(chat_db.messages[0].file_path,
pathlib.Path(self.db_path.name, f'0001{msg_suffix}')) pathlib.Path(self.db_path.name, '0001.txt'))
self.assertEqual(chat_db.messages[1].file_path, self.assertEqual(chat_db.messages[1].file_path,
pathlib.Path(self.db_path.name, f'0002{msg_suffix}')) pathlib.Path(self.db_path.name, '0002.yaml'))
self.assertEqual(chat_db.messages[2].file_path, self.assertEqual(chat_db.messages[2].file_path,
pathlib.Path(self.db_path.name, f'0003{msg_suffix}')) pathlib.Path(self.db_path.name, '0003.txt'))
self.assertEqual(chat_db.messages[3].file_path, self.assertEqual(chat_db.messages[3].file_path,
pathlib.Path(self.db_path.name, f'0004{msg_suffix}')) pathlib.Path(self.db_path.name, '0004.yaml'))
def test_from_dir_glob(self) -> None: def test_chat_db_from_dir_glob(self) -> None:
chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name), chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name),
pathlib.Path(self.db_path.name), pathlib.Path(self.db_path.name),
glob='*1.*') glob='*.txt')
self.assertEqual(len(chat_db.messages), 1) self.assertEqual(len(chat_db.messages), 2)
self.assertEqual(chat_db.cache_path, pathlib.Path(self.cache_path.name)) self.assertEqual(chat_db.cache_path, pathlib.Path(self.cache_path.name))
self.assertEqual(chat_db.db_path, pathlib.Path(self.db_path.name)) self.assertEqual(chat_db.db_path, pathlib.Path(self.db_path.name))
self.assertEqual(chat_db.messages[0].file_path, self.assertEqual(chat_db.messages[0].file_path,
pathlib.Path(self.db_path.name, f'0001{msg_suffix}')) pathlib.Path(self.db_path.name, '0001.txt'))
self.assertEqual(chat_db.messages[1].file_path,
pathlib.Path(self.db_path.name, '0003.txt'))
def test_from_dir_filter_tags(self) -> None: def test_chat_db_filter(self) -> None:
chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name),
pathlib.Path(self.db_path.name),
mfilter=MessageFilter(tags_or={Tag('tag1')}))
self.assertEqual(len(chat_db.messages), 1)
self.assertEqual(chat_db.cache_path, pathlib.Path(self.cache_path.name))
self.assertEqual(chat_db.db_path, pathlib.Path(self.db_path.name))
self.assertEqual(chat_db.messages[0].file_path,
pathlib.Path(self.db_path.name, f'0001{msg_suffix}'))
def test_from_dir_filter_tags_empty(self) -> None:
chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name),
pathlib.Path(self.db_path.name),
mfilter=MessageFilter(tags_or=set(),
tags_and=set(),
tags_not=set()))
self.assertEqual(len(chat_db.messages), 0)
def test_from_dir_filter_answer(self) -> None:
chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name), chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name),
pathlib.Path(self.db_path.name), pathlib.Path(self.db_path.name),
mfilter=MessageFilter(answer_contains='Answer 2')) mfilter=MessageFilter(answer_contains='Answer 2'))
@@ -313,10 +180,10 @@ class TestChatDB(TestChatBase):
self.assertEqual(chat_db.cache_path, pathlib.Path(self.cache_path.name)) self.assertEqual(chat_db.cache_path, pathlib.Path(self.cache_path.name))
self.assertEqual(chat_db.db_path, pathlib.Path(self.db_path.name)) self.assertEqual(chat_db.db_path, pathlib.Path(self.db_path.name))
self.assertEqual(chat_db.messages[0].file_path, self.assertEqual(chat_db.messages[0].file_path,
pathlib.Path(self.db_path.name, f'0002{msg_suffix}')) pathlib.Path(self.db_path.name, '0002.yaml'))
self.assertEqual(chat_db.messages[0].answer, 'Answer 2') self.assertEqual(chat_db.messages[0].answer, 'Answer 2')
def test_from_messages(self) -> None: def test_chat_db_from_messges(self) -> None:
chat_db = ChatDB.from_messages(pathlib.Path(self.cache_path.name), chat_db = ChatDB.from_messages(pathlib.Path(self.cache_path.name),
pathlib.Path(self.db_path.name), pathlib.Path(self.db_path.name),
messages=[self.message1, self.message2, messages=[self.message1, self.message2,
@@ -325,58 +192,39 @@ class TestChatDB(TestChatBase):
self.assertEqual(chat_db.cache_path, pathlib.Path(self.cache_path.name)) self.assertEqual(chat_db.cache_path, pathlib.Path(self.cache_path.name))
self.assertEqual(chat_db.db_path, pathlib.Path(self.db_path.name)) self.assertEqual(chat_db.db_path, pathlib.Path(self.db_path.name))
def test_fids(self) -> None: def test_chat_db_fids(self) -> None:
chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name), chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name),
pathlib.Path(self.db_path.name)) pathlib.Path(self.db_path.name))
self.assertEqual(chat_db.get_next_fid(), 5) self.assertEqual(chat_db.get_next_fid(), 5)
self.assertEqual(chat_db.get_next_fid(), 6) self.assertEqual(chat_db.get_next_fid(), 6)
self.assertEqual(chat_db.get_next_fid(), 7) self.assertEqual(chat_db.get_next_fid(), 7)
with open(chat_db.next_path, 'r') as f: with open(chat_db.next_fname, 'r') as f:
self.assertEqual(f.read(), '7') self.assertEqual(f.read(), '7')
def test_msg_in_db_or_cache(self) -> None: def test_chat_db_write(self) -> None:
# create a new ChatDB instance
chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name),
pathlib.Path(self.db_path.name))
self.assertTrue(chat_db.msg_in_db(self.message1))
self.assertTrue(chat_db.msg_in_db(str(self.message1.file_path)))
self.assertTrue(chat_db.msg_in_db(self.message1.msg_id()))
self.assertFalse(chat_db.msg_in_cache(self.message1))
self.assertFalse(chat_db.msg_in_cache(str(self.message1.file_path)))
self.assertFalse(chat_db.msg_in_cache(self.message1.msg_id()))
# add new message to the cache dir
cache_message = Message(question=Question("Question 1"),
answer=Answer("Answer 1"))
chat_db.cache_add([cache_message])
self.assertTrue(chat_db.msg_in_cache(cache_message))
self.assertTrue(chat_db.msg_in_cache(cache_message.msg_id()))
self.assertFalse(chat_db.msg_in_db(cache_message))
self.assertFalse(chat_db.msg_in_db(str(cache_message.file_path)))
def test_db_write(self) -> None:
# create a new ChatDB instance # create a new ChatDB instance
chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name), chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name),
pathlib.Path(self.db_path.name)) pathlib.Path(self.db_path.name))
# check that Message.file_path is correct # check that Message.file_path is correct
self.assertEqual(chat_db.messages[0].file_path, pathlib.Path(self.db_path.name, f'0001{msg_suffix}')) self.assertEqual(chat_db.messages[0].file_path, pathlib.Path(self.db_path.name, '0001.txt'))
self.assertEqual(chat_db.messages[1].file_path, pathlib.Path(self.db_path.name, f'0002{msg_suffix}')) self.assertEqual(chat_db.messages[1].file_path, pathlib.Path(self.db_path.name, '0002.yaml'))
self.assertEqual(chat_db.messages[2].file_path, pathlib.Path(self.db_path.name, f'0003{msg_suffix}')) self.assertEqual(chat_db.messages[2].file_path, pathlib.Path(self.db_path.name, '0003.txt'))
self.assertEqual(chat_db.messages[3].file_path, pathlib.Path(self.db_path.name, f'0004{msg_suffix}')) self.assertEqual(chat_db.messages[3].file_path, pathlib.Path(self.db_path.name, '0004.yaml'))
# write the messages to the cache directory # write the messages to the cache directory
chat_db.cache_write() chat_db.write_cache()
# check if the written files are in the cache directory # check if the written files are in the cache directory
cache_dir_files = self.message_list(self.cache_path) cache_dir_files = self.message_list(self.cache_path)
self.assertEqual(len(cache_dir_files), 4) self.assertEqual(len(cache_dir_files), 4)
self.assertIn(pathlib.Path(self.cache_path.name, f'0001{msg_suffix}'), cache_dir_files) self.assertIn(pathlib.Path(self.cache_path.name, '0001.txt'), cache_dir_files)
self.assertIn(pathlib.Path(self.cache_path.name, f'0002{msg_suffix}'), cache_dir_files) self.assertIn(pathlib.Path(self.cache_path.name, '0002.yaml'), cache_dir_files)
self.assertIn(pathlib.Path(self.cache_path.name, f'0003{msg_suffix}'), cache_dir_files) self.assertIn(pathlib.Path(self.cache_path.name, '0003.txt'), cache_dir_files)
self.assertIn(pathlib.Path(self.cache_path.name, f'0004{msg_suffix}'), cache_dir_files) self.assertIn(pathlib.Path(self.cache_path.name, '0004.yaml'), cache_dir_files)
# check that Message.file_path has been correctly updated # check that Message.file_path has been correctly updated
self.assertEqual(chat_db.messages[0].file_path, pathlib.Path(self.cache_path.name, f'0001{msg_suffix}')) self.assertEqual(chat_db.messages[0].file_path, pathlib.Path(self.cache_path.name, '0001.txt'))
self.assertEqual(chat_db.messages[1].file_path, pathlib.Path(self.cache_path.name, f'0002{msg_suffix}')) self.assertEqual(chat_db.messages[1].file_path, pathlib.Path(self.cache_path.name, '0002.yaml'))
self.assertEqual(chat_db.messages[2].file_path, pathlib.Path(self.cache_path.name, f'0003{msg_suffix}')) self.assertEqual(chat_db.messages[2].file_path, pathlib.Path(self.cache_path.name, '0003.txt'))
self.assertEqual(chat_db.messages[3].file_path, pathlib.Path(self.cache_path.name, f'0004{msg_suffix}')) self.assertEqual(chat_db.messages[3].file_path, pathlib.Path(self.cache_path.name, '0004.yaml'))
# check the timestamp of the files in the DB directory # check the timestamp of the files in the DB directory
db_dir_files = self.message_list(self.db_path) db_dir_files = self.message_list(self.db_path)
@@ -384,24 +232,24 @@ class TestChatDB(TestChatBase):
old_timestamps = {file: file.stat().st_mtime for file in db_dir_files} old_timestamps = {file: file.stat().st_mtime for file in db_dir_files}
# overwrite the messages in the db directory # overwrite the messages in the db directory
time.sleep(0.05) time.sleep(0.05)
chat_db.db_write() chat_db.write_db()
# check if the written files are in the DB directory # check if the written files are in the DB directory
db_dir_files = self.message_list(self.db_path) db_dir_files = self.message_list(self.db_path)
self.assertEqual(len(db_dir_files), 4) self.assertEqual(len(db_dir_files), 4)
self.assertIn(pathlib.Path(self.db_path.name, f'0001{msg_suffix}'), db_dir_files) self.assertIn(pathlib.Path(self.db_path.name, '0001.txt'), db_dir_files)
self.assertIn(pathlib.Path(self.db_path.name, f'0002{msg_suffix}'), db_dir_files) self.assertIn(pathlib.Path(self.db_path.name, '0002.yaml'), db_dir_files)
self.assertIn(pathlib.Path(self.db_path.name, f'0003{msg_suffix}'), db_dir_files) self.assertIn(pathlib.Path(self.db_path.name, '0003.txt'), db_dir_files)
self.assertIn(pathlib.Path(self.db_path.name, f'0004{msg_suffix}'), db_dir_files) self.assertIn(pathlib.Path(self.db_path.name, '0004.yaml'), db_dir_files)
# check if all files in the DB dir have actually been overwritten # check if all files in the DB dir have actually been overwritten
for file in db_dir_files: for file in db_dir_files:
self.assertGreater(file.stat().st_mtime, old_timestamps[file]) self.assertGreater(file.stat().st_mtime, old_timestamps[file])
# check that Message.file_path has been correctly updated (again) # check that Message.file_path has been correctly updated (again)
self.assertEqual(chat_db.messages[0].file_path, pathlib.Path(self.db_path.name, f'0001{msg_suffix}')) self.assertEqual(chat_db.messages[0].file_path, pathlib.Path(self.db_path.name, '0001.txt'))
self.assertEqual(chat_db.messages[1].file_path, pathlib.Path(self.db_path.name, f'0002{msg_suffix}')) self.assertEqual(chat_db.messages[1].file_path, pathlib.Path(self.db_path.name, '0002.yaml'))
self.assertEqual(chat_db.messages[2].file_path, pathlib.Path(self.db_path.name, f'0003{msg_suffix}')) self.assertEqual(chat_db.messages[2].file_path, pathlib.Path(self.db_path.name, '0003.txt'))
self.assertEqual(chat_db.messages[3].file_path, pathlib.Path(self.db_path.name, f'0004{msg_suffix}')) self.assertEqual(chat_db.messages[3].file_path, pathlib.Path(self.db_path.name, '0004.yaml'))
def test_db_read(self) -> None: def test_chat_db_read(self) -> None:
# create a new ChatDB instance # create a new ChatDB instance
chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name), chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name),
pathlib.Path(self.db_path.name)) pathlib.Path(self.db_path.name))
@@ -414,80 +262,80 @@ class TestChatDB(TestChatBase):
new_message2 = Message(Question('Question 6'), new_message2 = Message(Question('Question 6'),
Answer('Answer 6'), Answer('Answer 6'),
{Tag('tag6')}) {Tag('tag6')})
new_message1.to_file(pathlib.Path(self.db_path.name, f'0005{msg_suffix}'), mformat='txt') new_message1.to_file(pathlib.Path(self.db_path.name, '0005.txt'))
new_message2.to_file(pathlib.Path(self.db_path.name, f'0006{msg_suffix}'), mformat='yaml') new_message2.to_file(pathlib.Path(self.db_path.name, '0006.yaml'))
# read and check them # read and check them
chat_db.db_read() chat_db.read_db()
self.assertEqual(len(chat_db.messages), 6) self.assertEqual(len(chat_db.messages), 6)
self.assertEqual(chat_db.messages[4].file_path, pathlib.Path(self.db_path.name, f'0005{msg_suffix}')) self.assertEqual(chat_db.messages[4].file_path, pathlib.Path(self.db_path.name, '0005.txt'))
self.assertEqual(chat_db.messages[5].file_path, pathlib.Path(self.db_path.name, f'0006{msg_suffix}')) self.assertEqual(chat_db.messages[5].file_path, pathlib.Path(self.db_path.name, '0006.yaml'))
# create 2 new files in the cache directory # create 2 new files in the cache directory
new_message3 = Message(Question('Question 7'), new_message3 = Message(Question('Question 7'),
Answer('Answer 7'), Answer('Answer 5'),
{Tag('tag7')}) {Tag('tag7')})
new_message4 = Message(Question('Question 8'), new_message4 = Message(Question('Question 8'),
Answer('Answer 8'), Answer('Answer 6'),
{Tag('tag8')}) {Tag('tag8')})
new_message3.to_file(pathlib.Path(self.cache_path.name, f'0007{msg_suffix}'), mformat='txt') new_message3.to_file(pathlib.Path(self.cache_path.name, '0007.txt'))
new_message4.to_file(pathlib.Path(self.cache_path.name, f'0008{msg_suffix}'), mformat='yaml') new_message4.to_file(pathlib.Path(self.cache_path.name, '0008.yaml'))
# read and check them # read and check them
chat_db.cache_read() chat_db.read_cache()
self.assertEqual(len(chat_db.messages), 8) self.assertEqual(len(chat_db.messages), 8)
# check that the new message have the cache dir path # check that the new message have the cache dir path
self.assertEqual(chat_db.messages[6].file_path, pathlib.Path(self.cache_path.name, f'0007{msg_suffix}')) self.assertEqual(chat_db.messages[6].file_path, pathlib.Path(self.cache_path.name, '0007.txt'))
self.assertEqual(chat_db.messages[7].file_path, pathlib.Path(self.cache_path.name, f'0008{msg_suffix}')) self.assertEqual(chat_db.messages[7].file_path, pathlib.Path(self.cache_path.name, '0008.yaml'))
# an the old ones keep their path (since they have not been replaced) # an the old ones keep their path (since they have not been replaced)
self.assertEqual(chat_db.messages[4].file_path, pathlib.Path(self.db_path.name, f'0005{msg_suffix}')) self.assertEqual(chat_db.messages[4].file_path, pathlib.Path(self.db_path.name, '0005.txt'))
self.assertEqual(chat_db.messages[5].file_path, pathlib.Path(self.db_path.name, f'0006{msg_suffix}')) self.assertEqual(chat_db.messages[5].file_path, pathlib.Path(self.db_path.name, '0006.yaml'))
# now overwrite two messages in the DB directory # now overwrite two messages in the DB directory
new_message1.question = Question('New Question 1') new_message1.question = Question('New Question 1')
new_message2.question = Question('New Question 2') new_message2.question = Question('New Question 2')
new_message1.to_file(pathlib.Path(self.db_path.name, f'0005{msg_suffix}'), mformat='txt') new_message1.to_file(pathlib.Path(self.db_path.name, '0005.txt'))
new_message2.to_file(pathlib.Path(self.db_path.name, f'0006{msg_suffix}'), mformat='yaml') new_message2.to_file(pathlib.Path(self.db_path.name, '0006.yaml'))
# read from the DB dir and check if the modified messages have been updated # read from the DB dir and check if the modified messages have been updated
chat_db.db_read() chat_db.read_db()
self.assertEqual(len(chat_db.messages), 8) self.assertEqual(len(chat_db.messages), 8)
self.assertEqual(chat_db.messages[4].question, 'New Question 1') self.assertEqual(chat_db.messages[4].question, 'New Question 1')
self.assertEqual(chat_db.messages[5].question, 'New Question 2') self.assertEqual(chat_db.messages[5].question, 'New Question 2')
self.assertEqual(chat_db.messages[4].file_path, pathlib.Path(self.db_path.name, f'0005{msg_suffix}')) self.assertEqual(chat_db.messages[4].file_path, pathlib.Path(self.db_path.name, '0005.txt'))
self.assertEqual(chat_db.messages[5].file_path, pathlib.Path(self.db_path.name, f'0006{msg_suffix}')) self.assertEqual(chat_db.messages[5].file_path, pathlib.Path(self.db_path.name, '0006.yaml'))
# now write the messages from the cache to the DB directory # now write the messages from the cache to the DB directory
new_message3.to_file(pathlib.Path(self.db_path.name, f'0007{msg_suffix}')) new_message3.to_file(pathlib.Path(self.db_path.name, '0007.txt'))
new_message4.to_file(pathlib.Path(self.db_path.name, f'0008{msg_suffix}')) new_message4.to_file(pathlib.Path(self.db_path.name, '0008.yaml'))
# read and check them # read and check them
chat_db.db_read() chat_db.read_db()
self.assertEqual(len(chat_db.messages), 8) self.assertEqual(len(chat_db.messages), 8)
# check that they now have the DB path # check that they now have the DB path
self.assertEqual(chat_db.messages[6].file_path, pathlib.Path(self.db_path.name, f'0007{msg_suffix}')) self.assertEqual(chat_db.messages[6].file_path, pathlib.Path(self.db_path.name, '0007.txt'))
self.assertEqual(chat_db.messages[7].file_path, pathlib.Path(self.db_path.name, f'0008{msg_suffix}')) self.assertEqual(chat_db.messages[7].file_path, pathlib.Path(self.db_path.name, '0008.yaml'))
def test_cache_clear(self) -> None: def test_chat_db_clear(self) -> None:
# create a new ChatDB instance # create a new ChatDB instance
chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name), chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name),
pathlib.Path(self.db_path.name)) pathlib.Path(self.db_path.name))
# check that Message.file_path is correct # check that Message.file_path is correct
self.assertEqual(chat_db.messages[0].file_path, pathlib.Path(self.db_path.name, f'0001{msg_suffix}')) self.assertEqual(chat_db.messages[0].file_path, pathlib.Path(self.db_path.name, '0001.txt'))
self.assertEqual(chat_db.messages[1].file_path, pathlib.Path(self.db_path.name, f'0002{msg_suffix}')) self.assertEqual(chat_db.messages[1].file_path, pathlib.Path(self.db_path.name, '0002.yaml'))
self.assertEqual(chat_db.messages[2].file_path, pathlib.Path(self.db_path.name, f'0003{msg_suffix}')) self.assertEqual(chat_db.messages[2].file_path, pathlib.Path(self.db_path.name, '0003.txt'))
self.assertEqual(chat_db.messages[3].file_path, pathlib.Path(self.db_path.name, f'0004{msg_suffix}')) self.assertEqual(chat_db.messages[3].file_path, pathlib.Path(self.db_path.name, '0004.yaml'))
# write the messages to the cache directory # write the messages to the cache directory
chat_db.cache_write() chat_db.write_cache()
# check if the written files are in the cache directory # check if the written files are in the cache directory
cache_dir_files = self.message_list(self.cache_path) cache_dir_files = self.message_list(self.cache_path)
self.assertEqual(len(cache_dir_files), 4) self.assertEqual(len(cache_dir_files), 4)
# now rewrite them to the DB dir and check for modified paths # now rewrite them to the DB dir and check for modified paths
chat_db.db_write() chat_db.write_db()
db_dir_files = self.message_list(self.db_path) db_dir_files = self.message_list(self.db_path)
self.assertEqual(len(db_dir_files), 4) self.assertEqual(len(db_dir_files), 4)
self.assertIn(pathlib.Path(self.db_path.name, f'0001{msg_suffix}'), db_dir_files) self.assertIn(pathlib.Path(self.db_path.name, '0001.txt'), db_dir_files)
self.assertIn(pathlib.Path(self.db_path.name, f'0002{msg_suffix}'), db_dir_files) self.assertIn(pathlib.Path(self.db_path.name, '0002.yaml'), db_dir_files)
self.assertIn(pathlib.Path(self.db_path.name, f'0003{msg_suffix}'), db_dir_files) self.assertIn(pathlib.Path(self.db_path.name, '0003.txt'), db_dir_files)
self.assertIn(pathlib.Path(self.db_path.name, f'0004{msg_suffix}'), db_dir_files) self.assertIn(pathlib.Path(self.db_path.name, '0004.yaml'), db_dir_files)
# add a new message with empty file_path # add a new message with empty file_path
message_empty = Message(question=Question("What the hell am I doing here?"), message_empty = Message(question=Question("What the hell am I doing here?"),
@@ -495,11 +343,11 @@ class TestChatDB(TestChatBase):
# and one for the cache dir # and one for the cache dir
message_cache = Message(question=Question("What the hell am I doing here?"), message_cache = Message(question=Question("What the hell am I doing here?"),
answer=Answer("You're a creep!"), answer=Answer("You're a creep!"),
file_path=pathlib.Path(self.cache_path.name, '0005')) file_path=pathlib.Path(self.cache_path.name, '0005.txt'))
chat_db.msg_add([message_empty, message_cache]) chat_db.add_messages([message_empty, message_cache])
# clear the cache and check the cache dir # clear the cache and check the cache dir
chat_db.cache_clear() chat_db.clear_cache()
cache_dir_files = self.message_list(self.cache_path) cache_dir_files = self.message_list(self.cache_path)
self.assertEqual(len(cache_dir_files), 0) self.assertEqual(len(cache_dir_files), 0)
# make sure that the DB messages (and the new message) are still there # make sure that the DB messages (and the new message) are still there
@@ -509,7 +357,7 @@ class TestChatDB(TestChatBase):
# but not the message with the cache dir path # but not the message with the cache dir path
self.assertFalse(any(m.file_path == message_cache.file_path for m in chat_db.messages)) self.assertFalse(any(m.file_path == message_cache.file_path for m in chat_db.messages))
def test_add(self) -> None: def test_chat_db_add(self) -> None:
# create a new ChatDB instance # create a new ChatDB instance
chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name), chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name),
pathlib.Path(self.db_path.name)) pathlib.Path(self.db_path.name))
@@ -520,7 +368,7 @@ class TestChatDB(TestChatBase):
# add new messages to the cache dir # add new messages to the cache dir
message1 = Message(question=Question("Question 1"), message1 = Message(question=Question("Question 1"),
answer=Answer("Answer 1")) answer=Answer("Answer 1"))
chat_db.cache_add([message1]) chat_db.add_to_cache([message1])
# check if the file_path has been correctly set # check if the file_path has been correctly set
self.assertIsNotNone(message1.file_path) self.assertIsNotNone(message1.file_path)
self.assertEqual(message1.file_path.parent, pathlib.Path(self.cache_path.name)) # type: ignore [union-attr] self.assertEqual(message1.file_path.parent, pathlib.Path(self.cache_path.name)) # type: ignore [union-attr]
@@ -530,157 +378,9 @@ class TestChatDB(TestChatBase):
# add new messages to the DB dir # add new messages to the DB dir
message2 = Message(question=Question("Question 2"), message2 = Message(question=Question("Question 2"),
answer=Answer("Answer 2")) answer=Answer("Answer 2"))
chat_db.db_add([message2]) chat_db.add_to_db([message2])
# check if the file_path has been correctly set # check if the file_path has been correctly set
self.assertIsNotNone(message2.file_path) self.assertIsNotNone(message2.file_path)
self.assertEqual(message2.file_path.parent, pathlib.Path(self.db_path.name)) # type: ignore [union-attr] self.assertEqual(message2.file_path.parent, pathlib.Path(self.db_path.name)) # type: ignore [union-attr]
db_dir_files = self.message_list(self.db_path) db_dir_files = self.message_list(self.db_path)
self.assertEqual(len(db_dir_files), 5) self.assertEqual(len(db_dir_files), 5)
with self.assertRaises(ChatError):
chat_db.cache_add([Message(Question("?"), file_path=pathlib.Path("foo"))])
def test_msg_write(self) -> None:
# create a new ChatDB instance
chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name),
pathlib.Path(self.db_path.name))
db_dir_files = self.message_list(self.db_path)
self.assertEqual(len(db_dir_files), 4)
cache_dir_files = self.message_list(self.cache_path)
self.assertEqual(len(cache_dir_files), 0)
# try to write a message without a valid file_path
message = Message(question=Question("Question 1"),
answer=Answer("Answer 1"))
with self.assertRaises(ChatError):
chat_db.msg_write([message])
# write a message with a valid file_path
message.file_path = pathlib.Path(self.cache_path.name) / '123456'
chat_db.msg_write([message])
cache_dir_files = self.message_list(self.cache_path)
self.assertEqual(len(cache_dir_files), 1)
self.assertIn(pathlib.Path(self.cache_path.name, f'123456{msg_suffix}'), cache_dir_files)
def test_msg_update(self) -> None:
# create a new ChatDB instance
chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name),
pathlib.Path(self.db_path.name))
db_dir_files = self.message_list(self.db_path)
self.assertEqual(len(db_dir_files), 4)
cache_dir_files = self.message_list(self.cache_path)
self.assertEqual(len(cache_dir_files), 0)
message = chat_db.messages[0]
message.answer = Answer("New answer")
# update message without writing
chat_db.msg_update([message], write=False)
self.assertEqual(chat_db.messages[0].answer, Answer("New answer"))
# re-read the message and check for old content
chat_db.db_read()
self.assertEqual(chat_db.messages[0].answer, Answer("Answer 1"))
# now check with writing (message should be overwritten)
chat_db.msg_update([message], write=True)
chat_db.db_read()
self.assertEqual(chat_db.messages[0].answer, Answer("New answer"))
# test without file_path -> expect error
message1 = Message(question=Question("Question 1"),
answer=Answer("Answer 1"))
with self.assertRaises(ChatError):
chat_db.msg_update([message1])
def test_msg_find(self) -> None:
chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name),
pathlib.Path(self.db_path.name))
# search for a DB file in memory
self.assertEqual(chat_db.msg_find([str(self.message1.file_path)], loc='mem'), [self.message1])
self.assertEqual(chat_db.msg_find([self.message1.msg_id()], loc='mem'), [self.message1])
self.assertEqual(chat_db.msg_find(['0001.msg'], loc='mem'), [self.message1])
self.assertEqual(chat_db.msg_find(['0001'], loc='mem'), [self.message1])
# and on disk
self.assertEqual(chat_db.msg_find([str(self.message2.file_path)], loc='db'), [self.message2])
self.assertEqual(chat_db.msg_find([self.message2.msg_id()], loc='db'), [self.message2])
self.assertEqual(chat_db.msg_find(['0002.msg'], loc='db'), [self.message2])
self.assertEqual(chat_db.msg_find(['0002'], loc='db'), [self.message2])
# now search the cache -> expect empty result
self.assertEqual(chat_db.msg_find([str(self.message3.file_path)], loc='cache'), [])
self.assertEqual(chat_db.msg_find([self.message3.msg_id()], loc='cache'), [])
self.assertEqual(chat_db.msg_find(['0003.msg'], loc='cache'), [])
self.assertEqual(chat_db.msg_find(['0003'], loc='cache'), [])
# search for multiple messages
# -> search one twice, expect result to be unique
search_names = ['0001', '0002.msg', self.message3.msg_id(), str(self.message3.file_path)]
expected_result = [self.message1, self.message2, self.message3]
result = chat_db.msg_find(search_names, loc='all')
self.assert_messages_equal(result, expected_result)
def test_msg_latest(self) -> None:
chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name),
pathlib.Path(self.db_path.name))
self.assertEqual(chat_db.msg_latest(loc='mem'), self.message4)
self.assertEqual(chat_db.msg_latest(loc='db'), self.message4)
self.assertEqual(chat_db.msg_latest(loc='disk'), self.message4)
self.assertEqual(chat_db.msg_latest(loc='all'), self.message4)
# the cache is currently empty:
self.assertIsNone(chat_db.msg_latest(loc='cache'))
# add new messages to the cache dir
new_message = Message(question=Question("New Question"),
answer=Answer("New Answer"))
chat_db.cache_add([new_message])
self.assertEqual(chat_db.msg_latest(loc='cache'), new_message)
self.assertEqual(chat_db.msg_latest(loc='mem'), new_message)
self.assertEqual(chat_db.msg_latest(loc='disk'), new_message)
self.assertEqual(chat_db.msg_latest(loc='all'), new_message)
# the DB does not contain the new message
self.assertEqual(chat_db.msg_latest(loc='db'), self.message4)
def test_msg_gather(self) -> None:
chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name),
pathlib.Path(self.db_path.name))
all_messages = [self.message1, self.message2, self.message3, self.message4]
self.assert_messages_equal(chat_db.msg_gather(loc='all'), all_messages)
self.assert_messages_equal(chat_db.msg_gather(loc='db'), all_messages)
self.assert_messages_equal(chat_db.msg_gather(loc='mem'), all_messages)
self.assert_messages_equal(chat_db.msg_gather(loc='disk'), all_messages)
self.assert_messages_equal(chat_db.msg_gather(loc='cache'), [])
# add a new message, but only to the internal list
new_message = Message(Question("What?"))
all_messages_mem = all_messages + [new_message]
chat_db.msg_add([new_message])
self.assert_messages_equal(chat_db.msg_gather(loc='mem'), all_messages_mem)
self.assert_messages_equal(chat_db.msg_gather(loc='all'), all_messages_mem)
# the nr. of messages on disk did not change -> expect old result
self.assert_messages_equal(chat_db.msg_gather(loc='db'), all_messages)
self.assert_messages_equal(chat_db.msg_gather(loc='disk'), all_messages)
self.assert_messages_equal(chat_db.msg_gather(loc='cache'), [])
# test with MessageFilter
self.assert_messages_equal(chat_db.msg_gather(loc='all', mfilter=MessageFilter(tags_or={Tag('tag1')})),
[self.message1])
self.assert_messages_equal(chat_db.msg_gather(loc='disk', mfilter=MessageFilter(tags_or={Tag('tag2')})),
[self.message2])
self.assert_messages_equal(chat_db.msg_gather(loc='cache', mfilter=MessageFilter(tags_or={Tag('tag3')})),
[])
self.assert_messages_equal(chat_db.msg_gather(loc='mem', mfilter=MessageFilter(question_contains="What")),
[new_message])
def test_msg_move_and_gather(self) -> None:
chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name),
pathlib.Path(self.db_path.name))
all_messages = [self.message1, self.message2, self.message3, self.message4]
self.assert_messages_equal(chat_db.msg_gather(loc='db'), all_messages)
self.assert_messages_equal(chat_db.msg_gather(loc='cache'), [])
# move first message to the cache
chat_db.cache_move(self.message1)
self.assert_messages_equal(chat_db.msg_gather(loc='cache'), [self.message1])
self.assertEqual(self.message1.file_path.parent, pathlib.Path(self.cache_path.name)) # type: ignore [union-attr]
self.assert_messages_equal(chat_db.msg_gather(loc='db'), [self.message2, self.message3, self.message4])
self.assert_messages_equal(chat_db.msg_gather(loc='all'), all_messages)
self.assert_messages_equal(chat_db.msg_gather(loc='disk'), all_messages)
self.assert_messages_equal(chat_db.msg_gather(loc='mem'), all_messages)
# now move first message back to the DB
chat_db.db_move(self.message1)
self.assert_messages_equal(chat_db.msg_gather(loc='cache'), [])
self.assertEqual(self.message1.file_path.parent, pathlib.Path(self.db_path.name)) # type: ignore [union-attr]
self.assert_messages_equal(chat_db.msg_gather(loc='db'), all_messages)
-100
View File
@@ -1,100 +0,0 @@
import unittest
import argparse
from typing import Union, Optional
from chatmastermind.configuration import Config, AIConfig
from chatmastermind.tags import Tag
from chatmastermind.message import Message, Answer
from chatmastermind.chat import Chat
from chatmastermind.ai import AI, AIResponse, Tokens, AIError
class FakeAI(AI):
"""
A mocked version of the 'AI' class.
"""
ID: str
name: str
config: AIConfig
def models(self) -> list[str]:
raise NotImplementedError
def tokens(self, data: Union[Message, Chat]) -> int:
return 123
def print(self) -> None:
pass
def print_models(self) -> None:
pass
def __init__(self, ID: str, model: str, error: bool = False):
self.ID = ID
self.model = model
self.error = error
def request(self,
question: Message,
chat: Chat,
num_answers: int = 1,
otags: Optional[set[Tag]] = None) -> AIResponse:
"""
Mock the 'ai.request()' function by either returning fake
answers or raising an exception.
"""
if self.error:
raise AIError
question.answer = Answer("Answer 0")
question.tags = set(otags) if otags is not None else None
question.ai = self.ID
question.model = self.model
answers: list[Message] = [question]
for n in range(1, num_answers):
answers.append(Message(question=question.question,
answer=Answer(f"Answer {n}"),
tags=otags,
ai=self.ID,
model=self.model))
return AIResponse(answers, Tokens(10, 10, 20))
class TestWithFakeAI(unittest.TestCase):
"""
Base class for all tests that need to use the FakeAI.
"""
def assert_msgs_equal_except_file_path(self, msg1: list[Message], msg2: list[Message]) -> None:
"""
Compare messages using Question, Answer and all metadata excecot for the file_path.
"""
self.assertEqual(len(msg1), len(msg2))
for m1, m2 in zip(msg1, msg2):
# exclude the file_path, compare only Q, A and metadata
self.assertTrue(m1.equals(m2, file_path=False, verbose=True))
def assert_msgs_all_equal(self, msg1: list[Message], msg2: list[Message]) -> None:
"""
Compare messages using Question, Answer and ALL metadata.
"""
self.assertEqual(len(msg1), len(msg2))
for m1, m2 in zip(msg1, msg2):
self.assertTrue(m1.equals(m2, verbose=True))
def assert_msgs_content_equal(self, msg1: list[Message], msg2: list[Message]) -> None:
"""
Compare messages using only Question and Answer.
"""
self.assertEqual(len(msg1), len(msg2))
for m1, m2 in zip(msg1, msg2):
self.assertEqual(m1, m2)
def mock_create_ai(self, args: argparse.Namespace, config: Config) -> AI:
"""
Mocked 'create_ai' that returns a 'FakeAI' instance.
"""
return FakeAI(args.AI, args.model)
def mock_create_ai_with_error(self, args: argparse.Namespace, config: Config) -> AI:
"""
Mocked 'create_ai' that returns a 'FakeAI' instance.
"""
return FakeAI(args.AI, args.model, error=True)
-167
View File
@@ -1,167 +0,0 @@
import os
import unittest
import yaml
from tempfile import NamedTemporaryFile
from pathlib import Path
from typing import cast
from chatmastermind.configuration import AIConfig, OpenAIConfig, ConfigError, ai_config_instance, Config
class TestAIConfigInstance(unittest.TestCase):
def test_ai_config_instance_with_valid_name_should_return_instance_with_default_values(self) -> None:
ai_config = cast(OpenAIConfig, ai_config_instance('openai'))
ai_reference = OpenAIConfig()
self.assertEqual(ai_config.ID, ai_reference.ID)
self.assertEqual(ai_config.name, ai_reference.name)
self.assertEqual(ai_config.api_key, ai_reference.api_key)
self.assertEqual(ai_config.system, ai_reference.system)
self.assertEqual(ai_config.model, ai_reference.model)
self.assertEqual(ai_config.temperature, ai_reference.temperature)
self.assertEqual(ai_config.max_tokens, ai_reference.max_tokens)
self.assertEqual(ai_config.top_p, ai_reference.top_p)
self.assertEqual(ai_config.frequency_penalty, ai_reference.frequency_penalty)
self.assertEqual(ai_config.presence_penalty, ai_reference.presence_penalty)
def test_ai_config_instance_with_valid_name_and_configuration_should_return_instance_with_custom_values(self) -> None:
conf_dict = {
'system': 'Custom system',
'api_key': '9876543210',
'model': 'custom_model',
'max_tokens': 5000,
'temperature': 0.5,
'top_p': 0.8,
'frequency_penalty': 0.7,
'presence_penalty': 0.2
}
ai_config = cast(OpenAIConfig, ai_config_instance('openai', conf_dict))
self.assertEqual(ai_config.system, 'Custom system')
self.assertEqual(ai_config.api_key, '9876543210')
self.assertEqual(ai_config.model, 'custom_model')
self.assertEqual(ai_config.max_tokens, 5000)
self.assertAlmostEqual(ai_config.temperature, 0.5)
self.assertAlmostEqual(ai_config.top_p, 0.8)
self.assertAlmostEqual(ai_config.frequency_penalty, 0.7)
self.assertAlmostEqual(ai_config.presence_penalty, 0.2)
def test_ai_config_instance_with_invalid_name_should_raise_config_error(self) -> None:
with self.assertRaises(ConfigError):
ai_config_instance('invalid_name')
class TestConfig(unittest.TestCase):
def setUp(self) -> None:
self.test_file = NamedTemporaryFile(delete=False)
def tearDown(self) -> None:
os.remove(self.test_file.name)
def test_from_dict_should_create_config_from_dict(self) -> None:
source_dict = {
'cache': '.',
'db': './test_db/',
'ais': {
'myopenai': {
'name': 'openai',
'system': 'Custom system',
'api_key': '9876543210',
'model': 'custom_model',
'max_tokens': 5000,
'temperature': 0.5,
'top_p': 0.8,
'frequency_penalty': 0.7,
'presence_penalty': 0.2
}
}
}
config = Config.from_dict(source_dict)
self.assertEqual(config.cache, '.')
self.assertEqual(config.db, './test_db/')
self.assertEqual(len(config.ais), 1)
self.assertEqual(config.ais['myopenai'].name, 'openai')
self.assertEqual(cast(OpenAIConfig, config.ais['myopenai']).system, 'Custom system')
# check that 'ID' has been added
self.assertEqual(config.ais['myopenai'].ID, 'myopenai')
def test_create_default_should_create_default_config(self) -> None:
Config.create_default(Path(self.test_file.name))
with open(self.test_file.name, 'r') as f:
default_config = yaml.load(f, Loader=yaml.FullLoader)
config_reference = Config()
self.assertEqual(default_config['db'], config_reference.db)
def test_from_file_should_load_config_from_file(self) -> None:
source_dict = {
'cache': './test_cache/',
'db': './test_db/',
'ais': {
'default': {
'name': 'openai',
'system': 'Custom system',
'api_key': '9876543210',
'model': 'custom_model',
'max_tokens': 5000,
'temperature': 0.5,
'top_p': 0.8,
'frequency_penalty': 0.7,
'presence_penalty': 0.2
}
}
}
with open(self.test_file.name, 'w') as f:
yaml.dump(source_dict, f)
config = Config.from_file(self.test_file.name)
self.assertIsInstance(config, Config)
self.assertEqual(config.cache, './test_cache/')
self.assertEqual(config.db, './test_db/')
self.assertEqual(len(config.ais), 1)
self.assertIsInstance(config.ais['default'], AIConfig)
self.assertEqual(cast(OpenAIConfig, config.ais['default']).system, 'Custom system')
def test_to_file_should_save_config_to_file(self) -> None:
config = Config(
cache='./test_cache/',
db='./test_db/',
ais={
'myopenai': OpenAIConfig(
ID='myopenai',
system='Custom system',
api_key='9876543210',
model='custom_model',
max_tokens=5000,
temperature=0.5,
top_p=0.8,
frequency_penalty=0.7,
presence_penalty=0.2
)
}
)
config.to_file(Path(self.test_file.name))
with open(self.test_file.name, 'r') as f:
saved_config = yaml.load(f, Loader=yaml.FullLoader)
self.assertEqual(saved_config['cache'], './test_cache/')
self.assertEqual(saved_config['db'], './test_db/')
self.assertEqual(len(saved_config['ais']), 1)
self.assertEqual(saved_config['ais']['myopenai']['system'], 'Custom system')
def test_from_file_error_unknown_ai(self) -> None:
source_dict = {
'cache': './test_cache/',
'db': './test_db/',
'ais': {
'default': {
'name': 'foobla',
'system': 'Custom system',
'api_key': '9876543210',
'model': 'custom_model',
'max_tokens': 5000,
'temperature': 0.5,
'top_p': 0.8,
'frequency_penalty': 0.7,
'presence_penalty': 0.2
}
}
}
with open(self.test_file.name, 'w') as f:
yaml.dump(source_dict, f)
with self.assertRaises(ConfigError):
Config.from_file(self.test_file.name)
-62
View File
@@ -1,62 +0,0 @@
import unittest
import argparse
import tempfile
import yaml
from pathlib import Path
from chatmastermind.message import Message, Question
from chatmastermind.chat import ChatDB, ChatError
from chatmastermind.configuration import Config
from chatmastermind.commands.hist import convert_messages
msg_suffix = Message.file_suffix_write
class TestConvertMessages(unittest.TestCase):
def setUp(self) -> None:
self.db_dir = tempfile.TemporaryDirectory()
self.cache_dir = tempfile.TemporaryDirectory()
self.db_path = Path(self.db_dir.name)
self.cache_path = Path(self.cache_dir.name)
self.args = argparse.Namespace()
self.config = Config()
self.config.cache = self.cache_dir.name
self.config.db = self.db_dir.name
# Prepare some messages
self.chat = ChatDB.from_dir(Path(self.cache_path),
Path(self.db_path))
self.messages = [Message(Question(f'Question {i}')) for i in range(0, 6)]
self.chat.db_write(self.messages[0:2])
self.chat.cache_write(self.messages[2:])
# Change some of the suffixes
assert self.messages[0].file_path
assert self.messages[1].file_path
self.messages[0].file_path.rename(self.messages[0].file_path.with_suffix('.txt'))
self.messages[1].file_path.rename(self.messages[1].file_path.with_suffix('.yaml'))
def tearDown(self) -> None:
self.db_dir.cleanup()
self.cache_dir.cleanup()
def test_convert_messages(self) -> None:
self.args.convert = 'yaml'
convert_messages(self.args, self.config)
msgs = self.chat.msg_gather(loc='disk', glob='*.*')
# Check if the number of messages is the same as before
self.assertEqual(len(msgs), len(self.messages))
# Check if all messages have the requested suffix
for msg in msgs:
assert msg.file_path
self.assertEqual(msg.file_path.suffix, msg_suffix)
# Check if the message IDs are correctly maintained
for m_new, m_old in zip(msgs, self.messages):
self.assertEqual(m_new.msg_id(), m_old.msg_id())
# check if all messages have the new format
for m in msgs:
with open(str(m.file_path), "r") as fd:
yaml.load(fd, Loader=yaml.FullLoader)
def test_convert_messages_wrong_format(self) -> None:
self.args.convert = 'foo'
with self.assertRaises(ChatError):
convert_messages(self.args, self.config)
+236
View File
@@ -0,0 +1,236 @@
import unittest
import io
import pathlib
import argparse
from chatmastermind.utils import terminal_width
from chatmastermind.main import create_parser, ask_cmd
from chatmastermind.api_client import ai
from chatmastermind.configuration import Config
from chatmastermind.storage import create_chat_hist, save_answers, dump_data
from unittest import mock
from unittest.mock import patch, MagicMock, Mock, ANY
class CmmTestCase(unittest.TestCase):
"""
Base class for all cmm testcases.
"""
def dummy_config(self, db: str) -> Config:
"""
Creates a dummy configuration.
"""
return Config.from_dict(
{'system': 'dummy_system',
'db': db,
'openai': {'api_key': 'dummy_key',
'model': 'dummy_model',
'max_tokens': 4000,
'temperature': 1.0,
'top_p': 1,
'frequency_penalty': 0,
'presence_penalty': 0}}
)
class TestCreateChat(CmmTestCase):
def setUp(self) -> None:
self.config = self.dummy_config(db='test_files')
self.question = "test question"
self.tags = ['test_tag']
@patch('os.listdir')
@patch('pathlib.Path.iterdir')
@patch('builtins.open')
def test_create_chat_with_tags(self, open_mock: MagicMock, iterdir_mock: MagicMock, listdir_mock: MagicMock) -> None:
listdir_mock.return_value = ['testfile.txt']
iterdir_mock.return_value = [pathlib.Path(x) for x in listdir_mock.return_value]
open_mock.return_value.__enter__.return_value = io.StringIO(dump_data(
{'question': 'test_content', 'answer': 'some answer',
'tags': ['test_tag']}))
test_chat = create_chat_hist(self.question, self.tags, None, self.config)
self.assertEqual(len(test_chat), 4)
self.assertEqual(test_chat[0],
{'role': 'system', 'content': self.config.system})
self.assertEqual(test_chat[1],
{'role': 'user', 'content': 'test_content'})
self.assertEqual(test_chat[2],
{'role': 'assistant', 'content': 'some answer'})
self.assertEqual(test_chat[3],
{'role': 'user', 'content': self.question})
@patch('os.listdir')
@patch('pathlib.Path.iterdir')
@patch('builtins.open')
def test_create_chat_with_other_tags(self, open_mock: MagicMock, iterdir_mock: MagicMock, listdir_mock: MagicMock) -> None:
listdir_mock.return_value = ['testfile.txt']
iterdir_mock.return_value = [pathlib.Path(x) for x in listdir_mock.return_value]
open_mock.return_value.__enter__.return_value = io.StringIO(dump_data(
{'question': 'test_content', 'answer': 'some answer',
'tags': ['other_tag']}))
test_chat = create_chat_hist(self.question, self.tags, None, self.config)
self.assertEqual(len(test_chat), 2)
self.assertEqual(test_chat[0],
{'role': 'system', 'content': self.config.system})
self.assertEqual(test_chat[1],
{'role': 'user', 'content': self.question})
@patch('os.listdir')
@patch('pathlib.Path.iterdir')
@patch('builtins.open')
def test_create_chat_without_tags(self, open_mock: MagicMock, iterdir_mock: MagicMock, listdir_mock: MagicMock) -> None:
listdir_mock.return_value = ['testfile.txt', 'testfile2.txt']
iterdir_mock.return_value = [pathlib.Path(x) for x in listdir_mock.return_value]
open_mock.side_effect = (
io.StringIO(dump_data({'question': 'test_content',
'answer': 'some answer',
'tags': ['test_tag']})),
io.StringIO(dump_data({'question': 'test_content2',
'answer': 'some answer2',
'tags': ['test_tag2']})),
)
test_chat = create_chat_hist(self.question, [], None, self.config)
self.assertEqual(len(test_chat), 6)
self.assertEqual(test_chat[0],
{'role': 'system', 'content': self.config.system})
self.assertEqual(test_chat[1],
{'role': 'user', 'content': 'test_content'})
self.assertEqual(test_chat[2],
{'role': 'assistant', 'content': 'some answer'})
self.assertEqual(test_chat[3],
{'role': 'user', 'content': 'test_content2'})
self.assertEqual(test_chat[4],
{'role': 'assistant', 'content': 'some answer2'})
class TestHandleQuestion(CmmTestCase):
def setUp(self) -> None:
self.question = "test question"
self.args = argparse.Namespace(
tags=['tag1'],
atags=None,
etags=['etag1'],
output_tags=None,
question=[self.question],
source=None,
source_code_only=False,
number=3,
max_tokens=None,
temperature=None,
model=None,
match_all_tags=False,
with_tags=False,
with_file=False,
)
self.config = self.dummy_config(db='test_files')
@patch("chatmastermind.main.create_chat_hist", return_value="test_chat")
@patch("chatmastermind.main.print_tag_args")
@patch("chatmastermind.main.print_chat_hist")
@patch("chatmastermind.main.ai", return_value=(["answer1", "answer2", "answer3"], "test_usage"))
@patch("chatmastermind.utils.pp")
@patch("builtins.print")
def test_ask_cmd(self, mock_print: MagicMock, mock_pp: MagicMock, mock_ai: MagicMock,
mock_print_chat_hist: MagicMock, mock_print_tag_args: MagicMock,
mock_create_chat_hist: MagicMock) -> None:
open_mock = MagicMock()
with patch("chatmastermind.storage.open", open_mock):
ask_cmd(self.args, self.config)
mock_print_tag_args.assert_called_once_with(self.args.tags,
self.args.etags,
[])
mock_create_chat_hist.assert_called_once_with(self.question,
self.args.tags,
self.args.etags,
self.config,
match_all_tags=False,
with_tags=False,
with_file=False)
mock_print_chat_hist.assert_called_once_with('test_chat',
False,
self.args.source_code_only)
mock_ai.assert_called_with("test_chat",
self.config,
self.args.number)
expected_calls = []
for num, answer in enumerate(mock_ai.return_value[0], start=1):
title = f'-- ANSWER {num} '
title_end = '-' * (terminal_width() - len(title))
expected_calls.append(((f'{title}{title_end}',),))
expected_calls.append(((answer,),))
expected_calls.append((("-" * terminal_width(),),))
expected_calls.append(((f"Usage: {mock_ai.return_value[1]}",),))
self.assertEqual(mock_print.call_args_list, expected_calls)
open_expected_calls = list([mock.call(f"{num:04d}.txt", "w") for num in range(2, 5)])
open_mock.assert_has_calls(open_expected_calls, any_order=True)
class TestSaveAnswers(CmmTestCase):
@mock.patch('builtins.open')
@mock.patch('chatmastermind.storage.print')
def test_save_answers(self, print_mock: MagicMock, open_mock: MagicMock) -> None:
question = "Test question?"
answers = ["Answer 1", "Answer 2"]
tags = ["tag1", "tag2"]
otags = ["otag1", "otag2"]
config = self.dummy_config(db='test_db')
with mock.patch('chatmastermind.storage.pathlib.Path.exists', return_value=True), \
mock.patch('chatmastermind.storage.yaml.dump'), \
mock.patch('io.StringIO') as stringio_mock:
stringio_instance = stringio_mock.return_value
stringio_instance.getvalue.side_effect = ["question", "answer1", "answer2"]
save_answers(question, answers, tags, otags, config)
open_calls = [
mock.call(pathlib.Path('test_db/.next'), 'r'),
mock.call(pathlib.Path('test_db/.next'), 'w'),
]
open_mock.assert_has_calls(open_calls, any_order=True)
class TestAI(CmmTestCase):
@patch("openai.ChatCompletion.create")
def test_ai(self, mock_create: MagicMock) -> None:
mock_create.return_value = {
'choices': [
{'message': {'content': 'response_text_1'}},
{'message': {'content': 'response_text_2'}}
],
'usage': {'tokens': 10}
}
chat = [{"role": "system", "content": "hello ai"}]
config = self.dummy_config(db='dummy')
config.openai.model = "text-davinci-002"
config.openai.max_tokens = 150
config.openai.temperature = 0.5
result = ai(chat, config, 2)
expected_result = (['response_text_1', 'response_text_2'],
{'tokens': 10})
self.assertEqual(result, expected_result)
class TestCreateParser(CmmTestCase):
def test_create_parser(self) -> None:
with patch('argparse.ArgumentParser.add_subparsers') as mock_add_subparsers:
mock_cmdparser = Mock()
mock_add_subparsers.return_value = mock_cmdparser
parser = create_parser()
self.assertIsInstance(parser, argparse.ArgumentParser)
mock_add_subparsers.assert_called_once_with(dest='command', title='commands', description='supported commands', required=True)
mock_cmdparser.add_parser.assert_any_call('ask', parents=ANY, help=ANY, aliases=ANY)
mock_cmdparser.add_parser.assert_any_call('hist', parents=ANY, help=ANY, aliases=ANY)
mock_cmdparser.add_parser.assert_any_call('tags', help=ANY, aliases=ANY)
mock_cmdparser.add_parser.assert_any_call('config', help=ANY, aliases=ANY)
mock_cmdparser.add_parser.assert_any_call('print', help=ANY, aliases=ANY)
self.assertTrue('.config.yaml' in parser.get_default('config'))
+55 -148
View File
@@ -1,17 +1,12 @@
import unittest
import pathlib import pathlib
import tempfile import tempfile
import itertools
from typing import cast from typing import cast
from chatmastermind.message import source_code, Message, MessageError, Question, Answer, AILine, ModelLine,\ from .test_main import CmmTestCase
MessageFilter, message_in, message_valid_formats from chatmastermind.message import source_code, Message, MessageError, Question, Answer, AILine, ModelLine, MessageFilter, message_in
from chatmastermind.tags import Tag, TagLine from chatmastermind.tags import Tag, TagLine
msg_suffix: str = Message.file_suffix_write class SourceCodeTestCase(CmmTestCase):
class SourceCodeTestCase(unittest.TestCase):
def test_source_code_with_include_delims(self) -> None: def test_source_code_with_include_delims(self) -> None:
text = """ text = """
Some text before the code block Some text before the code block
@@ -65,7 +60,7 @@ class SourceCodeTestCase(unittest.TestCase):
self.assertEqual(result, expected_result) self.assertEqual(result, expected_result)
class QuestionTestCase(unittest.TestCase): class QuestionTestCase(CmmTestCase):
def test_question_with_header(self) -> None: def test_question_with_header(self) -> None:
with self.assertRaises(MessageError): with self.assertRaises(MessageError):
Question(f"{Question.txt_header}\nWhat is your name?") Question(f"{Question.txt_header}\nWhat is your name?")
@@ -88,7 +83,7 @@ class QuestionTestCase(unittest.TestCase):
self.assertEqual(question, "What is your favorite color?") self.assertEqual(question, "What is your favorite color?")
class AnswerTestCase(unittest.TestCase): class AnswerTestCase(CmmTestCase):
def test_answer_with_header(self) -> None: def test_answer_with_header(self) -> None:
with self.assertRaises(MessageError): with self.assertRaises(MessageError):
Answer(f"{Answer.txt_header}\nno") Answer(f"{Answer.txt_header}\nno")
@@ -104,9 +99,9 @@ class AnswerTestCase(unittest.TestCase):
self.assertEqual(answer, "No") self.assertEqual(answer, "No")
class MessageToFileTxtTestCase(unittest.TestCase): class MessageToFileTxtTestCase(CmmTestCase):
def setUp(self) -> None: def setUp(self) -> None:
self.file = tempfile.NamedTemporaryFile(delete=False, suffix=msg_suffix) self.file = tempfile.NamedTemporaryFile(delete=False, suffix='.txt')
self.file_path = pathlib.Path(self.file.name) self.file_path = pathlib.Path(self.file.name)
self.message_complete = Message(Question('This is a question.'), self.message_complete = Message(Question('This is a question.'),
Answer('This is an answer.'), Answer('This is an answer.'),
@@ -122,7 +117,7 @@ class MessageToFileTxtTestCase(unittest.TestCase):
self.file_path.unlink() self.file_path.unlink()
def test_to_file_txt_complete(self) -> None: def test_to_file_txt_complete(self) -> None:
self.message_complete.to_file(self.file_path, mformat='txt') self.message_complete.to_file(self.file_path)
with open(self.file_path, "r") as fd: with open(self.file_path, "r") as fd:
content = fd.read() content = fd.read()
@@ -137,7 +132,7 @@ This is an answer.
self.assertEqual(content, expected_content) self.assertEqual(content, expected_content)
def test_to_file_txt_min(self) -> None: def test_to_file_txt_min(self) -> None:
self.message_min.to_file(self.file_path, mformat='txt') self.message_min.to_file(self.file_path)
with open(self.file_path, "r") as fd: with open(self.file_path, "r") as fd:
content = fd.read() content = fd.read()
@@ -146,17 +141,11 @@ This is a question.
""" """
self.assertEqual(content, expected_content) self.assertEqual(content, expected_content)
def test_to_file_unsupported_file_suffix(self) -> None: def test_to_file_unsupported_file_type(self) -> None:
unsupported_file_path = pathlib.Path("example.doc") unsupported_file_path = pathlib.Path("example.doc")
with self.assertRaises(MessageError) as cm: with self.assertRaises(MessageError) as cm:
self.message_complete.to_file(unsupported_file_path) self.message_complete.to_file(unsupported_file_path)
self.assertEqual(str(cm.exception), "File suffix '.doc' is not supported") self.assertEqual(str(cm.exception), "File type '.doc' is not supported")
def test_to_file_unsupported_file_format(self) -> None:
unsupported_file_format = pathlib.Path(f"example{msg_suffix}")
with self.assertRaises(MessageError) as cm:
self.message_complete.to_file(unsupported_file_format, mformat='doc') # type: ignore [arg-type]
self.assertEqual(str(cm.exception), "File format 'doc' is not supported")
def test_to_file_no_file_path(self) -> None: def test_to_file_no_file_path(self) -> None:
""" """
@@ -170,24 +159,10 @@ This is a question.
# reset the internal file_path # reset the internal file_path
self.message_complete.file_path = self.file_path self.message_complete.file_path = self.file_path
def test_to_file_txt_auto_suffix(self) -> None:
"""
Test if suffix is auto-generated if omitted.
"""
file_path_no_suffix = self.file_path.with_suffix('')
# test with file_path member
self.message_min.file_path = file_path_no_suffix
self.message_min.to_file(mformat='txt')
self.assertEqual(self.message_min.file_path.suffix, msg_suffix)
# test with explicit file_path
self.message_min.file_path = file_path_no_suffix
self.message_min.to_file(file_path=file_path_no_suffix, mformat='txt')
self.assertEqual(self.message_min.file_path.suffix, msg_suffix)
class MessageToFileYamlTestCase(CmmTestCase):
class MessageToFileYamlTestCase(unittest.TestCase):
def setUp(self) -> None: def setUp(self) -> None:
self.file = tempfile.NamedTemporaryFile(delete=False, suffix=msg_suffix) self.file = tempfile.NamedTemporaryFile(delete=False, suffix='.yaml')
self.file_path = pathlib.Path(self.file.name) self.file_path = pathlib.Path(self.file.name)
self.message_complete = Message(Question('This is a question.'), self.message_complete = Message(Question('This is a question.'),
Answer('This is an answer.'), Answer('This is an answer.'),
@@ -209,7 +184,7 @@ class MessageToFileYamlTestCase(unittest.TestCase):
self.file_path.unlink() self.file_path.unlink()
def test_to_file_yaml_complete(self) -> None: def test_to_file_yaml_complete(self) -> None:
self.message_complete.to_file(self.file_path, mformat='yaml') self.message_complete.to_file(self.file_path)
with open(self.file_path, "r") as fd: with open(self.file_path, "r") as fd:
content = fd.read() content = fd.read()
@@ -224,7 +199,7 @@ class MessageToFileYamlTestCase(unittest.TestCase):
self.assertEqual(content, expected_content) self.assertEqual(content, expected_content)
def test_to_file_yaml_multiline(self) -> None: def test_to_file_yaml_multiline(self) -> None:
self.message_multiline.to_file(self.file_path, mformat='yaml') self.message_multiline.to_file(self.file_path)
with open(self.file_path, "r") as fd: with open(self.file_path, "r") as fd:
content = fd.read() content = fd.read()
@@ -243,31 +218,17 @@ class MessageToFileYamlTestCase(unittest.TestCase):
self.assertEqual(content, expected_content) self.assertEqual(content, expected_content)
def test_to_file_yaml_min(self) -> None: def test_to_file_yaml_min(self) -> None:
self.message_min.to_file(self.file_path, mformat='yaml') self.message_min.to_file(self.file_path)
with open(self.file_path, "r") as fd: with open(self.file_path, "r") as fd:
content = fd.read() content = fd.read()
expected_content = f"{Question.yaml_key}: This is a question.\n" expected_content = f"{Question.yaml_key}: This is a question.\n"
self.assertEqual(content, expected_content) self.assertEqual(content, expected_content)
def test_to_file_yaml_auto_suffix(self) -> None:
"""
Test if suffix is auto-generated if omitted.
"""
file_path_no_suffix = self.file_path.with_suffix('')
# test with file_path member
self.message_min.file_path = file_path_no_suffix
self.message_min.to_file(mformat='yaml')
self.assertEqual(self.message_min.file_path.suffix, msg_suffix)
# test with explicit file_path
self.message_min.file_path = file_path_no_suffix
self.message_min.to_file(file_path=file_path_no_suffix, mformat='yaml')
self.assertEqual(self.message_min.file_path.suffix, msg_suffix)
class MessageFromFileTxtTestCase(CmmTestCase):
class MessageFromFileTxtTestCase(unittest.TestCase):
def setUp(self) -> None: def setUp(self) -> None:
self.file = tempfile.NamedTemporaryFile(delete=False, suffix=msg_suffix) self.file = tempfile.NamedTemporaryFile(delete=False, suffix='.txt')
self.file_path = pathlib.Path(self.file.name) self.file_path = pathlib.Path(self.file.name)
with open(self.file_path, "w") as fd: with open(self.file_path, "w") as fd:
fd.write(f"""{TagLine.prefix} tag1 tag2 fd.write(f"""{TagLine.prefix} tag1 tag2
@@ -278,7 +239,7 @@ This is a question.
{Answer.txt_header} {Answer.txt_header}
This is an answer. This is an answer.
""") """)
self.file_min = tempfile.NamedTemporaryFile(delete=False, suffix=msg_suffix) self.file_min = tempfile.NamedTemporaryFile(delete=False, suffix='.txt')
self.file_path_min = pathlib.Path(self.file_min.name) self.file_path_min = pathlib.Path(self.file_min.name)
with open(self.file_path_min, "w") as fd: with open(self.file_path_min, "w") as fd:
fd.write(f"""{Question.txt_header} fd.write(f"""{Question.txt_header}
@@ -298,7 +259,7 @@ This is a question.
message = Message.from_file(self.file_path) message = Message.from_file(self.file_path)
self.assertIsNotNone(message) self.assertIsNotNone(message)
self.assertIsInstance(message, Message) self.assertIsInstance(message, Message)
assert message if message: # mypy bug
self.assertEqual(message.question, 'This is a question.') self.assertEqual(message.question, 'This is a question.')
self.assertEqual(message.answer, 'This is an answer.') self.assertEqual(message.answer, 'This is an answer.')
self.assertSetEqual(cast(set[Tag], message.tags), {Tag('tag1'), Tag('tag2')}) self.assertSetEqual(cast(set[Tag], message.tags), {Tag('tag1'), Tag('tag2')})
@@ -313,7 +274,7 @@ This is a question.
message = Message.from_file(self.file_path_min) message = Message.from_file(self.file_path_min)
self.assertIsNotNone(message) self.assertIsNotNone(message)
self.assertIsInstance(message, Message) self.assertIsInstance(message, Message)
assert message if message: # mypy bug
self.assertEqual(message.question, 'This is a question.') self.assertEqual(message.question, 'This is a question.')
self.assertEqual(message.file_path, self.file_path_min) self.assertEqual(message.file_path, self.file_path_min)
self.assertIsNone(message.answer) self.assertIsNone(message.answer)
@@ -323,7 +284,7 @@ This is a question.
MessageFilter(tags_or={Tag('tag1')})) MessageFilter(tags_or={Tag('tag1')}))
self.assertIsNotNone(message) self.assertIsNotNone(message)
self.assertIsInstance(message, Message) self.assertIsInstance(message, Message)
assert message if message: # mypy bug
self.assertEqual(message.question, 'This is a question.') self.assertEqual(message.question, 'This is a question.')
self.assertEqual(message.answer, 'This is an answer.') self.assertEqual(message.answer, 'This is an answer.')
self.assertSetEqual(cast(set[Tag], message.tags), {Tag('tag1'), Tag('tag2')}) self.assertSetEqual(cast(set[Tag], message.tags), {Tag('tag1'), Tag('tag2')})
@@ -339,24 +300,18 @@ This is a question.
MessageFilter(tags_or={Tag('tag1')})) MessageFilter(tags_or={Tag('tag1')}))
self.assertIsNone(message) self.assertIsNone(message)
def test_from_file_txt_empty_tags_dont_match(self) -> None:
message = Message.from_file(self.file_path_min,
MessageFilter(tags_or=set(),
tags_and=set()))
self.assertIsNone(message)
def test_from_file_txt_no_tags_match_tags_not(self) -> None: def test_from_file_txt_no_tags_match_tags_not(self) -> None:
message = Message.from_file(self.file_path_min, message = Message.from_file(self.file_path_min,
MessageFilter(tags_not={Tag('tag1')})) MessageFilter(tags_not={Tag('tag1')}))
self.assertIsNotNone(message) self.assertIsNotNone(message)
self.assertIsInstance(message, Message) self.assertIsInstance(message, Message)
assert message if message: # mypy bug
self.assertEqual(message.question, 'This is a question.') self.assertEqual(message.question, 'This is a question.')
self.assertSetEqual(cast(set[Tag], message.tags), set()) self.assertSetEqual(cast(set[Tag], message.tags), set())
self.assertEqual(message.file_path, self.file_path_min) self.assertEqual(message.file_path, self.file_path_min)
def test_from_file_not_exists(self) -> None: def test_from_file_not_exists(self) -> None:
file_not_exists = pathlib.Path(f"example{msg_suffix}") file_not_exists = pathlib.Path("example.txt")
with self.assertRaises(MessageError) as cm: with self.assertRaises(MessageError) as cm:
Message.from_file(file_not_exists) Message.from_file(file_not_exists)
self.assertEqual(str(cm.exception), f"Message file '{file_not_exists}' does not exist") self.assertEqual(str(cm.exception), f"Message file '{file_not_exists}' does not exist")
@@ -433,9 +388,9 @@ This is a question.
self.assertIsNone(message) self.assertIsNone(message)
class MessageFromFileYamlTestCase(unittest.TestCase): class MessageFromFileYamlTestCase(CmmTestCase):
def setUp(self) -> None: def setUp(self) -> None:
self.file = tempfile.NamedTemporaryFile(delete=False, suffix=msg_suffix) self.file = tempfile.NamedTemporaryFile(delete=False, suffix='.yaml')
self.file_path = pathlib.Path(self.file.name) self.file_path = pathlib.Path(self.file.name)
with open(self.file_path, "w") as fd: with open(self.file_path, "w") as fd:
fd.write(f""" fd.write(f"""
@@ -449,7 +404,7 @@ class MessageFromFileYamlTestCase(unittest.TestCase):
- tag1 - tag1
- tag2 - tag2
""") """)
self.file_min = tempfile.NamedTemporaryFile(delete=False, suffix=msg_suffix) self.file_min = tempfile.NamedTemporaryFile(delete=False, suffix='.yaml')
self.file_path_min = pathlib.Path(self.file_min.name) self.file_path_min = pathlib.Path(self.file_min.name)
with open(self.file_path_min, "w") as fd: with open(self.file_path_min, "w") as fd:
fd.write(f""" fd.write(f"""
@@ -470,7 +425,7 @@ class MessageFromFileYamlTestCase(unittest.TestCase):
message = Message.from_file(self.file_path) message = Message.from_file(self.file_path)
self.assertIsInstance(message, Message) self.assertIsInstance(message, Message)
self.assertIsNotNone(message) self.assertIsNotNone(message)
assert message if message: # mypy bug
self.assertEqual(message.question, 'This is a question.') self.assertEqual(message.question, 'This is a question.')
self.assertEqual(message.answer, 'This is an answer.') self.assertEqual(message.answer, 'This is an answer.')
self.assertSetEqual(cast(set[Tag], message.tags), {Tag('tag1'), Tag('tag2')}) self.assertSetEqual(cast(set[Tag], message.tags), {Tag('tag1'), Tag('tag2')})
@@ -485,14 +440,14 @@ class MessageFromFileYamlTestCase(unittest.TestCase):
message = Message.from_file(self.file_path_min) message = Message.from_file(self.file_path_min)
self.assertIsInstance(message, Message) self.assertIsInstance(message, Message)
self.assertIsNotNone(message) self.assertIsNotNone(message)
assert message if message: # mypy bug
self.assertEqual(message.question, 'This is a question.') self.assertEqual(message.question, 'This is a question.')
self.assertSetEqual(cast(set[Tag], message.tags), set()) self.assertSetEqual(cast(set[Tag], message.tags), set())
self.assertEqual(message.file_path, self.file_path_min) self.assertEqual(message.file_path, self.file_path_min)
self.assertIsNone(message.answer) self.assertIsNone(message.answer)
def test_from_file_not_exists(self) -> None: def test_from_file_not_exists(self) -> None:
file_not_exists = pathlib.Path(f"example{msg_suffix}") file_not_exists = pathlib.Path("example.yaml")
with self.assertRaises(MessageError) as cm: with self.assertRaises(MessageError) as cm:
Message.from_file(file_not_exists) Message.from_file(file_not_exists)
self.assertEqual(str(cm.exception), f"Message file '{file_not_exists}' does not exist") self.assertEqual(str(cm.exception), f"Message file '{file_not_exists}' does not exist")
@@ -502,7 +457,7 @@ class MessageFromFileYamlTestCase(unittest.TestCase):
MessageFilter(tags_or={Tag('tag1')})) MessageFilter(tags_or={Tag('tag1')}))
self.assertIsNotNone(message) self.assertIsNotNone(message)
self.assertIsInstance(message, Message) self.assertIsInstance(message, Message)
assert message if message: # mypy bug
self.assertEqual(message.question, 'This is a question.') self.assertEqual(message.question, 'This is a question.')
self.assertEqual(message.answer, 'This is an answer.') self.assertEqual(message.answer, 'This is an answer.')
self.assertSetEqual(cast(set[Tag], message.tags), {Tag('tag1'), Tag('tag2')}) self.assertSetEqual(cast(set[Tag], message.tags), {Tag('tag1'), Tag('tag2')})
@@ -523,7 +478,7 @@ class MessageFromFileYamlTestCase(unittest.TestCase):
MessageFilter(tags_not={Tag('tag1')})) MessageFilter(tags_not={Tag('tag1')}))
self.assertIsNotNone(message) self.assertIsNotNone(message)
self.assertIsInstance(message, Message) self.assertIsInstance(message, Message)
assert message if message: # mypy bug
self.assertEqual(message.question, 'This is a question.') self.assertEqual(message.question, 'This is a question.')
self.assertSetEqual(cast(set[Tag], message.tags), set()) self.assertSetEqual(cast(set[Tag], message.tags), set())
self.assertEqual(message.file_path, self.file_path_min) self.assertEqual(message.file_path, self.file_path_min)
@@ -600,9 +555,9 @@ class MessageFromFileYamlTestCase(unittest.TestCase):
self.assertIsNone(message) self.assertIsNone(message)
class TagsFromFileTestCase(unittest.TestCase): class TagsFromFileTestCase(CmmTestCase):
def setUp(self) -> None: def setUp(self) -> None:
self.file_txt = tempfile.NamedTemporaryFile(delete=False, suffix=msg_suffix) self.file_txt = tempfile.NamedTemporaryFile(delete=False, suffix='.txt')
self.file_path_txt = pathlib.Path(self.file_txt.name) self.file_path_txt = pathlib.Path(self.file_txt.name)
with open(self.file_path_txt, "w") as fd: with open(self.file_path_txt, "w") as fd:
fd.write(f"""{TagLine.prefix} tag1 tag2 ptag3 fd.write(f"""{TagLine.prefix} tag1 tag2 ptag3
@@ -611,7 +566,7 @@ This is a question.
{Answer.txt_header} {Answer.txt_header}
This is an answer. This is an answer.
""") """)
self.file_txt_no_tags = tempfile.NamedTemporaryFile(delete=False, suffix=msg_suffix) self.file_txt_no_tags = tempfile.NamedTemporaryFile(delete=False, suffix='.txt')
self.file_path_txt_no_tags = pathlib.Path(self.file_txt_no_tags.name) self.file_path_txt_no_tags = pathlib.Path(self.file_txt_no_tags.name)
with open(self.file_path_txt_no_tags, "w") as fd: with open(self.file_path_txt_no_tags, "w") as fd:
fd.write(f"""{Question.txt_header} fd.write(f"""{Question.txt_header}
@@ -619,7 +574,7 @@ This is a question.
{Answer.txt_header} {Answer.txt_header}
This is an answer. This is an answer.
""") """)
self.file_txt_tags_empty = tempfile.NamedTemporaryFile(delete=False, suffix=msg_suffix) self.file_txt_tags_empty = tempfile.NamedTemporaryFile(delete=False, suffix='.txt')
self.file_path_txt_tags_empty = pathlib.Path(self.file_txt_tags_empty.name) self.file_path_txt_tags_empty = pathlib.Path(self.file_txt_tags_empty.name)
with open(self.file_path_txt_tags_empty, "w") as fd: with open(self.file_path_txt_tags_empty, "w") as fd:
fd.write(f"""TAGS: fd.write(f"""TAGS:
@@ -628,7 +583,7 @@ This is a question.
{Answer.txt_header} {Answer.txt_header}
This is an answer. This is an answer.
""") """)
self.file_yaml = tempfile.NamedTemporaryFile(delete=False, suffix=msg_suffix) self.file_yaml = tempfile.NamedTemporaryFile(delete=False, suffix='.yaml')
self.file_path_yaml = pathlib.Path(self.file_yaml.name) self.file_path_yaml = pathlib.Path(self.file_yaml.name)
with open(self.file_path_yaml, "w") as fd: with open(self.file_path_yaml, "w") as fd:
fd.write(f""" fd.write(f"""
@@ -641,7 +596,7 @@ This is an answer.
- tag2 - tag2
- ptag3 - ptag3
""") """)
self.file_yaml_no_tags = tempfile.NamedTemporaryFile(delete=False, suffix=msg_suffix) self.file_yaml_no_tags = tempfile.NamedTemporaryFile(delete=False, suffix='.yaml')
self.file_path_yaml_no_tags = pathlib.Path(self.file_yaml_no_tags.name) self.file_path_yaml_no_tags = pathlib.Path(self.file_yaml_no_tags.name)
with open(self.file_path_yaml_no_tags, "w") as fd: with open(self.file_path_yaml_no_tags, "w") as fd:
fd.write(f""" fd.write(f"""
@@ -708,7 +663,7 @@ This is an answer.
self.assertSetEqual(tags, set()) self.assertSetEqual(tags, set())
class TagsFromDirTestCase(unittest.TestCase): class TagsFromDirTestCase(CmmTestCase):
def setUp(self) -> None: def setUp(self) -> None:
self.temp_dir = tempfile.TemporaryDirectory() self.temp_dir = tempfile.TemporaryDirectory()
self.temp_dir_no_tags = tempfile.TemporaryDirectory() self.temp_dir_no_tags = tempfile.TemporaryDirectory()
@@ -718,25 +673,24 @@ class TagsFromDirTestCase(unittest.TestCase):
{Tag('ctag5'), Tag('ctag6')} {Tag('ctag5'), Tag('ctag6')}
] ]
self.files = [ self.files = [
pathlib.Path(self.temp_dir.name, f'file1{msg_suffix}'), pathlib.Path(self.temp_dir.name, 'file1.txt'),
pathlib.Path(self.temp_dir.name, f'file2{msg_suffix}'), pathlib.Path(self.temp_dir.name, 'file2.yaml'),
pathlib.Path(self.temp_dir.name, f'file3{msg_suffix}') pathlib.Path(self.temp_dir.name, 'file3.txt')
] ]
self.files_no_tags = [ self.files_no_tags = [
pathlib.Path(self.temp_dir_no_tags.name, f'file4{msg_suffix}'), pathlib.Path(self.temp_dir_no_tags.name, 'file4.txt'),
pathlib.Path(self.temp_dir_no_tags.name, f'file5{msg_suffix}'), pathlib.Path(self.temp_dir_no_tags.name, 'file5.yaml'),
pathlib.Path(self.temp_dir_no_tags.name, f'file6{msg_suffix}') pathlib.Path(self.temp_dir_no_tags.name, 'file6.txt')
] ]
mformats = itertools.cycle(message_valid_formats)
for file, tags in zip(self.files, self.tag_sets): for file, tags in zip(self.files, self.tag_sets):
message = Message(Question('This is a question.'), message = Message(Question('This is a question.'),
Answer('This is an answer.'), Answer('This is an answer.'),
tags) tags)
message.to_file(file, next(mformats)) message.to_file(file)
for file in self.files_no_tags: for file in self.files_no_tags:
message = Message(Question('This is a question.'), message = Message(Question('This is a question.'),
Answer('This is an answer.')) Answer('This is an answer.'))
message.to_file(file, next(mformats)) message.to_file(file)
def tearDown(self) -> None: def tearDown(self) -> None:
self.temp_dir.cleanup() self.temp_dir.cleanup()
@@ -757,9 +711,9 @@ class TagsFromDirTestCase(unittest.TestCase):
self.assertSetEqual(all_tags, set()) self.assertSetEqual(all_tags, set())
class MessageIDTestCase(unittest.TestCase): class MessageIDTestCase(CmmTestCase):
def setUp(self) -> None: def setUp(self) -> None:
self.file = tempfile.NamedTemporaryFile(delete=False, suffix=msg_suffix) self.file = tempfile.NamedTemporaryFile(delete=False, suffix='.txt')
self.file_path = pathlib.Path(self.file.name) self.file_path = pathlib.Path(self.file.name)
self.message = Message(Question('This is a question.'), self.message = Message(Question('This is a question.'),
file_path=self.file_path) file_path=self.file_path)
@@ -770,14 +724,14 @@ class MessageIDTestCase(unittest.TestCase):
self.file_path.unlink() self.file_path.unlink()
def test_msg_id_txt(self) -> None: def test_msg_id_txt(self) -> None:
self.assertEqual(self.message.msg_id(), self.file_path.stem) self.assertEqual(self.message.msg_id(), self.file_path.name)
def test_msg_id_txt_exception(self) -> None: def test_msg_id_txt_exception(self) -> None:
with self.assertRaises(MessageError): with self.assertRaises(MessageError):
self.message_no_file_path.msg_id() self.message_no_file_path.msg_id()
class MessageHashTestCase(unittest.TestCase): class MessageHashTestCase(CmmTestCase):
def setUp(self) -> None: def setUp(self) -> None:
self.message1 = Message(Question('This is a question.'), self.message1 = Message(Question('This is a question.'),
tags={Tag('tag1')}, tags={Tag('tag1')},
@@ -801,7 +755,7 @@ class MessageHashTestCase(unittest.TestCase):
self.assertIn(msg, msgs) self.assertIn(msg, msgs)
class MessageTagsStrTestCase(unittest.TestCase): class MessageTagsStrTestCase(CmmTestCase):
def setUp(self) -> None: def setUp(self) -> None:
self.message = Message(Question('This is a question.'), self.message = Message(Question('This is a question.'),
tags={Tag('tag1')}, tags={Tag('tag1')},
@@ -811,7 +765,7 @@ class MessageTagsStrTestCase(unittest.TestCase):
self.assertEqual(self.message.tags_str(), f'{TagLine.prefix} tag1') self.assertEqual(self.message.tags_str(), f'{TagLine.prefix} tag1')
class MessageFilterTagsTestCase(unittest.TestCase): class MessageFilterTagsTestCase(CmmTestCase):
def setUp(self) -> None: def setUp(self) -> None:
self.message = Message(Question('This is a question.'), self.message = Message(Question('This is a question.'),
tags={Tag('atag1'), Tag('btag2')}, tags={Tag('atag1'), Tag('btag2')},
@@ -826,7 +780,7 @@ class MessageFilterTagsTestCase(unittest.TestCase):
self.assertSetEqual(tags_cont, {Tag('btag2')}) self.assertSetEqual(tags_cont, {Tag('btag2')})
class MessageInTestCase(unittest.TestCase): class MessageInTestCase(CmmTestCase):
def setUp(self) -> None: def setUp(self) -> None:
self.message1 = Message(Question('This is a question.'), self.message1 = Message(Question('This is a question.'),
tags={Tag('atag1'), Tag('btag2')}, tags={Tag('atag1'), Tag('btag2')},
@@ -840,7 +794,7 @@ class MessageInTestCase(unittest.TestCase):
self.assertFalse(message_in(self.message1, [self.message2])) self.assertFalse(message_in(self.message1, [self.message2]))
class MessageRenameTagsTestCase(unittest.TestCase): class MessageRenameTagsTestCase(CmmTestCase):
def setUp(self) -> None: def setUp(self) -> None:
self.message = Message(Question('This is a question.'), self.message = Message(Question('This is a question.'),
tags={Tag('atag1'), Tag('btag2')}, tags={Tag('atag1'), Tag('btag2')},
@@ -850,50 +804,3 @@ class MessageRenameTagsTestCase(unittest.TestCase):
self.message.rename_tags({(Tag('atag1'), Tag('atag2')), (Tag('btag2'), Tag('btag3'))}) self.message.rename_tags({(Tag('atag1'), Tag('atag2')), (Tag('btag2'), Tag('btag3'))})
self.assertIsNotNone(self.message.tags) self.assertIsNotNone(self.message.tags)
self.assertSetEqual(self.message.tags, {Tag('atag2'), Tag('btag3')}) # type: ignore [arg-type] self.assertSetEqual(self.message.tags, {Tag('atag2'), Tag('btag3')}) # type: ignore [arg-type]
class MessageToStrTestCase(unittest.TestCase):
def setUp(self) -> None:
self.message = Message(Question('This is a question.'),
Answer('This is an answer.'),
ai=('FakeAI'),
model=('FakeModel'),
tags={Tag('atag1'), Tag('btag2')},
file_path=pathlib.Path('/tmp/foo/bla'))
def test_to_str(self) -> None:
expected_output = f"""{Question.txt_header}
This is a question.
{Answer.txt_header}
This is an answer."""
self.assertEqual(self.message.to_str(), expected_output)
def test_to_str_with_tags_and_file(self) -> None:
expected_output = f"""{TagLine.prefix} atag1 btag2
FILE: /tmp/foo/bla
AI: FakeAI
MODEL: FakeModel
{Question.txt_header}
This is a question.
{Answer.txt_header}
This is an answer."""
self.assertEqual(self.message.to_str(with_metadata=True), expected_output)
class MessageRmFileTestCase(unittest.TestCase):
def setUp(self) -> None:
self.file = tempfile.NamedTemporaryFile(delete=False, suffix=msg_suffix)
self.file_path = pathlib.Path(self.file.name)
self.message = Message(Question('This is a question.'),
file_path=self.file_path)
self.message.to_file()
def tearDown(self) -> None:
self.file.close()
self.file_path.unlink(missing_ok=True)
def test_rm_file(self) -> None:
assert self.message.file_path
self.assertTrue(self.message.file_path.exists())
self.message.rm_file()
self.assertFalse(self.message.file_path.exists())
-593
View File
@@ -1,593 +0,0 @@
import os
import argparse
import tempfile
from copy import copy
from pathlib import Path
from unittest import mock
from unittest.mock import MagicMock, call
from chatmastermind.configuration import Config
from chatmastermind.commands.question import create_message, question_cmd
from chatmastermind.tags import Tag
from chatmastermind.message import Message, Question, Answer
from chatmastermind.chat import Chat, ChatDB
from chatmastermind.ai import AIError
from .test_common import TestWithFakeAI
msg_suffix = Message.file_suffix_write
class TestMessageCreate(TestWithFakeAI):
"""
Test if messages created by the 'question' command have
the correct format.
"""
def setUp(self) -> None:
# create ChatDB structure
self.db_dir = tempfile.TemporaryDirectory()
self.cache_dir = tempfile.TemporaryDirectory()
self.chat = ChatDB.from_dir(cache_path=Path(self.cache_dir.name),
db_path=Path(self.db_dir.name))
# create some messages
self.message_text = Message(Question("What is this?"),
Answer("It is pure text"))
self.message_code = Message(Question("What is this?"),
Answer("Text\n```\nIt is embedded code\n```\ntext"))
self.chat.db_add([self.message_text, self.message_code])
# create arguments mock
self.args = MagicMock(spec=argparse.Namespace)
self.args.source_text = None
self.args.source_code = None
self.args.AI = None
self.args.model = None
self.args.output_tags = None
self.args.ask = None
self.args.create = None
# File 1 : no source code block, only text
self.source_file1 = tempfile.NamedTemporaryFile(delete=False)
self.source_file1_content = """This is just text.
No source code.
Nope. Go look elsewhere!"""
with open(self.source_file1.name, 'w') as f:
f.write(self.source_file1_content)
# File 2 : one embedded source code block
self.source_file2 = tempfile.NamedTemporaryFile(delete=False)
self.source_file2_content = """This is just text.
```
This is embedded source code.
```
And some text again."""
with open(self.source_file2.name, 'w') as f:
f.write(self.source_file2_content)
# File 3 : all source code
self.source_file3 = tempfile.NamedTemporaryFile(delete=False)
self.source_file3_content = """This is all source code.
Yes, really.
Language is called 'brainfart'."""
with open(self.source_file3.name, 'w') as f:
f.write(self.source_file3_content)
# File 4 : two source code blocks
self.source_file4 = tempfile.NamedTemporaryFile(delete=False)
self.source_file4_content = """This is just text.
```
This is embedded source code.
```
And some text again.
```
This is embedded source code.
```
Aaaand again some text."""
with open(self.source_file4.name, 'w') as f:
f.write(self.source_file4_content)
def tearDown(self) -> None:
os.remove(self.source_file1.name)
os.remove(self.source_file2.name)
os.remove(self.source_file3.name)
os.remove(self.source_file4.name)
def message_list(self, tmp_dir: tempfile.TemporaryDirectory) -> list[Path]:
# exclude '.next'
return list(Path(tmp_dir.name).glob(f'*{msg_suffix}'))
def test_message_file_created(self) -> None:
self.args.ask = ["What is this?"]
cache_dir_files = self.message_list(self.cache_dir)
self.assertEqual(len(cache_dir_files), 0)
create_message(self.chat, self.args)
cache_dir_files = self.message_list(self.cache_dir)
self.assertEqual(len(cache_dir_files), 1)
message = Message.from_file(cache_dir_files[0])
self.assertIsInstance(message, Message)
self.assertEqual(message.question, Question("What is this?")) # type: ignore [union-attr]
def test_single_question(self) -> None:
self.args.ask = ["What is this?"]
message = create_message(self.chat, self.args)
self.assertIsInstance(message, Message)
self.assertEqual(message.question, Question("What is this?"))
self.assertEqual(len(message.question.source_code()), 0)
def test_multipart_question(self) -> None:
self.args.ask = ["What is this", "'bard' thing?", "Is it good?"]
message = create_message(self.chat, self.args)
self.assertIsInstance(message, Message)
self.assertEqual(message.question, Question("""What is this
'bard' thing?
Is it good?"""))
def test_single_question_with_text_only_file(self) -> None:
self.args.ask = ["What is this?"]
self.args.source_text = [f"{self.source_file1.name}"]
message = create_message(self.chat, self.args)
self.assertIsInstance(message, Message)
# file contains no source code (only text)
# -> don't expect any in the question
self.assertEqual(len(message.question.source_code()), 0)
self.assertEqual(message.question, Question(f"""What is this?
{self.source_file1_content}"""))
def test_single_question_with_text_file_and_embedded_code(self) -> None:
self.args.ask = ["What is this?"]
self.args.source_code = [f"{self.source_file2.name}"]
message = create_message(self.chat, self.args)
self.assertIsInstance(message, Message)
# file contains 1 source code block
# -> expect it in the question
self.assertEqual(len(message.question.source_code()), 1)
self.assertEqual(message.question, Question("""What is this?
```
This is embedded source code.
```
"""))
def test_single_question_with_code_only_file(self) -> None:
self.args.ask = ["What is this?"]
self.args.source_code = [f"{self.source_file3.name}"]
message = create_message(self.chat, self.args)
self.assertIsInstance(message, Message)
# file is complete source code
self.assertEqual(len(message.question.source_code()), 1)
self.assertEqual(message.question, Question(f"""What is this?
```
{self.source_file3_content}
```"""))
def test_single_question_with_text_file_and_multi_embedded_code(self) -> None:
self.args.ask = ["What is this?"]
self.args.source_code = [f"{self.source_file4.name}"]
message = create_message(self.chat, self.args)
self.assertIsInstance(message, Message)
# file contains 2 source code blocks
# -> expect them in the question
self.assertEqual(len(message.question.source_code()), 2)
self.assertEqual(message.question, Question("""What is this?
```
This is embedded source code.
```
```
This is embedded source code.
```
"""))
def test_single_question_with_text_only_message(self) -> None:
self.args.ask = ["What is this?"]
self.args.source_text = [f"{self.chat.messages[0].file_path}"]
message = create_message(self.chat, self.args)
self.assertIsInstance(message, Message)
# file contains no source code (only text)
# -> don't expect any in the question
self.assertEqual(len(message.question.source_code()), 0)
self.assertEqual(message.question, Question(f"""What is this?
{self.message_text.answer}"""))
def test_single_question_with_message_and_embedded_code(self) -> None:
self.args.ask = ["What is this?"]
self.args.source_code = [f"{self.chat.messages[1].file_path}"]
message = create_message(self.chat, self.args)
self.assertIsInstance(message, Message)
# answer contains 1 source code block
# -> expect it in the question
self.assertEqual(len(message.question.source_code()), 1)
self.assertEqual(message.question, Question("""What is this?
```
It is embedded code
```
"""))
class TestCreateOption(TestMessageCreate):
def test_message_file_created(self) -> None:
self.args.create = ["How does question --create work?"]
self.args.ask = None
cache_dir_files = self.message_list(self.cache_dir)
self.assertEqual(len(cache_dir_files), 0)
create_message(self.chat, self.args)
cache_dir_files = self.message_list(self.cache_dir)
self.assertEqual(len(cache_dir_files), 1)
message = Message.from_file(cache_dir_files[0])
self.assertIsInstance(message, Message)
self.assertEqual(message.question, Question("How does question --create work?")) # type: ignore [union-attr]
class TestQuestionCmd(TestWithFakeAI):
def setUp(self) -> None:
# create DB and cache
self.db_dir = tempfile.TemporaryDirectory()
self.cache_dir = tempfile.TemporaryDirectory()
# create configuration
self.config = Config()
self.config.cache = self.cache_dir.name
self.config.db = self.db_dir.name
# create a mock argparse.Namespace
self.args = argparse.Namespace(
ask=['What is the meaning of life?'],
num_answers=1,
output_tags=['science'],
AI='FakeAI',
model='FakeModel',
or_tags=None,
and_tags=None,
exclude_tags=None,
source_text=None,
source_code=None,
create=None,
repeat=None,
process=None,
overwrite=None
)
def message_list(self, tmp_dir: tempfile.TemporaryDirectory) -> list[Path]:
# exclude '.next'
return sorted([f for f in Path(tmp_dir.name).glob(f'*{msg_suffix}')])
class TestQuestionCmdAsk(TestQuestionCmd):
@mock.patch('chatmastermind.commands.question.create_ai')
def test_ask_single_answer(self, mock_create_ai: MagicMock) -> None:
"""
Test single answer with no errors.
"""
mock_create_ai.side_effect = self.mock_create_ai
expected_question = Message(Question(self.args.ask[0]),
tags=set(self.args.output_tags),
ai=self.args.AI,
model=self.args.model,
file_path=Path('<NOT COMPARED>'))
fake_ai = self.mock_create_ai(self.args, self.config)
expected_responses = fake_ai.request(expected_question,
Chat([]),
self.args.num_answers,
self.args.output_tags).messages
# execute the command
question_cmd(self.args, self.config)
# check for the expected message files
chat = ChatDB.from_dir(Path(self.cache_dir.name),
Path(self.db_dir.name))
cached_msg = chat.msg_gather(loc='cache')
self.assertEqual(len(self.message_list(self.cache_dir)), 1)
self.assert_msgs_equal_except_file_path(cached_msg, expected_responses)
@mock.patch('chatmastermind.commands.question.ChatDB.from_dir')
@mock.patch('chatmastermind.commands.question.create_ai')
def test_ask_single_answer_mocked(self, mock_create_ai: MagicMock, mock_from_dir: MagicMock) -> None:
"""
Test single answer with no errors (mocked ChatDB version).
"""
chat = MagicMock(spec=ChatDB)
mock_from_dir.return_value = chat
mock_create_ai.side_effect = self.mock_create_ai
expected_question = Message(Question(self.args.ask[0]),
tags=set(self.args.output_tags),
ai=self.args.AI,
model=self.args.model,
file_path=Path('<NOT COMPARED>'))
fake_ai = self.mock_create_ai(self.args, self.config)
expected_responses = fake_ai.request(expected_question,
Chat([]),
self.args.num_answers,
self.args.output_tags).messages
# execute the command
question_cmd(self.args, self.config)
# check for the correct ChatDB calls:
# - initial question has been written (prior to the actual request)
# - responses have been written (after the request)
chat.cache_write.assert_has_calls([call([expected_question]),
call(expected_responses)],
any_order=False)
# check that the messages have not been added to the internal message list
chat.cache_add.assert_not_called()
@mock.patch('chatmastermind.commands.question.create_ai')
def test_ask_with_error(self, mock_create_ai: MagicMock) -> None:
"""
Provoke an error during the AI request and verify that the question
has been correctly stored in the cache.
"""
mock_create_ai.side_effect = self.mock_create_ai_with_error
expected_question = Message(Question(self.args.ask[0]),
tags=set(self.args.output_tags),
ai=self.args.AI,
model=self.args.model,
file_path=Path('<NOT COMPARED>'))
# execute the command
with self.assertRaises(AIError):
question_cmd(self.args, self.config)
# check for the expected message files
chat = ChatDB.from_dir(Path(self.cache_dir.name),
Path(self.db_dir.name))
cached_msg = chat.msg_gather(loc='cache')
self.assertEqual(len(self.message_list(self.cache_dir)), 1)
self.assert_msgs_equal_except_file_path(cached_msg, [expected_question])
class TestQuestionCmdRepeat(TestQuestionCmd):
@mock.patch('chatmastermind.commands.question.create_ai')
def test_repeat_single_question(self, mock_create_ai: MagicMock) -> None:
"""
Repeat a single question.
"""
mock_create_ai.side_effect = self.mock_create_ai
chat = ChatDB.from_dir(Path(self.cache_dir.name),
Path(self.db_dir.name))
# create a message
message = Message(Question(self.args.ask[0]),
Answer('Old Answer'),
tags=set(self.args.output_tags),
ai=self.args.AI,
model=self.args.model,
file_path=Path(self.cache_dir.name) / f'0001{msg_suffix}')
chat.msg_write([message])
# repeat the last question (without overwriting)
# -> expect two identical messages (except for the file_path)
self.args.ask = None
self.args.repeat = []
self.args.overwrite = False
expected_response = Message(Question(message.question),
Answer('Answer 0'),
ai=message.ai,
model=message.model,
tags=message.tags,
file_path=Path('<NOT COMPARED>'))
# we expect the original message + the one with the new response
expected_responses = [message] + [expected_response]
question_cmd(self.args, self.config)
cached_msg = chat.msg_gather(loc='cache')
print(self.message_list(self.cache_dir))
self.assertEqual(len(self.message_list(self.cache_dir)), 2)
self.assert_msgs_equal_except_file_path(cached_msg, expected_responses)
@mock.patch('chatmastermind.commands.question.create_ai')
def test_repeat_single_question_overwrite(self, mock_create_ai: MagicMock) -> None:
"""
Repeat a single question and overwrite the old one.
"""
mock_create_ai.side_effect = self.mock_create_ai
chat = ChatDB.from_dir(Path(self.cache_dir.name),
Path(self.db_dir.name))
# create a message
message = Message(Question(self.args.ask[0]),
Answer('Old Answer'),
tags=set(self.args.output_tags),
ai=self.args.AI,
model=self.args.model,
file_path=Path(self.cache_dir.name) / f'0001{msg_suffix}')
chat.msg_write([message])
cached_msg = chat.msg_gather(loc='cache')
assert cached_msg[0].file_path
cached_msg_file_id = cached_msg[0].file_path.stem
# repeat the last question (WITH overwriting)
# -> expect a single message afterwards (with a new answer)
self.args.ask = None
self.args.repeat = []
self.args.overwrite = True
expected_response = Message(Question(message.question),
Answer('Answer 0'),
ai=message.ai,
model=message.model,
tags=message.tags,
file_path=Path('<NOT COMPARED>'))
question_cmd(self.args, self.config)
cached_msg = chat.msg_gather(loc='cache')
self.assertEqual(len(self.message_list(self.cache_dir)), 1)
self.assert_msgs_equal_except_file_path(cached_msg, [expected_response])
# also check that the file ID has not been changed
assert cached_msg[0].file_path
self.assertEqual(cached_msg_file_id, cached_msg[0].file_path.stem)
@mock.patch('chatmastermind.commands.question.create_ai')
def test_repeat_single_question_after_error(self, mock_create_ai: MagicMock) -> None:
"""
Repeat a single question after an error.
"""
mock_create_ai.side_effect = self.mock_create_ai
chat = ChatDB.from_dir(Path(self.cache_dir.name),
Path(self.db_dir.name))
# create a question WITHOUT an answer
# -> just like after an error, which is tested above
message = Message(Question(self.args.ask[0]),
tags=set(self.args.output_tags),
ai=self.args.AI,
model=self.args.model,
file_path=Path(self.cache_dir.name) / f'0001{msg_suffix}')
chat.msg_write([message])
cached_msg = chat.msg_gather(loc='cache')
assert cached_msg[0].file_path
cached_msg_file_id = cached_msg[0].file_path.stem
# repeat the last question (without overwriting)
# -> expect a single message because if the original has
# no answer, it should be overwritten by default
self.args.ask = None
self.args.repeat = []
self.args.overwrite = False
expected_response = Message(Question(message.question),
Answer('Answer 0'),
ai=message.ai,
model=message.model,
tags=message.tags,
file_path=Path('<NOT COMPARED>'))
question_cmd(self.args, self.config)
cached_msg = chat.msg_gather(loc='cache')
self.assertEqual(len(self.message_list(self.cache_dir)), 1)
self.assert_msgs_equal_except_file_path(cached_msg, [expected_response])
# also check that the file ID has not been changed
assert cached_msg[0].file_path
self.assertEqual(cached_msg_file_id, cached_msg[0].file_path.stem)
@mock.patch('chatmastermind.commands.question.create_ai')
def test_repeat_single_question_new_args(self, mock_create_ai: MagicMock) -> None:
"""
Repeat a single question with new arguments.
"""
mock_create_ai.side_effect = self.mock_create_ai
chat = ChatDB.from_dir(Path(self.cache_dir.name),
Path(self.db_dir.name))
# create a message
message = Message(Question(self.args.ask[0]),
Answer('Old Answer'),
tags=set(self.args.output_tags),
ai=self.args.AI,
model=self.args.model,
file_path=Path(self.cache_dir.name) / f'0001{msg_suffix}')
chat.msg_write([message])
cached_msg = chat.msg_gather(loc='cache')
assert cached_msg[0].file_path
# repeat the last question with new arguments (without overwriting)
# -> expect two messages with identical question but different metadata and new answer
self.args.ask = None
self.args.repeat = []
self.args.overwrite = False
self.args.output_tags = ['newtag']
self.args.AI = 'newai'
self.args.model = 'newmodel'
new_expected_response = Message(Question(message.question),
Answer('Answer 0'),
ai='newai',
model='newmodel',
tags={Tag('newtag')},
file_path=Path('<NOT COMPARED>'))
question_cmd(self.args, self.config)
cached_msg = chat.msg_gather(loc='cache')
self.assertEqual(len(self.message_list(self.cache_dir)), 2)
self.assert_msgs_equal_except_file_path(cached_msg, [message] + [new_expected_response])
@mock.patch('chatmastermind.commands.question.create_ai')
def test_repeat_single_question_new_args_overwrite(self, mock_create_ai: MagicMock) -> None:
"""
Repeat a single question with new arguments, overwriting the old one.
"""
mock_create_ai.side_effect = self.mock_create_ai
chat = ChatDB.from_dir(Path(self.cache_dir.name),
Path(self.db_dir.name))
# create a message
message = Message(Question(self.args.ask[0]),
Answer('Old Answer'),
tags=set(self.args.output_tags),
ai=self.args.AI,
model=self.args.model,
file_path=Path(self.cache_dir.name) / f'0001{msg_suffix}')
chat.msg_write([message])
cached_msg = chat.msg_gather(loc='cache')
assert cached_msg[0].file_path
# repeat the last question with new arguments
self.args.ask = None
self.args.repeat = []
self.args.overwrite = True
self.args.output_tags = ['newtag']
self.args.AI = 'newai'
self.args.model = 'newmodel'
new_expected_response = Message(Question(message.question),
Answer('Answer 0'),
ai='newai',
model='newmodel',
tags={Tag('newtag')},
file_path=Path('<NOT COMPARED>'))
question_cmd(self.args, self.config)
cached_msg = chat.msg_gather(loc='cache')
self.assertEqual(len(self.message_list(self.cache_dir)), 1)
self.assert_msgs_equal_except_file_path(cached_msg, [new_expected_response])
@mock.patch('chatmastermind.commands.question.create_ai')
def test_repeat_multiple_questions(self, mock_create_ai: MagicMock) -> None:
"""
Repeat multiple questions.
"""
mock_create_ai.side_effect = self.mock_create_ai
chat = ChatDB.from_dir(Path(self.cache_dir.name),
Path(self.db_dir.name))
# 1. === create three questions ===
# cached message without an answer
message1 = Message(Question(self.args.ask[0]),
tags=self.args.output_tags,
ai=self.args.AI,
model=self.args.model,
file_path=Path(self.cache_dir.name) / f'0001{msg_suffix}')
# cached message with an answer
message2 = Message(Question(self.args.ask[0]),
Answer('Old Answer'),
tags=self.args.output_tags,
ai=self.args.AI,
model=self.args.model,
file_path=Path(self.cache_dir.name) / f'0002{msg_suffix}')
# DB message without an answer
message3 = Message(Question(self.args.ask[0]),
tags=self.args.output_tags,
ai=self.args.AI,
model=self.args.model,
file_path=Path(self.db_dir.name) / f'0003{msg_suffix}')
chat.msg_write([message1, message2, message3])
questions = [message1, message2, message3]
expected_responses: list[Message] = []
fake_ai = self.mock_create_ai(self.args, self.config)
for question in questions:
# since the message's answer is modified, we use a copy
# -> the original is used for comparison below
expected_responses += fake_ai.request(copy(question),
Chat([]),
self.args.num_answers,
set(self.args.output_tags)).messages
# 2. === repeat all three questions (without overwriting) ===
self.args.ask = None
self.args.repeat = ['0001', '0002', '0003']
self.args.overwrite = False
question_cmd(self.args, self.config)
# two new files should be in the cache directory
# * the repeated cached message with answer
# * the repeated DB message
# -> the cached message without answer should be overwritten
self.assertEqual(len(self.message_list(self.cache_dir)), 4)
self.assertEqual(len(self.message_list(self.db_dir)), 1)
expected_cache_messages = [expected_responses[0], message2, expected_responses[1], expected_responses[2]]
cached_msg = chat.msg_gather(loc='cache')
self.assert_msgs_equal_except_file_path(cached_msg, expected_cache_messages)
# check that the DB message has not been modified at all
db_msg = chat.msg_gather(loc='db')
self.assert_msgs_all_equal(db_msg, [message3])
+3 -3
View File
@@ -1,8 +1,8 @@
import unittest from .test_main import CmmTestCase
from chatmastermind.tags import Tag, TagLine, TagError from chatmastermind.tags import Tag, TagLine, TagError
class TestTag(unittest.TestCase): class TestTag(CmmTestCase):
def test_valid_tag(self) -> None: def test_valid_tag(self) -> None:
tag = Tag('mytag') tag = Tag('mytag')
self.assertEqual(tag, 'mytag') self.assertEqual(tag, 'mytag')
@@ -18,7 +18,7 @@ class TestTag(unittest.TestCase):
self.assertEqual(Tag.alternative_separators, [',']) self.assertEqual(Tag.alternative_separators, [','])
class TestTagLine(unittest.TestCase): class TestTagLine(CmmTestCase):
def test_valid_tagline(self) -> None: def test_valid_tagline(self) -> None:
tagline = TagLine('TAGS: tag1 tag2') tagline = TagLine('TAGS: tag1 tag2')
self.assertEqual(tagline, 'TAGS: tag1 tag2') self.assertEqual(tagline, 'TAGS: tag1 tag2')