Compare commits
236 Commits
caf5244d52
..
main
| Author | SHA1 | Date | |
|---|---|---|---|
| 9a957a89ac | |||
| 5d1bb1f9e4 | |||
| 75a123eb72 | |||
| 7c1c67f8ff | |||
| dbe72ff11c | |||
| bbc1ab5a0a | |||
| 2aee018708 | |||
| 17c6fa2453 | |||
| 5774278fb7 | |||
| 40d0de50de | |||
| 72d31c26e9 | |||
| 980e5ac51f | |||
| 114282dfd8 | |||
| 9a493b57da | |||
| 9b0951cb3f | |||
| 5f29f60168 | |||
| 3ea1f49027 | |||
| 8f56399844 | |||
| e4cb6eb22b | |||
| e19c6bb1ea | |||
| 811b2e6830 | |||
| 2a8f01aee4 | |||
| efdb3cae2f | |||
| aecfd1088d | |||
| 140dbed809 | |||
| 01860ace2c | |||
| df42bcee09 | |||
| e34eab6519 | |||
| d07fd13e8e | |||
| b8681e8274 | |||
| d2be53aeab | |||
| 9ca9a23569 | |||
| 6f3758e12e | |||
| dd836cd72d | |||
| 601ebe731a | |||
| 87b25993be | |||
| a478408449 | |||
| b83b396c7b | |||
| 3c932aa88e | |||
| b50caa345c | |||
| 80c5dcc801 | |||
| 33df84beaa | |||
| 0657a1bab8 | |||
| e9175aface | |||
| 21f81f3569 | |||
| 4538624247 | |||
| ac3c19739d | |||
| ed379ed535 | |||
| c43bafe47a | |||
| 7dd83428fb | |||
| 3ad4b96b8f | |||
| 561003aabe | |||
| 59eb45a3ca | |||
| 29a20bd2d8 | |||
| 80a1457dd1 | |||
| f964c5471e | |||
| 25fffb6fea | |||
| cf572e1882 | |||
| 2fb7410b43 | |||
| 33ae27f00e | |||
| f6a6e6036b | |||
| 525cdb92a1 | |||
| fc82f85b7c | |||
| d90845b58b | |||
| 98777295d6 | |||
| f6109949c8 | |||
| 071871f929 | |||
| 5cb88dad1b | |||
| 17a0264025 | |||
| 7f4a16894e | |||
| 26e3d38afb | |||
| b5af751193 | |||
| a7345cbc41 | |||
| 310cb9421e | |||
| 1ec3d6fcda | |||
| 544bf0bf06 | |||
| f96e82bdd7 | |||
| 2b62cb8c4b | |||
| a895c1fc6a | |||
| ddfcc71510 | |||
| 17de0b9967 | |||
| 33023d29f9 | |||
| 481f9ecf7c | |||
| 22fa187e5f | |||
| b840ebd792 | |||
| 66908f5fed | |||
| 2e08ccf606 | |||
| 595ff8e294 | |||
| faac42d3c2 | |||
| 864ab7aeb1 | |||
| cc76da2ab3 | |||
| f99cd3ed41 | |||
| 6f3ea98425 | |||
| 54ece6efeb | |||
| 86eebc39ea | |||
| 3eca53998b | |||
| c4f7bcc94e | |||
| c52713c833 | |||
| ecb6994783 | |||
| 61e710a4b1 | |||
| 21d39c6c66 | |||
| 6a4cc7a65d | |||
| d6bb5800b1 | |||
| 034e4093f1 | |||
| 7d15452242 | |||
| 823d3bf7dc | |||
| 4bd144c4d7 | |||
| e186afbef0 | |||
| 5e4ec70072 | |||
| 4c378dde85 | |||
| 8923a13352 | |||
| e1414835c8 | |||
| abb7fdacb6 | |||
| 2e2228bd60 | |||
| 713b55482a | |||
| d35de86c67 | |||
| aba3eb783d | |||
| 8e63831701 | |||
| c318b99671 | |||
| 48c8e951e1 | |||
| b22a4b07ed | |||
| 33565d351d | |||
| 6737fa98c7 | |||
| 815a21893c | |||
| 64893949a4 | |||
| a093f9b867 | |||
| dc3f3dc168 | |||
| 74c39070d6 | |||
| fde0ae4652 | |||
| 238dbbee60 | |||
| 17f7b2fb45 | |||
| 9c2598a4b8 | |||
| acec5f1d55 | |||
| c0f50bace5 | |||
| 30ccec2462 | |||
| 09da312657 | |||
| 33567df15f | |||
| 264979a60d | |||
| 061e5f8682 | |||
| 2d456e68f1 | |||
| 8bd659e888 | |||
| d4021eeb11 | |||
| c143c001f9 | |||
| 59b851650a | |||
| 6f71a2ff69 | |||
| eca44b14cb | |||
| b48667bfa0 | |||
| 533ee1c1a9 | |||
| cf50818f28 | |||
| dd3d3ffc82 | |||
| 1e3bfdd67f | |||
| 53582a7123 | |||
| 39b518a8a6 | |||
| d22877a0f1 | |||
| 7cf62c54ef | |||
| 3ef1339cc0 | |||
| 5fb5dde550 | |||
| c0b7d17587 | |||
| 76f2373397 | |||
| eaa399bcb9 | |||
| b1a23394fc | |||
| ed567afbea | |||
| 2df9dd6427 | |||
| 74a26b8c2f | |||
| 6e447018d5 | |||
| 893917e455 | |||
| ba5aa1fbc7 | |||
| eb2fcba99d | |||
| b7e3ca7ca7 | |||
| aa322de718 | |||
| bf1cbff6a2 | |||
| f93a57c00d | |||
| b0504aedbe | |||
| eb0d97ddc8 | |||
| 7e25a08d6e | |||
| 63040b3688 | |||
| 6e2d5009c1 | |||
| 44cd1fab45 | |||
| 4b0f40bccd | |||
| fa292fb73a | |||
| f9d749cdd8 | |||
| ba56caf013 | |||
| d80c3962bd | |||
| ddfe29b951 | |||
| d93598a74f | |||
| 7f612bfc17 | |||
| 93290da5b5 | |||
| 9f4897a5b8 | |||
| 214a6919db | |||
| b83cbb719b | |||
| 8e1cdee3bf | |||
| 73d2a9ea3b | |||
| 169f1bb458 | |||
| 7f91a2b567 | |||
| fc1b8006a0 | |||
| aa89270876 | |||
| 0d6a6dd604 | |||
| 580c86e948 | |||
| 879831d7f5 | |||
| dfc1261931 | |||
| 173a46a9b5 | |||
| 604e5ccf73 | |||
| ef46f5efc9 | |||
| b13a68836a | |||
| a5c91adc41 | |||
| 380b7c1b67 | |||
| e8343fde01 | |||
| ee8deed320 | |||
| dc13213c4d | |||
| 4303fb414f | |||
| ba41794f4e | |||
| a5075b14a0 | |||
| e8eba0b755 | |||
| 1e15a52e26 | |||
| c4a7c07a0c | |||
| 22bebc16ed | |||
| f7ba0c000f | |||
| b6eb7d9af8 | |||
| f371a6146e | |||
| 6ed459be6f | |||
| 1fb9144192 | |||
| 4b2f634b79 | |||
| e4d055b900 | |||
| bc5e6228a6 | |||
| 056bf4c6b5 | |||
| 93a8b0081a | |||
| 5119b3a874 | |||
| 5a435c5f8f | |||
| f90e7bcd47 | |||
| 6406d2f5b5 | |||
| df91ca863a | |||
| bc9baff0dc | |||
| 7a92ebe539 | |||
| 9b6b13993c | |||
| c5c4a6628f | |||
| f8ed0e3636 |
@@ -106,6 +106,7 @@ celerybeat.pid
|
|||||||
.venv
|
.venv
|
||||||
env/
|
env/
|
||||||
venv/
|
venv/
|
||||||
|
.old/
|
||||||
ENV/
|
ENV/
|
||||||
env.bak/
|
env.bak/
|
||||||
venv.bak/
|
venv.bak/
|
||||||
@@ -131,3 +132,4 @@ dmypy.json
|
|||||||
.config.yaml
|
.config.yaml
|
||||||
db
|
db
|
||||||
noweb
|
noweb
|
||||||
|
Session.vim
|
||||||
|
|||||||
@@ -4,9 +4,11 @@ ChatMastermind is a Python application that automates conversation with AI, stor
|
|||||||
|
|
||||||
The project uses the OpenAI API to generate responses and stores the data in YAML files. It also allows you to filter chat history based on tags and supports autocompletion for tags.
|
The project uses the OpenAI API to generate responses and stores the data in YAML files. It also allows you to filter chat history based on tags and supports autocompletion for tags.
|
||||||
|
|
||||||
|
Official repository URL: https://kaizenkodo.no/gitea/kaizenkodo/ChatMastermind.git
|
||||||
|
|
||||||
## Requirements
|
## Requirements
|
||||||
|
|
||||||
- Python 3.6 or higher
|
- Python 3.9 or higher
|
||||||
- openai
|
- openai
|
||||||
- PyYAML
|
- PyYAML
|
||||||
- argcomplete
|
- argcomplete
|
||||||
@@ -27,81 +29,164 @@ pip install .
|
|||||||
|
|
||||||
## Usage
|
## Usage
|
||||||
|
|
||||||
|
The `cmm` script has global options, a list of commands, and options per command:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
cmm [-h] [-p PRINT | -q QUESTION | -D | -d | -l] [-c CONFIG] [-m MAX_TOKENS] [-T TEMPERATURE] [-M MODEL] [-n NUMBER] [-t [TAGS [TAGS ...]]] [-e [EXTAGS [EXTAGS ...]]] [-o [OTAGS [OTAGS ...]]] [-a] [-w] [-W]
|
cmm [global options] command [command options]
|
||||||
```
|
```
|
||||||
|
|
||||||
### Arguments
|
### Global Options
|
||||||
|
|
||||||
- `-p`, `--print`: YAML file to print.
|
- `-C`, `--config`: Config file name (defaults to `.config.yaml`).
|
||||||
- `-q`, `--question`: Question to ask.
|
|
||||||
- `-D`, `--chat-dump`: Print chat history as a Python structure.
|
### Command Options
|
||||||
- `-d`, `--chat`: Print chat history as readable text.
|
|
||||||
- `-a`, `--match-all-tags`: All given tags must match when selecting chat history entries.
|
#### Question
|
||||||
- `-w`, `--with-tags`: Print chat history with tags.
|
|
||||||
- `-W`, `--with-tags`: Print chat history with filenames.
|
The `question` command is used to ask, create, and process questions.
|
||||||
- `-l`, `--list-tags`: List all tags and their frequency.
|
|
||||||
- `-c`, `--config`: Config file name (defaults to `.config.yaml`).
|
```bash
|
||||||
- `-m`, `--max-tokens`: Max tokens to use.
|
cmm question [-t OTAGS]... [-k ATAGS]... [-x XTAGS]... [-o OUTTAGS]... [-A AI_ID] [-M MODEL] [-n NUM] [-m MAX] [-T TEMP] (-a QUESTION | -c QUESTION | -r [MESSAGE ...] | -p [MESSAGE ...]) [-O] [-s FILE]... [-S FILE]...
|
||||||
- `-T`, `--temperature`: Temperature to use.
|
```
|
||||||
- `-M`, `--model`: Model to use.
|
|
||||||
- `-n`, `--number`: Number of answers to produce (default is 3).
|
* `-t, --or-tags OTAGS`: List of tags (one must match)
|
||||||
- `-t`, `--tags`: List of tag names.
|
* `-k, --and-tags ATAGS`: List of tags (all must match)
|
||||||
- `-e`, `--extags`: List of tag names to exclude.
|
* `-x, --exclude-tags XTAGS`: List of tags to exclude
|
||||||
- `-o`, `--output-tags`: List of output tag names (default is the input tags).
|
* `-o, --output-tags OUTTAGS`: List of output tags (default: use input tags)
|
||||||
|
* `-A, --AI AI_ID`: AI ID to use
|
||||||
|
* `-M, --model MODEL`: Model to use
|
||||||
|
* `-n, --num-answers NUM`: Number of answers to request
|
||||||
|
* `-m, --max-tokens MAX`: Max. number of tokens
|
||||||
|
* `-T, --temperature TEMP`: Temperature value
|
||||||
|
* `-a, --ask QUESTION`: Ask a question
|
||||||
|
* `-c, --create QUESTION`: Create a question
|
||||||
|
* `-r, --repeat [MESSAGE ...]`: Repeat a question
|
||||||
|
* `-p, --process [MESSAGE ...]`: Process existing questions
|
||||||
|
* `-O, --overwrite`: Overwrite existing messages when repeating them
|
||||||
|
* `-s, --source-text FILE`: Add content of a file to the query
|
||||||
|
* `-S, --source-code FILE`: Add source code file content to the chat history
|
||||||
|
* `-l, --location {cache,db,all}`: Use given location when building the chat history (default: 'db')
|
||||||
|
* `-g, --glob GLOB`: Filter message files using the given glob pattern
|
||||||
|
|
||||||
|
#### Hist
|
||||||
|
|
||||||
|
The `hist` command is used to print and manage the chat history.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
cmm hist [--print | --convert FORMAT] [-t OTAGS]... [-k ATAGS]... [-x XTAGS]... [-w] [-W] [-S] [-A SUBSTRING] [-Q SUBSTRING]
|
||||||
|
```
|
||||||
|
|
||||||
|
* `-p, --print`: Print the DB chat history
|
||||||
|
* `-c, --convert FORMAT`: Convert all messages to the given format
|
||||||
|
* `-t, --or-tags OTAGS`: List of tags (one must match)
|
||||||
|
* `-k, --and-tags ATAGS`: List of tags (all must match)
|
||||||
|
* `-x, --exclude-tags XTAGS`: List of tags to exclude
|
||||||
|
* `-w, --with-metadata`: Print chat history with metadata (tags, filenames, AI, etc.)
|
||||||
|
* `-S, --source-code-only`: Only print embedded source code
|
||||||
|
* `-A, --answer SUBSTRING`: Filter for answer substring
|
||||||
|
* `-Q, --question SUBSTRING`: Filter for question substring
|
||||||
|
* `-l, --location {cache,db,all}`: Use given location when building the chat history (default: 'db')
|
||||||
|
* `-g, --glob GLOB`: Filter message files using the given glob pattern
|
||||||
|
|
||||||
|
#### Tags
|
||||||
|
|
||||||
|
The `tags` command is used to manage tags.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
cmm tags (-l | -p PREFIX | -c SUBSTRING)
|
||||||
|
```
|
||||||
|
|
||||||
|
* `-l, --list`: List all tags and their frequency
|
||||||
|
* `-p, --prefix PREFIX`: Filter tags by prefix
|
||||||
|
* `-c, --contain SUBSTRING`: Filter tags by contained substring
|
||||||
|
|
||||||
|
#### Config
|
||||||
|
|
||||||
|
The `config` command is used to manage the configuration.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
cmm config (-l | -m | -c FILE)
|
||||||
|
```
|
||||||
|
|
||||||
|
* `-l, --list-models`: List all available models
|
||||||
|
* `-m, --print-model`: Print the currently configured model
|
||||||
|
* `-c, --create FILE`: Create config with default settings in the given file
|
||||||
|
|
||||||
|
#### Print
|
||||||
|
|
||||||
|
The `print` command is used to print message files.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
cmm print (-f FILE | -l) [-q | -a | -S]
|
||||||
|
```
|
||||||
|
|
||||||
|
* `-f, --file FILE`: Print given file
|
||||||
|
* `-l, --latest`: Print latest message
|
||||||
|
* `-q, --question`: Only print the question
|
||||||
|
* `-a, --answer`: Only print the answer
|
||||||
|
* `-S, --only-source-code`: Only print embedded source code
|
||||||
|
|
||||||
### Examples
|
### Examples
|
||||||
|
|
||||||
1. Print the contents of a YAML file:
|
1. Ask a question:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
cmm -p example.yaml
|
cmm question -a "What is the meaning of life?" -t philosophy -x religion
|
||||||
```
|
```
|
||||||
|
|
||||||
2. Ask a question:
|
2. Display the chat history:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
cmm -q "What is the meaning of life?" -t philosophy -e religion
|
cmm hist
|
||||||
```
|
```
|
||||||
|
|
||||||
3. Display the chat history as a Python structure:
|
3. Filter chat history by tags:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
cmm -D
|
cmm hist --or-tags tag1 tag2
|
||||||
```
|
```
|
||||||
|
|
||||||
4. Display the chat history as readable text:
|
4. Exclude chat history by tags:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
cmm -d
|
cmm hist --exclude-tags tag3 tag4
|
||||||
```
|
```
|
||||||
|
|
||||||
5. Filter chat history by tags:
|
5. List all tags and their frequency:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
cmm -d -t tag1 tag2
|
cmm tags -l
|
||||||
```
|
```
|
||||||
|
|
||||||
6. Exclude chat history by tags:
|
6. Print the contents of a file:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
cmm -d -e tag3 tag4
|
cmm print -f example.yaml
|
||||||
```
|
```
|
||||||
|
|
||||||
## Configuration
|
## Configuration
|
||||||
|
|
||||||
The configuration file (`.config.yaml`) should contain the following fields:
|
The default configuration filename is `.config.yaml` (it is searched in the current working directory).
|
||||||
|
Use the command `cmm config --create <FILENAME>` to create a default configuration:
|
||||||
|
|
||||||
- `openai`:
|
```
|
||||||
- `api_key`: Your OpenAI API key.
|
cache: .
|
||||||
- `model`: The name of the OpenAI model to use (e.g. "text-davinci-002").
|
db: ./db/
|
||||||
- `temperature`: The temperature value for the model.
|
ais:
|
||||||
- `max_tokens`: The maximum number of tokens for the model.
|
myopenai:
|
||||||
- `top_p`: The top P value for the model.
|
name: openai
|
||||||
- `frequency_penalty`: The frequency penalty value.
|
model: gpt-3.5-turbo-16k
|
||||||
- `presence_penalty`: The presence penalty value.
|
api_key: 0123456789
|
||||||
- `system`: The system message used to set the behavior of the AI.
|
temperature: 1.0
|
||||||
- `db`: The directory where the question-answer pairs are stored in YAML files.
|
max_tokens: 4000
|
||||||
|
top_p: 1.0
|
||||||
|
frequency_penalty: 0.0
|
||||||
|
presence_penalty: 0.0
|
||||||
|
system: You are an assistant
|
||||||
|
```
|
||||||
|
|
||||||
|
Each AI has its own section and the name of that section is called the 'AI ID' (in the example above it is `myopenai`).
|
||||||
|
The AI ID can be any string, as long as it's unique within the `ais` section. The AI ID is used for all commands that support the `AI` parameter and it's also stored within each message file.
|
||||||
|
|
||||||
## Autocompletion
|
## Autocompletion
|
||||||
|
|
||||||
@@ -113,6 +198,43 @@ eval "$(register-python-argcomplete cmm)"
|
|||||||
|
|
||||||
After adding this line, restart your shell or run `source <your-shell-config-file>` to enable autocompletion for the `cmm` script.
|
After adding this line, restart your shell or run `source <your-shell-config-file>` to enable autocompletion for the `cmm` script.
|
||||||
|
|
||||||
|
## Contributing
|
||||||
|
|
||||||
|
### Enable commit hooks
|
||||||
|
```bash
|
||||||
|
pip install pre-commit
|
||||||
|
pre-commit install
|
||||||
|
```
|
||||||
|
### Execute tests before opening a PR
|
||||||
|
```bash
|
||||||
|
pytest
|
||||||
|
```
|
||||||
|
### Consider using `pyenv` / `pyenv-virtualenv`
|
||||||
|
Short installation instructions:
|
||||||
|
* install `pyenv`:
|
||||||
|
```bash
|
||||||
|
cd ~
|
||||||
|
git clone https://github.com/pyenv/pyenv .pyenv
|
||||||
|
cd ~/.pyenv && src/configure && make -C src
|
||||||
|
```
|
||||||
|
* make sure that `~/.pyenv/shims` and `~/.pyenv/bin` are the first entries in your `PATH`, e.g., by setting it in `~/.bashrc`
|
||||||
|
* add the following to your `~/.bashrc` (after setting `PATH`): `eval "$(pyenv init -)"`
|
||||||
|
* create a new terminal or source the changes (e.g., `source ~/.bashrc`)
|
||||||
|
* install `virtualenv`
|
||||||
|
```bash
|
||||||
|
git clone https://github.com/pyenv/pyenv-virtualenv.git $(pyenv root)/plugins/pyenv-virtualenv
|
||||||
|
```
|
||||||
|
* add the following to your `~/.bashrc` (after the commands above): `eval "$(pyenv virtualenv-init -)`
|
||||||
|
* create a new terminal or source the changes (e.g., `source ~/.bashrc`)
|
||||||
|
* go back to the `ChatMasterMind` repo and create a virtual environment with the latest `Python`, e.g., `3.11.4`:
|
||||||
|
```bash
|
||||||
|
cd <CMM_REPO_PATH>
|
||||||
|
pyenv install 3.11.4
|
||||||
|
pyenv virtualenv 3.11.4 py311
|
||||||
|
pyenv activate py311
|
||||||
|
```
|
||||||
|
* see also the [official pyenv documentation](https://github.com/pyenv/pyenv#readme)
|
||||||
|
|
||||||
## License
|
## License
|
||||||
|
|
||||||
This project is licensed under the terms of the WTFPL License.
|
This project is licensed under the terms of the WTFPL License.
|
||||||
|
|||||||
@@ -0,0 +1,80 @@
|
|||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Protocol, Optional, Union
|
||||||
|
from .configuration import AIConfig
|
||||||
|
from .tags import Tag
|
||||||
|
from .message import Message
|
||||||
|
from .chat import Chat
|
||||||
|
|
||||||
|
|
||||||
|
class AIError(Exception):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Tokens:
|
||||||
|
prompt: int = 0
|
||||||
|
completion: int = 0
|
||||||
|
total: int = 0
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class AIResponse:
|
||||||
|
"""
|
||||||
|
The response to an AI request. Consists of one or more messages
|
||||||
|
(each containing the question and a single answer) and the nr.
|
||||||
|
of used tokens.
|
||||||
|
"""
|
||||||
|
messages: list[Message]
|
||||||
|
tokens: Optional[Tokens] = None
|
||||||
|
|
||||||
|
|
||||||
|
class AI(Protocol):
|
||||||
|
"""
|
||||||
|
The base class for AI clients.
|
||||||
|
"""
|
||||||
|
|
||||||
|
ID: str
|
||||||
|
name: str
|
||||||
|
config: AIConfig
|
||||||
|
|
||||||
|
def request(self,
|
||||||
|
question: Message,
|
||||||
|
chat: Chat,
|
||||||
|
num_answers: int = 1,
|
||||||
|
otags: Optional[set[Tag]] = None) -> AIResponse:
|
||||||
|
"""
|
||||||
|
Make an AI request. Parameters:
|
||||||
|
* question: the question to ask
|
||||||
|
* chat: the chat history to be added as context
|
||||||
|
* num_answers: nr. of requested answers (corresponds
|
||||||
|
to the nr. of messages in the 'AIResponse')
|
||||||
|
* otags: the output tags, i. e. the tags that all
|
||||||
|
returned messages should contain
|
||||||
|
"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def models(self) -> list[str]:
|
||||||
|
"""
|
||||||
|
Return all models supported by this AI.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def print_models(self) -> None:
|
||||||
|
"""
|
||||||
|
Print all models supported by this AI.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def tokens(self, data: Union[Message, Chat]) -> int:
|
||||||
|
"""
|
||||||
|
Computes the nr. of AI language tokens for the given message
|
||||||
|
or chat. Note that the computation may not be 100% accurate
|
||||||
|
and is not implemented for all AIs.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def print(self) -> None:
|
||||||
|
"""
|
||||||
|
Print some info about the current AI, like system message.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
@@ -0,0 +1,49 @@
|
|||||||
|
"""
|
||||||
|
Creates different AI instances, based on the given configuration.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
from typing import cast, Optional
|
||||||
|
from .configuration import Config, AIConfig, OpenAIConfig
|
||||||
|
from .ai import AI, AIError
|
||||||
|
from .ais.openai import OpenAI
|
||||||
|
|
||||||
|
|
||||||
|
def create_ai(args: argparse.Namespace, config: Config, # noqa: 11
|
||||||
|
def_ai: Optional[str] = None,
|
||||||
|
def_model: Optional[str] = None) -> AI:
|
||||||
|
"""
|
||||||
|
Creates an AI subclass instance from the given arguments and configuration file.
|
||||||
|
If AI has not been set in the arguments, it searches for the ID 'default'. If
|
||||||
|
that is not found, it uses the first AI in the list. It's also possible to
|
||||||
|
specify a default AI and model using 'def_ai' and 'def_model'.
|
||||||
|
"""
|
||||||
|
ai_conf: AIConfig
|
||||||
|
if hasattr(args, 'AI') and args.AI:
|
||||||
|
try:
|
||||||
|
ai_conf = config.ais[args.AI]
|
||||||
|
except KeyError:
|
||||||
|
raise AIError(f"AI ID '{args.AI}' does not exist in this configuration")
|
||||||
|
elif def_ai:
|
||||||
|
ai_conf = config.ais[def_ai]
|
||||||
|
elif 'default' in config.ais:
|
||||||
|
ai_conf = config.ais['default']
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
ai_conf = next(iter(config.ais.values()))
|
||||||
|
except StopIteration:
|
||||||
|
raise AIError("No AI found in this configuration")
|
||||||
|
|
||||||
|
if ai_conf.name == 'openai':
|
||||||
|
ai = OpenAI(cast(OpenAIConfig, ai_conf))
|
||||||
|
if hasattr(args, 'model') and args.model:
|
||||||
|
ai.config.model = args.model
|
||||||
|
elif def_model:
|
||||||
|
ai.config.model = def_model
|
||||||
|
if hasattr(args, 'max_tokens') and args.max_tokens:
|
||||||
|
ai.config.max_tokens = args.max_tokens
|
||||||
|
if hasattr(args, 'temperature') and args.temperature:
|
||||||
|
ai.config.temperature = args.temperature
|
||||||
|
return ai
|
||||||
|
else:
|
||||||
|
raise AIError(f"AI '{args.AI}' is not supported")
|
||||||
@@ -0,0 +1,160 @@
|
|||||||
|
"""
|
||||||
|
Implements the OpenAI client classes and functions.
|
||||||
|
"""
|
||||||
|
import openai
|
||||||
|
import tiktoken
|
||||||
|
from typing import Optional, Union, Generator
|
||||||
|
from ..tags import Tag
|
||||||
|
from ..message import Message, Answer
|
||||||
|
from ..chat import Chat
|
||||||
|
from ..ai import AI, AIResponse, Tokens
|
||||||
|
from ..configuration import OpenAIConfig
|
||||||
|
|
||||||
|
ChatType = list[dict[str, str]]
|
||||||
|
|
||||||
|
|
||||||
|
class OpenAIAnswer:
|
||||||
|
def __init__(self,
|
||||||
|
idx: int,
|
||||||
|
streams: dict[int, 'OpenAIAnswer'],
|
||||||
|
response: openai.ChatCompletion,
|
||||||
|
tokens: Tokens,
|
||||||
|
encoding: tiktoken.core.Encoding) -> None:
|
||||||
|
self.idx = idx
|
||||||
|
self.streams = streams
|
||||||
|
self.response = response
|
||||||
|
self.position: int = 0
|
||||||
|
self.encoding = encoding
|
||||||
|
self.data: list[str] = []
|
||||||
|
self.finished: bool = False
|
||||||
|
self.tokens = tokens
|
||||||
|
|
||||||
|
def stream(self) -> Generator[str, None, None]:
|
||||||
|
while True:
|
||||||
|
if not self.next():
|
||||||
|
continue
|
||||||
|
if len(self.data) <= self.position:
|
||||||
|
break
|
||||||
|
yield self.data[self.position]
|
||||||
|
self.position += 1
|
||||||
|
|
||||||
|
def next(self) -> bool:
|
||||||
|
if self.finished:
|
||||||
|
return True
|
||||||
|
try:
|
||||||
|
chunk = next(self.response)
|
||||||
|
except StopIteration:
|
||||||
|
self.finished = True
|
||||||
|
if not self.finished:
|
||||||
|
found_choice = False
|
||||||
|
for choice in chunk.choices:
|
||||||
|
if not choice.finish_reason:
|
||||||
|
self.streams[choice.index].data.append(choice.delta.content)
|
||||||
|
self.tokens.completion += len(self.encoding.encode(choice.delta.content))
|
||||||
|
self.tokens.total = self.tokens.prompt + self.tokens.completion
|
||||||
|
if choice.index == self.idx:
|
||||||
|
found_choice = True
|
||||||
|
if not found_choice:
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
class OpenAI(AI):
|
||||||
|
"""
|
||||||
|
The OpenAI AI client.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, config: OpenAIConfig) -> None:
|
||||||
|
self.ID = config.ID
|
||||||
|
self.name = config.name
|
||||||
|
self.config = config
|
||||||
|
self.client = openai.OpenAI(api_key=self.config.api_key)
|
||||||
|
|
||||||
|
def _completions(self, *args, **kw): # type: ignore
|
||||||
|
return self.client.chat.completions.create(*args, **kw)
|
||||||
|
|
||||||
|
def request(self,
|
||||||
|
question: Message,
|
||||||
|
chat: Chat,
|
||||||
|
num_answers: int = 1,
|
||||||
|
otags: Optional[set[Tag]] = None) -> AIResponse:
|
||||||
|
"""
|
||||||
|
Make an AI request, asking the given question with the given
|
||||||
|
chat history. The nr. of requested answers corresponds to the
|
||||||
|
nr. of messages in the 'AIResponse'.
|
||||||
|
"""
|
||||||
|
self.encoding = tiktoken.encoding_for_model(self.config.model)
|
||||||
|
oai_chat, prompt_tokens = self.openai_chat(chat, self.config.system, question)
|
||||||
|
tokens: Tokens = Tokens(prompt_tokens, 0, prompt_tokens)
|
||||||
|
response = self._completions(
|
||||||
|
model=self.config.model,
|
||||||
|
messages=oai_chat,
|
||||||
|
temperature=self.config.temperature,
|
||||||
|
max_tokens=self.config.max_tokens,
|
||||||
|
top_p=self.config.top_p,
|
||||||
|
n=num_answers,
|
||||||
|
stream=True,
|
||||||
|
frequency_penalty=self.config.frequency_penalty,
|
||||||
|
presence_penalty=self.config.presence_penalty)
|
||||||
|
streams: dict[int, OpenAIAnswer] = {}
|
||||||
|
for n in range(num_answers):
|
||||||
|
streams[n] = OpenAIAnswer(n, streams, response, tokens, self.encoding)
|
||||||
|
question.answer = Answer(streams[0].stream())
|
||||||
|
question.tags = set(otags) if otags is not None else None
|
||||||
|
question.ai = self.ID
|
||||||
|
question.model = self.config.model
|
||||||
|
answers: list[Message] = [question]
|
||||||
|
for idx in range(1, num_answers):
|
||||||
|
answers.append(Message(question=question.question,
|
||||||
|
answer=Answer(streams[idx].stream()),
|
||||||
|
tags=otags,
|
||||||
|
ai=self.ID,
|
||||||
|
model=self.config.model))
|
||||||
|
return AIResponse(answers, tokens)
|
||||||
|
|
||||||
|
def models(self) -> list[str]:
|
||||||
|
"""
|
||||||
|
Return all models supported by this AI.
|
||||||
|
"""
|
||||||
|
ret = []
|
||||||
|
for engine in sorted(self.client.models.list().data, key=lambda x: x.id):
|
||||||
|
ret.append(engine.id)
|
||||||
|
ret.sort()
|
||||||
|
return ret
|
||||||
|
|
||||||
|
def print_models(self) -> None:
|
||||||
|
"""
|
||||||
|
Print all models supported by the current AI.
|
||||||
|
"""
|
||||||
|
for model in self.models():
|
||||||
|
print(model)
|
||||||
|
|
||||||
|
def openai_chat(self, chat: Chat, system: str,
|
||||||
|
question: Optional[Message] = None) -> tuple[ChatType, int]:
|
||||||
|
"""
|
||||||
|
Create a chat history with system message in OpenAI format.
|
||||||
|
Optionally append a new question.
|
||||||
|
"""
|
||||||
|
oai_chat: ChatType = []
|
||||||
|
prompt_tokens: int = 0
|
||||||
|
|
||||||
|
def append(role: str, content: str) -> int:
|
||||||
|
oai_chat.append({'role': role, 'content': content.replace("''", "'")})
|
||||||
|
return len(self.encoding.encode(', '.join(['role:', oai_chat[-1]['role'], 'content:', oai_chat[-1]['content']])))
|
||||||
|
|
||||||
|
prompt_tokens += append('system', system)
|
||||||
|
for message in chat.messages:
|
||||||
|
if message.answer:
|
||||||
|
prompt_tokens += append('user', message.question)
|
||||||
|
prompt_tokens += append('assistant', str(message.answer))
|
||||||
|
if question:
|
||||||
|
prompt_tokens += append('user', question.question)
|
||||||
|
return oai_chat, prompt_tokens
|
||||||
|
|
||||||
|
def tokens(self, data: Union[Message, Chat]) -> int:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def print(self) -> None:
|
||||||
|
print(f"MODEL: {self.config.model}")
|
||||||
|
print("=== SYSTEM ===")
|
||||||
|
print(self.config.system)
|
||||||
@@ -1,35 +0,0 @@
|
|||||||
import openai
|
|
||||||
|
|
||||||
|
|
||||||
def openai_api_key(api_key: str) -> None:
|
|
||||||
openai.api_key = api_key
|
|
||||||
|
|
||||||
|
|
||||||
def display_models() -> None:
|
|
||||||
not_ready = []
|
|
||||||
for engine in sorted(openai.Engine.list()['data'], key=lambda x: x['id']):
|
|
||||||
if engine['ready']:
|
|
||||||
print(engine['id'])
|
|
||||||
else:
|
|
||||||
not_ready.append(engine['id'])
|
|
||||||
if len(not_ready) > 0:
|
|
||||||
print('\nNot ready: ' + ', '.join(not_ready))
|
|
||||||
|
|
||||||
|
|
||||||
def ai(chat: list[dict[str, str]],
|
|
||||||
config: dict,
|
|
||||||
number: int
|
|
||||||
) -> tuple[list[str], dict[str, int]]:
|
|
||||||
response = openai.ChatCompletion.create(
|
|
||||||
model=config['openai']['model'],
|
|
||||||
messages=chat,
|
|
||||||
temperature=config['openai']['temperature'],
|
|
||||||
max_tokens=config['openai']['max_tokens'],
|
|
||||||
top_p=config['openai']['top_p'],
|
|
||||||
n=number,
|
|
||||||
frequency_penalty=config['openai']['frequency_penalty'],
|
|
||||||
presence_penalty=config['openai']['presence_penalty'])
|
|
||||||
result = []
|
|
||||||
for choice in response['choices']: # type: ignore
|
|
||||||
result.append(choice['message']['content'].strip())
|
|
||||||
return result, dict(response['usage']) # type: ignore
|
|
||||||
@@ -0,0 +1,649 @@
|
|||||||
|
"""
|
||||||
|
Module implementing various chat classes and functions for managing a chat history.
|
||||||
|
"""
|
||||||
|
import shutil
|
||||||
|
from pathlib import Path
|
||||||
|
from pprint import PrettyPrinter
|
||||||
|
from pydoc import pager
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from enum import Enum
|
||||||
|
from typing import TypeVar, Type, Optional, Any, Callable, Union
|
||||||
|
from .configuration import default_config_file
|
||||||
|
from .message import Message, MessageFilter, MessageError, MessageFormat, message_in, message_valid_formats
|
||||||
|
from .tags import Tag
|
||||||
|
|
||||||
|
ChatInst = TypeVar('ChatInst', bound='Chat')
|
||||||
|
ChatDBInst = TypeVar('ChatDBInst', bound='ChatDB')
|
||||||
|
|
||||||
|
db_next_file = '.next'
|
||||||
|
ignored_files = [db_next_file, default_config_file]
|
||||||
|
msg_suffix = Message.file_suffix_write
|
||||||
|
|
||||||
|
|
||||||
|
class msg_location(Enum):
|
||||||
|
MEM = 'mem'
|
||||||
|
DISK = 'disk'
|
||||||
|
CACHE = 'cache'
|
||||||
|
DB = 'db'
|
||||||
|
ALL = 'all'
|
||||||
|
|
||||||
|
|
||||||
|
class ChatError(Exception):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def terminal_width() -> int:
|
||||||
|
return shutil.get_terminal_size().columns
|
||||||
|
|
||||||
|
|
||||||
|
def pp(*args: Any, **kwargs: Any) -> None:
|
||||||
|
return PrettyPrinter(width=terminal_width()).pprint(*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def print_paged(text: str) -> None:
|
||||||
|
pager(text)
|
||||||
|
|
||||||
|
|
||||||
|
def read_dir(dir_path: Path,
|
||||||
|
glob: Optional[str] = None,
|
||||||
|
mfilter: Optional[MessageFilter] = None) -> list[Message]:
|
||||||
|
"""
|
||||||
|
Reads the messages from the given folder.
|
||||||
|
Parameters:
|
||||||
|
* 'dir_path': source directory
|
||||||
|
* 'glob': if specified, files will be filtered using 'path.glob()',
|
||||||
|
otherwise it reads all files with the default message suffix
|
||||||
|
* 'mfilter': use with 'Message.from_file()' to filter messages
|
||||||
|
when reading them.
|
||||||
|
"""
|
||||||
|
messages: list[Message] = []
|
||||||
|
file_iter = dir_path.glob(glob) if glob else dir_path.glob(f'*{msg_suffix}')
|
||||||
|
for file_path in sorted(file_iter):
|
||||||
|
if (file_path.is_file()
|
||||||
|
and file_path.name not in ignored_files # noqa: W503
|
||||||
|
and file_path.suffix in Message.file_suffixes_read): # noqa: W503
|
||||||
|
try:
|
||||||
|
message = Message.from_file(file_path, mfilter)
|
||||||
|
if message:
|
||||||
|
messages.append(message)
|
||||||
|
except MessageError as e:
|
||||||
|
print(f"WARNING: Skipping message in '{file_path}': {str(e)}")
|
||||||
|
return messages
|
||||||
|
|
||||||
|
|
||||||
|
def make_file_path(dir_path: Path,
|
||||||
|
next_fid: Callable[[], int]) -> Path:
|
||||||
|
"""
|
||||||
|
Create a file_path for the given directory using the given ID generator function.
|
||||||
|
"""
|
||||||
|
file_path = dir_path / f"{next_fid():04d}{msg_suffix}"
|
||||||
|
while file_path.exists():
|
||||||
|
file_path = dir_path / f"{next_fid():04d}{msg_suffix}"
|
||||||
|
return file_path
|
||||||
|
|
||||||
|
|
||||||
|
def write_dir(dir_path: Path,
|
||||||
|
messages: list[Message],
|
||||||
|
next_fid: Callable[[], int],
|
||||||
|
mformat: MessageFormat = Message.default_format) -> None:
|
||||||
|
"""
|
||||||
|
Write all messages to the given directory. If a message has no file_path,
|
||||||
|
a new one will be created. If message.file_path exists, it will be modified
|
||||||
|
to point to the given directory.
|
||||||
|
Parameters:
|
||||||
|
* 'dir_path': destination directory
|
||||||
|
* 'messages': list of messages to write
|
||||||
|
* 'next_fid': callable that returns the next file ID
|
||||||
|
"""
|
||||||
|
for message in messages:
|
||||||
|
file_path = message.file_path
|
||||||
|
# message has no file_path: create one
|
||||||
|
if not file_path:
|
||||||
|
file_path = make_file_path(dir_path, next_fid)
|
||||||
|
# file_path does not point to given directory: modify it
|
||||||
|
elif not file_path.parent.samefile(dir_path):
|
||||||
|
file_path = dir_path / file_path.name
|
||||||
|
message.to_file(file_path, mformat=mformat)
|
||||||
|
|
||||||
|
|
||||||
|
def clear_dir(dir_path: Path,
|
||||||
|
glob: Optional[str] = None) -> None:
|
||||||
|
"""
|
||||||
|
Deletes all Message files in the given directory.
|
||||||
|
"""
|
||||||
|
file_iter = dir_path.glob(glob) if glob else dir_path.iterdir()
|
||||||
|
for file_path in file_iter:
|
||||||
|
if (file_path.is_file()
|
||||||
|
and file_path.name not in ignored_files # noqa: W503
|
||||||
|
and file_path.suffix in Message.file_suffixes_read): # noqa: W503
|
||||||
|
file_path.unlink(missing_ok=True)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Chat:
|
||||||
|
"""
|
||||||
|
A class containing a complete chat history.
|
||||||
|
"""
|
||||||
|
|
||||||
|
messages: list[Message]
|
||||||
|
|
||||||
|
def __post_init__(self) -> None:
|
||||||
|
self.validate()
|
||||||
|
|
||||||
|
def validate(self) -> None:
|
||||||
|
"""
|
||||||
|
Validate this Chat instance.
|
||||||
|
"""
|
||||||
|
def msg_paths(stem: str) -> list[str]:
|
||||||
|
return [str(fp) for fp in file_paths if fp.stem == stem]
|
||||||
|
file_paths: set[Path] = {m.file_path for m in self.messages if m.file_path is not None}
|
||||||
|
file_stems = [m.file_path.stem for m in self.messages if m.file_path is not None]
|
||||||
|
error = False
|
||||||
|
for fp in file_paths:
|
||||||
|
if file_stems.count(fp.stem) > 1:
|
||||||
|
print(f"ERROR: Found multiple copies of message '{fp.stem}': {msg_paths(fp.stem)}")
|
||||||
|
error = True
|
||||||
|
if error:
|
||||||
|
raise ChatError("Validation failed")
|
||||||
|
|
||||||
|
def msg_name_matches(self, file_path: Path, name: str) -> bool:
|
||||||
|
"""
|
||||||
|
Return True if the given name matches the given file_path.
|
||||||
|
Matching is True if:
|
||||||
|
* 'name' matches the full 'file_path'
|
||||||
|
* 'name' matches 'file_path.name' (i. e. including the suffix)
|
||||||
|
* 'name' matches 'file_path.stem' (i. e. without the suffix)
|
||||||
|
"""
|
||||||
|
return Path(name) == file_path or name == file_path.name or name == file_path.stem
|
||||||
|
|
||||||
|
def msg_filter(self, mfilter: MessageFilter) -> None:
|
||||||
|
"""
|
||||||
|
Use 'Message.match(mfilter) to remove all messages that
|
||||||
|
don't fulfill the filter requirements.
|
||||||
|
"""
|
||||||
|
self.messages = [m for m in self.messages if m.match(mfilter)]
|
||||||
|
|
||||||
|
def msg_sort(self, reverse: bool = False) -> None:
|
||||||
|
"""
|
||||||
|
Sort the messages according to 'Message.msg_id()'.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# the message may not have an ID if it doesn't have a file_path
|
||||||
|
self.messages.sort(key=lambda m: m.msg_id(), reverse=reverse)
|
||||||
|
except MessageError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def msg_unique_id(self) -> None:
|
||||||
|
"""
|
||||||
|
Remove duplicates from the internal messages, based on the msg_id (i. e. file_path).
|
||||||
|
Messages without a file_path are kept.
|
||||||
|
"""
|
||||||
|
old_msgs = self.messages.copy()
|
||||||
|
self.messages = []
|
||||||
|
for m in old_msgs:
|
||||||
|
if not message_in(m, self.messages):
|
||||||
|
self.messages.append(m)
|
||||||
|
self.msg_sort()
|
||||||
|
|
||||||
|
def msg_unique_content(self) -> None:
|
||||||
|
"""
|
||||||
|
Remove duplicates from the internal messages, based on the content (i. e. question + answer).
|
||||||
|
"""
|
||||||
|
self.messages = list(set(self.messages))
|
||||||
|
self.msg_sort()
|
||||||
|
|
||||||
|
def msg_clear(self) -> None:
|
||||||
|
"""
|
||||||
|
Delete all messages.
|
||||||
|
"""
|
||||||
|
self.messages = []
|
||||||
|
|
||||||
|
def msg_add(self, messages: list[Message]) -> None:
|
||||||
|
"""
|
||||||
|
Add new messages and sort them if possible.
|
||||||
|
"""
|
||||||
|
self.messages += messages
|
||||||
|
self.msg_sort()
|
||||||
|
|
||||||
|
def msg_latest(self, mfilter: Optional[MessageFilter] = None) -> Optional[Message]:
|
||||||
|
"""
|
||||||
|
Return the last added message (according to the file ID) that matches the given filter.
|
||||||
|
When containing messages without a valid file_path, it returns the latest message in
|
||||||
|
the internal list.
|
||||||
|
"""
|
||||||
|
if len(self.messages) > 0:
|
||||||
|
self.msg_sort()
|
||||||
|
for m in reversed(self.messages):
|
||||||
|
if mfilter is None or m.match(mfilter):
|
||||||
|
return m
|
||||||
|
return None
|
||||||
|
|
||||||
|
def msg_find(self, msg_names: list[str]) -> list[Message]:
|
||||||
|
"""
|
||||||
|
Search and return the messages with the given names. Names can either be filenames
|
||||||
|
(with or without suffix), full paths or Message.msg_id(). Messages that can't be
|
||||||
|
found are ignored (i. e. the caller should check the result if they require all
|
||||||
|
messages).
|
||||||
|
"""
|
||||||
|
return [m for m in self.messages
|
||||||
|
if any((m.file_path and self.msg_name_matches(m.file_path, mn)) for mn in msg_names)]
|
||||||
|
|
||||||
|
def msg_remove(self, msg_names: list[str]) -> None:
|
||||||
|
"""
|
||||||
|
Remove the messages with the given names. Names can either be filenames
|
||||||
|
(with or without suffix), full paths or Message.msg_id().
|
||||||
|
"""
|
||||||
|
self.messages = [m for m in self.messages
|
||||||
|
if not any((m.file_path and self.msg_name_matches(m.file_path, mn)) for mn in msg_names)]
|
||||||
|
self.msg_sort()
|
||||||
|
|
||||||
|
def msg_tags(self, prefix: Optional[str] = None, contain: Optional[str] = None) -> set[Tag]:
|
||||||
|
"""
|
||||||
|
Get the tags of all messages, optionally filtered by prefix or substring.
|
||||||
|
"""
|
||||||
|
tags: set[Tag] = set()
|
||||||
|
for m in self.messages:
|
||||||
|
tags |= m.filter_tags(prefix, contain)
|
||||||
|
return set(sorted(tags))
|
||||||
|
|
||||||
|
def msg_tags_frequency(self, prefix: Optional[str] = None, contain: Optional[str] = None) -> dict[Tag, int]:
|
||||||
|
"""
|
||||||
|
Get the frequency of all tags of all messages, optionally filtered by prefix or substring.
|
||||||
|
"""
|
||||||
|
tags: list[Tag] = []
|
||||||
|
for m in self.messages:
|
||||||
|
tags += [tag for tag in m.filter_tags(prefix, contain)]
|
||||||
|
return {tag: tags.count(tag) for tag in sorted(tags)}
|
||||||
|
|
||||||
|
def tokens(self) -> int:
|
||||||
|
"""
|
||||||
|
Returns the nr. of AI language tokens used by all messages in this chat.
|
||||||
|
If unknown, 0 is returned.
|
||||||
|
"""
|
||||||
|
return sum(m.tokens() for m in self.messages)
|
||||||
|
|
||||||
|
def print(self, source_code_only: bool = False,
|
||||||
|
with_metadata: bool = False,
|
||||||
|
paged: bool = True,
|
||||||
|
tight: bool = False) -> None:
|
||||||
|
output: list[str] = []
|
||||||
|
for message in self.messages:
|
||||||
|
if source_code_only:
|
||||||
|
output.append(message.to_str(source_code_only=True))
|
||||||
|
continue
|
||||||
|
output.append(message.to_str(with_metadata))
|
||||||
|
if not tight:
|
||||||
|
output.append('\n' + ('-' * terminal_width()) + '\n')
|
||||||
|
if paged:
|
||||||
|
print_paged('\n'.join(output))
|
||||||
|
else:
|
||||||
|
print(*output, sep='\n')
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ChatDB(Chat):
|
||||||
|
"""
|
||||||
|
A 'Chat' class that is bound to a given directory structure. Supports reading
|
||||||
|
and writing messages from / to that structure. Such a structure consists of
|
||||||
|
two directories: a 'cache directory', where all messages are temporarily
|
||||||
|
stored, and a 'DB' directory, where selected messages can be stored
|
||||||
|
persistently.
|
||||||
|
"""
|
||||||
|
|
||||||
|
cache_path: Path
|
||||||
|
db_path: Path
|
||||||
|
# a MessageFilter that all messages must match (if given)
|
||||||
|
mfilter: Optional[MessageFilter] = None
|
||||||
|
# the glob pattern for all messages
|
||||||
|
glob: str = f'*{msg_suffix}'
|
||||||
|
# message format (for writing)
|
||||||
|
mformat: MessageFormat = Message.default_format
|
||||||
|
|
||||||
|
def __post_init__(self) -> None:
|
||||||
|
# contains the latest message ID
|
||||||
|
self.next_path = self.db_path / db_next_file
|
||||||
|
# make all paths absolute
|
||||||
|
self.cache_path = self.cache_path.absolute()
|
||||||
|
self.db_path = self.db_path.absolute()
|
||||||
|
self.validate()
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_dir(cls: Type[ChatDBInst],
|
||||||
|
cache_path: Path,
|
||||||
|
db_path: Path,
|
||||||
|
glob: str = f'*{msg_suffix}',
|
||||||
|
mfilter: Optional[MessageFilter] = None,
|
||||||
|
loc: msg_location = msg_location.DB) -> ChatDBInst:
|
||||||
|
"""
|
||||||
|
Create a 'ChatDB' instance from the given directory structure.
|
||||||
|
Reads all messages from 'db_path' into the local message list.
|
||||||
|
Parameters:
|
||||||
|
* 'cache_path': path to the directory for temporary messages
|
||||||
|
* 'db_path': path to the directory for persistent messages
|
||||||
|
* 'glob': if specified, files will be filtered using 'path.glob()'
|
||||||
|
* 'mfilter': use with 'Message.from_file()' to filter messages
|
||||||
|
when reading them.
|
||||||
|
* 'loc': read messages from given location instead of 'db_path'
|
||||||
|
"""
|
||||||
|
if loc == msg_location.MEM:
|
||||||
|
raise ChatError(f"Can't build ChatDB from message location '{loc}'")
|
||||||
|
messages: list[Message] = []
|
||||||
|
if loc in [msg_location.DB, msg_location.DISK, msg_location.ALL]:
|
||||||
|
messages.extend(read_dir(db_path, glob, mfilter))
|
||||||
|
if loc in [msg_location.CACHE, msg_location.DISK, msg_location.ALL]:
|
||||||
|
messages.extend(read_dir(cache_path, glob, mfilter))
|
||||||
|
messages.sort(key=lambda x: x.msg_id())
|
||||||
|
return cls(messages, cache_path, db_path, mfilter, glob)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_messages(cls: Type[ChatDBInst],
|
||||||
|
cache_path: Path,
|
||||||
|
db_path: Path,
|
||||||
|
messages: list[Message],
|
||||||
|
mfilter: Optional[MessageFilter] = None) -> ChatDBInst:
|
||||||
|
"""
|
||||||
|
Create a ChatDB instance from the given message list.
|
||||||
|
"""
|
||||||
|
return cls(messages, cache_path, db_path, mfilter)
|
||||||
|
|
||||||
|
def get_next_fid(self) -> int:
|
||||||
|
try:
|
||||||
|
with open(self.next_path, 'r') as f:
|
||||||
|
next_fid = int(f.read()) + 1
|
||||||
|
self.set_next_fid(next_fid)
|
||||||
|
return next_fid
|
||||||
|
except Exception:
|
||||||
|
self.set_next_fid(1)
|
||||||
|
return 1
|
||||||
|
|
||||||
|
def set_next_fid(self, fid: int) -> None:
|
||||||
|
with open(self.next_path, 'w') as f:
|
||||||
|
f.write(f'{fid}')
|
||||||
|
|
||||||
|
def set_msg_format(self, mformat: MessageFormat) -> None:
|
||||||
|
"""
|
||||||
|
Set message format for writing messages.
|
||||||
|
"""
|
||||||
|
if mformat not in message_valid_formats:
|
||||||
|
raise ChatError(f"Message format '{mformat}' is not supported")
|
||||||
|
self.mformat = mformat
|
||||||
|
|
||||||
|
def msg_write(self,
|
||||||
|
messages: Optional[list[Message]] = None,
|
||||||
|
mformat: Optional[MessageFormat] = None) -> None:
|
||||||
|
"""
|
||||||
|
Write either the given messages or the internal ones to their CURRENT file_path.
|
||||||
|
If messages are given, they all must have a valid file_path. When writing the
|
||||||
|
internal messages, the ones with a valid file_path are written, the others
|
||||||
|
are ignored.
|
||||||
|
"""
|
||||||
|
if messages and any(m.file_path is None for m in messages):
|
||||||
|
raise ChatError("Can't write files without a valid file_path")
|
||||||
|
msgs = iter(messages if messages else self.messages)
|
||||||
|
while (m := next(msgs, None)):
|
||||||
|
m.to_file(mformat=mformat if mformat else self.mformat)
|
||||||
|
|
||||||
|
def msg_update(self, messages: list[Message], write: bool = True) -> None:
|
||||||
|
"""
|
||||||
|
Update EXISTING messages. A message is determined as 'existing' if a message with
|
||||||
|
the same base filename (i. e. 'file_path.name') is already in the list.
|
||||||
|
Only accepts existing messages.
|
||||||
|
"""
|
||||||
|
if any(not message_in(m, self.messages) for m in messages):
|
||||||
|
raise ChatError("Can't update messages that are not in the internal list")
|
||||||
|
# remove old versions and add new ones
|
||||||
|
self.messages = [m for m in self.messages if not message_in(m, messages)]
|
||||||
|
self.messages += messages
|
||||||
|
self.msg_sort()
|
||||||
|
# write the UPDATED messages if requested
|
||||||
|
if write:
|
||||||
|
self.msg_write(messages)
|
||||||
|
|
||||||
|
def msg_gather(self,
|
||||||
|
loc: msg_location,
|
||||||
|
require_file_path: bool = False,
|
||||||
|
glob: str = f'*{msg_suffix}',
|
||||||
|
mfilter: Optional[MessageFilter] = None) -> list[Message]:
|
||||||
|
"""
|
||||||
|
Gather and return messages from the given locations:
|
||||||
|
* 'mem' : messages currently in memory
|
||||||
|
* 'disk' : messages on disk (cache + DB directory), but not in memory
|
||||||
|
* 'cache': messages in the cache directory
|
||||||
|
* 'db' : messages in the DB directory
|
||||||
|
* 'all' : all messages ('mem' + 'disk')
|
||||||
|
|
||||||
|
If 'require_file_path' is True, return only files with a valid file_path.
|
||||||
|
"""
|
||||||
|
loc_messages: list[Message] = []
|
||||||
|
if loc in [msg_location.MEM, msg_location.ALL]:
|
||||||
|
if require_file_path:
|
||||||
|
loc_messages += [m for m in self.messages if (m.file_path is not None and (mfilter is None or m.match(mfilter)))]
|
||||||
|
else:
|
||||||
|
loc_messages += [m for m in self.messages if (mfilter is None or m.match(mfilter))]
|
||||||
|
if loc in [msg_location.CACHE, msg_location.DISK, msg_location.ALL]:
|
||||||
|
loc_messages += read_dir(self.cache_path, glob=glob, mfilter=mfilter)
|
||||||
|
if loc in [msg_location.DB, msg_location.DISK, msg_location.ALL]:
|
||||||
|
loc_messages += read_dir(self.db_path, glob=glob, mfilter=mfilter)
|
||||||
|
# remove_duplicates and sort the list
|
||||||
|
unique_messages: list[Message] = []
|
||||||
|
for m in loc_messages:
|
||||||
|
if not message_in(m, unique_messages):
|
||||||
|
unique_messages.append(m)
|
||||||
|
try:
|
||||||
|
unique_messages.sort(key=lambda m: m.msg_id())
|
||||||
|
# messages in 'mem' can have an empty file_path
|
||||||
|
except MessageError:
|
||||||
|
pass
|
||||||
|
return unique_messages
|
||||||
|
|
||||||
|
def msg_find(self,
|
||||||
|
msg_names: list[str],
|
||||||
|
loc: msg_location = msg_location.MEM,
|
||||||
|
) -> list[Message]:
|
||||||
|
"""
|
||||||
|
Search and return the messages with the given names. Names can either be filenames
|
||||||
|
(with or without suffix), full paths or Message.msg_id(). Messages that can't be
|
||||||
|
found are ignored (i. e. the caller should check the result if they require all
|
||||||
|
messages).
|
||||||
|
Searches one of the following locations:
|
||||||
|
* 'mem' : messages currently in memory
|
||||||
|
* 'disk' : messages on disk (cache + DB directory), but not in memory
|
||||||
|
* 'cache': messages in the cache directory
|
||||||
|
* 'db' : messages in the DB directory
|
||||||
|
* 'all' : all messages ('mem' + 'disk')
|
||||||
|
"""
|
||||||
|
loc_messages = self.msg_gather(loc, require_file_path=True)
|
||||||
|
return [m for m in loc_messages
|
||||||
|
if any((m.file_path and self.msg_name_matches(m.file_path, mn)) for mn in msg_names)]
|
||||||
|
|
||||||
|
def msg_remove(self, msg_names: list[str], loc: msg_location = msg_location.MEM) -> None:
|
||||||
|
"""
|
||||||
|
Remove the messages with the given names. Names can either be filenames
|
||||||
|
(with or without suffix), full paths or Message.msg_id(). Also deletes the
|
||||||
|
files of all given messages with a valid file_path.
|
||||||
|
Delete files from one of the following locations:
|
||||||
|
* 'mem' : messages currently in memory
|
||||||
|
* 'disk' : messages on disk (cache + DB directory), but not in memory
|
||||||
|
* 'cache': messages in the cache directory
|
||||||
|
* 'db' : messages in the DB directory
|
||||||
|
* 'all' : all messages ('mem' + 'disk')
|
||||||
|
"""
|
||||||
|
if loc != msg_location.MEM:
|
||||||
|
# delete the message files first
|
||||||
|
rm_messages = self.msg_find(msg_names, loc=loc)
|
||||||
|
for m in rm_messages:
|
||||||
|
if (m.file_path):
|
||||||
|
m.file_path.unlink()
|
||||||
|
# then remove them from the internal list
|
||||||
|
super().msg_remove(msg_names)
|
||||||
|
|
||||||
|
def msg_latest(self,
|
||||||
|
mfilter: Optional[MessageFilter] = None,
|
||||||
|
loc: msg_location = msg_location.MEM) -> Optional[Message]:
|
||||||
|
"""
|
||||||
|
Return the last added message (according to the file ID) that matches the given filter.
|
||||||
|
Only consider messages with a valid file_path (except if loc is 'mem').
|
||||||
|
Searches one of the following locations:
|
||||||
|
* 'mem' : messages currently in memory
|
||||||
|
* 'disk' : messages on disk (cache + DB directory), but not in memory
|
||||||
|
* 'cache': messages in the cache directory
|
||||||
|
* 'db' : messages in the DB directory
|
||||||
|
* 'all' : all messages ('mem' + 'disk')
|
||||||
|
"""
|
||||||
|
# only consider messages with a valid file_path so they can be sorted
|
||||||
|
loc_messages = self.msg_gather(loc, require_file_path=True)
|
||||||
|
loc_messages.sort(key=lambda m: m.msg_id(), reverse=True)
|
||||||
|
for m in loc_messages:
|
||||||
|
if mfilter is None or m.match(mfilter):
|
||||||
|
return m
|
||||||
|
return None
|
||||||
|
|
||||||
|
def msg_in_cache(self, message: Union[Message, str]) -> bool:
|
||||||
|
"""
|
||||||
|
Return true if the given Message (or filename or Message.msg_id())
|
||||||
|
is located in the cache directory. False otherwise.
|
||||||
|
"""
|
||||||
|
if isinstance(message, Message):
|
||||||
|
return (message.file_path is not None
|
||||||
|
and message.file_path.parent.samefile(self.cache_path) # noqa: W503
|
||||||
|
and message.file_path.exists()) # noqa: W503
|
||||||
|
else:
|
||||||
|
return len(self.msg_find([message], loc=msg_location.CACHE)) > 0
|
||||||
|
|
||||||
|
def msg_in_db(self, message: Union[Message, str]) -> bool:
|
||||||
|
"""
|
||||||
|
Return true if the given Message (or filename or Message.msg_id())
|
||||||
|
is located in the DB directory. False otherwise.
|
||||||
|
"""
|
||||||
|
if isinstance(message, Message):
|
||||||
|
return (message.file_path is not None
|
||||||
|
and message.file_path.parent.samefile(self.db_path) # noqa: W503
|
||||||
|
and message.file_path.exists()) # noqa: W503
|
||||||
|
else:
|
||||||
|
return len(self.msg_find([message], loc=msg_location.DB)) > 0
|
||||||
|
|
||||||
|
def cache_read(self, glob: str = f'*{msg_suffix}', mfilter: Optional[MessageFilter] = None) -> None:
|
||||||
|
"""
|
||||||
|
Read messages from the cache directory. New ones are added to the internal list,
|
||||||
|
existing ones are replaced. A message is determined as 'existing' if a message
|
||||||
|
with the same base filename (i. e. 'file_path.name') is already in the list.
|
||||||
|
"""
|
||||||
|
new_messages = read_dir(self.cache_path, glob, mfilter)
|
||||||
|
# remove all messages from self.messages that are in the new list
|
||||||
|
self.messages = [m for m in self.messages if not message_in(m, new_messages)]
|
||||||
|
# copy the messages from the temporary list to self.messages and sort them
|
||||||
|
self.messages += new_messages
|
||||||
|
self.msg_sort()
|
||||||
|
|
||||||
|
def cache_write(self, messages: Optional[list[Message]] = None) -> None:
|
||||||
|
"""
|
||||||
|
Write messages to the cache directory. If a message has no file_path, a new one
|
||||||
|
will be created. If message.file_path exists, it will be modified to point to
|
||||||
|
the cache directory.
|
||||||
|
Does NOT add the messages to the internal list (use 'cache_add()' for that)!
|
||||||
|
"""
|
||||||
|
write_dir(self.cache_path,
|
||||||
|
messages if messages else self.messages,
|
||||||
|
self.get_next_fid,
|
||||||
|
self.mformat)
|
||||||
|
|
||||||
|
def cache_add(self, messages: list[Message], write: bool = True) -> None:
|
||||||
|
"""
|
||||||
|
Add NEW messages and set the file_path to the cache directory.
|
||||||
|
Only accepts messages without a file_path.
|
||||||
|
"""
|
||||||
|
if any(m.file_path is not None for m in messages):
|
||||||
|
raise ChatError("Can't add new messages with existing file_path")
|
||||||
|
if write:
|
||||||
|
write_dir(self.cache_path,
|
||||||
|
messages,
|
||||||
|
self.get_next_fid,
|
||||||
|
self.mformat)
|
||||||
|
else:
|
||||||
|
for m in messages:
|
||||||
|
m.file_path = make_file_path(self.cache_path, self.get_next_fid)
|
||||||
|
self.messages += messages
|
||||||
|
self.msg_sort()
|
||||||
|
|
||||||
|
def cache_clear(self, glob: str = f'*{msg_suffix}') -> None:
|
||||||
|
"""
|
||||||
|
Delete all message files from the cache dir and remove them from the internal list.
|
||||||
|
"""
|
||||||
|
clear_dir(self.cache_path, glob)
|
||||||
|
# only keep messages from DB dir (or those that have not yet been written)
|
||||||
|
self.messages = [m for m in self.messages if not m.file_path or m.file_path.parent.samefile(self.db_path)]
|
||||||
|
|
||||||
|
def cache_move(self, message: Message) -> None:
|
||||||
|
"""
|
||||||
|
Moves the given messages to the cache directory.
|
||||||
|
"""
|
||||||
|
# remember the old path (if any)
|
||||||
|
old_path: Optional[Path] = None
|
||||||
|
if message.file_path:
|
||||||
|
old_path = message.file_path
|
||||||
|
# write message to the new destination
|
||||||
|
self.cache_write([message])
|
||||||
|
# remove the old one (if any)
|
||||||
|
if old_path:
|
||||||
|
self.msg_remove([str(old_path)], loc=msg_location.DB)
|
||||||
|
# (re)add it to the internal list
|
||||||
|
self.msg_add([message])
|
||||||
|
|
||||||
|
def db_read(self, glob: str = f'*{msg_suffix}', mfilter: Optional[MessageFilter] = None) -> None:
|
||||||
|
"""
|
||||||
|
Read messages from the DB directory. New ones are added to the internal list,
|
||||||
|
existing ones are replaced. A message is determined as 'existing' if a message
|
||||||
|
with the same base filename (i. e. 'file_path.name') is already in the list.
|
||||||
|
"""
|
||||||
|
new_messages = read_dir(self.db_path, self.glob, self.mfilter)
|
||||||
|
# remove all messages from self.messages that are in the new list
|
||||||
|
self.messages = [m for m in self.messages if not message_in(m, new_messages)]
|
||||||
|
# copy the messages from the temporary list to self.messages and sort them
|
||||||
|
self.messages += new_messages
|
||||||
|
self.msg_sort()
|
||||||
|
|
||||||
|
def db_write(self, messages: Optional[list[Message]] = None) -> None:
|
||||||
|
"""
|
||||||
|
Write messages to the DB directory. If a message has no file_path, a new one
|
||||||
|
will be created. If message.file_path exists, it will be modified to point
|
||||||
|
to the DB directory.
|
||||||
|
Does NOT add the messages to the internal list (use 'db_add()' for that)!
|
||||||
|
"""
|
||||||
|
write_dir(self.db_path,
|
||||||
|
messages if messages else self.messages,
|
||||||
|
self.get_next_fid,
|
||||||
|
self.mformat)
|
||||||
|
|
||||||
|
def db_add(self, messages: list[Message], write: bool = True) -> None:
|
||||||
|
"""
|
||||||
|
Add NEW messages and set the file_path to the DB directory.
|
||||||
|
Only accepts messages without a file_path.
|
||||||
|
"""
|
||||||
|
if any(m.file_path is not None for m in messages):
|
||||||
|
raise ChatError("Can't add new messages with existing file_path")
|
||||||
|
if write:
|
||||||
|
write_dir(self.db_path,
|
||||||
|
messages,
|
||||||
|
self.get_next_fid,
|
||||||
|
self.mformat)
|
||||||
|
else:
|
||||||
|
for m in messages:
|
||||||
|
m.file_path = make_file_path(self.db_path, self.get_next_fid)
|
||||||
|
self.messages += messages
|
||||||
|
self.msg_sort()
|
||||||
|
|
||||||
|
def db_move(self, message: Message) -> None:
|
||||||
|
"""
|
||||||
|
Moves the given messages to the db directory.
|
||||||
|
"""
|
||||||
|
# remember the old path (if any)
|
||||||
|
old_path: Optional[Path] = None
|
||||||
|
if message.file_path:
|
||||||
|
old_path = message.file_path
|
||||||
|
# write message to the new destination
|
||||||
|
self.db_write([message])
|
||||||
|
# remove the old one (if any)
|
||||||
|
if old_path:
|
||||||
|
self.msg_remove([str(old_path)], loc=msg_location.CACHE)
|
||||||
|
# (re)add it to the internal list
|
||||||
|
self.msg_add([message])
|
||||||
@@ -0,0 +1,20 @@
|
|||||||
|
import argparse
|
||||||
|
from pathlib import Path
|
||||||
|
from ..configuration import Config
|
||||||
|
from ..ai import AI
|
||||||
|
from ..ai_factory import create_ai
|
||||||
|
|
||||||
|
|
||||||
|
def config_cmd(args: argparse.Namespace) -> None:
|
||||||
|
"""
|
||||||
|
Handler for the 'config' command.
|
||||||
|
"""
|
||||||
|
if args.create:
|
||||||
|
Config.create_default(Path(args.create))
|
||||||
|
elif args.list_models or args.print_model:
|
||||||
|
config: Config = Config.from_file(args.config)
|
||||||
|
ai: AI = create_ai(args, config)
|
||||||
|
if args.list_models:
|
||||||
|
ai.print_models()
|
||||||
|
else:
|
||||||
|
print(ai.config.model)
|
||||||
@@ -0,0 +1,75 @@
|
|||||||
|
import sys
|
||||||
|
import argparse
|
||||||
|
from pathlib import Path
|
||||||
|
from ..configuration import Config
|
||||||
|
from ..chat import ChatDB, msg_location
|
||||||
|
from ..message import MessageFilter, Message
|
||||||
|
|
||||||
|
|
||||||
|
msg_suffix = Message.file_suffix_write # currently '.msg'
|
||||||
|
|
||||||
|
|
||||||
|
def convert_messages(args: argparse.Namespace, config: Config) -> None:
|
||||||
|
"""
|
||||||
|
Convert messages to a new format. Also used to change old suffixes
|
||||||
|
('.txt', '.yaml') to the latest default message file suffix ('.msg').
|
||||||
|
"""
|
||||||
|
chat = ChatDB.from_dir(Path(config.cache),
|
||||||
|
Path(config.db),
|
||||||
|
glob='*')
|
||||||
|
# read all known message files
|
||||||
|
msgs = chat.msg_gather(loc=msg_location.DISK, glob='*.*')
|
||||||
|
# make a set of all message IDs
|
||||||
|
msg_ids = set([m.msg_id() for m in msgs])
|
||||||
|
# set requested format and write all messages
|
||||||
|
chat.set_msg_format(args.convert)
|
||||||
|
# delete the current suffix
|
||||||
|
# -> a new one will automatically be created
|
||||||
|
for m in msgs:
|
||||||
|
if m.file_path:
|
||||||
|
m.file_path = m.file_path.with_suffix('')
|
||||||
|
chat.msg_write(msgs)
|
||||||
|
# read all messages with the current default suffix
|
||||||
|
msgs = chat.msg_gather(loc=msg_location.DISK, glob=f'*{msg_suffix}')
|
||||||
|
# make sure we converted all of the original messages
|
||||||
|
for mid in msg_ids:
|
||||||
|
if not any(mid == m.msg_id() for m in msgs):
|
||||||
|
print(f"Message '{mid}' has not been found after conversion. Aborting.")
|
||||||
|
sys.exit(1)
|
||||||
|
# delete messages with old suffixes
|
||||||
|
msgs = chat.msg_gather(loc=msg_location.DISK, glob='*.*')
|
||||||
|
for m in msgs:
|
||||||
|
if m.file_path and m.file_path.suffix != msg_suffix:
|
||||||
|
m.rm_file()
|
||||||
|
print(f"Successfully converted {len(msg_ids)} messages.")
|
||||||
|
|
||||||
|
|
||||||
|
def print_chat(args: argparse.Namespace, config: Config) -> None:
|
||||||
|
"""
|
||||||
|
Print the DB chat history.
|
||||||
|
"""
|
||||||
|
|
||||||
|
mfilter = MessageFilter(tags_or=args.or_tags,
|
||||||
|
tags_and=args.and_tags,
|
||||||
|
tags_not=args.exclude_tags,
|
||||||
|
question_contains=args.question,
|
||||||
|
answer_contains=args.answer)
|
||||||
|
chat = ChatDB.from_dir(Path(config.cache),
|
||||||
|
Path(config.db),
|
||||||
|
mfilter=mfilter,
|
||||||
|
loc=msg_location(args.location),
|
||||||
|
glob=args.glob)
|
||||||
|
chat.print(args.source_code_only,
|
||||||
|
args.with_metadata,
|
||||||
|
paged=not args.no_paging,
|
||||||
|
tight=args.tight)
|
||||||
|
|
||||||
|
|
||||||
|
def hist_cmd(args: argparse.Namespace, config: Config) -> None:
|
||||||
|
"""
|
||||||
|
Handler for the 'hist' command.
|
||||||
|
"""
|
||||||
|
if args.print:
|
||||||
|
print_chat(args, config)
|
||||||
|
elif args.convert:
|
||||||
|
convert_messages(args, config)
|
||||||
@@ -0,0 +1,45 @@
|
|||||||
|
import sys
|
||||||
|
import argparse
|
||||||
|
from pathlib import Path
|
||||||
|
from ..configuration import Config
|
||||||
|
from ..message import Message, MessageError
|
||||||
|
from ..chat import ChatDB, msg_location
|
||||||
|
|
||||||
|
|
||||||
|
def print_message(message: Message, args: argparse.Namespace) -> None:
|
||||||
|
"""
|
||||||
|
Print given message according to give arguments.
|
||||||
|
"""
|
||||||
|
if args.question:
|
||||||
|
print(message.question)
|
||||||
|
elif args.answer:
|
||||||
|
print(message.answer)
|
||||||
|
elif message.answer and args.only_source_code:
|
||||||
|
for code in message.answer.source_code():
|
||||||
|
print(code)
|
||||||
|
else:
|
||||||
|
print(message.to_str())
|
||||||
|
|
||||||
|
|
||||||
|
def print_cmd(args: argparse.Namespace, config: Config) -> None:
|
||||||
|
"""
|
||||||
|
Handler for the 'print' command.
|
||||||
|
"""
|
||||||
|
# print given file
|
||||||
|
if args.file is not None:
|
||||||
|
fname = Path(args.file)
|
||||||
|
try:
|
||||||
|
message = Message.from_file(fname)
|
||||||
|
if message:
|
||||||
|
print_message(message, args)
|
||||||
|
except MessageError:
|
||||||
|
print(f"File is not a valid message: {args.file}")
|
||||||
|
sys.exit(1)
|
||||||
|
# print latest message
|
||||||
|
elif args.latest:
|
||||||
|
chat = ChatDB.from_dir(Path(config.cache), Path(config.db))
|
||||||
|
latest = chat.msg_latest(loc=msg_location.DISK)
|
||||||
|
if not latest:
|
||||||
|
print("No message found!")
|
||||||
|
sys.exit(1)
|
||||||
|
print_message(latest, args)
|
||||||
@@ -0,0 +1,223 @@
|
|||||||
|
import sys
|
||||||
|
import argparse
|
||||||
|
from pathlib import Path
|
||||||
|
from itertools import zip_longest
|
||||||
|
from copy import deepcopy
|
||||||
|
from ..configuration import Config
|
||||||
|
from ..chat import ChatDB, msg_location
|
||||||
|
from ..message import Message, MessageFilter, MessageError, Question, source_code
|
||||||
|
from ..ai_factory import create_ai
|
||||||
|
from ..ai import AI, AIResponse
|
||||||
|
|
||||||
|
|
||||||
|
class QuestionCmdError(Exception):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def add_file_as_text(question_parts: list[str], file: str) -> None:
|
||||||
|
"""
|
||||||
|
Add the given file as plain text to the question part list.
|
||||||
|
If the file is a Message, add the answer.
|
||||||
|
"""
|
||||||
|
file_path = Path(file)
|
||||||
|
content: str
|
||||||
|
try:
|
||||||
|
message = Message.from_file(file_path)
|
||||||
|
if message and message.answer:
|
||||||
|
content = message.answer
|
||||||
|
except MessageError:
|
||||||
|
with open(file) as r:
|
||||||
|
content = r.read().strip()
|
||||||
|
if len(content) > 0:
|
||||||
|
question_parts.append(content)
|
||||||
|
|
||||||
|
|
||||||
|
def add_file_as_code(question_parts: list[str], file: str) -> None:
|
||||||
|
"""
|
||||||
|
Add all source code from the given file. If no code segments can be extracted,
|
||||||
|
the whole content is added as source code segment. If the file is a Message,
|
||||||
|
extract the source code from the answer.
|
||||||
|
"""
|
||||||
|
file_path = Path(file)
|
||||||
|
content: str
|
||||||
|
try:
|
||||||
|
message = Message.from_file(file_path)
|
||||||
|
if message and message.answer:
|
||||||
|
content = message.answer
|
||||||
|
except MessageError:
|
||||||
|
with open(file) as r:
|
||||||
|
content = r.read().strip()
|
||||||
|
# extract and add source code
|
||||||
|
code_parts = source_code(content, include_delims=True)
|
||||||
|
if len(code_parts) > 0:
|
||||||
|
question_parts += code_parts
|
||||||
|
else:
|
||||||
|
question_parts.append(f"```\n{content}\n```")
|
||||||
|
|
||||||
|
|
||||||
|
def create_msg_args(msg: Message, args: argparse.Namespace) -> argparse.Namespace:
|
||||||
|
"""
|
||||||
|
Takes an existing message and CLI arguments, and returns modified args based
|
||||||
|
on the members of the given message. Used e.g. when repeating messages, where
|
||||||
|
it's necessary to determine the correct AI, module and output tags to use
|
||||||
|
(either from the existing message or the given args).
|
||||||
|
"""
|
||||||
|
msg_args = args
|
||||||
|
# if AI, model or output tags have not been specified,
|
||||||
|
# use those from the original message
|
||||||
|
if (args.AI is None
|
||||||
|
or args.model is None # noqa: W503
|
||||||
|
or args.output_tags is None): # noqa: W503
|
||||||
|
msg_args = deepcopy(args)
|
||||||
|
if args.AI is None and msg.ai is not None:
|
||||||
|
msg_args.AI = msg.ai
|
||||||
|
if args.model is None and msg.model is not None:
|
||||||
|
msg_args.model = msg.model
|
||||||
|
if args.output_tags is None and msg.tags is not None:
|
||||||
|
msg_args.output_tags = msg.tags
|
||||||
|
return msg_args
|
||||||
|
|
||||||
|
|
||||||
|
def create_message(chat: ChatDB, args: argparse.Namespace) -> Message:
|
||||||
|
"""
|
||||||
|
Create a new message from the given arguments and write it
|
||||||
|
to the cache directory.
|
||||||
|
"""
|
||||||
|
question_parts = []
|
||||||
|
if args.create is not None:
|
||||||
|
question_list = args.create
|
||||||
|
elif args.ask is not None:
|
||||||
|
question_list = args.ask
|
||||||
|
else:
|
||||||
|
raise QuestionCmdError("No question found")
|
||||||
|
text_files = args.source_text if args.source_text is not None else []
|
||||||
|
code_files = args.source_code if args.source_code is not None else []
|
||||||
|
|
||||||
|
for question, text_file, code_file in zip_longest(question_list, text_files, code_files, fillvalue=None):
|
||||||
|
if question is not None and len(question.strip()) > 0:
|
||||||
|
question_parts.append(question)
|
||||||
|
if text_file is not None and len(text_file) > 0:
|
||||||
|
add_file_as_text(question_parts, text_file)
|
||||||
|
if code_file is not None and len(code_file) > 0:
|
||||||
|
add_file_as_code(question_parts, code_file)
|
||||||
|
|
||||||
|
full_question = '\n\n'.join([str(s) for s in question_parts])
|
||||||
|
|
||||||
|
message = Message(question=Question(full_question),
|
||||||
|
tags=args.output_tags,
|
||||||
|
ai=args.AI,
|
||||||
|
model=args.model)
|
||||||
|
# only write the new message to the cache,
|
||||||
|
# don't add it to the internal list
|
||||||
|
chat.cache_write([message])
|
||||||
|
return message
|
||||||
|
|
||||||
|
|
||||||
|
def make_request(ai: AI, chat: ChatDB, message: Message, args: argparse.Namespace) -> None:
|
||||||
|
"""
|
||||||
|
Make an AI request with the given AI, chat history, message and arguments.
|
||||||
|
Write the response(s) to the cache directory, without appending it to the
|
||||||
|
given chat history. Then print the response(s).
|
||||||
|
"""
|
||||||
|
# print history and message question before making the request
|
||||||
|
ai.print()
|
||||||
|
chat.print(paged=False)
|
||||||
|
print(message.to_str())
|
||||||
|
response: AIResponse = ai.request(message,
|
||||||
|
chat,
|
||||||
|
args.num_answers,
|
||||||
|
args.output_tags)
|
||||||
|
# only write the response messages to the cache,
|
||||||
|
# don't add them to the internal list
|
||||||
|
for idx, msg in enumerate(response.messages):
|
||||||
|
print(f"=== ANSWER {idx+1} ===", flush=True)
|
||||||
|
if msg.answer:
|
||||||
|
for piece in msg.answer:
|
||||||
|
print(piece, end='', flush=True)
|
||||||
|
print()
|
||||||
|
if response.tokens:
|
||||||
|
print("===============")
|
||||||
|
print(response.tokens)
|
||||||
|
chat.cache_write(response.messages)
|
||||||
|
|
||||||
|
|
||||||
|
def repeat_messages(messages: list[Message], chat: ChatDB, args: argparse.Namespace, config: Config) -> None:
|
||||||
|
"""
|
||||||
|
Repeat the given messages using the given arguments.
|
||||||
|
"""
|
||||||
|
ai: AI
|
||||||
|
for msg in messages:
|
||||||
|
msg_args = create_msg_args(msg, args)
|
||||||
|
ai = create_ai(msg_args, config)
|
||||||
|
print(f"--------- Repeating message '{msg.msg_id()}': ---------")
|
||||||
|
# overwrite the latest message if requested or empty
|
||||||
|
# -> but not if it's in the DB!
|
||||||
|
if ((msg.answer is None or msg_args.overwrite is True)
|
||||||
|
and (not chat.msg_in_db(msg))): # noqa: W503
|
||||||
|
msg.clear_answer()
|
||||||
|
make_request(ai, chat, msg, msg_args)
|
||||||
|
# otherwise create a new one
|
||||||
|
else:
|
||||||
|
msg_args.ask = [msg.question]
|
||||||
|
message = create_message(chat, msg_args)
|
||||||
|
make_request(ai, chat, message, msg_args)
|
||||||
|
|
||||||
|
|
||||||
|
def invert_input_tag_args(args: argparse.Namespace) -> None:
|
||||||
|
"""
|
||||||
|
Changes the semantics of the INPUT tags for this command:
|
||||||
|
* not tags specified on the CLI -> no tags are selected
|
||||||
|
* empty tags specified on the CLI -> all tags are selected
|
||||||
|
"""
|
||||||
|
if args.or_tags is None:
|
||||||
|
args.or_tags = set()
|
||||||
|
elif len(args.or_tags) == 0:
|
||||||
|
args.or_tags = None
|
||||||
|
if args.and_tags is None:
|
||||||
|
args.and_tags = set()
|
||||||
|
elif len(args.and_tags) == 0:
|
||||||
|
args.and_tags = None
|
||||||
|
|
||||||
|
|
||||||
|
def question_cmd(args: argparse.Namespace, config: Config) -> None:
|
||||||
|
"""
|
||||||
|
Handler for the 'question' command.
|
||||||
|
"""
|
||||||
|
invert_input_tag_args(args)
|
||||||
|
mfilter = MessageFilter(tags_or=args.or_tags,
|
||||||
|
tags_and=args.and_tags,
|
||||||
|
tags_not=args.exclude_tags)
|
||||||
|
chat = ChatDB.from_dir(cache_path=Path(config.cache),
|
||||||
|
db_path=Path(config.db),
|
||||||
|
mfilter=mfilter,
|
||||||
|
glob=args.glob,
|
||||||
|
loc=msg_location(args.location))
|
||||||
|
# if it's a new question, create and store it immediately
|
||||||
|
if args.ask or args.create:
|
||||||
|
message = create_message(chat, args)
|
||||||
|
if args.create:
|
||||||
|
return
|
||||||
|
|
||||||
|
# === ASK ===
|
||||||
|
if args.ask:
|
||||||
|
ai: AI = create_ai(args, config)
|
||||||
|
make_request(ai, chat, message, args)
|
||||||
|
# === REPEAT ===
|
||||||
|
elif args.repeat is not None:
|
||||||
|
repeat_msgs: list[Message] = []
|
||||||
|
# repeat latest message
|
||||||
|
if len(args.repeat) == 0:
|
||||||
|
lmessage = chat.msg_latest(loc=msg_location.CACHE)
|
||||||
|
if lmessage is None:
|
||||||
|
print("No message found to repeat!")
|
||||||
|
sys.exit(1)
|
||||||
|
repeat_msgs.append(lmessage)
|
||||||
|
# repeat given message(s)
|
||||||
|
else:
|
||||||
|
repeat_msgs = chat.msg_find(args.repeat, loc=msg_location.DISK)
|
||||||
|
repeat_messages(repeat_msgs, chat, args, config)
|
||||||
|
# === PROCESS ===
|
||||||
|
elif args.process is not None:
|
||||||
|
# TODO: process either all questions without an
|
||||||
|
# answer or the one(s) given in 'args.process'
|
||||||
|
pass
|
||||||
@@ -0,0 +1,17 @@
|
|||||||
|
import argparse
|
||||||
|
from pathlib import Path
|
||||||
|
from ..configuration import Config
|
||||||
|
from ..chat import ChatDB
|
||||||
|
|
||||||
|
|
||||||
|
def tags_cmd(args: argparse.Namespace, config: Config) -> None:
|
||||||
|
"""
|
||||||
|
Handler for the 'tags' command.
|
||||||
|
"""
|
||||||
|
chat = ChatDB.from_dir(cache_path=Path(config.cache),
|
||||||
|
db_path=Path(config.db))
|
||||||
|
if args.list:
|
||||||
|
tags_freq = chat.msg_tags_frequency(args.prefix, args.contain)
|
||||||
|
for tag, freq in tags_freq.items():
|
||||||
|
print(f"- {tag}: {freq}")
|
||||||
|
# TODO: add renaming
|
||||||
@@ -0,0 +1,169 @@
|
|||||||
|
import yaml
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Type, TypeVar, Any, Optional, ClassVar
|
||||||
|
from dataclasses import dataclass, asdict, field
|
||||||
|
|
||||||
|
ConfigInst = TypeVar('ConfigInst', bound='Config')
|
||||||
|
AIConfigInst = TypeVar('AIConfigInst', bound='AIConfig')
|
||||||
|
OpenAIConfigInst = TypeVar('OpenAIConfigInst', bound='OpenAIConfig')
|
||||||
|
|
||||||
|
|
||||||
|
supported_ais: list[str] = ['openai']
|
||||||
|
default_config_file = '.config.yaml'
|
||||||
|
|
||||||
|
|
||||||
|
class ConfigError(Exception):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def str_presenter(dumper: yaml.Dumper, data: str) -> yaml.ScalarNode:
|
||||||
|
"""
|
||||||
|
Changes the YAML dump style to multiline syntax for multiline strings.
|
||||||
|
"""
|
||||||
|
if len(data.splitlines()) > 1:
|
||||||
|
return dumper.represent_scalar('tag:yaml.org,2002:str', data, style='|')
|
||||||
|
return dumper.represent_scalar('tag:yaml.org,2002:str', data)
|
||||||
|
|
||||||
|
|
||||||
|
yaml.add_representer(str, str_presenter)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class AIConfig:
|
||||||
|
"""
|
||||||
|
The base class of all AI configurations.
|
||||||
|
"""
|
||||||
|
# the name of the AI the config class represents
|
||||||
|
# -> it's a class variable and thus not part of the
|
||||||
|
# dataclass constructor
|
||||||
|
name: ClassVar[str]
|
||||||
|
# a user-defined ID for an AI configuration entry
|
||||||
|
ID: str
|
||||||
|
model: str = 'n/a'
|
||||||
|
|
||||||
|
# the name must not be changed
|
||||||
|
def __setattr__(self, name: str, value: Any) -> None:
|
||||||
|
if name == 'name':
|
||||||
|
raise AttributeError("'{name}' is not allowed to be changed")
|
||||||
|
else:
|
||||||
|
super().__setattr__(name, value)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class OpenAIConfig(AIConfig):
|
||||||
|
"""
|
||||||
|
The OpenAI section of the configuration file.
|
||||||
|
"""
|
||||||
|
name: ClassVar[str] = 'openai'
|
||||||
|
|
||||||
|
# all members have default values, so we can easily create
|
||||||
|
# a default configuration
|
||||||
|
ID: str = 'myopenai'
|
||||||
|
api_key: str = '0123456789'
|
||||||
|
model: str = 'gpt-3.5-turbo-16k'
|
||||||
|
temperature: float = 1.0
|
||||||
|
max_tokens: int = 4000
|
||||||
|
top_p: float = 1.0
|
||||||
|
frequency_penalty: float = 0.0
|
||||||
|
presence_penalty: float = 0.0
|
||||||
|
system: str = 'You are an assistant'
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_dict(cls: Type[OpenAIConfigInst], source: dict[str, Any]) -> OpenAIConfigInst:
|
||||||
|
"""
|
||||||
|
Create OpenAIConfig from a dict.
|
||||||
|
"""
|
||||||
|
res = cls(
|
||||||
|
api_key=str(source['api_key']),
|
||||||
|
model=str(source['model']),
|
||||||
|
max_tokens=int(source['max_tokens']),
|
||||||
|
temperature=float(source['temperature']),
|
||||||
|
top_p=float(source['top_p']),
|
||||||
|
frequency_penalty=float(source['frequency_penalty']),
|
||||||
|
presence_penalty=float(source['presence_penalty']),
|
||||||
|
system=str(source['system'])
|
||||||
|
)
|
||||||
|
# overwrite default ID if provided
|
||||||
|
if 'ID' in source:
|
||||||
|
res.ID = source['ID']
|
||||||
|
return res
|
||||||
|
|
||||||
|
|
||||||
|
def ai_config_instance(name: str, conf_dict: Optional[dict[str, Any]] = None) -> AIConfig:
|
||||||
|
"""
|
||||||
|
Creates an AIConfig instance of the given name.
|
||||||
|
"""
|
||||||
|
if name.lower() == 'openai':
|
||||||
|
if conf_dict is None:
|
||||||
|
return OpenAIConfig()
|
||||||
|
else:
|
||||||
|
return OpenAIConfig.from_dict(conf_dict)
|
||||||
|
else:
|
||||||
|
raise ConfigError(f"Unknown AI '{name}'")
|
||||||
|
|
||||||
|
|
||||||
|
def create_default_ai_configs() -> dict[str, AIConfig]:
|
||||||
|
"""
|
||||||
|
Create a dict containing default configurations for all supported AIs.
|
||||||
|
"""
|
||||||
|
return {ai_config_instance(name).ID: ai_config_instance(name) for name in supported_ais}
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Config:
|
||||||
|
"""
|
||||||
|
The configuration file structure.
|
||||||
|
"""
|
||||||
|
# all members have default values, so we can easily create
|
||||||
|
# a default configuration
|
||||||
|
cache: str = '.'
|
||||||
|
db: str = './db/'
|
||||||
|
ais: dict[str, AIConfig] = field(default_factory=create_default_ai_configs)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_dict(cls: Type[ConfigInst], source: dict[str, Any]) -> ConfigInst:
|
||||||
|
"""
|
||||||
|
Create Config from a dict (with the same format as the config file).
|
||||||
|
"""
|
||||||
|
# create the correct AI type instances
|
||||||
|
ais: dict[str, AIConfig] = {}
|
||||||
|
for ID, conf in source['ais'].items():
|
||||||
|
# add the AI ID to the config (for easy internal access)
|
||||||
|
conf['ID'] = ID
|
||||||
|
ai_conf = ai_config_instance(conf['name'], conf)
|
||||||
|
ais[ID] = ai_conf
|
||||||
|
return cls(
|
||||||
|
cache=str(source['cache']) if 'cache' in source else '.',
|
||||||
|
db=str(source['db']),
|
||||||
|
ais=ais
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def create_default(self, file_path: Path) -> None:
|
||||||
|
"""
|
||||||
|
Creates a default Config in the given file.
|
||||||
|
"""
|
||||||
|
conf = Config()
|
||||||
|
conf.to_file(file_path)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_file(cls: Type[ConfigInst], path: str) -> ConfigInst:
|
||||||
|
with open(path, 'r') as f:
|
||||||
|
source = yaml.load(f, Loader=yaml.FullLoader)
|
||||||
|
return cls.from_dict(source)
|
||||||
|
|
||||||
|
def to_file(self, file_path: Path) -> None:
|
||||||
|
# remove the AI name from the config (for a cleaner format)
|
||||||
|
data = self.as_dict()
|
||||||
|
for conf in data['ais'].values():
|
||||||
|
del (conf['ID'])
|
||||||
|
with open(file_path, 'w') as f:
|
||||||
|
yaml.dump(data, f, sort_keys=False)
|
||||||
|
|
||||||
|
def as_dict(self) -> dict[str, Any]:
|
||||||
|
res = asdict(self)
|
||||||
|
# add the AI name manually (as first element)
|
||||||
|
# (not done by 'asdict' because it's a class variable)
|
||||||
|
for ID, conf in res['ais'].items():
|
||||||
|
res['ais'][ID] = {**{'name': self.ais[ID].name}, **conf}
|
||||||
|
return res
|
||||||
+128
-130
@@ -2,122 +2,140 @@
|
|||||||
# -*- coding: utf-8 -*-
|
# -*- coding: utf-8 -*-
|
||||||
# vim: set fileencoding=utf-8 :
|
# vim: set fileencoding=utf-8 :
|
||||||
|
|
||||||
import yaml
|
|
||||||
import sys
|
import sys
|
||||||
import argcomplete
|
import argcomplete
|
||||||
import argparse
|
import argparse
|
||||||
import pathlib
|
from pathlib import Path
|
||||||
from .utils import terminal_width, process_tags, display_chat, display_source_code, display_tags_frequency
|
from typing import Any
|
||||||
from .storage import save_answers, create_chat, get_tags, get_tags_unique, read_file, dump_data
|
from .configuration import Config, default_config_file
|
||||||
from .api_client import ai, openai_api_key, display_models
|
from .message import Message
|
||||||
from itertools import zip_longest
|
from .commands.question import question_cmd
|
||||||
|
from .commands.tags import tags_cmd
|
||||||
|
from .commands.config import config_cmd
|
||||||
|
from .commands.hist import hist_cmd
|
||||||
|
from .commands.print import print_cmd
|
||||||
|
from .chat import msg_location
|
||||||
|
|
||||||
|
|
||||||
def run_print_command(args: argparse.Namespace, config: dict) -> None:
|
def tags_completer(prefix: str, parsed_args: Any, **kwargs: Any) -> list[str]:
|
||||||
fname = pathlib.Path(args.print)
|
config = Config.from_file(parsed_args.config)
|
||||||
if fname.suffix == '.yaml':
|
return list(Message.tags_from_dir(Path(config.db), prefix=prefix))
|
||||||
with open(args.print, 'r') as f:
|
|
||||||
data = yaml.load(f, Loader=yaml.FullLoader)
|
|
||||||
elif fname.suffix == '.txt':
|
|
||||||
data = read_file(fname)
|
|
||||||
else:
|
|
||||||
print(f"Unknown file type: {args.print}")
|
|
||||||
sys.exit(1)
|
|
||||||
if args.only_source_code:
|
|
||||||
display_source_code(data['answer'])
|
|
||||||
else:
|
|
||||||
print(dump_data(data).strip())
|
|
||||||
|
|
||||||
|
|
||||||
def process_and_display_chat(args: argparse.Namespace,
|
|
||||||
config: dict,
|
|
||||||
dump: bool = False
|
|
||||||
) -> tuple[list[dict[str, str]], str, list[str]]:
|
|
||||||
tags = args.tags or []
|
|
||||||
extags = args.extags or []
|
|
||||||
otags = args.output_tags or []
|
|
||||||
|
|
||||||
if not args.only_source_code:
|
|
||||||
process_tags(tags, extags, otags)
|
|
||||||
|
|
||||||
question_parts = []
|
|
||||||
question_list = args.question if args.question is not None else []
|
|
||||||
source_list = args.source if args.source is not None else []
|
|
||||||
|
|
||||||
for question, source in zip_longest(question_list, source_list, fillvalue=None):
|
|
||||||
if question is not None and source is not None:
|
|
||||||
with open(source) as r:
|
|
||||||
question_parts.append(f"{question}\n\n```\n{r.read().strip()}\n```")
|
|
||||||
elif question is not None:
|
|
||||||
question_parts.append(question)
|
|
||||||
elif source is not None:
|
|
||||||
with open(source) as r:
|
|
||||||
question_parts.append(f"```\n{r.read().strip()}\n```")
|
|
||||||
|
|
||||||
full_question = '\n\n'.join(question_parts)
|
|
||||||
chat = create_chat(full_question, tags, extags, config,
|
|
||||||
args.match_all_tags, args.with_tags,
|
|
||||||
args.with_file)
|
|
||||||
display_chat(chat, dump, args.only_source_code)
|
|
||||||
return chat, full_question, tags
|
|
||||||
|
|
||||||
|
|
||||||
def process_and_display_tags(args: argparse.Namespace,
|
|
||||||
config: dict,
|
|
||||||
dump: bool = False
|
|
||||||
) -> None:
|
|
||||||
display_tags_frequency(get_tags(config, None), dump)
|
|
||||||
|
|
||||||
|
|
||||||
def handle_question(args: argparse.Namespace,
|
|
||||||
config: dict,
|
|
||||||
dump: bool = False
|
|
||||||
) -> None:
|
|
||||||
chat, question, tags = process_and_display_chat(args, config, dump)
|
|
||||||
otags = args.output_tags or []
|
|
||||||
answers, usage = ai(chat, config, args.number)
|
|
||||||
save_answers(question, answers, tags, otags, config)
|
|
||||||
print("-" * terminal_width())
|
|
||||||
print(f"Usage: {usage}")
|
|
||||||
|
|
||||||
|
|
||||||
def tags_completer(prefix, parsed_args, **kwargs):
|
|
||||||
with open(parsed_args.config, 'r') as f:
|
|
||||||
config = yaml.load(f, Loader=yaml.FullLoader)
|
|
||||||
return get_tags_unique(config, prefix)
|
|
||||||
|
|
||||||
|
|
||||||
def create_parser() -> argparse.ArgumentParser:
|
def create_parser() -> argparse.ArgumentParser:
|
||||||
default_config = '.config.yaml'
|
|
||||||
parser = argparse.ArgumentParser(
|
parser = argparse.ArgumentParser(
|
||||||
description="ChatMastermind is a Python application that automates conversation with AI")
|
description="ChatMastermind is a Python application that automates conversation with AI")
|
||||||
group = parser.add_mutually_exclusive_group(required=True)
|
parser.add_argument('-C', '--config', help='Config file name.', default=default_config_file)
|
||||||
group.add_argument('-p', '--print', help='File to print')
|
|
||||||
group.add_argument('-q', '--question', nargs='*', help='Question to ask')
|
# subcommand-parser
|
||||||
group.add_argument('-D', '--chat-dump', help="Print chat history as Python structure", action='store_true')
|
cmdparser = parser.add_subparsers(dest='command',
|
||||||
group.add_argument('-d', '--chat', help="Print chat history as readable text", action='store_true')
|
title='commands',
|
||||||
group.add_argument('-l', '--list-tags', help="List all tags and their frequency", action='store_true')
|
description='supported commands',
|
||||||
group.add_argument('-L', '--list-models', help="List all available models", action='store_true')
|
required=True)
|
||||||
parser.add_argument('-c', '--config', help='Config file name.', default=default_config)
|
|
||||||
parser.add_argument('-m', '--max-tokens', help='Max tokens to use', type=int)
|
# a parent parser for all commands that support tag selection
|
||||||
parser.add_argument('-T', '--temperature', help='Temperature to use', type=float)
|
tag_parser = argparse.ArgumentParser(add_help=False)
|
||||||
parser.add_argument('-M', '--model', help='Model to use')
|
tag_arg = tag_parser.add_argument('-t', '--or-tags', nargs='*',
|
||||||
parser.add_argument('-n', '--number', help='Number of answers to produce', type=int, default=1)
|
help='List of tags (one must match)', metavar='OTAGS')
|
||||||
parser.add_argument('-s', '--source', nargs='*', help='Source add content of a file to the query')
|
tag_arg.completer = tags_completer # type: ignore
|
||||||
parser.add_argument('-S', '--only-source-code', help='Print only source code', action='store_true')
|
atag_arg = tag_parser.add_argument('-k', '--and-tags', nargs='*',
|
||||||
parser.add_argument('-w', '--with-tags', help="Print chat history with tags.", action='store_true')
|
help='List of tags (all must match)', metavar='ATAGS')
|
||||||
parser.add_argument('-W', '--with-file',
|
atag_arg.completer = tags_completer # type: ignore
|
||||||
help="Print chat history with filename.",
|
etag_arg = tag_parser.add_argument('-x', '--exclude-tags', nargs='*',
|
||||||
|
help='List of tags to exclude', metavar='XTAGS')
|
||||||
|
etag_arg.completer = tags_completer # type: ignore
|
||||||
|
otag_arg = tag_parser.add_argument('-o', '--output-tags', nargs='+',
|
||||||
|
help='List of output tags (default: use input tags)', metavar='OUTAGS')
|
||||||
|
otag_arg.completer = tags_completer # type: ignore
|
||||||
|
|
||||||
|
# a parent parser for all commands that support AI configuration
|
||||||
|
ai_parser = argparse.ArgumentParser(add_help=False)
|
||||||
|
ai_parser.add_argument('-A', '--AI', help='AI ID to use', metavar='AI_ID')
|
||||||
|
ai_parser.add_argument('-M', '--model', help='Model to use', metavar='MODEL')
|
||||||
|
ai_parser.add_argument('-n', '--num-answers', help='Number of answers to request', type=int, default=1)
|
||||||
|
ai_parser.add_argument('-m', '--max-tokens', help='Max. nr. of tokens', type=int)
|
||||||
|
ai_parser.add_argument('-T', '--temperature', help='Temperature value', type=float)
|
||||||
|
|
||||||
|
# 'question' command parser
|
||||||
|
question_cmd_parser = cmdparser.add_parser('question', parents=[tag_parser, ai_parser],
|
||||||
|
help="ask, create and process questions.",
|
||||||
|
aliases=['q'])
|
||||||
|
question_cmd_parser.set_defaults(func=question_cmd)
|
||||||
|
question_group = question_cmd_parser.add_mutually_exclusive_group(required=True)
|
||||||
|
question_group.add_argument('-a', '--ask', nargs='+', help='Ask a question', metavar='QUESTION')
|
||||||
|
question_group.add_argument('-c', '--create', nargs='+', help='Create a question', metavar='QUESTION')
|
||||||
|
question_group.add_argument('-r', '--repeat', nargs='*', help='Repeat a question', metavar='MESSAGE')
|
||||||
|
question_group.add_argument('-p', '--process', nargs='*', help='Process existing questions', metavar='MESSAGE')
|
||||||
|
question_cmd_parser.add_argument('-l', '--location',
|
||||||
|
choices=[x.value for x in msg_location if x not in [msg_location.MEM, msg_location.DISK]],
|
||||||
|
default='db',
|
||||||
|
help='Use given location when building the chat history (default: \'db\')')
|
||||||
|
question_cmd_parser.add_argument('-g', '--glob', help='Filter message files using the given glob pattern')
|
||||||
|
question_cmd_parser.add_argument('-O', '--overwrite', help='Overwrite existing messages when repeating them',
|
||||||
action='store_true')
|
action='store_true')
|
||||||
parser.add_argument('-a', '--match-all-tags',
|
question_cmd_parser.add_argument('-s', '--source-text', nargs='+', help='Add content of a file to the query', metavar='FILE')
|
||||||
help="All given tags must match when selecting chat history entries.",
|
question_cmd_parser.add_argument('-S', '--source-code', nargs='+', help='Add source code file content to the chat history',
|
||||||
|
metavar='FILE')
|
||||||
|
|
||||||
|
# 'hist' command parser
|
||||||
|
hist_cmd_parser = cmdparser.add_parser('hist', parents=[tag_parser],
|
||||||
|
help="Print and manage chat history.",
|
||||||
|
aliases=['h'])
|
||||||
|
hist_cmd_parser.set_defaults(func=hist_cmd)
|
||||||
|
hist_group = hist_cmd_parser.add_mutually_exclusive_group(required=True)
|
||||||
|
hist_group.add_argument('-p', '--print', help='Print the DB chat history', action='store_true')
|
||||||
|
hist_group.add_argument('-c', '--convert', help='Convert all message files to the given format [txt|yaml]', metavar='FORMAT')
|
||||||
|
hist_cmd_parser.add_argument('-w', '--with-metadata', help="Print chat history with metadata (tags, filename, AI, etc.).",
|
||||||
action='store_true')
|
action='store_true')
|
||||||
tags_arg = parser.add_argument('-t', '--tags', nargs='*', help='List of tag names', metavar='TAGS')
|
hist_cmd_parser.add_argument('-S', '--source-code-only', help='Only print embedded source code',
|
||||||
tags_arg.completer = tags_completer # type: ignore
|
action='store_true')
|
||||||
extags_arg = parser.add_argument('-e', '--extags', nargs='*', help='List of tag names to exclude', metavar='EXTAGS')
|
hist_cmd_parser.add_argument('-A', '--answer', help='Print only answers with given substring', metavar='SUBSTRING')
|
||||||
extags_arg.completer = tags_completer # type: ignore
|
hist_cmd_parser.add_argument('-Q', '--question', help='Print only questions with given substring', metavar='SUBSTRING')
|
||||||
otags_arg = parser.add_argument('-o', '--output-tags', nargs='*', help='List of output tag names, default is input', metavar='OTAGS')
|
hist_cmd_parser.add_argument('-d', '--tight', help='Print without message separators', action='store_true')
|
||||||
otags_arg.completer = tags_completer # type: ignore
|
hist_cmd_parser.add_argument('-P', '--no-paging', help='Print without paging', action='store_true')
|
||||||
|
hist_cmd_parser.add_argument('-l', '--location',
|
||||||
|
choices=[x.value for x in msg_location if x not in [msg_location.MEM, msg_location.DISK]],
|
||||||
|
default='db',
|
||||||
|
help='Use given location when building the chat history (default: \'db\')')
|
||||||
|
hist_cmd_parser.add_argument('-g', '--glob', help='Filter message files using the given glob pattern')
|
||||||
|
|
||||||
|
# 'tags' command parser
|
||||||
|
tags_cmd_parser = cmdparser.add_parser('tags',
|
||||||
|
help="Manage tags.",
|
||||||
|
aliases=['t'])
|
||||||
|
tags_cmd_parser.set_defaults(func=tags_cmd)
|
||||||
|
tags_group = tags_cmd_parser.add_mutually_exclusive_group(required=True)
|
||||||
|
tags_group.add_argument('-l', '--list', help="List all tags and their frequency",
|
||||||
|
action='store_true')
|
||||||
|
tags_cmd_parser.add_argument('-p', '--prefix', help="Filter tags by prefix", metavar='PREFIX')
|
||||||
|
tags_cmd_parser.add_argument('-c', '--contain', help="Filter tags by contained substring", metavar='SUBSTRING')
|
||||||
|
|
||||||
|
# 'config' command parser
|
||||||
|
config_cmd_parser = cmdparser.add_parser('config',
|
||||||
|
help="Manage configuration",
|
||||||
|
aliases=['c'])
|
||||||
|
config_cmd_parser.set_defaults(func=config_cmd)
|
||||||
|
config_cmd_parser.add_argument('-A', '--AI', help='AI ID to use')
|
||||||
|
config_group = config_cmd_parser.add_mutually_exclusive_group(required=True)
|
||||||
|
config_group.add_argument('-l', '--list-models', help="List all available models",
|
||||||
|
action='store_true')
|
||||||
|
config_group.add_argument('-m', '--print-model', help="Print the currently configured model",
|
||||||
|
action='store_true')
|
||||||
|
config_group.add_argument('-c', '--create', help="Create config with default settings in the given file", metavar='FILE')
|
||||||
|
|
||||||
|
# 'print' command parser
|
||||||
|
print_cmd_parser = cmdparser.add_parser('print',
|
||||||
|
help="Print message files.",
|
||||||
|
aliases=['p'])
|
||||||
|
print_cmd_parser.set_defaults(func=print_cmd)
|
||||||
|
print_group = print_cmd_parser.add_mutually_exclusive_group(required=True)
|
||||||
|
print_group.add_argument('-f', '--file', help='Print given message file', metavar='FILE')
|
||||||
|
print_group.add_argument('-l', '--latest', help='Print latest message', action='store_true')
|
||||||
|
print_cmd_modes = print_cmd_parser.add_mutually_exclusive_group()
|
||||||
|
print_cmd_modes.add_argument('-q', '--question', help='Only print the question', action='store_true')
|
||||||
|
print_cmd_modes.add_argument('-a', '--answer', help='Only print the answer', action='store_true')
|
||||||
|
print_cmd_modes.add_argument('-S', '--only-source-code', help='Only print embedded source code', action='store_true')
|
||||||
|
|
||||||
argcomplete.autocomplete(parser)
|
argcomplete.autocomplete(parser)
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
@@ -125,33 +143,13 @@ def create_parser() -> argparse.ArgumentParser:
|
|||||||
def main() -> int:
|
def main() -> int:
|
||||||
parser = create_parser()
|
parser = create_parser()
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
command = parser.parse_args()
|
||||||
|
|
||||||
with open(args.config, 'r') as f:
|
if command.func == config_cmd:
|
||||||
config = yaml.load(f, Loader=yaml.FullLoader)
|
command.func(command)
|
||||||
|
else:
|
||||||
openai_api_key(config['openai']['api_key'])
|
config = Config.from_file(args.config)
|
||||||
|
command.func(command, config)
|
||||||
if args.max_tokens:
|
|
||||||
config['openai']['max_tokens'] = args.max_tokens
|
|
||||||
|
|
||||||
if args.temperature:
|
|
||||||
config['openai']['temperature'] = args.temperature
|
|
||||||
|
|
||||||
if args.model:
|
|
||||||
config['openai']['model'] = args.model
|
|
||||||
|
|
||||||
if args.print:
|
|
||||||
run_print_command(args, config)
|
|
||||||
elif args.question:
|
|
||||||
handle_question(args, config)
|
|
||||||
elif args.chat_dump:
|
|
||||||
process_and_display_chat(args, config, dump=True)
|
|
||||||
elif args.chat:
|
|
||||||
process_and_display_chat(args, config)
|
|
||||||
elif args.list_tags:
|
|
||||||
process_and_display_tags(args, config)
|
|
||||||
elif args.list_models:
|
|
||||||
display_models()
|
|
||||||
|
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
|
|||||||
@@ -0,0 +1,678 @@
|
|||||||
|
"""
|
||||||
|
Module implementing message related functions and classes.
|
||||||
|
"""
|
||||||
|
import pathlib
|
||||||
|
import yaml
|
||||||
|
import tempfile
|
||||||
|
import shutil
|
||||||
|
import io
|
||||||
|
from typing import Type, TypeVar, ClassVar, Optional, Any, Union, Final, Literal, Iterable, Tuple
|
||||||
|
from typing import Generator, Iterator
|
||||||
|
from typing import get_args as typing_get_args
|
||||||
|
from dataclasses import dataclass, asdict, field
|
||||||
|
from .tags import Tag, TagLine, TagError, match_tags, rename_tags
|
||||||
|
|
||||||
|
QuestionInst = TypeVar('QuestionInst', bound='Question')
|
||||||
|
AnswerInst = TypeVar('AnswerInst', bound='Answer')
|
||||||
|
MessageInst = TypeVar('MessageInst', bound='Message')
|
||||||
|
AILineInst = TypeVar('AILineInst', bound='AILine')
|
||||||
|
ModelLineInst = TypeVar('ModelLineInst', bound='ModelLine')
|
||||||
|
YamlDict = dict[str, Union[QuestionInst, AnswerInst, set[Tag]]]
|
||||||
|
MessageFormat = Literal['txt', 'yaml']
|
||||||
|
message_valid_formats: Final[Tuple[MessageFormat, ...]] = typing_get_args(MessageFormat)
|
||||||
|
message_default_format: Final[MessageFormat] = 'txt'
|
||||||
|
|
||||||
|
|
||||||
|
class MessageError(Exception):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def str_presenter(dumper: yaml.Dumper, data: str) -> yaml.ScalarNode:
|
||||||
|
"""
|
||||||
|
Changes the YAML dump style to multiline syntax for multiline strings.
|
||||||
|
"""
|
||||||
|
if len(data.splitlines()) > 1:
|
||||||
|
return dumper.represent_scalar('tag:yaml.org,2002:str', data, style='|')
|
||||||
|
return dumper.represent_scalar('tag:yaml.org,2002:str', data)
|
||||||
|
|
||||||
|
|
||||||
|
yaml.add_representer(str, str_presenter)
|
||||||
|
|
||||||
|
|
||||||
|
def source_code(text: str, include_delims: bool = False) -> list[str]:
|
||||||
|
"""
|
||||||
|
Extract all source code sections from the given text, i. e. all lines
|
||||||
|
surrounded by lines tarting with '```'. If 'include_delims' is True,
|
||||||
|
the surrounding lines are included, otherwise they are omitted. The
|
||||||
|
result list contains every source code section as a single string.
|
||||||
|
The order in the list represents the order of the sections in the text.
|
||||||
|
"""
|
||||||
|
code_sections: list[str] = []
|
||||||
|
code_lines: list[str] = []
|
||||||
|
in_code_block = False
|
||||||
|
|
||||||
|
for line in str(text).split('\n'):
|
||||||
|
if line.strip().startswith('```'):
|
||||||
|
if include_delims:
|
||||||
|
code_lines.append(line)
|
||||||
|
if in_code_block:
|
||||||
|
code_sections.append('\n'.join(code_lines) + '\n')
|
||||||
|
code_lines.clear()
|
||||||
|
in_code_block = not in_code_block
|
||||||
|
elif in_code_block:
|
||||||
|
code_lines.append(line)
|
||||||
|
|
||||||
|
return code_sections
|
||||||
|
|
||||||
|
|
||||||
|
def message_in(message: MessageInst, messages: Iterable[MessageInst]) -> bool:
|
||||||
|
"""
|
||||||
|
Searches the given message list for a message with the same file
|
||||||
|
name as the given one (i. e. it compares Message.file_path.name).
|
||||||
|
If the given message has no file_path, False is returned.
|
||||||
|
"""
|
||||||
|
if not message.file_path:
|
||||||
|
return False
|
||||||
|
for m in messages:
|
||||||
|
if m.file_path and m.file_path.name == message.file_path.name:
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(kw_only=True)
|
||||||
|
class MessageFilter:
|
||||||
|
"""
|
||||||
|
Various filters for a Message.
|
||||||
|
"""
|
||||||
|
tags_or: Optional[set[Tag]] = None
|
||||||
|
tags_and: Optional[set[Tag]] = None
|
||||||
|
tags_not: Optional[set[Tag]] = None
|
||||||
|
ai: Optional[str] = None
|
||||||
|
model: Optional[str] = None
|
||||||
|
question_contains: Optional[str] = None
|
||||||
|
answer_contains: Optional[str] = None
|
||||||
|
answer_state: Optional[Literal['available', 'missing']] = None
|
||||||
|
ai_state: Optional[Literal['available', 'missing']] = None
|
||||||
|
model_state: Optional[Literal['available', 'missing']] = None
|
||||||
|
|
||||||
|
|
||||||
|
class AILine(str):
|
||||||
|
"""
|
||||||
|
A line that represents the AI name in the 'txt' format.
|
||||||
|
"""
|
||||||
|
prefix: Final[str] = 'AI:'
|
||||||
|
|
||||||
|
def __new__(cls: Type[AILineInst], string: str) -> AILineInst:
|
||||||
|
if not string.startswith(cls.prefix):
|
||||||
|
raise MessageError(f"AILine '{string}' is missing prefix '{cls.prefix}'")
|
||||||
|
instance = super().__new__(cls, string)
|
||||||
|
return instance
|
||||||
|
|
||||||
|
def ai(self) -> str:
|
||||||
|
return self[len(self.prefix):].strip()
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_ai(cls: Type[AILineInst], ai: str) -> AILineInst:
|
||||||
|
return cls(' '.join([cls.prefix, ai]))
|
||||||
|
|
||||||
|
|
||||||
|
class ModelLine(str):
|
||||||
|
"""
|
||||||
|
A line that represents the model name in the 'txt' format.
|
||||||
|
"""
|
||||||
|
prefix: Final[str] = 'MODEL:'
|
||||||
|
|
||||||
|
def __new__(cls: Type[ModelLineInst], string: str) -> ModelLineInst:
|
||||||
|
if not string.startswith(cls.prefix):
|
||||||
|
raise MessageError(f"ModelLine '{string}' is missing prefix '{cls.prefix}'")
|
||||||
|
instance = super().__new__(cls, string)
|
||||||
|
return instance
|
||||||
|
|
||||||
|
def model(self) -> str:
|
||||||
|
return self[len(self.prefix):].strip()
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_model(cls: Type[ModelLineInst], model: str) -> ModelLineInst:
|
||||||
|
return cls(' '.join([cls.prefix, model]))
|
||||||
|
|
||||||
|
|
||||||
|
class Answer(str):
|
||||||
|
"""
|
||||||
|
A single answer with a defined header.
|
||||||
|
"""
|
||||||
|
tokens: int = 0 # tokens used by this answer
|
||||||
|
txt_header: ClassVar[str] = '==== ANSWER ===='
|
||||||
|
yaml_key: ClassVar[str] = 'answer'
|
||||||
|
|
||||||
|
def __init__(self, data: Union[str, Generator[str, None, None]]) -> None:
|
||||||
|
# Indicator of whether all of data has been processed
|
||||||
|
self.is_exhausted: bool = False
|
||||||
|
|
||||||
|
# Initialize data
|
||||||
|
self.iterator: Iterator[str] = self._init_data(data)
|
||||||
|
|
||||||
|
# Set up the buffer to hold the 'Answer' content
|
||||||
|
self.buffer: io.StringIO = io.StringIO()
|
||||||
|
|
||||||
|
def _init_data(self, data: Union[str, Generator[str, None, None]]) -> Iterator[str]:
|
||||||
|
"""
|
||||||
|
Process input data (either a string or a string generator)
|
||||||
|
"""
|
||||||
|
if isinstance(data, str):
|
||||||
|
yield data
|
||||||
|
else:
|
||||||
|
yield from data
|
||||||
|
|
||||||
|
def __str__(self) -> str:
|
||||||
|
"""
|
||||||
|
Output all content when converted into a string
|
||||||
|
"""
|
||||||
|
# Ensure all data has been processed
|
||||||
|
for _ in self:
|
||||||
|
pass
|
||||||
|
# Return the 'Answer' content
|
||||||
|
return self.buffer.getvalue()
|
||||||
|
|
||||||
|
def __repr__(self) -> str:
|
||||||
|
return repr(str(self))
|
||||||
|
|
||||||
|
def __iter__(self) -> Generator[str, None, None]:
|
||||||
|
"""
|
||||||
|
Allows the object to be iterable
|
||||||
|
"""
|
||||||
|
# Generate content if not all data has been processed
|
||||||
|
if not self.is_exhausted:
|
||||||
|
yield from self.generator_iter()
|
||||||
|
else:
|
||||||
|
yield self.buffer.getvalue()
|
||||||
|
|
||||||
|
def generator_iter(self) -> Generator[str, None, None]:
|
||||||
|
"""
|
||||||
|
Main generator method to process data
|
||||||
|
"""
|
||||||
|
for piece in self.iterator:
|
||||||
|
# Write to buffer and yield piece for the iterator
|
||||||
|
self.buffer.write(piece)
|
||||||
|
yield piece
|
||||||
|
self.is_exhausted = True # Set the flag that all data has been processed
|
||||||
|
# If the header occurs in the 'Answer' content, raise an error
|
||||||
|
if f'\n{self.txt_header}' in self.buffer.getvalue() or self.buffer.getvalue().startswith(self.txt_header):
|
||||||
|
raise MessageError(f"Answer {repr(self.buffer.getvalue())} contains the header {repr(Answer.txt_header)}")
|
||||||
|
|
||||||
|
def __eq__(self, other: object) -> bool:
|
||||||
|
"""
|
||||||
|
Comparing the object to a string or another object
|
||||||
|
"""
|
||||||
|
if isinstance(other, str):
|
||||||
|
return str(self) == other # Compare the string value of this object to the other string
|
||||||
|
# Default behavior for comparing non-string objects
|
||||||
|
return super().__eq__(other)
|
||||||
|
|
||||||
|
def __hash__(self) -> int:
|
||||||
|
"""
|
||||||
|
Generate a hash for the object based on its string representation.
|
||||||
|
"""
|
||||||
|
return hash(str(self))
|
||||||
|
|
||||||
|
def __format__(self, format_spec: str) -> str:
|
||||||
|
"""
|
||||||
|
Return a formatted version of the string as per the format specification.
|
||||||
|
"""
|
||||||
|
return str(self).__format__(format_spec)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_list(cls: Type[AnswerInst], strings: list[str]) -> AnswerInst:
|
||||||
|
"""
|
||||||
|
Build Answer from a list of strings. Make sure strings do not contain the header.
|
||||||
|
"""
|
||||||
|
def _gen() -> Generator[str, None, None]:
|
||||||
|
if len(strings) > 0:
|
||||||
|
yield strings[0]
|
||||||
|
for s in strings[1:]:
|
||||||
|
yield '\n'
|
||||||
|
yield s
|
||||||
|
return cls(_gen())
|
||||||
|
|
||||||
|
def source_code(self, include_delims: bool = False) -> list[str]:
|
||||||
|
"""
|
||||||
|
Extract and return all source code sections.
|
||||||
|
"""
|
||||||
|
return source_code(str(self), include_delims)
|
||||||
|
|
||||||
|
|
||||||
|
class Question(str):
|
||||||
|
"""
|
||||||
|
A single question with a defined header.
|
||||||
|
"""
|
||||||
|
tokens: int = 0 # tokens used by this question
|
||||||
|
txt_header: ClassVar[str] = '=== QUESTION ==='
|
||||||
|
yaml_key: ClassVar[str] = 'question'
|
||||||
|
|
||||||
|
def __new__(cls: Type[QuestionInst], string: str) -> QuestionInst:
|
||||||
|
"""
|
||||||
|
Make sure the question string does not contain the header as a whole line
|
||||||
|
(also not that from 'Answer', so it's always clear where the answer starts).
|
||||||
|
"""
|
||||||
|
string_lines = string.split('\n')
|
||||||
|
if cls.txt_header in string_lines:
|
||||||
|
raise MessageError(f"Question '{string}' contains the header '{cls.txt_header}'")
|
||||||
|
if Answer.txt_header in string_lines:
|
||||||
|
raise MessageError(f"Question '{string}' contains the header '{Answer.txt_header}'")
|
||||||
|
instance = super().__new__(cls, string)
|
||||||
|
return instance
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_list(cls: Type[QuestionInst], strings: list[str]) -> QuestionInst:
|
||||||
|
"""
|
||||||
|
Build Question from a list of strings. Make sure strings do not contain the header.
|
||||||
|
"""
|
||||||
|
if cls.txt_header in strings:
|
||||||
|
raise MessageError(f"Question contains the header '{cls.txt_header}'")
|
||||||
|
instance = super().__new__(cls, '\n'.join(strings).strip())
|
||||||
|
return instance
|
||||||
|
|
||||||
|
def source_code(self, include_delims: bool = False) -> list[str]:
|
||||||
|
"""
|
||||||
|
Extract and return all source code sections.
|
||||||
|
"""
|
||||||
|
return source_code(self, include_delims)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Message():
|
||||||
|
"""
|
||||||
|
Single message. Consists of a question and optionally an answer, a set of tags
|
||||||
|
and a file path.
|
||||||
|
"""
|
||||||
|
question: Question
|
||||||
|
answer: Optional[Answer] = None
|
||||||
|
# metadata, ignored when comparing messages
|
||||||
|
tags: Optional[set[Tag]] = field(default=None, compare=False)
|
||||||
|
ai: Optional[str] = field(default=None, compare=False)
|
||||||
|
model: Optional[str] = field(default=None, compare=False)
|
||||||
|
file_path: Optional[pathlib.Path] = field(default=None, compare=False)
|
||||||
|
# class variables
|
||||||
|
file_suffixes_read: ClassVar[list[str]] = ['.msg', '.txt', '.yaml']
|
||||||
|
file_suffix_write: ClassVar[str] = '.msg'
|
||||||
|
default_format: ClassVar[MessageFormat] = message_default_format
|
||||||
|
tags_yaml_key: ClassVar[str] = 'tags'
|
||||||
|
file_yaml_key: ClassVar[str] = 'file_path'
|
||||||
|
ai_yaml_key: ClassVar[str] = 'ai'
|
||||||
|
model_yaml_key: ClassVar[str] = 'model'
|
||||||
|
|
||||||
|
def __post_init__(self) -> None:
|
||||||
|
# convert some types that are often set wrong
|
||||||
|
if self.tags is not None and not isinstance(self.tags, set):
|
||||||
|
self.tags = set(self.tags)
|
||||||
|
if self.file_path is not None and not isinstance(self.file_path, pathlib.Path):
|
||||||
|
self.file_path = pathlib.Path(self.file_path)
|
||||||
|
|
||||||
|
def __hash__(self) -> int:
|
||||||
|
"""
|
||||||
|
The hash value is computed based on immutable members.
|
||||||
|
"""
|
||||||
|
return hash((self.question, self.answer))
|
||||||
|
|
||||||
|
def equals(self, other: MessageInst, tags: bool = True, ai: bool = True,
|
||||||
|
model: bool = True, file_path: bool = True, verbose: bool = False) -> bool:
|
||||||
|
"""
|
||||||
|
Compare this message with another one, including the metadata.
|
||||||
|
Return True if everything is identical, False otherwise.
|
||||||
|
"""
|
||||||
|
equal: bool = ((not tags or (self.tags == other.tags))
|
||||||
|
and (not ai or (self.ai == other.ai)) # noqa: W503
|
||||||
|
and (not model or (self.model == other.model)) # noqa: W503
|
||||||
|
and (not file_path or (self.file_path == other.file_path)) # noqa: W503
|
||||||
|
and (self == other)) # noqa: W503
|
||||||
|
if not equal and verbose:
|
||||||
|
print("Messages not equal:")
|
||||||
|
print(self)
|
||||||
|
print(other)
|
||||||
|
return equal
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_dict(cls: Type[MessageInst], data: dict[str, Any]) -> MessageInst:
|
||||||
|
"""
|
||||||
|
Create a Message from the given dict.
|
||||||
|
"""
|
||||||
|
return cls(question=data[Question.yaml_key],
|
||||||
|
answer=data.get(Answer.yaml_key, None),
|
||||||
|
tags=set(data.get(cls.tags_yaml_key, [])),
|
||||||
|
ai=data.get(cls.ai_yaml_key, None),
|
||||||
|
model=data.get(cls.model_yaml_key, None),
|
||||||
|
file_path=data.get(cls.file_yaml_key, None))
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def tags_from_file(cls: Type[MessageInst],
|
||||||
|
file_path: pathlib.Path,
|
||||||
|
prefix: Optional[str] = None,
|
||||||
|
contain: Optional[str] = None) -> set[Tag]:
|
||||||
|
"""
|
||||||
|
Return only the tags from the given Message file,
|
||||||
|
optionally filtered based on prefix or contained string.
|
||||||
|
"""
|
||||||
|
tags: set[Tag] = set()
|
||||||
|
if not file_path.exists():
|
||||||
|
raise MessageError(f"Message file '{file_path}' does not exist")
|
||||||
|
if file_path.suffix not in cls.file_suffixes_read:
|
||||||
|
raise MessageError(f"File type '{file_path.suffix}' is not supported")
|
||||||
|
try:
|
||||||
|
message = cls.from_file(file_path)
|
||||||
|
if message:
|
||||||
|
msg_tags = message.filter_tags(prefix=prefix, contain=contain)
|
||||||
|
except MessageError as e:
|
||||||
|
print(f"Error processing message in '{file_path}': {str(e)}")
|
||||||
|
if msg_tags:
|
||||||
|
tags = msg_tags
|
||||||
|
return tags
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def tags_from_dir(cls: Type[MessageInst],
|
||||||
|
path: pathlib.Path,
|
||||||
|
glob: Optional[str] = None,
|
||||||
|
prefix: Optional[str] = None,
|
||||||
|
contain: Optional[str] = None) -> set[Tag]:
|
||||||
|
|
||||||
|
"""
|
||||||
|
Return only the tags from message files in the given directory.
|
||||||
|
The files can be filtered using 'glob', the tags by using 'prefix'
|
||||||
|
and 'contain'.
|
||||||
|
"""
|
||||||
|
tags: set[Tag] = set()
|
||||||
|
file_iter = path.glob(glob) if glob else path.iterdir()
|
||||||
|
for file_path in sorted(file_iter):
|
||||||
|
if file_path.is_file():
|
||||||
|
try:
|
||||||
|
tags |= cls.tags_from_file(file_path, prefix, contain)
|
||||||
|
except MessageError as e:
|
||||||
|
print(f"Error processing message in '{file_path}': {str(e)}")
|
||||||
|
return tags
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_file(cls: Type[MessageInst], file_path: pathlib.Path,
|
||||||
|
mfilter: Optional[MessageFilter] = None) -> Optional[MessageInst]:
|
||||||
|
"""
|
||||||
|
Create a Message from the given file. Returns 'None' if the message does
|
||||||
|
not fulfill the filter requirements. For TXT files, the tags are matched
|
||||||
|
before building the whole message. The other filters are applied afterwards.
|
||||||
|
"""
|
||||||
|
if not file_path.exists():
|
||||||
|
raise MessageError(f"Message file '{file_path}' does not exist")
|
||||||
|
if file_path.suffix not in cls.file_suffixes_read:
|
||||||
|
raise MessageError(f"File type '{file_path.suffix}' is not supported")
|
||||||
|
# try TXT first
|
||||||
|
try:
|
||||||
|
message = cls.__from_file_txt(file_path,
|
||||||
|
mfilter.tags_or if mfilter else None,
|
||||||
|
mfilter.tags_and if mfilter else None,
|
||||||
|
mfilter.tags_not if mfilter else None)
|
||||||
|
# then YAML
|
||||||
|
except MessageError:
|
||||||
|
message = cls.__from_file_yaml(file_path)
|
||||||
|
if message and (mfilter is None or message.match(mfilter)):
|
||||||
|
return message
|
||||||
|
else:
|
||||||
|
return None
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def __from_file_txt(cls: Type[MessageInst], file_path: pathlib.Path, # noqa: 11
|
||||||
|
tags_or: Optional[set[Tag]] = None,
|
||||||
|
tags_and: Optional[set[Tag]] = None,
|
||||||
|
tags_not: Optional[set[Tag]] = None) -> Optional[MessageInst]:
|
||||||
|
"""
|
||||||
|
Create a Message from the given TXT file. Expects the following file structures:
|
||||||
|
For '.txt':
|
||||||
|
* TagLine [Optional]
|
||||||
|
* AI [Optional]
|
||||||
|
* Model [Optional]
|
||||||
|
* Question.txt_header
|
||||||
|
* Question
|
||||||
|
* Answer.txt_header [Optional]
|
||||||
|
* Answer [Optional]
|
||||||
|
|
||||||
|
Returns 'None' if the message does not fulfill the tag requirements.
|
||||||
|
"""
|
||||||
|
tags: set[Tag] = set()
|
||||||
|
question: Question
|
||||||
|
answer: Optional[Answer] = None
|
||||||
|
ai: Optional[str] = None
|
||||||
|
model: Optional[str] = None
|
||||||
|
with open(file_path, "r") as fd:
|
||||||
|
# TagLine (Optional)
|
||||||
|
try:
|
||||||
|
pos = fd.tell()
|
||||||
|
tags = TagLine(fd.readline()).tags()
|
||||||
|
except TagError:
|
||||||
|
fd.seek(pos)
|
||||||
|
# AILine (Optional)
|
||||||
|
try:
|
||||||
|
pos = fd.tell()
|
||||||
|
ai = AILine(fd.readline()).ai()
|
||||||
|
except MessageError:
|
||||||
|
fd.seek(pos)
|
||||||
|
# ModelLine (Optional)
|
||||||
|
try:
|
||||||
|
pos = fd.tell()
|
||||||
|
model = ModelLine(fd.readline()).model()
|
||||||
|
except MessageError:
|
||||||
|
fd.seek(pos)
|
||||||
|
# Question and Answer
|
||||||
|
text = fd.read().strip().split('\n')
|
||||||
|
try:
|
||||||
|
question_idx = text.index(Question.txt_header) + 1
|
||||||
|
except ValueError:
|
||||||
|
raise MessageError(f"'{file_path}' does not contain a valid message")
|
||||||
|
try:
|
||||||
|
answer_idx = text.index(Answer.txt_header)
|
||||||
|
question = Question.from_list(text[question_idx:answer_idx])
|
||||||
|
answer = Answer.from_list(text[answer_idx + 1:])
|
||||||
|
except ValueError:
|
||||||
|
question = Question.from_list(text[question_idx:])
|
||||||
|
# match tags AFTER reading the whole file
|
||||||
|
# -> make sure it's a valid 'txt' file format
|
||||||
|
if tags_or or tags_and or tags_not:
|
||||||
|
# match with an empty set if the file has no tags
|
||||||
|
if not match_tags(tags, tags_or, tags_and, tags_not):
|
||||||
|
return None
|
||||||
|
return cls(question, answer, tags, ai, model, file_path)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def __from_file_yaml(cls: Type[MessageInst], file_path: pathlib.Path) -> MessageInst:
|
||||||
|
"""
|
||||||
|
Create a Message from the given YAML file. Expects the following file structures:
|
||||||
|
* Question.yaml_key: single or multiline string
|
||||||
|
* Answer.yaml_key: single or multiline string [Optional]
|
||||||
|
* Message.tags_yaml_key: list of strings [Optional]
|
||||||
|
* Message.ai_yaml_key: str [Optional]
|
||||||
|
* Message.model_yaml_key: str [Optional]
|
||||||
|
"""
|
||||||
|
with open(file_path, "r") as fd:
|
||||||
|
try:
|
||||||
|
data = yaml.load(fd, Loader=yaml.FullLoader)
|
||||||
|
data[cls.file_yaml_key] = file_path
|
||||||
|
return cls.from_dict(data)
|
||||||
|
except Exception:
|
||||||
|
raise MessageError(f"'{file_path}' does not contain a valid message")
|
||||||
|
|
||||||
|
def to_str(self, with_metadata: bool = False, source_code_only: bool = False) -> str:
|
||||||
|
"""
|
||||||
|
Return the current Message as a string.
|
||||||
|
"""
|
||||||
|
output: list[str] = []
|
||||||
|
if source_code_only:
|
||||||
|
# use the source code from answer only
|
||||||
|
if self.answer:
|
||||||
|
output.extend(self.answer.source_code(include_delims=True))
|
||||||
|
return '\n'.join(output) if len(output) > 0 else ''
|
||||||
|
if with_metadata:
|
||||||
|
output.append(self.tags_str())
|
||||||
|
output.append('FILE: ' + str(self.file_path))
|
||||||
|
output.append('AI: ' + str(self.ai))
|
||||||
|
output.append('MODEL: ' + str(self.model))
|
||||||
|
output.append(Question.txt_header)
|
||||||
|
output.append(self.question)
|
||||||
|
if self.answer:
|
||||||
|
output.append(Answer.txt_header)
|
||||||
|
output.append(str(self.answer))
|
||||||
|
return '\n'.join(output)
|
||||||
|
|
||||||
|
def to_file(self, file_path: Optional[pathlib.Path]=None, mformat: MessageFormat = message_default_format) -> None: # noqa: 11
|
||||||
|
"""
|
||||||
|
Write a Message to the given file. Supported message file formats are 'txt' and 'yaml'.
|
||||||
|
Suffix is always '.msg'.
|
||||||
|
"""
|
||||||
|
if file_path:
|
||||||
|
self.file_path = file_path
|
||||||
|
if not self.file_path:
|
||||||
|
raise MessageError("Got no valid path to write message")
|
||||||
|
if mformat not in message_valid_formats:
|
||||||
|
raise MessageError(f"File format '{mformat}' is not supported")
|
||||||
|
# check for valid suffix
|
||||||
|
# -> add one if it's empty
|
||||||
|
# -> refuse old or otherwise unsupported suffixes
|
||||||
|
if not self.file_path.suffix:
|
||||||
|
self.file_path = self.file_path.with_suffix(self.file_suffix_write)
|
||||||
|
elif self.file_path.suffix != self.file_suffix_write:
|
||||||
|
raise MessageError(f"File suffix '{self.file_path.suffix}' is not supported")
|
||||||
|
# TXT
|
||||||
|
if mformat == 'txt':
|
||||||
|
return self.__to_file_txt(self.file_path)
|
||||||
|
# YAML
|
||||||
|
elif mformat == 'yaml':
|
||||||
|
return self.__to_file_yaml(self.file_path)
|
||||||
|
|
||||||
|
def __to_file_txt(self, file_path: pathlib.Path) -> None:
|
||||||
|
"""
|
||||||
|
Write a Message to the given file in TXT format.
|
||||||
|
Creates the following file structures:
|
||||||
|
* TagLine
|
||||||
|
* AI [Optional]
|
||||||
|
* Model [Optional]
|
||||||
|
* Question.txt_header
|
||||||
|
* Question
|
||||||
|
* Answer.txt_header [Optional]
|
||||||
|
* Answer [Optional]
|
||||||
|
"""
|
||||||
|
with tempfile.NamedTemporaryFile(dir=file_path.parent, prefix=file_path.name, mode="w", delete=False) as temp_fd:
|
||||||
|
temp_file_path = pathlib.Path(temp_fd.name)
|
||||||
|
if self.tags:
|
||||||
|
temp_fd.write(f'{TagLine.from_set(self.tags)}\n')
|
||||||
|
if self.ai:
|
||||||
|
temp_fd.write(f'{AILine.from_ai(self.ai)}\n')
|
||||||
|
if self.model:
|
||||||
|
temp_fd.write(f'{ModelLine.from_model(self.model)}\n')
|
||||||
|
temp_fd.write(f'{Question.txt_header}\n{self.question}\n')
|
||||||
|
if self.answer:
|
||||||
|
temp_fd.write(f'{Answer.txt_header}\n{str(self.answer)}\n')
|
||||||
|
shutil.move(temp_file_path, file_path)
|
||||||
|
|
||||||
|
def __to_file_yaml(self, file_path: pathlib.Path) -> None:
|
||||||
|
"""
|
||||||
|
Write a Message to the given file in YAML format.
|
||||||
|
Creates the following file structures:
|
||||||
|
* Question.yaml_key: single or multiline string
|
||||||
|
* Answer.yaml_key: single or multiline string
|
||||||
|
* Message.tags_yaml_key: list of strings
|
||||||
|
* Message.ai_yaml_key: str [Optional]
|
||||||
|
* Message.model_yaml_key: str [Optional]
|
||||||
|
"""
|
||||||
|
with tempfile.NamedTemporaryFile(dir=file_path.parent, prefix=file_path.name, mode="w", delete=False) as temp_fd:
|
||||||
|
temp_file_path = pathlib.Path(temp_fd.name)
|
||||||
|
data: YamlDict = {Question.yaml_key: str(self.question)}
|
||||||
|
if self.answer:
|
||||||
|
data[Answer.yaml_key] = str(self.answer)
|
||||||
|
if self.ai:
|
||||||
|
data[self.ai_yaml_key] = self.ai
|
||||||
|
if self.model:
|
||||||
|
data[self.model_yaml_key] = self.model
|
||||||
|
if self.tags:
|
||||||
|
data[self.tags_yaml_key] = sorted([str(tag) for tag in self.tags])
|
||||||
|
yaml.dump(data, temp_fd, sort_keys=False)
|
||||||
|
shutil.move(temp_file_path, file_path)
|
||||||
|
|
||||||
|
def rm_file(self) -> None:
|
||||||
|
"""
|
||||||
|
Delete the message file. Ignore empty file_path and not existing files.
|
||||||
|
"""
|
||||||
|
if self.file_path is not None:
|
||||||
|
self.file_path.unlink(missing_ok=True)
|
||||||
|
|
||||||
|
def filter_tags(self, prefix: Optional[str] = None, contain: Optional[str] = None) -> set[Tag]:
|
||||||
|
"""
|
||||||
|
Filter tags based on their prefix (i. e. the tag starts with a given string)
|
||||||
|
or some contained string.
|
||||||
|
"""
|
||||||
|
if not self.tags:
|
||||||
|
return set()
|
||||||
|
res_tags = self.tags.copy()
|
||||||
|
if prefix and len(prefix) > 0:
|
||||||
|
res_tags -= {tag for tag in res_tags if not tag.startswith(prefix)}
|
||||||
|
if contain and len(contain) > 0:
|
||||||
|
res_tags -= {tag for tag in res_tags if contain not in tag}
|
||||||
|
return res_tags
|
||||||
|
|
||||||
|
def tags_str(self, prefix: Optional[str] = None, contain: Optional[str] = None) -> str:
|
||||||
|
"""
|
||||||
|
Returns all tags as a string with the TagLine prefix. Optionally filtered
|
||||||
|
using 'Message.filter_tags()'.
|
||||||
|
"""
|
||||||
|
if self.tags:
|
||||||
|
return str(TagLine.from_set(self.filter_tags(prefix, contain)))
|
||||||
|
else:
|
||||||
|
return str(TagLine.from_set(set()))
|
||||||
|
|
||||||
|
def match(self, mfilter: MessageFilter) -> bool: # noqa: 13
|
||||||
|
"""
|
||||||
|
Matches the current Message to the given filter atttributes.
|
||||||
|
Return True if all attributes match, else False.
|
||||||
|
"""
|
||||||
|
mytags = self.tags or set()
|
||||||
|
if (((mfilter.tags_or is not None or mfilter.tags_and is not None or mfilter.tags_not is not None)
|
||||||
|
and not match_tags(mytags, mfilter.tags_or, mfilter.tags_and, mfilter.tags_not)) # noqa: W503
|
||||||
|
or (mfilter.ai and (not self.ai or mfilter.ai != self.ai)) # noqa: W503
|
||||||
|
or (mfilter.model and (not self.model or mfilter.model != self.model)) # noqa: W503
|
||||||
|
or (mfilter.question_contains and mfilter.question_contains not in self.question) # noqa: W503
|
||||||
|
or (mfilter.answer_contains and (not self.answer or mfilter.answer_contains not in str(self.answer))) # noqa: W503
|
||||||
|
or (mfilter.answer_state == 'available' and not self.answer) # noqa: W503
|
||||||
|
or (mfilter.ai_state == 'available' and not self.ai) # noqa: W503
|
||||||
|
or (mfilter.model_state == 'available' and not self.model) # noqa: W503
|
||||||
|
or (mfilter.answer_state == 'missing' and self.answer) # noqa: W503
|
||||||
|
or (mfilter.ai_state == 'missing' and self.ai) # noqa: W503
|
||||||
|
or (mfilter.model_state == 'missing' and self.model)): # noqa: W503
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
def rename_tags(self, tags_rename: set[tuple[Tag, Tag]]) -> None:
|
||||||
|
"""
|
||||||
|
Renames the given tags. The first tuple element is the old name,
|
||||||
|
the second one is the new name.
|
||||||
|
"""
|
||||||
|
if self.tags:
|
||||||
|
self.tags = rename_tags(self.tags, tags_rename)
|
||||||
|
|
||||||
|
def clear_answer(self) -> None:
|
||||||
|
self.answer = None
|
||||||
|
|
||||||
|
def msg_id(self) -> str:
|
||||||
|
"""
|
||||||
|
Returns an ID that is unique throughout all messages in the same (DB) directory.
|
||||||
|
Currently this is the file name without suffix. The ID is also used for sorting
|
||||||
|
messages.
|
||||||
|
"""
|
||||||
|
if self.file_path:
|
||||||
|
return self.file_path.stem
|
||||||
|
else:
|
||||||
|
raise MessageError("Can't create file ID without a file path")
|
||||||
|
|
||||||
|
def as_dict(self) -> dict[str, Any]:
|
||||||
|
return asdict(self)
|
||||||
|
|
||||||
|
def tokens(self) -> int:
|
||||||
|
"""
|
||||||
|
Returns the nr. of AI language tokens used by this message.
|
||||||
|
If unknown, 0 is returned.
|
||||||
|
"""
|
||||||
|
if self.answer:
|
||||||
|
return self.question.tokens + self.answer.tokens
|
||||||
|
else:
|
||||||
|
return self.question.tokens
|
||||||
@@ -1,117 +0,0 @@
|
|||||||
import yaml
|
|
||||||
import io
|
|
||||||
import pathlib
|
|
||||||
from .utils import terminal_width, append_message, message_to_chat
|
|
||||||
from typing import List, Dict, Any, Optional
|
|
||||||
|
|
||||||
|
|
||||||
def read_file(fname: pathlib.Path, tags_only: bool = False) -> Dict[str, Any]:
|
|
||||||
with open(fname, "r") as fd:
|
|
||||||
if tags_only:
|
|
||||||
return {"tags": [x.strip() for x in fd.readline().strip().split(':')[1].strip().split(',')]}
|
|
||||||
text = fd.read().strip().split('\n')
|
|
||||||
tags = [x.strip() for x in text.pop(0).split(':')[1].strip().split(',')]
|
|
||||||
question_idx = text.index("=== QUESTION ===") + 1
|
|
||||||
answer_idx = text.index("==== ANSWER ====")
|
|
||||||
question = "\n".join(text[question_idx:answer_idx]).strip()
|
|
||||||
answer = "\n".join(text[answer_idx + 1:]).strip()
|
|
||||||
return {"question": question, "answer": answer, "tags": tags,
|
|
||||||
"file": fname.name}
|
|
||||||
|
|
||||||
|
|
||||||
def dump_data(data: Dict[str, Any]) -> str:
|
|
||||||
with io.StringIO() as fd:
|
|
||||||
fd.write(f'TAGS: {", ".join(data["tags"])}\n')
|
|
||||||
fd.write(f'=== QUESTION ===\n{data["question"]}\n')
|
|
||||||
fd.write(f'==== ANSWER ====\n{data["answer"]}\n')
|
|
||||||
return fd.getvalue()
|
|
||||||
|
|
||||||
|
|
||||||
def write_file(fname: str, data: Dict[str, Any]) -> None:
|
|
||||||
with open(fname, "w") as fd:
|
|
||||||
fd.write(f'TAGS: {", ".join(data["tags"])}\n')
|
|
||||||
fd.write(f'=== QUESTION ===\n{data["question"]}\n')
|
|
||||||
fd.write(f'==== ANSWER ====\n{data["answer"]}\n')
|
|
||||||
|
|
||||||
|
|
||||||
def save_answers(question: str,
|
|
||||||
answers: list[str],
|
|
||||||
tags: list[str],
|
|
||||||
otags: Optional[list[str]],
|
|
||||||
config: Dict[str, Any]
|
|
||||||
) -> None:
|
|
||||||
wtags = otags or tags
|
|
||||||
num, inum = 0, 0
|
|
||||||
next_fname = pathlib.Path(config['db']) / '.next'
|
|
||||||
try:
|
|
||||||
with open(next_fname, 'r') as f:
|
|
||||||
num = int(f.read())
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
for answer in answers:
|
|
||||||
num += 1
|
|
||||||
inum += 1
|
|
||||||
title = f'-- ANSWER {inum} '
|
|
||||||
title_end = '-' * (terminal_width() - len(title))
|
|
||||||
print(f'{title}{title_end}')
|
|
||||||
print(answer)
|
|
||||||
write_file(f"{num:04d}.txt", {"question": question, "answer": answer, "tags": wtags})
|
|
||||||
with open(next_fname, 'w') as f:
|
|
||||||
f.write(f'{num}')
|
|
||||||
|
|
||||||
|
|
||||||
def create_chat(question: Optional[str],
|
|
||||||
tags: Optional[List[str]],
|
|
||||||
extags: Optional[List[str]],
|
|
||||||
config: Dict[str, Any],
|
|
||||||
match_all_tags: bool = False,
|
|
||||||
with_tags: bool = False,
|
|
||||||
with_file: bool = False
|
|
||||||
) -> List[Dict[str, str]]:
|
|
||||||
chat: List[Dict[str, str]] = []
|
|
||||||
append_message(chat, 'system', config['system'].strip())
|
|
||||||
for file in sorted(pathlib.Path(config['db']).iterdir()):
|
|
||||||
if file.suffix == '.yaml':
|
|
||||||
with open(file, 'r') as f:
|
|
||||||
data = yaml.load(f, Loader=yaml.FullLoader)
|
|
||||||
data['file'] = file.name
|
|
||||||
elif file.suffix == '.txt':
|
|
||||||
data = read_file(file)
|
|
||||||
else:
|
|
||||||
continue
|
|
||||||
data_tags = set(data.get('tags', []))
|
|
||||||
tags_match: bool
|
|
||||||
if match_all_tags:
|
|
||||||
tags_match = not tags or set(tags).issubset(data_tags)
|
|
||||||
else:
|
|
||||||
tags_match = not tags or bool(data_tags.intersection(tags))
|
|
||||||
extags_do_not_match = \
|
|
||||||
not extags or not data_tags.intersection(extags)
|
|
||||||
if tags_match and extags_do_not_match:
|
|
||||||
message_to_chat(data, chat, with_tags, with_file)
|
|
||||||
if question:
|
|
||||||
append_message(chat, 'user', question)
|
|
||||||
return chat
|
|
||||||
|
|
||||||
|
|
||||||
def get_tags(config: Dict[str, Any], prefix: Optional[str]) -> List[str]:
|
|
||||||
result = []
|
|
||||||
for file in sorted(pathlib.Path(config['db']).iterdir()):
|
|
||||||
if file.suffix == '.yaml':
|
|
||||||
with open(file, 'r') as f:
|
|
||||||
data = yaml.load(f, Loader=yaml.FullLoader)
|
|
||||||
elif file.suffix == '.txt':
|
|
||||||
data = read_file(file, tags_only=True)
|
|
||||||
else:
|
|
||||||
continue
|
|
||||||
for tag in data.get('tags', []):
|
|
||||||
if prefix and len(prefix) > 0:
|
|
||||||
if tag.startswith(prefix):
|
|
||||||
result.append(tag)
|
|
||||||
else:
|
|
||||||
result.append(tag)
|
|
||||||
return result
|
|
||||||
|
|
||||||
|
|
||||||
def get_tags_unique(config: Dict[str, Any], prefix: Optional[str]) -> List[str]:
|
|
||||||
return list(set(get_tags(config, prefix)))
|
|
||||||
@@ -0,0 +1,184 @@
|
|||||||
|
"""
|
||||||
|
Module implementing tag related functions and classes.
|
||||||
|
"""
|
||||||
|
from typing import Type, TypeVar, Optional, Final
|
||||||
|
|
||||||
|
TagInst = TypeVar('TagInst', bound='Tag')
|
||||||
|
TagLineInst = TypeVar('TagLineInst', bound='TagLine')
|
||||||
|
|
||||||
|
|
||||||
|
class TagError(Exception):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class Tag(str):
|
||||||
|
"""
|
||||||
|
A single tag. A string that can contain anything but the default separator (' ').
|
||||||
|
"""
|
||||||
|
# default separator
|
||||||
|
default_separator: Final[str] = ' '
|
||||||
|
# alternative separators (e. g. for backwards compatibility)
|
||||||
|
alternative_separators: Final[list[str]] = [',']
|
||||||
|
|
||||||
|
def __new__(cls: Type[TagInst], string: str) -> TagInst:
|
||||||
|
"""
|
||||||
|
Make sure the tag string does not contain the default separator.
|
||||||
|
"""
|
||||||
|
if cls.default_separator in string:
|
||||||
|
raise TagError(f"Tag '{string}' contains the separator char '{cls.default_separator}'")
|
||||||
|
instance = super().__new__(cls, string)
|
||||||
|
return instance
|
||||||
|
|
||||||
|
|
||||||
|
def delete_tags(tags: set[Tag], tags_delete: set[Tag]) -> set[Tag]:
|
||||||
|
"""
|
||||||
|
Deletes the given tags and returns a new set.
|
||||||
|
"""
|
||||||
|
return tags.difference(tags_delete)
|
||||||
|
|
||||||
|
|
||||||
|
def add_tags(tags: set[Tag], tags_add: set[Tag]) -> set[Tag]:
|
||||||
|
"""
|
||||||
|
Adds the given tags and returns a new set.
|
||||||
|
"""
|
||||||
|
return set(sorted(tags | tags_add))
|
||||||
|
|
||||||
|
|
||||||
|
def merge_tags(tags: set[Tag], tags_merge: list[set[Tag]]) -> set[Tag]:
|
||||||
|
"""
|
||||||
|
Merges the tags in 'tags_merge' into the current one and returns a new set.
|
||||||
|
"""
|
||||||
|
for ts in tags_merge:
|
||||||
|
tags |= ts
|
||||||
|
return tags
|
||||||
|
|
||||||
|
|
||||||
|
def rename_tags(tags: set[Tag], tags_rename: set[tuple[Tag, Tag]]) -> set[Tag]:
|
||||||
|
"""
|
||||||
|
Renames the given tags and returns a new set. The first tuple element
|
||||||
|
is the old name, the second one is the new name.
|
||||||
|
"""
|
||||||
|
for t in tags_rename:
|
||||||
|
if t[0] in tags:
|
||||||
|
tags.remove(t[0])
|
||||||
|
tags.add(t[1])
|
||||||
|
return set(sorted(tags))
|
||||||
|
|
||||||
|
|
||||||
|
def match_tags(tags: set[Tag], tags_or: Optional[set[Tag]], tags_and: Optional[set[Tag]],
|
||||||
|
tags_not: Optional[set[Tag]]) -> bool:
|
||||||
|
"""
|
||||||
|
Checks if the given set 'tags' matches the given tag requirements:
|
||||||
|
- 'tags_or' : matches if this TagLine contains ANY of those tags
|
||||||
|
- 'tags_and': matches if this TagLine contains ALL of those tags
|
||||||
|
- 'tags_not': matches if this TagLine contains NONE of those tags
|
||||||
|
|
||||||
|
Note that it's sufficient if 'tags' matches one of 'tags_or' or 'tags_and',
|
||||||
|
i. e. you can select a TagLine if it either contains one of the tags in 'tags_or'
|
||||||
|
or all of the tags in 'tags_and' but it must never contain any of the tags in
|
||||||
|
'tags_not'. If 'tags_or' and 'tags_and' are 'None', they match all tags (tag
|
||||||
|
exclusion is still done if 'tags_not' is not 'None'). If they are empty (set()),
|
||||||
|
they match no tags.
|
||||||
|
"""
|
||||||
|
required_tags_present = False
|
||||||
|
excluded_tags_missing = False
|
||||||
|
if ((tags_or is None and tags_and is None)
|
||||||
|
or (tags_or and any(tag in tags for tag in tags_or)) # noqa: W503
|
||||||
|
or (tags_and and all(tag in tags for tag in tags_and))): # noqa: W503
|
||||||
|
required_tags_present = True
|
||||||
|
if ((tags_not is None)
|
||||||
|
or (not any(tag in tags for tag in tags_not))): # noqa: W503
|
||||||
|
excluded_tags_missing = True
|
||||||
|
return required_tags_present and excluded_tags_missing
|
||||||
|
|
||||||
|
|
||||||
|
class TagLine(str):
|
||||||
|
"""
|
||||||
|
A line of tags in a '.txt' file. It starts with a prefix ('TAGS:'), followed by
|
||||||
|
a list of tags, separated by the defaut separator (' '). Any operations on a
|
||||||
|
TagLine will sort the tags.
|
||||||
|
"""
|
||||||
|
# the prefix
|
||||||
|
prefix: Final[str] = 'TAGS:'
|
||||||
|
|
||||||
|
def __new__(cls: Type[TagLineInst], string: str) -> TagLineInst:
|
||||||
|
"""
|
||||||
|
Make sure the tagline string starts with the prefix. Also replace newlines
|
||||||
|
and multiple spaces with ' ', in order to support multiline TagLines.
|
||||||
|
"""
|
||||||
|
if not string.startswith(cls.prefix):
|
||||||
|
raise TagError(f"TagLine '{string}' is missing prefix '{cls.prefix}'")
|
||||||
|
string = ' '.join(string.split())
|
||||||
|
instance = super().__new__(cls, string)
|
||||||
|
return instance
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_set(cls: Type[TagLineInst], tags: set[Tag]) -> TagLineInst:
|
||||||
|
"""
|
||||||
|
Create a new TagLine from a set of tags.
|
||||||
|
"""
|
||||||
|
return cls(' '.join([cls.prefix] + sorted([t for t in tags])))
|
||||||
|
|
||||||
|
def tags(self, prefix: Optional[str] = None, contain: Optional[str] = None) -> set[Tag]:
|
||||||
|
"""
|
||||||
|
Returns all tags contained in this line as a set, optionally
|
||||||
|
filtered based on prefix or contained string.
|
||||||
|
"""
|
||||||
|
tagstr = self[len(self.prefix):].strip()
|
||||||
|
if tagstr == '':
|
||||||
|
return set() # no tags, only prefix
|
||||||
|
separator = Tag.default_separator
|
||||||
|
# look for alternative separators and use the first one found
|
||||||
|
# -> we don't support different separators in the same TagLine
|
||||||
|
for s in Tag.alternative_separators:
|
||||||
|
if s in tagstr:
|
||||||
|
separator = s
|
||||||
|
break
|
||||||
|
res_tags = set(sorted([Tag(t.strip()) for t in tagstr.split(separator)]))
|
||||||
|
if prefix and len(prefix) > 0:
|
||||||
|
res_tags -= {tag for tag in res_tags if not tag.startswith(prefix)}
|
||||||
|
if contain and len(contain) > 0:
|
||||||
|
res_tags -= {tag for tag in res_tags if contain not in tag}
|
||||||
|
return res_tags or set()
|
||||||
|
|
||||||
|
def merge(self, taglines: set['TagLine']) -> 'TagLine':
|
||||||
|
"""
|
||||||
|
Merges the tags of all given taglines into the current one and returns a new TagLine.
|
||||||
|
"""
|
||||||
|
tags_merge = [tl.tags() for tl in taglines]
|
||||||
|
return self.from_set(merge_tags(self.tags(), tags_merge))
|
||||||
|
|
||||||
|
def delete_tags(self, tags_delete: set[Tag]) -> 'TagLine':
|
||||||
|
"""
|
||||||
|
Deletes the given tags and returns a new TagLine.
|
||||||
|
"""
|
||||||
|
return self.from_set(delete_tags(self.tags(), tags_delete))
|
||||||
|
|
||||||
|
def add_tags(self, tags_add: set[Tag]) -> 'TagLine':
|
||||||
|
"""
|
||||||
|
Adds the given tags and returns a new TagLine.
|
||||||
|
"""
|
||||||
|
return self.from_set(add_tags(self.tags(), tags_add))
|
||||||
|
|
||||||
|
def rename_tags(self, tags_rename: set[tuple[Tag, Tag]]) -> 'TagLine':
|
||||||
|
"""
|
||||||
|
Renames the given tags and returns a new TagLine. The first
|
||||||
|
tuple element is the old name, the second one is the new name.
|
||||||
|
"""
|
||||||
|
return self.from_set(rename_tags(self.tags(), tags_rename))
|
||||||
|
|
||||||
|
def match_tags(self, tags_or: Optional[set[Tag]], tags_and: Optional[set[Tag]],
|
||||||
|
tags_not: Optional[set[Tag]]) -> bool:
|
||||||
|
"""
|
||||||
|
Checks if the current TagLine matches the given tag requirements:
|
||||||
|
- 'tags_or' : matches if this TagLine contains ANY of those tags
|
||||||
|
- 'tags_and': matches if this TagLine contains ALL of those tags
|
||||||
|
- 'tags_not': matches if this TagLine contains NONE of those tags
|
||||||
|
|
||||||
|
Note that it's sufficient if the TagLine matches one of 'tags_or' or 'tags_and',
|
||||||
|
i. e. you can select a TagLine if it either contains one of the tags in 'tags_or'
|
||||||
|
or all of the tags in 'tags_and' but it must never contain any of the tags in
|
||||||
|
'tags_not'. If 'tags_or' and 'tags_and' are 'None', they match all tags (tag
|
||||||
|
exclusion is still done if 'tags_not' is not 'None').
|
||||||
|
"""
|
||||||
|
return match_tags(self.tags(), tags_or, tags_and, tags_not)
|
||||||
@@ -1,83 +0,0 @@
|
|||||||
import shutil
|
|
||||||
from pprint import PrettyPrinter
|
|
||||||
from typing import List, Dict
|
|
||||||
|
|
||||||
|
|
||||||
def terminal_width() -> int:
|
|
||||||
return shutil.get_terminal_size().columns
|
|
||||||
|
|
||||||
|
|
||||||
def pp(*args, **kwargs) -> None:
|
|
||||||
return PrettyPrinter(width=terminal_width()).pprint(*args, **kwargs)
|
|
||||||
|
|
||||||
|
|
||||||
def process_tags(tags: list[str], extags: list[str], otags: list[str]) -> None:
|
|
||||||
printed_messages = []
|
|
||||||
|
|
||||||
if tags:
|
|
||||||
printed_messages.append(f"Tags: {', '.join(tags)}")
|
|
||||||
if extags:
|
|
||||||
printed_messages.append(f"Excluding tags: {', '.join(extags)}")
|
|
||||||
if otags:
|
|
||||||
printed_messages.append(f"Output tags: {', '.join(otags)}")
|
|
||||||
|
|
||||||
if printed_messages:
|
|
||||||
print("\n".join(printed_messages))
|
|
||||||
print()
|
|
||||||
|
|
||||||
|
|
||||||
def append_message(chat: List[Dict[str, str]],
|
|
||||||
role: str,
|
|
||||||
content: str
|
|
||||||
) -> None:
|
|
||||||
chat.append({'role': role, 'content': content.replace("''", "'")})
|
|
||||||
|
|
||||||
|
|
||||||
def message_to_chat(message: Dict[str, str],
|
|
||||||
chat: List[Dict[str, str]],
|
|
||||||
with_tags: bool = False,
|
|
||||||
with_file: bool = False
|
|
||||||
) -> None:
|
|
||||||
append_message(chat, 'user', message['question'])
|
|
||||||
append_message(chat, 'assistant', message['answer'])
|
|
||||||
if with_tags:
|
|
||||||
tags = ", ".join(message['tags'])
|
|
||||||
append_message(chat, 'tags', tags)
|
|
||||||
if with_file:
|
|
||||||
append_message(chat, 'file', message['file'])
|
|
||||||
|
|
||||||
|
|
||||||
def display_source_code(content: str) -> None:
|
|
||||||
try:
|
|
||||||
content_start = content.index('```')
|
|
||||||
content_end = content.rindex('```')
|
|
||||||
if content_start + 3 < content_end:
|
|
||||||
print(content[content_start + 3:content_end].strip())
|
|
||||||
except ValueError:
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
def display_chat(chat, dump=False, source_code=False) -> None:
|
|
||||||
if dump:
|
|
||||||
pp(chat)
|
|
||||||
return
|
|
||||||
for message in chat:
|
|
||||||
text_too_long = len(message['content']) > terminal_width() - len(message['role']) - 2
|
|
||||||
if source_code:
|
|
||||||
display_source_code(message['content'])
|
|
||||||
continue
|
|
||||||
if message['role'] == 'user':
|
|
||||||
print('-' * terminal_width())
|
|
||||||
if text_too_long:
|
|
||||||
print(f"{message['role'].upper()}:")
|
|
||||||
print(message['content'])
|
|
||||||
else:
|
|
||||||
print(f"{message['role'].upper()}: {message['content']}")
|
|
||||||
|
|
||||||
|
|
||||||
def display_tags_frequency(tags: List[str], dump=False) -> None:
|
|
||||||
if dump:
|
|
||||||
pp(tags)
|
|
||||||
return
|
|
||||||
for tag in set(tags):
|
|
||||||
print(f"- {tag}: {tags.count(tag)}")
|
|
||||||
@@ -0,0 +1,56 @@
|
|||||||
|
<?php
|
||||||
|
|
||||||
|
$secret_key = '123';
|
||||||
|
|
||||||
|
// check for POST request
|
||||||
|
if ($_SERVER['REQUEST_METHOD'] != 'POST') {
|
||||||
|
error_log('FAILED - not POST - '. $_SERVER['REQUEST_METHOD']);
|
||||||
|
exit();
|
||||||
|
}
|
||||||
|
|
||||||
|
// get content type
|
||||||
|
$content_type = isset($_SERVER['CONTENT_TYPE']) ? strtolower(trim($_SERVER['CONTENT_TYPE'])) : '';
|
||||||
|
|
||||||
|
if ($content_type != 'application/json') {
|
||||||
|
error_log('FAILED - not application/json - '. $content_type);
|
||||||
|
exit();
|
||||||
|
}
|
||||||
|
|
||||||
|
// get payload
|
||||||
|
$payload = trim(file_get_contents("php://input"));
|
||||||
|
|
||||||
|
if (empty($payload)) {
|
||||||
|
error_log('FAILED - no payload');
|
||||||
|
exit();
|
||||||
|
}
|
||||||
|
|
||||||
|
// get header signature
|
||||||
|
$header_signature = isset($_SERVER['HTTP_X_GITEA_SIGNATURE']) ? $_SERVER['HTTP_X_GITEA_SIGNATURE'] : '';
|
||||||
|
|
||||||
|
if (empty($header_signature)) {
|
||||||
|
error_log('FAILED - header signature missing');
|
||||||
|
exit();
|
||||||
|
}
|
||||||
|
|
||||||
|
// calculate payload signature
|
||||||
|
$payload_signature = hash_hmac('sha256', $payload, $secret_key, false);
|
||||||
|
|
||||||
|
// check payload signature against header signature
|
||||||
|
if ($header_signature !== $payload_signature) {
|
||||||
|
error_log('FAILED - payload signature');
|
||||||
|
exit();
|
||||||
|
}
|
||||||
|
|
||||||
|
// convert json to array
|
||||||
|
$decoded = json_decode($payload, true);
|
||||||
|
|
||||||
|
// check for json decode errors
|
||||||
|
if (json_last_error() !== JSON_ERROR_NONE) {
|
||||||
|
error_log('FAILED - json decode - '. json_last_error());
|
||||||
|
exit();
|
||||||
|
}
|
||||||
|
|
||||||
|
// success, do something
|
||||||
|
$output = shell_exec('/home/kaizen/repos/ChatMastermind/hooks/push_hook.sh');
|
||||||
|
echo "$output";
|
||||||
|
?>
|
||||||
Executable
+8
@@ -0,0 +1,8 @@
|
|||||||
|
#!/usr/bin/bash
|
||||||
|
|
||||||
|
. /home/kaizen/.bashrc
|
||||||
|
set -e
|
||||||
|
cd /home/kaizen/repos/ChatMastermind
|
||||||
|
git pull
|
||||||
|
pre-commit run -a
|
||||||
|
pytest
|
||||||
@@ -5,3 +5,4 @@ strict_optional = True
|
|||||||
warn_unused_ignores = False
|
warn_unused_ignores = False
|
||||||
warn_redundant_casts = True
|
warn_redundant_casts = True
|
||||||
warn_unused_configs = True
|
warn_unused_configs = True
|
||||||
|
disallow_untyped_defs = True
|
||||||
|
|||||||
@@ -2,3 +2,4 @@ openai
|
|||||||
PyYAML
|
PyYAML
|
||||||
argcomplete
|
argcomplete
|
||||||
pytest
|
pytest
|
||||||
|
tiktoken
|
||||||
|
|||||||
@@ -12,23 +12,29 @@ setup(
|
|||||||
long_description=long_description,
|
long_description=long_description,
|
||||||
long_description_content_type="text/markdown",
|
long_description_content_type="text/markdown",
|
||||||
url="https://github.com/ok2/ChatMastermind",
|
url="https://github.com/ok2/ChatMastermind",
|
||||||
packages=find_packages(),
|
packages=find_packages() + ["chatmastermind.ais", "chatmastermind.commands"],
|
||||||
classifiers=[
|
classifiers=[
|
||||||
"Development Status :: 3 - Alpha",
|
"Development Status :: 3 - Alpha",
|
||||||
|
"Environment :: Console",
|
||||||
"Intended Audience :: Developers",
|
"Intended Audience :: Developers",
|
||||||
"License :: OSI Approved :: MIT License",
|
"Intended Audience :: End Users/Desktop",
|
||||||
|
"Intended Audience :: Science/Research",
|
||||||
"Operating System :: OS Independent",
|
"Operating System :: OS Independent",
|
||||||
"Programming Language :: Python :: 3",
|
"Programming Language :: Python :: 3",
|
||||||
|
"Programming Language :: Python :: 3.9",
|
||||||
"Programming Language :: Python :: 3.10",
|
"Programming Language :: Python :: 3.10",
|
||||||
"Programming Language :: Python :: 3.11",
|
"Programming Language :: Python :: 3.11",
|
||||||
|
"Programming Language :: Python :: 3.12",
|
||||||
|
"Topic :: Utilities",
|
||||||
|
"Topic :: Text Processing",
|
||||||
],
|
],
|
||||||
install_requires=[
|
install_requires=[
|
||||||
"openai",
|
"openai",
|
||||||
"PyYAML",
|
"PyYAML",
|
||||||
"argcomplete",
|
"argcomplete",
|
||||||
"pytest"
|
"pytest",
|
||||||
],
|
],
|
||||||
python_requires=">=3.10",
|
python_requires=">=3.9",
|
||||||
test_suite="tests",
|
test_suite="tests",
|
||||||
entry_points={
|
entry_points={
|
||||||
"console_scripts": [
|
"console_scripts": [
|
||||||
|
|||||||
@@ -0,0 +1,48 @@
|
|||||||
|
import argparse
|
||||||
|
import unittest
|
||||||
|
from unittest.mock import MagicMock
|
||||||
|
from chatmastermind.ai_factory import create_ai
|
||||||
|
from chatmastermind.configuration import Config
|
||||||
|
from chatmastermind.ai import AIError
|
||||||
|
from chatmastermind.ais.openai import OpenAI
|
||||||
|
|
||||||
|
|
||||||
|
class TestCreateAI(unittest.TestCase):
|
||||||
|
def setUp(self) -> None:
|
||||||
|
self.args = MagicMock(spec=argparse.Namespace)
|
||||||
|
self.args.AI = 'myopenai'
|
||||||
|
self.args.model = None
|
||||||
|
self.args.max_tokens = None
|
||||||
|
self.args.temperature = None
|
||||||
|
|
||||||
|
def test_create_ai_from_args(self) -> None:
|
||||||
|
# Create an AI with the default configuration
|
||||||
|
config = Config()
|
||||||
|
self.args.AI = 'myopenai'
|
||||||
|
ai = create_ai(self.args, config)
|
||||||
|
self.assertIsInstance(ai, OpenAI)
|
||||||
|
|
||||||
|
def test_create_ai_from_default(self) -> None:
|
||||||
|
self.args.AI = None
|
||||||
|
# Create an AI with the default configuration
|
||||||
|
config = Config()
|
||||||
|
ai = create_ai(self.args, config)
|
||||||
|
self.assertIsInstance(ai, OpenAI)
|
||||||
|
|
||||||
|
def test_create_empty_ai_error(self) -> None:
|
||||||
|
self.args.AI = None
|
||||||
|
# Create Config with empty AIs
|
||||||
|
config = Config()
|
||||||
|
config.ais = {}
|
||||||
|
# Call create_ai function and assert that it raises AIError
|
||||||
|
with self.assertRaises(AIError):
|
||||||
|
create_ai(self.args, config)
|
||||||
|
|
||||||
|
def test_create_unsupported_ai_error(self) -> None:
|
||||||
|
# Mock argparse.Namespace with ai='invalid_ai'
|
||||||
|
self.args.AI = 'invalid_ai'
|
||||||
|
# Create default Config
|
||||||
|
config = Config()
|
||||||
|
# Call create_ai function and assert that it raises AIError
|
||||||
|
with self.assertRaises(AIError):
|
||||||
|
create_ai(self.args, config)
|
||||||
@@ -0,0 +1,81 @@
|
|||||||
|
import unittest
|
||||||
|
from unittest import mock
|
||||||
|
from chatmastermind.ais.openai import OpenAI
|
||||||
|
from chatmastermind.message import Message, Question, Answer
|
||||||
|
from chatmastermind.chat import Chat
|
||||||
|
from chatmastermind.ai import AIResponse, Tokens
|
||||||
|
from chatmastermind.configuration import OpenAIConfig
|
||||||
|
|
||||||
|
|
||||||
|
class OpenAITest(unittest.TestCase):
|
||||||
|
|
||||||
|
@mock.patch('chatmastermind.ais.openai.OpenAI._completions')
|
||||||
|
def test_request(self, mock_create: mock.MagicMock) -> None:
|
||||||
|
# Create a test instance of OpenAI
|
||||||
|
config = OpenAIConfig()
|
||||||
|
openai = OpenAI(config)
|
||||||
|
|
||||||
|
# Set up the mock response from openai.ChatCompletion.create
|
||||||
|
class mock_obj:
|
||||||
|
pass
|
||||||
|
mock_chunk1 = mock_obj()
|
||||||
|
mock_chunk1.choices = [mock_obj(), mock_obj()] # type: ignore
|
||||||
|
mock_chunk1.choices[0].index = 0 # type: ignore
|
||||||
|
mock_chunk1.choices[0].delta = mock_obj() # type: ignore
|
||||||
|
mock_chunk1.choices[0].delta.content = 'Answer 1' # type: ignore
|
||||||
|
mock_chunk1.choices[0].finish_reason = None # type: ignore
|
||||||
|
mock_chunk1.choices[1].index = 1 # type: ignore
|
||||||
|
mock_chunk1.choices[1].delta = mock_obj() # type: ignore
|
||||||
|
mock_chunk1.choices[1].delta.content = 'Answer 2' # type: ignore
|
||||||
|
mock_chunk1.choices[1].finish_reason = None # type: ignore
|
||||||
|
mock_chunk2 = mock_obj()
|
||||||
|
mock_chunk2.choices = [mock_obj(), mock_obj()] # type: ignore
|
||||||
|
mock_chunk2.choices[0].index = 0 # type: ignore
|
||||||
|
mock_chunk2.choices[0].finish_reason = 'stop' # type: ignore
|
||||||
|
mock_chunk2.choices[1].index = 1 # type: ignore
|
||||||
|
mock_chunk2.choices[1].finish_reason = 'stop' # type: ignore
|
||||||
|
mock_create.return_value = iter([mock_chunk1, mock_chunk2])
|
||||||
|
|
||||||
|
# Create test data
|
||||||
|
question = Message(Question('Question'))
|
||||||
|
chat = Chat([
|
||||||
|
Message(Question('Question 1'), answer=Answer('Answer 1')),
|
||||||
|
Message(Question('Question 2'), answer=Answer('Answer 2')),
|
||||||
|
# add message without an answer -> expect to be skipped
|
||||||
|
Message(Question('Question 3'))
|
||||||
|
])
|
||||||
|
|
||||||
|
# Make the request
|
||||||
|
response = openai.request(question, chat, num_answers=2)
|
||||||
|
|
||||||
|
# Assert the AIResponse
|
||||||
|
self.assertIsInstance(response, AIResponse)
|
||||||
|
self.assertEqual(len(response.messages), 2)
|
||||||
|
self.assertEqual(response.messages[0].answer, 'Answer 1')
|
||||||
|
self.assertEqual(response.messages[1].answer, 'Answer 2')
|
||||||
|
self.assertIsNotNone(response.tokens)
|
||||||
|
self.assertIsInstance(response.tokens, Tokens)
|
||||||
|
assert response.tokens
|
||||||
|
self.assertEqual(response.tokens.prompt, 53)
|
||||||
|
self.assertEqual(response.tokens.completion, 6)
|
||||||
|
self.assertEqual(response.tokens.total, 59)
|
||||||
|
|
||||||
|
# Assert the mock call to openai.ChatCompletion.create
|
||||||
|
mock_create.assert_called_once_with(
|
||||||
|
model=f'{config.model}',
|
||||||
|
messages=[
|
||||||
|
{'role': 'system', 'content': f'{config.system}'},
|
||||||
|
{'role': 'user', 'content': 'Question 1'},
|
||||||
|
{'role': 'assistant', 'content': 'Answer 1'},
|
||||||
|
{'role': 'user', 'content': 'Question 2'},
|
||||||
|
{'role': 'assistant', 'content': 'Answer 2'},
|
||||||
|
{'role': 'user', 'content': 'Question'}
|
||||||
|
],
|
||||||
|
temperature=config.temperature,
|
||||||
|
max_tokens=config.max_tokens,
|
||||||
|
top_p=config.top_p,
|
||||||
|
n=2,
|
||||||
|
stream=True,
|
||||||
|
frequency_penalty=config.frequency_penalty,
|
||||||
|
presence_penalty=config.presence_penalty
|
||||||
|
)
|
||||||
@@ -0,0 +1,687 @@
|
|||||||
|
import unittest
|
||||||
|
import pathlib
|
||||||
|
import tempfile
|
||||||
|
import time
|
||||||
|
import yaml
|
||||||
|
from io import StringIO
|
||||||
|
from unittest.mock import patch
|
||||||
|
from chatmastermind.tags import TagLine
|
||||||
|
from chatmastermind.message import Message, Question, Answer, Tag, MessageFilter
|
||||||
|
from chatmastermind.chat import Chat, ChatDB, ChatError, msg_location
|
||||||
|
|
||||||
|
|
||||||
|
msg_suffix: str = Message.file_suffix_write
|
||||||
|
|
||||||
|
|
||||||
|
def msg_to_file_force_suffix(msg: Message) -> None:
|
||||||
|
"""
|
||||||
|
Force writing a message file with illegal suffixes.
|
||||||
|
"""
|
||||||
|
def_suffix = Message.file_suffix_write
|
||||||
|
assert msg.file_path
|
||||||
|
Message.file_suffix_write = msg.file_path.suffix
|
||||||
|
msg.to_file()
|
||||||
|
Message.file_suffix_write = def_suffix
|
||||||
|
|
||||||
|
|
||||||
|
class TestChatBase(unittest.TestCase):
|
||||||
|
def assert_messages_equal(self, msg1: list[Message], msg2: list[Message]) -> None:
|
||||||
|
"""
|
||||||
|
Compare messages using more than just Question and Answer.
|
||||||
|
"""
|
||||||
|
self.assertEqual(len(msg1), len(msg2))
|
||||||
|
for m1, m2 in zip(msg1, msg2):
|
||||||
|
# exclude the file_path, compare only Q, A and metadata
|
||||||
|
self.assertTrue(m1.equals(m2, file_path=False, verbose=True))
|
||||||
|
|
||||||
|
|
||||||
|
class TestChat(TestChatBase):
|
||||||
|
def setUp(self) -> None:
|
||||||
|
self.chat = Chat([])
|
||||||
|
self.message1 = Message(Question('Question 1'),
|
||||||
|
Answer('Answer 1'),
|
||||||
|
{Tag('atag1'), Tag('btag2')},
|
||||||
|
ai='FakeAI',
|
||||||
|
model='FakeModel',
|
||||||
|
file_path=pathlib.Path(f'0001{msg_suffix}'))
|
||||||
|
self.message2 = Message(Question('Question 2'),
|
||||||
|
Answer('Answer 2'),
|
||||||
|
{Tag('btag2')},
|
||||||
|
ai='FakeAI',
|
||||||
|
model='FakeModel',
|
||||||
|
file_path=pathlib.Path(f'0002{msg_suffix}'))
|
||||||
|
self.maxDiff = None
|
||||||
|
|
||||||
|
def test_unique_id(self) -> None:
|
||||||
|
# test with two identical messages
|
||||||
|
self.chat.msg_add([self.message1, self.message1])
|
||||||
|
self.assert_messages_equal(self.chat.messages, [self.message1, self.message1])
|
||||||
|
self.chat.msg_unique_id()
|
||||||
|
self.assert_messages_equal(self.chat.messages, [self.message1])
|
||||||
|
# test with two different messages
|
||||||
|
self.chat.msg_add([self.message2])
|
||||||
|
self.chat.msg_unique_id()
|
||||||
|
self.assert_messages_equal(self.chat.messages, [self.message1, self.message2])
|
||||||
|
|
||||||
|
def test_unique_content(self) -> None:
|
||||||
|
# test with two identical messages
|
||||||
|
self.chat.msg_add([self.message1, self.message1])
|
||||||
|
self.assert_messages_equal(self.chat.messages, [self.message1, self.message1])
|
||||||
|
self.chat.msg_unique_content()
|
||||||
|
self.assert_messages_equal(self.chat.messages, [self.message1])
|
||||||
|
# test with two different messages
|
||||||
|
self.chat.msg_add([self.message2])
|
||||||
|
self.chat.msg_unique_content()
|
||||||
|
self.assert_messages_equal(self.chat.messages, [self.message1, self.message2])
|
||||||
|
|
||||||
|
def test_filter(self) -> None:
|
||||||
|
self.chat.msg_add([self.message1, self.message2])
|
||||||
|
self.chat.msg_filter(MessageFilter(answer_contains='Answer 1'))
|
||||||
|
|
||||||
|
self.assertEqual(len(self.chat.messages), 1)
|
||||||
|
self.assertEqual(self.chat.messages[0].question, 'Question 1')
|
||||||
|
|
||||||
|
def test_sort(self) -> None:
|
||||||
|
self.chat.msg_add([self.message2, self.message1])
|
||||||
|
self.chat.msg_sort()
|
||||||
|
self.assertEqual(self.chat.messages[0].question, 'Question 1')
|
||||||
|
self.assertEqual(self.chat.messages[1].question, 'Question 2')
|
||||||
|
self.chat.msg_sort(reverse=True)
|
||||||
|
self.assertEqual(self.chat.messages[0].question, 'Question 2')
|
||||||
|
self.assertEqual(self.chat.messages[1].question, 'Question 1')
|
||||||
|
|
||||||
|
def test_clear(self) -> None:
|
||||||
|
self.chat.msg_add([self.message1])
|
||||||
|
self.chat.msg_clear()
|
||||||
|
self.assertEqual(len(self.chat.messages), 0)
|
||||||
|
|
||||||
|
def test_add_messages(self) -> None:
|
||||||
|
self.chat.msg_add([self.message1, self.message2])
|
||||||
|
self.assertEqual(len(self.chat.messages), 2)
|
||||||
|
self.assertEqual(self.chat.messages[0].question, 'Question 1')
|
||||||
|
self.assertEqual(self.chat.messages[1].question, 'Question 2')
|
||||||
|
|
||||||
|
def test_tags(self) -> None:
|
||||||
|
self.chat.msg_add([self.message1, self.message2])
|
||||||
|
tags_all = self.chat.msg_tags()
|
||||||
|
self.assertSetEqual(tags_all, {Tag('atag1'), Tag('btag2')})
|
||||||
|
tags_pref = self.chat.msg_tags(prefix='a')
|
||||||
|
self.assertSetEqual(tags_pref, {Tag('atag1')})
|
||||||
|
tags_cont = self.chat.msg_tags(contain='2')
|
||||||
|
self.assertSetEqual(tags_cont, {Tag('btag2')})
|
||||||
|
|
||||||
|
def test_tags_frequency(self) -> None:
|
||||||
|
self.chat.msg_add([self.message1, self.message2])
|
||||||
|
tags_freq = self.chat.msg_tags_frequency()
|
||||||
|
self.assertDictEqual(tags_freq, {'atag1': 1, 'btag2': 2})
|
||||||
|
|
||||||
|
def test_find_remove_messages(self) -> None:
|
||||||
|
self.chat.msg_add([self.message1, self.message2])
|
||||||
|
msgs = self.chat.msg_find(['0001'])
|
||||||
|
self.assertListEqual(msgs, [self.message1])
|
||||||
|
msgs = self.chat.msg_find(['0001', '0002'])
|
||||||
|
self.assertListEqual(msgs, [self.message1, self.message2])
|
||||||
|
# add new Message with full path
|
||||||
|
message3 = Message(Question('Question 2'),
|
||||||
|
Answer('Answer 2'),
|
||||||
|
{Tag('btag2')},
|
||||||
|
file_path=pathlib.Path(f'/foo/bla/0003{msg_suffix}'))
|
||||||
|
self.chat.msg_add([message3])
|
||||||
|
# find new Message by full path
|
||||||
|
msgs = self.chat.msg_find([f'/foo/bla/0003{msg_suffix}'])
|
||||||
|
self.assertListEqual(msgs, [message3])
|
||||||
|
# find Message with full path only by filename
|
||||||
|
msgs = self.chat.msg_find([f'0003{msg_suffix}'])
|
||||||
|
self.assertListEqual(msgs, [message3])
|
||||||
|
# remove last message
|
||||||
|
self.chat.msg_remove(['0003'])
|
||||||
|
self.assertListEqual(self.chat.messages, [self.message1, self.message2])
|
||||||
|
|
||||||
|
def test_latest_message(self) -> None:
|
||||||
|
self.assertIsNone(self.chat.msg_latest())
|
||||||
|
self.chat.msg_add([self.message1])
|
||||||
|
self.assertEqual(self.chat.msg_latest(), self.message1)
|
||||||
|
self.chat.msg_add([self.message2])
|
||||||
|
self.assertEqual(self.chat.msg_latest(), self.message2)
|
||||||
|
|
||||||
|
@patch('sys.stdout', new_callable=StringIO)
|
||||||
|
def test_print(self, mock_stdout: StringIO) -> None:
|
||||||
|
self.chat.msg_add([self.message1, self.message2])
|
||||||
|
self.chat.print(paged=False, tight=True)
|
||||||
|
expected_output = f"""{Question.txt_header}
|
||||||
|
Question 1
|
||||||
|
{Answer.txt_header}
|
||||||
|
Answer 1
|
||||||
|
{Question.txt_header}
|
||||||
|
Question 2
|
||||||
|
{Answer.txt_header}
|
||||||
|
Answer 2
|
||||||
|
"""
|
||||||
|
self.assertEqual(mock_stdout.getvalue(), expected_output)
|
||||||
|
|
||||||
|
@patch('sys.stdout', new_callable=StringIO)
|
||||||
|
def test_print_with_metadata(self, mock_stdout: StringIO) -> None:
|
||||||
|
self.chat.msg_add([self.message1, self.message2])
|
||||||
|
self.chat.print(paged=False, with_metadata=True, tight=True)
|
||||||
|
expected_output = f"""{TagLine.prefix} atag1 btag2
|
||||||
|
FILE: 0001{msg_suffix}
|
||||||
|
AI: FakeAI
|
||||||
|
MODEL: FakeModel
|
||||||
|
{Question.txt_header}
|
||||||
|
Question 1
|
||||||
|
{Answer.txt_header}
|
||||||
|
Answer 1
|
||||||
|
{TagLine.prefix} btag2
|
||||||
|
FILE: 0002{msg_suffix}
|
||||||
|
AI: FakeAI
|
||||||
|
MODEL: FakeModel
|
||||||
|
{Question.txt_header}
|
||||||
|
Question 2
|
||||||
|
{Answer.txt_header}
|
||||||
|
Answer 2
|
||||||
|
"""
|
||||||
|
self.assertEqual(mock_stdout.getvalue(), expected_output)
|
||||||
|
|
||||||
|
|
||||||
|
class TestChatDB(TestChatBase):
|
||||||
|
def setUp(self) -> None:
|
||||||
|
self.db_path = tempfile.TemporaryDirectory()
|
||||||
|
self.cache_path = tempfile.TemporaryDirectory()
|
||||||
|
|
||||||
|
self.message1 = Message(Question('Question 1'),
|
||||||
|
Answer('Answer 1'),
|
||||||
|
{Tag('tag1')})
|
||||||
|
self.message2 = Message(Question('Question 2'),
|
||||||
|
Answer('Answer 2'),
|
||||||
|
{Tag('tag2')})
|
||||||
|
self.message3 = Message(Question('Question 3'),
|
||||||
|
Answer('Answer 3'),
|
||||||
|
{Tag('tag3')})
|
||||||
|
self.message4 = Message(Question('Question 4'),
|
||||||
|
Answer('Answer 4'),
|
||||||
|
{Tag('tag4')})
|
||||||
|
|
||||||
|
self.message1.to_file(pathlib.Path(self.db_path.name, '0001'), mformat='txt')
|
||||||
|
self.message2.to_file(pathlib.Path(self.db_path.name, '0002'), mformat='yaml')
|
||||||
|
self.message3.to_file(pathlib.Path(self.db_path.name, '0003'), mformat='txt')
|
||||||
|
self.message4.to_file(pathlib.Path(self.db_path.name, '0004'), mformat='yaml')
|
||||||
|
# make the next FID match the current state
|
||||||
|
next_fname = pathlib.Path(self.db_path.name) / '.next'
|
||||||
|
with open(next_fname, 'w') as f:
|
||||||
|
f.write('4')
|
||||||
|
# add some "trash" in order to test if it's correctly handled / ignored
|
||||||
|
self.trash_files = ['.config.yaml', 'foo.yaml', 'bla.txt', 'fubar.msg']
|
||||||
|
for file in self.trash_files:
|
||||||
|
with open(pathlib.Path(self.db_path.name) / file, 'w') as f:
|
||||||
|
f.write('test trash')
|
||||||
|
# also create a file with actual yaml content
|
||||||
|
with open(pathlib.Path(self.db_path.name) / 'content.yaml', 'w') as f:
|
||||||
|
yaml.dump({'key': 'value'}, f)
|
||||||
|
self.trash_files.append('content.yaml')
|
||||||
|
self.maxDiff = None
|
||||||
|
|
||||||
|
def message_list(self, tmp_dir: tempfile.TemporaryDirectory) -> list[pathlib.Path]:
|
||||||
|
"""
|
||||||
|
List all Message files in the given TemporaryDirectory.
|
||||||
|
"""
|
||||||
|
# exclude '.next'
|
||||||
|
return [f for f in pathlib.Path(tmp_dir.name).glob('*.[tym]*') if f.name not in self.trash_files]
|
||||||
|
|
||||||
|
def tearDown(self) -> None:
|
||||||
|
self.db_path.cleanup()
|
||||||
|
self.cache_path.cleanup()
|
||||||
|
pass
|
||||||
|
|
||||||
|
def test_validate(self) -> None:
|
||||||
|
duplicate_message = Message(Question('Question 4'),
|
||||||
|
Answer('Answer 4'),
|
||||||
|
{Tag('tag4')},
|
||||||
|
file_path=pathlib.Path(self.db_path.name, '0004.txt'))
|
||||||
|
msg_to_file_force_suffix(duplicate_message)
|
||||||
|
with self.assertRaises(ChatError) as cm:
|
||||||
|
ChatDB.from_dir(pathlib.Path(self.cache_path.name),
|
||||||
|
pathlib.Path(self.db_path.name),
|
||||||
|
glob='*')
|
||||||
|
self.assertEqual(str(cm.exception), "Validation failed")
|
||||||
|
|
||||||
|
def test_file_path_ID_exists(self) -> None:
|
||||||
|
"""
|
||||||
|
Tests if the CacheDB chooses another ID if a file path with
|
||||||
|
the given one exists.
|
||||||
|
"""
|
||||||
|
# create a new and empty CacheDB
|
||||||
|
db_path = tempfile.TemporaryDirectory()
|
||||||
|
cache_path = tempfile.TemporaryDirectory()
|
||||||
|
chat_db = ChatDB.from_dir(pathlib.Path(cache_path.name),
|
||||||
|
pathlib.Path(db_path.name))
|
||||||
|
# add a message file
|
||||||
|
message = Message(Question('What?'),
|
||||||
|
file_path=pathlib.Path(cache_path.name) / f'0001{msg_suffix}')
|
||||||
|
message.to_file()
|
||||||
|
message1 = Message(Question('Where?'))
|
||||||
|
chat_db.cache_write([message1])
|
||||||
|
self.assertEqual(message1.msg_id(), '0002')
|
||||||
|
|
||||||
|
def test_from_dir(self) -> None:
|
||||||
|
chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name),
|
||||||
|
pathlib.Path(self.db_path.name))
|
||||||
|
self.assertEqual(len(chat_db.messages), 4)
|
||||||
|
self.assertEqual(chat_db.cache_path, pathlib.Path(self.cache_path.name))
|
||||||
|
self.assertEqual(chat_db.db_path, pathlib.Path(self.db_path.name))
|
||||||
|
# check that the files are sorted
|
||||||
|
self.assertEqual(chat_db.messages[0].file_path,
|
||||||
|
pathlib.Path(self.db_path.name, f'0001{msg_suffix}'))
|
||||||
|
self.assertEqual(chat_db.messages[1].file_path,
|
||||||
|
pathlib.Path(self.db_path.name, f'0002{msg_suffix}'))
|
||||||
|
self.assertEqual(chat_db.messages[2].file_path,
|
||||||
|
pathlib.Path(self.db_path.name, f'0003{msg_suffix}'))
|
||||||
|
self.assertEqual(chat_db.messages[3].file_path,
|
||||||
|
pathlib.Path(self.db_path.name, f'0004{msg_suffix}'))
|
||||||
|
|
||||||
|
def test_from_dir_glob(self) -> None:
|
||||||
|
chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name),
|
||||||
|
pathlib.Path(self.db_path.name),
|
||||||
|
glob='*1.*')
|
||||||
|
self.assertEqual(len(chat_db.messages), 1)
|
||||||
|
self.assertEqual(chat_db.cache_path, pathlib.Path(self.cache_path.name))
|
||||||
|
self.assertEqual(chat_db.db_path, pathlib.Path(self.db_path.name))
|
||||||
|
self.assertEqual(chat_db.messages[0].file_path,
|
||||||
|
pathlib.Path(self.db_path.name, f'0001{msg_suffix}'))
|
||||||
|
|
||||||
|
def test_from_dir_filter_tags(self) -> None:
|
||||||
|
chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name),
|
||||||
|
pathlib.Path(self.db_path.name),
|
||||||
|
mfilter=MessageFilter(tags_or={Tag('tag1')}))
|
||||||
|
self.assertEqual(len(chat_db.messages), 1)
|
||||||
|
self.assertEqual(chat_db.cache_path, pathlib.Path(self.cache_path.name))
|
||||||
|
self.assertEqual(chat_db.db_path, pathlib.Path(self.db_path.name))
|
||||||
|
self.assertEqual(chat_db.messages[0].file_path,
|
||||||
|
pathlib.Path(self.db_path.name, f'0001{msg_suffix}'))
|
||||||
|
|
||||||
|
def test_from_dir_filter_tags_empty(self) -> None:
|
||||||
|
chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name),
|
||||||
|
pathlib.Path(self.db_path.name),
|
||||||
|
mfilter=MessageFilter(tags_or=set(),
|
||||||
|
tags_and=set(),
|
||||||
|
tags_not=set()))
|
||||||
|
self.assertEqual(len(chat_db.messages), 0)
|
||||||
|
|
||||||
|
def test_from_dir_filter_answer(self) -> None:
|
||||||
|
chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name),
|
||||||
|
pathlib.Path(self.db_path.name),
|
||||||
|
mfilter=MessageFilter(answer_contains='Answer 2'))
|
||||||
|
self.assertEqual(len(chat_db.messages), 1)
|
||||||
|
self.assertEqual(chat_db.cache_path, pathlib.Path(self.cache_path.name))
|
||||||
|
self.assertEqual(chat_db.db_path, pathlib.Path(self.db_path.name))
|
||||||
|
self.assertEqual(chat_db.messages[0].file_path,
|
||||||
|
pathlib.Path(self.db_path.name, f'0002{msg_suffix}'))
|
||||||
|
self.assertEqual(chat_db.messages[0].answer, 'Answer 2')
|
||||||
|
|
||||||
|
def test_from_messages(self) -> None:
|
||||||
|
chat_db = ChatDB.from_messages(pathlib.Path(self.cache_path.name),
|
||||||
|
pathlib.Path(self.db_path.name),
|
||||||
|
messages=[self.message1, self.message2,
|
||||||
|
self.message3, self.message4])
|
||||||
|
self.assertEqual(len(chat_db.messages), 4)
|
||||||
|
self.assertEqual(chat_db.cache_path, pathlib.Path(self.cache_path.name))
|
||||||
|
self.assertEqual(chat_db.db_path, pathlib.Path(self.db_path.name))
|
||||||
|
|
||||||
|
def test_fids(self) -> None:
|
||||||
|
chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name),
|
||||||
|
pathlib.Path(self.db_path.name))
|
||||||
|
self.assertEqual(chat_db.get_next_fid(), 5)
|
||||||
|
self.assertEqual(chat_db.get_next_fid(), 6)
|
||||||
|
self.assertEqual(chat_db.get_next_fid(), 7)
|
||||||
|
with open(chat_db.next_path, 'r') as f:
|
||||||
|
self.assertEqual(f.read(), '7')
|
||||||
|
|
||||||
|
def test_msg_in_db_or_cache(self) -> None:
|
||||||
|
# create a new ChatDB instance
|
||||||
|
chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name),
|
||||||
|
pathlib.Path(self.db_path.name))
|
||||||
|
self.assertTrue(chat_db.msg_in_db(self.message1))
|
||||||
|
self.assertTrue(chat_db.msg_in_db(str(self.message1.file_path)))
|
||||||
|
self.assertTrue(chat_db.msg_in_db(self.message1.msg_id()))
|
||||||
|
self.assertFalse(chat_db.msg_in_cache(self.message1))
|
||||||
|
self.assertFalse(chat_db.msg_in_cache(str(self.message1.file_path)))
|
||||||
|
self.assertFalse(chat_db.msg_in_cache(self.message1.msg_id()))
|
||||||
|
# add new message to the cache dir
|
||||||
|
cache_message = Message(question=Question("Question 1"),
|
||||||
|
answer=Answer("Answer 1"))
|
||||||
|
chat_db.cache_add([cache_message])
|
||||||
|
self.assertTrue(chat_db.msg_in_cache(cache_message))
|
||||||
|
self.assertTrue(chat_db.msg_in_cache(cache_message.msg_id()))
|
||||||
|
self.assertFalse(chat_db.msg_in_db(cache_message))
|
||||||
|
self.assertFalse(chat_db.msg_in_db(str(cache_message.file_path)))
|
||||||
|
|
||||||
|
def test_db_write(self) -> None:
|
||||||
|
# create a new ChatDB instance
|
||||||
|
chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name),
|
||||||
|
pathlib.Path(self.db_path.name))
|
||||||
|
# check that Message.file_path is correct
|
||||||
|
self.assertEqual(chat_db.messages[0].file_path, pathlib.Path(self.db_path.name, f'0001{msg_suffix}'))
|
||||||
|
self.assertEqual(chat_db.messages[1].file_path, pathlib.Path(self.db_path.name, f'0002{msg_suffix}'))
|
||||||
|
self.assertEqual(chat_db.messages[2].file_path, pathlib.Path(self.db_path.name, f'0003{msg_suffix}'))
|
||||||
|
self.assertEqual(chat_db.messages[3].file_path, pathlib.Path(self.db_path.name, f'0004{msg_suffix}'))
|
||||||
|
|
||||||
|
# write the messages to the cache directory
|
||||||
|
chat_db.cache_write()
|
||||||
|
# check if the written files are in the cache directory
|
||||||
|
cache_dir_files = self.message_list(self.cache_path)
|
||||||
|
self.assertEqual(len(cache_dir_files), 4)
|
||||||
|
self.assertIn(pathlib.Path(self.cache_path.name, f'0001{msg_suffix}'), cache_dir_files)
|
||||||
|
self.assertIn(pathlib.Path(self.cache_path.name, f'0002{msg_suffix}'), cache_dir_files)
|
||||||
|
self.assertIn(pathlib.Path(self.cache_path.name, f'0003{msg_suffix}'), cache_dir_files)
|
||||||
|
self.assertIn(pathlib.Path(self.cache_path.name, f'0004{msg_suffix}'), cache_dir_files)
|
||||||
|
# check that Message.file_path has been correctly updated
|
||||||
|
self.assertEqual(chat_db.messages[0].file_path, pathlib.Path(self.cache_path.name, f'0001{msg_suffix}'))
|
||||||
|
self.assertEqual(chat_db.messages[1].file_path, pathlib.Path(self.cache_path.name, f'0002{msg_suffix}'))
|
||||||
|
self.assertEqual(chat_db.messages[2].file_path, pathlib.Path(self.cache_path.name, f'0003{msg_suffix}'))
|
||||||
|
self.assertEqual(chat_db.messages[3].file_path, pathlib.Path(self.cache_path.name, f'0004{msg_suffix}'))
|
||||||
|
|
||||||
|
# check the timestamp of the files in the DB directory
|
||||||
|
db_dir_files = self.message_list(self.db_path)
|
||||||
|
self.assertEqual(len(db_dir_files), 4)
|
||||||
|
old_timestamps = {file: file.stat().st_mtime for file in db_dir_files}
|
||||||
|
# overwrite the messages in the db directory
|
||||||
|
time.sleep(0.05)
|
||||||
|
chat_db.db_write()
|
||||||
|
# check if the written files are in the DB directory
|
||||||
|
db_dir_files = self.message_list(self.db_path)
|
||||||
|
self.assertEqual(len(db_dir_files), 4)
|
||||||
|
self.assertIn(pathlib.Path(self.db_path.name, f'0001{msg_suffix}'), db_dir_files)
|
||||||
|
self.assertIn(pathlib.Path(self.db_path.name, f'0002{msg_suffix}'), db_dir_files)
|
||||||
|
self.assertIn(pathlib.Path(self.db_path.name, f'0003{msg_suffix}'), db_dir_files)
|
||||||
|
self.assertIn(pathlib.Path(self.db_path.name, f'0004{msg_suffix}'), db_dir_files)
|
||||||
|
# check if all files in the DB dir have actually been overwritten
|
||||||
|
for file in db_dir_files:
|
||||||
|
self.assertGreater(file.stat().st_mtime, old_timestamps[file])
|
||||||
|
# check that Message.file_path has been correctly updated (again)
|
||||||
|
self.assertEqual(chat_db.messages[0].file_path, pathlib.Path(self.db_path.name, f'0001{msg_suffix}'))
|
||||||
|
self.assertEqual(chat_db.messages[1].file_path, pathlib.Path(self.db_path.name, f'0002{msg_suffix}'))
|
||||||
|
self.assertEqual(chat_db.messages[2].file_path, pathlib.Path(self.db_path.name, f'0003{msg_suffix}'))
|
||||||
|
self.assertEqual(chat_db.messages[3].file_path, pathlib.Path(self.db_path.name, f'0004{msg_suffix}'))
|
||||||
|
|
||||||
|
def test_db_read(self) -> None:
|
||||||
|
# create a new ChatDB instance
|
||||||
|
chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name),
|
||||||
|
pathlib.Path(self.db_path.name))
|
||||||
|
self.assertEqual(len(chat_db.messages), 4)
|
||||||
|
|
||||||
|
# create 2 new files in the DB directory
|
||||||
|
new_message1 = Message(Question('Question 5'),
|
||||||
|
Answer('Answer 5'),
|
||||||
|
{Tag('tag5')})
|
||||||
|
new_message2 = Message(Question('Question 6'),
|
||||||
|
Answer('Answer 6'),
|
||||||
|
{Tag('tag6')})
|
||||||
|
new_message1.to_file(pathlib.Path(self.db_path.name, f'0005{msg_suffix}'), mformat='txt')
|
||||||
|
new_message2.to_file(pathlib.Path(self.db_path.name, f'0006{msg_suffix}'), mformat='yaml')
|
||||||
|
# read and check them
|
||||||
|
chat_db.db_read()
|
||||||
|
self.assertEqual(len(chat_db.messages), 6)
|
||||||
|
self.assertEqual(chat_db.messages[4].file_path, pathlib.Path(self.db_path.name, f'0005{msg_suffix}'))
|
||||||
|
self.assertEqual(chat_db.messages[5].file_path, pathlib.Path(self.db_path.name, f'0006{msg_suffix}'))
|
||||||
|
|
||||||
|
# create 2 new files in the cache directory
|
||||||
|
new_message3 = Message(Question('Question 7'),
|
||||||
|
Answer('Answer 7'),
|
||||||
|
{Tag('tag7')})
|
||||||
|
new_message4 = Message(Question('Question 8'),
|
||||||
|
Answer('Answer 8'),
|
||||||
|
{Tag('tag8')})
|
||||||
|
new_message3.to_file(pathlib.Path(self.cache_path.name, f'0007{msg_suffix}'), mformat='txt')
|
||||||
|
new_message4.to_file(pathlib.Path(self.cache_path.name, f'0008{msg_suffix}'), mformat='yaml')
|
||||||
|
# read and check them
|
||||||
|
chat_db.cache_read()
|
||||||
|
self.assertEqual(len(chat_db.messages), 8)
|
||||||
|
# check that the new message have the cache dir path
|
||||||
|
self.assertEqual(chat_db.messages[6].file_path, pathlib.Path(self.cache_path.name, f'0007{msg_suffix}'))
|
||||||
|
self.assertEqual(chat_db.messages[7].file_path, pathlib.Path(self.cache_path.name, f'0008{msg_suffix}'))
|
||||||
|
# an the old ones keep their path (since they have not been replaced)
|
||||||
|
self.assertEqual(chat_db.messages[4].file_path, pathlib.Path(self.db_path.name, f'0005{msg_suffix}'))
|
||||||
|
self.assertEqual(chat_db.messages[5].file_path, pathlib.Path(self.db_path.name, f'0006{msg_suffix}'))
|
||||||
|
|
||||||
|
# now overwrite two messages in the DB directory
|
||||||
|
new_message1.question = Question('New Question 1')
|
||||||
|
new_message2.question = Question('New Question 2')
|
||||||
|
new_message1.to_file(pathlib.Path(self.db_path.name, f'0005{msg_suffix}'), mformat='txt')
|
||||||
|
new_message2.to_file(pathlib.Path(self.db_path.name, f'0006{msg_suffix}'), mformat='yaml')
|
||||||
|
# read from the DB dir and check if the modified messages have been updated
|
||||||
|
chat_db.db_read()
|
||||||
|
self.assertEqual(len(chat_db.messages), 8)
|
||||||
|
self.assertEqual(chat_db.messages[4].question, 'New Question 1')
|
||||||
|
self.assertEqual(chat_db.messages[5].question, 'New Question 2')
|
||||||
|
self.assertEqual(chat_db.messages[4].file_path, pathlib.Path(self.db_path.name, f'0005{msg_suffix}'))
|
||||||
|
self.assertEqual(chat_db.messages[5].file_path, pathlib.Path(self.db_path.name, f'0006{msg_suffix}'))
|
||||||
|
|
||||||
|
# now write the messages from the cache to the DB directory
|
||||||
|
new_message3.to_file(pathlib.Path(self.db_path.name, f'0007{msg_suffix}'))
|
||||||
|
new_message4.to_file(pathlib.Path(self.db_path.name, f'0008{msg_suffix}'))
|
||||||
|
# read and check them
|
||||||
|
chat_db.db_read()
|
||||||
|
self.assertEqual(len(chat_db.messages), 8)
|
||||||
|
# check that they now have the DB path
|
||||||
|
self.assertEqual(chat_db.messages[6].file_path, pathlib.Path(self.db_path.name, f'0007{msg_suffix}'))
|
||||||
|
self.assertEqual(chat_db.messages[7].file_path, pathlib.Path(self.db_path.name, f'0008{msg_suffix}'))
|
||||||
|
|
||||||
|
def test_cache_clear(self) -> None:
|
||||||
|
# create a new ChatDB instance
|
||||||
|
chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name),
|
||||||
|
pathlib.Path(self.db_path.name))
|
||||||
|
# check that Message.file_path is correct
|
||||||
|
self.assertEqual(chat_db.messages[0].file_path, pathlib.Path(self.db_path.name, f'0001{msg_suffix}'))
|
||||||
|
self.assertEqual(chat_db.messages[1].file_path, pathlib.Path(self.db_path.name, f'0002{msg_suffix}'))
|
||||||
|
self.assertEqual(chat_db.messages[2].file_path, pathlib.Path(self.db_path.name, f'0003{msg_suffix}'))
|
||||||
|
self.assertEqual(chat_db.messages[3].file_path, pathlib.Path(self.db_path.name, f'0004{msg_suffix}'))
|
||||||
|
|
||||||
|
# write the messages to the cache directory
|
||||||
|
chat_db.cache_write()
|
||||||
|
# check if the written files are in the cache directory
|
||||||
|
cache_dir_files = self.message_list(self.cache_path)
|
||||||
|
self.assertEqual(len(cache_dir_files), 4)
|
||||||
|
|
||||||
|
# now rewrite them to the DB dir and check for modified paths
|
||||||
|
chat_db.db_write()
|
||||||
|
db_dir_files = self.message_list(self.db_path)
|
||||||
|
self.assertEqual(len(db_dir_files), 4)
|
||||||
|
self.assertIn(pathlib.Path(self.db_path.name, f'0001{msg_suffix}'), db_dir_files)
|
||||||
|
self.assertIn(pathlib.Path(self.db_path.name, f'0002{msg_suffix}'), db_dir_files)
|
||||||
|
self.assertIn(pathlib.Path(self.db_path.name, f'0003{msg_suffix}'), db_dir_files)
|
||||||
|
self.assertIn(pathlib.Path(self.db_path.name, f'0004{msg_suffix}'), db_dir_files)
|
||||||
|
|
||||||
|
# add a new message with empty file_path
|
||||||
|
message_empty = Message(question=Question("What the hell am I doing here?"),
|
||||||
|
answer=Answer("You don't belong here!"))
|
||||||
|
# and one for the cache dir
|
||||||
|
message_cache = Message(question=Question("What the hell am I doing here?"),
|
||||||
|
answer=Answer("You're a creep!"),
|
||||||
|
file_path=pathlib.Path(self.cache_path.name, '0005'))
|
||||||
|
chat_db.msg_add([message_empty, message_cache])
|
||||||
|
|
||||||
|
# clear the cache and check the cache dir
|
||||||
|
chat_db.cache_clear()
|
||||||
|
cache_dir_files = self.message_list(self.cache_path)
|
||||||
|
self.assertEqual(len(cache_dir_files), 0)
|
||||||
|
# make sure that the DB messages (and the new message) are still there
|
||||||
|
self.assertEqual(len(chat_db.messages), 5)
|
||||||
|
db_dir_files = self.message_list(self.db_path)
|
||||||
|
self.assertEqual(len(db_dir_files), 4)
|
||||||
|
# but not the message with the cache dir path
|
||||||
|
self.assertFalse(any(m.file_path == message_cache.file_path for m in chat_db.messages))
|
||||||
|
|
||||||
|
def test_add(self) -> None:
|
||||||
|
# create a new ChatDB instance
|
||||||
|
chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name),
|
||||||
|
pathlib.Path(self.db_path.name))
|
||||||
|
|
||||||
|
db_dir_files = self.message_list(self.db_path)
|
||||||
|
self.assertEqual(len(db_dir_files), 4)
|
||||||
|
|
||||||
|
# add new messages to the cache dir
|
||||||
|
message1 = Message(question=Question("Question 1"),
|
||||||
|
answer=Answer("Answer 1"))
|
||||||
|
chat_db.cache_add([message1])
|
||||||
|
# check if the file_path has been correctly set
|
||||||
|
self.assertIsNotNone(message1.file_path)
|
||||||
|
self.assertEqual(message1.file_path.parent, pathlib.Path(self.cache_path.name)) # type: ignore [union-attr]
|
||||||
|
cache_dir_files = self.message_list(self.cache_path)
|
||||||
|
self.assertEqual(len(cache_dir_files), 1)
|
||||||
|
|
||||||
|
# add new messages to the DB dir
|
||||||
|
message2 = Message(question=Question("Question 2"),
|
||||||
|
answer=Answer("Answer 2"))
|
||||||
|
chat_db.db_add([message2])
|
||||||
|
# check if the file_path has been correctly set
|
||||||
|
self.assertIsNotNone(message2.file_path)
|
||||||
|
self.assertEqual(message2.file_path.parent, pathlib.Path(self.db_path.name)) # type: ignore [union-attr]
|
||||||
|
db_dir_files = self.message_list(self.db_path)
|
||||||
|
self.assertEqual(len(db_dir_files), 5)
|
||||||
|
|
||||||
|
with self.assertRaises(ChatError):
|
||||||
|
chat_db.cache_add([Message(Question("?"), file_path=pathlib.Path("foo"))])
|
||||||
|
|
||||||
|
def test_msg_write(self) -> None:
|
||||||
|
# create a new ChatDB instance
|
||||||
|
chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name),
|
||||||
|
pathlib.Path(self.db_path.name))
|
||||||
|
|
||||||
|
db_dir_files = self.message_list(self.db_path)
|
||||||
|
self.assertEqual(len(db_dir_files), 4)
|
||||||
|
cache_dir_files = self.message_list(self.cache_path)
|
||||||
|
self.assertEqual(len(cache_dir_files), 0)
|
||||||
|
|
||||||
|
# try to write a message without a valid file_path
|
||||||
|
message = Message(question=Question("Question 1"),
|
||||||
|
answer=Answer("Answer 1"))
|
||||||
|
with self.assertRaises(ChatError):
|
||||||
|
chat_db.msg_write([message])
|
||||||
|
|
||||||
|
# write a message with a valid file_path
|
||||||
|
message.file_path = pathlib.Path(self.cache_path.name) / '123456'
|
||||||
|
chat_db.msg_write([message])
|
||||||
|
cache_dir_files = self.message_list(self.cache_path)
|
||||||
|
self.assertEqual(len(cache_dir_files), 1)
|
||||||
|
self.assertIn(pathlib.Path(self.cache_path.name, f'123456{msg_suffix}'), cache_dir_files)
|
||||||
|
|
||||||
|
def test_msg_update(self) -> None:
|
||||||
|
# create a new ChatDB instance
|
||||||
|
chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name),
|
||||||
|
pathlib.Path(self.db_path.name))
|
||||||
|
|
||||||
|
db_dir_files = self.message_list(self.db_path)
|
||||||
|
self.assertEqual(len(db_dir_files), 4)
|
||||||
|
cache_dir_files = self.message_list(self.cache_path)
|
||||||
|
self.assertEqual(len(cache_dir_files), 0)
|
||||||
|
|
||||||
|
message = chat_db.messages[0]
|
||||||
|
message.answer = Answer("New answer")
|
||||||
|
# update message without writing
|
||||||
|
chat_db.msg_update([message], write=False)
|
||||||
|
self.assertEqual(chat_db.messages[0].answer, Answer("New answer"))
|
||||||
|
# re-read the message and check for old content
|
||||||
|
chat_db.db_read()
|
||||||
|
self.assertEqual(chat_db.messages[0].answer, Answer("Answer 1"))
|
||||||
|
# now check with writing (message should be overwritten)
|
||||||
|
chat_db.msg_update([message], write=True)
|
||||||
|
chat_db.db_read()
|
||||||
|
self.assertEqual(chat_db.messages[0].answer, Answer("New answer"))
|
||||||
|
# test without file_path -> expect error
|
||||||
|
message1 = Message(question=Question("Question 1"),
|
||||||
|
answer=Answer("Answer 1"))
|
||||||
|
with self.assertRaises(ChatError):
|
||||||
|
chat_db.msg_update([message1])
|
||||||
|
|
||||||
|
def test_msg_find(self) -> None:
|
||||||
|
chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name),
|
||||||
|
pathlib.Path(self.db_path.name))
|
||||||
|
# search for a DB file in memory
|
||||||
|
self.assertEqual(chat_db.msg_find([str(self.message1.file_path)], loc=msg_location.MEM), [self.message1])
|
||||||
|
self.assertEqual(chat_db.msg_find([self.message1.msg_id()], loc=msg_location.MEM), [self.message1])
|
||||||
|
self.assertEqual(chat_db.msg_find(['0001.msg'], loc=msg_location.MEM), [self.message1])
|
||||||
|
self.assertEqual(chat_db.msg_find(['0001'], loc=msg_location.MEM), [self.message1])
|
||||||
|
# and on disk
|
||||||
|
self.assertEqual(chat_db.msg_find([str(self.message2.file_path)], loc=msg_location.DB), [self.message2])
|
||||||
|
self.assertEqual(chat_db.msg_find([self.message2.msg_id()], loc=msg_location.DB), [self.message2])
|
||||||
|
self.assertEqual(chat_db.msg_find(['0002.msg'], loc=msg_location.DB), [self.message2])
|
||||||
|
self.assertEqual(chat_db.msg_find(['0002'], loc=msg_location.DB), [self.message2])
|
||||||
|
# now search the cache -> expect empty result
|
||||||
|
self.assertEqual(chat_db.msg_find([str(self.message3.file_path)], loc=msg_location.CACHE), [])
|
||||||
|
self.assertEqual(chat_db.msg_find([self.message3.msg_id()], loc=msg_location.CACHE), [])
|
||||||
|
self.assertEqual(chat_db.msg_find(['0003.msg'], loc=msg_location.CACHE), [])
|
||||||
|
self.assertEqual(chat_db.msg_find(['0003'], loc=msg_location.CACHE), [])
|
||||||
|
# search for multiple messages
|
||||||
|
# -> search one twice, expect result to be unique
|
||||||
|
search_names = ['0001', '0002.msg', self.message3.msg_id(), str(self.message3.file_path)]
|
||||||
|
expected_result = [self.message1, self.message2, self.message3]
|
||||||
|
result = chat_db.msg_find(search_names, loc=msg_location.ALL)
|
||||||
|
self.assert_messages_equal(result, expected_result)
|
||||||
|
|
||||||
|
def test_msg_latest(self) -> None:
|
||||||
|
chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name),
|
||||||
|
pathlib.Path(self.db_path.name))
|
||||||
|
self.assertEqual(chat_db.msg_latest(loc=msg_location.MEM), self.message4)
|
||||||
|
self.assertEqual(chat_db.msg_latest(loc=msg_location.DB), self.message4)
|
||||||
|
self.assertEqual(chat_db.msg_latest(loc=msg_location.DISK), self.message4)
|
||||||
|
self.assertEqual(chat_db.msg_latest(loc=msg_location.ALL), self.message4)
|
||||||
|
# the cache is currently empty:
|
||||||
|
self.assertIsNone(chat_db.msg_latest(loc=msg_location.CACHE))
|
||||||
|
# add new messages to the cache dir
|
||||||
|
new_message = Message(question=Question("New Question"),
|
||||||
|
answer=Answer("New Answer"))
|
||||||
|
chat_db.cache_add([new_message])
|
||||||
|
self.assertEqual(chat_db.msg_latest(loc=msg_location.CACHE), new_message)
|
||||||
|
self.assertEqual(chat_db.msg_latest(loc=msg_location.MEM), new_message)
|
||||||
|
self.assertEqual(chat_db.msg_latest(loc=msg_location.DISK), new_message)
|
||||||
|
self.assertEqual(chat_db.msg_latest(loc=msg_location.ALL), new_message)
|
||||||
|
# the DB does not contain the new message
|
||||||
|
self.assertEqual(chat_db.msg_latest(loc=msg_location.DB), self.message4)
|
||||||
|
|
||||||
|
def test_msg_gather(self) -> None:
|
||||||
|
chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name),
|
||||||
|
pathlib.Path(self.db_path.name))
|
||||||
|
all_messages = [self.message1, self.message2, self.message3, self.message4]
|
||||||
|
self.assert_messages_equal(chat_db.msg_gather(loc=msg_location.ALL), all_messages)
|
||||||
|
self.assert_messages_equal(chat_db.msg_gather(loc=msg_location.DB), all_messages)
|
||||||
|
self.assert_messages_equal(chat_db.msg_gather(loc=msg_location.MEM), all_messages)
|
||||||
|
self.assert_messages_equal(chat_db.msg_gather(loc=msg_location.DISK), all_messages)
|
||||||
|
self.assert_messages_equal(chat_db.msg_gather(loc=msg_location.CACHE), [])
|
||||||
|
# add a new message, but only to the internal list
|
||||||
|
new_message = Message(Question("What?"))
|
||||||
|
all_messages_mem = all_messages + [new_message]
|
||||||
|
chat_db.msg_add([new_message])
|
||||||
|
self.assert_messages_equal(chat_db.msg_gather(loc=msg_location.MEM), all_messages_mem)
|
||||||
|
self.assert_messages_equal(chat_db.msg_gather(loc=msg_location.ALL), all_messages_mem)
|
||||||
|
# the nr. of messages on disk did not change -> expect old result
|
||||||
|
self.assert_messages_equal(chat_db.msg_gather(loc=msg_location.DB), all_messages)
|
||||||
|
self.assert_messages_equal(chat_db.msg_gather(loc=msg_location.DISK), all_messages)
|
||||||
|
self.assert_messages_equal(chat_db.msg_gather(loc=msg_location.CACHE), [])
|
||||||
|
# test with MessageFilter
|
||||||
|
self.assert_messages_equal(chat_db.msg_gather(loc=msg_location.ALL, mfilter=MessageFilter(tags_or={Tag('tag1')})),
|
||||||
|
[self.message1])
|
||||||
|
self.assert_messages_equal(chat_db.msg_gather(loc=msg_location.DISK, mfilter=MessageFilter(tags_or={Tag('tag2')})),
|
||||||
|
[self.message2])
|
||||||
|
self.assert_messages_equal(chat_db.msg_gather(loc=msg_location.CACHE, mfilter=MessageFilter(tags_or={Tag('tag3')})),
|
||||||
|
[])
|
||||||
|
self.assert_messages_equal(chat_db.msg_gather(loc=msg_location.MEM, mfilter=MessageFilter(question_contains="What")),
|
||||||
|
[new_message])
|
||||||
|
|
||||||
|
def test_msg_move_and_gather(self) -> None:
|
||||||
|
chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name),
|
||||||
|
pathlib.Path(self.db_path.name))
|
||||||
|
all_messages = [self.message1, self.message2, self.message3, self.message4]
|
||||||
|
self.assert_messages_equal(chat_db.msg_gather(loc=msg_location.DB), all_messages)
|
||||||
|
self.assert_messages_equal(chat_db.msg_gather(loc=msg_location.CACHE), [])
|
||||||
|
# move first message to the cache
|
||||||
|
chat_db.cache_move(self.message1)
|
||||||
|
self.assert_messages_equal(chat_db.msg_gather(loc=msg_location.CACHE), [self.message1])
|
||||||
|
self.assertEqual(self.message1.file_path.parent, pathlib.Path(self.cache_path.name)) # type: ignore [union-attr]
|
||||||
|
self.assert_messages_equal(chat_db.msg_gather(loc=msg_location.DB), [self.message2, self.message3, self.message4])
|
||||||
|
self.assert_messages_equal(chat_db.msg_gather(loc=msg_location.ALL), all_messages)
|
||||||
|
self.assert_messages_equal(chat_db.msg_gather(loc=msg_location.DISK), all_messages)
|
||||||
|
self.assert_messages_equal(chat_db.msg_gather(loc=msg_location.MEM), all_messages)
|
||||||
|
# now move first message back to the DB
|
||||||
|
chat_db.db_move(self.message1)
|
||||||
|
self.assert_messages_equal(chat_db.msg_gather(loc=msg_location.CACHE), [])
|
||||||
|
self.assertEqual(self.message1.file_path.parent, pathlib.Path(self.db_path.name)) # type: ignore [union-attr]
|
||||||
|
self.assert_messages_equal(chat_db.msg_gather(loc=msg_location.DB), all_messages)
|
||||||
@@ -0,0 +1,100 @@
|
|||||||
|
import unittest
|
||||||
|
import argparse
|
||||||
|
from typing import Union, Optional
|
||||||
|
from chatmastermind.configuration import Config, AIConfig
|
||||||
|
from chatmastermind.tags import Tag
|
||||||
|
from chatmastermind.message import Message, Answer
|
||||||
|
from chatmastermind.chat import Chat
|
||||||
|
from chatmastermind.ai import AI, AIResponse, Tokens, AIError
|
||||||
|
|
||||||
|
|
||||||
|
class FakeAI(AI):
|
||||||
|
"""
|
||||||
|
A mocked version of the 'AI' class.
|
||||||
|
"""
|
||||||
|
ID: str
|
||||||
|
name: str
|
||||||
|
config: AIConfig
|
||||||
|
|
||||||
|
def models(self) -> list[str]:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def tokens(self, data: Union[Message, Chat]) -> int:
|
||||||
|
return 123
|
||||||
|
|
||||||
|
def print(self) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def print_models(self) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def __init__(self, ID: str, model: str, error: bool = False):
|
||||||
|
self.ID = ID
|
||||||
|
self.model = model
|
||||||
|
self.error = error
|
||||||
|
|
||||||
|
def request(self,
|
||||||
|
question: Message,
|
||||||
|
chat: Chat,
|
||||||
|
num_answers: int = 1,
|
||||||
|
otags: Optional[set[Tag]] = None) -> AIResponse:
|
||||||
|
"""
|
||||||
|
Mock the 'ai.request()' function by either returning fake
|
||||||
|
answers or raising an exception.
|
||||||
|
"""
|
||||||
|
if self.error:
|
||||||
|
raise AIError
|
||||||
|
question.answer = Answer("Answer 0")
|
||||||
|
question.tags = set(otags) if otags is not None else None
|
||||||
|
question.ai = self.ID
|
||||||
|
question.model = self.model
|
||||||
|
answers: list[Message] = [question]
|
||||||
|
for n in range(1, num_answers):
|
||||||
|
answers.append(Message(question=question.question,
|
||||||
|
answer=Answer(f"Answer {n}"),
|
||||||
|
tags=otags,
|
||||||
|
ai=self.ID,
|
||||||
|
model=self.model))
|
||||||
|
return AIResponse(answers, Tokens(10, 10, 20))
|
||||||
|
|
||||||
|
|
||||||
|
class TestWithFakeAI(unittest.TestCase):
|
||||||
|
"""
|
||||||
|
Base class for all tests that need to use the FakeAI.
|
||||||
|
"""
|
||||||
|
def assert_msgs_equal_except_file_path(self, msg1: list[Message], msg2: list[Message]) -> None:
|
||||||
|
"""
|
||||||
|
Compare messages using Question, Answer and all metadata excecot for the file_path.
|
||||||
|
"""
|
||||||
|
self.assertEqual(len(msg1), len(msg2))
|
||||||
|
for m1, m2 in zip(msg1, msg2):
|
||||||
|
# exclude the file_path, compare only Q, A and metadata
|
||||||
|
self.assertTrue(m1.equals(m2, file_path=False, verbose=True))
|
||||||
|
|
||||||
|
def assert_msgs_all_equal(self, msg1: list[Message], msg2: list[Message]) -> None:
|
||||||
|
"""
|
||||||
|
Compare messages using Question, Answer and ALL metadata.
|
||||||
|
"""
|
||||||
|
self.assertEqual(len(msg1), len(msg2))
|
||||||
|
for m1, m2 in zip(msg1, msg2):
|
||||||
|
self.assertTrue(m1.equals(m2, verbose=True))
|
||||||
|
|
||||||
|
def assert_msgs_content_equal(self, msg1: list[Message], msg2: list[Message]) -> None:
|
||||||
|
"""
|
||||||
|
Compare messages using only Question and Answer.
|
||||||
|
"""
|
||||||
|
self.assertEqual(len(msg1), len(msg2))
|
||||||
|
for m1, m2 in zip(msg1, msg2):
|
||||||
|
self.assertEqual(m1, m2)
|
||||||
|
|
||||||
|
def mock_create_ai(self, args: argparse.Namespace, config: Config) -> AI:
|
||||||
|
"""
|
||||||
|
Mocked 'create_ai' that returns a 'FakeAI' instance.
|
||||||
|
"""
|
||||||
|
return FakeAI(args.AI, args.model)
|
||||||
|
|
||||||
|
def mock_create_ai_with_error(self, args: argparse.Namespace, config: Config) -> AI:
|
||||||
|
"""
|
||||||
|
Mocked 'create_ai' that returns a 'FakeAI' instance.
|
||||||
|
"""
|
||||||
|
return FakeAI(args.AI, args.model, error=True)
|
||||||
@@ -0,0 +1,167 @@
|
|||||||
|
import os
|
||||||
|
import unittest
|
||||||
|
import yaml
|
||||||
|
from tempfile import NamedTemporaryFile
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import cast
|
||||||
|
from chatmastermind.configuration import AIConfig, OpenAIConfig, ConfigError, ai_config_instance, Config
|
||||||
|
|
||||||
|
|
||||||
|
class TestAIConfigInstance(unittest.TestCase):
|
||||||
|
def test_ai_config_instance_with_valid_name_should_return_instance_with_default_values(self) -> None:
|
||||||
|
ai_config = cast(OpenAIConfig, ai_config_instance('openai'))
|
||||||
|
ai_reference = OpenAIConfig()
|
||||||
|
self.assertEqual(ai_config.ID, ai_reference.ID)
|
||||||
|
self.assertEqual(ai_config.name, ai_reference.name)
|
||||||
|
self.assertEqual(ai_config.api_key, ai_reference.api_key)
|
||||||
|
self.assertEqual(ai_config.system, ai_reference.system)
|
||||||
|
self.assertEqual(ai_config.model, ai_reference.model)
|
||||||
|
self.assertEqual(ai_config.temperature, ai_reference.temperature)
|
||||||
|
self.assertEqual(ai_config.max_tokens, ai_reference.max_tokens)
|
||||||
|
self.assertEqual(ai_config.top_p, ai_reference.top_p)
|
||||||
|
self.assertEqual(ai_config.frequency_penalty, ai_reference.frequency_penalty)
|
||||||
|
self.assertEqual(ai_config.presence_penalty, ai_reference.presence_penalty)
|
||||||
|
|
||||||
|
def test_ai_config_instance_with_valid_name_and_configuration_should_return_instance_with_custom_values(self) -> None:
|
||||||
|
conf_dict = {
|
||||||
|
'system': 'Custom system',
|
||||||
|
'api_key': '9876543210',
|
||||||
|
'model': 'custom_model',
|
||||||
|
'max_tokens': 5000,
|
||||||
|
'temperature': 0.5,
|
||||||
|
'top_p': 0.8,
|
||||||
|
'frequency_penalty': 0.7,
|
||||||
|
'presence_penalty': 0.2
|
||||||
|
}
|
||||||
|
ai_config = cast(OpenAIConfig, ai_config_instance('openai', conf_dict))
|
||||||
|
self.assertEqual(ai_config.system, 'Custom system')
|
||||||
|
self.assertEqual(ai_config.api_key, '9876543210')
|
||||||
|
self.assertEqual(ai_config.model, 'custom_model')
|
||||||
|
self.assertEqual(ai_config.max_tokens, 5000)
|
||||||
|
self.assertAlmostEqual(ai_config.temperature, 0.5)
|
||||||
|
self.assertAlmostEqual(ai_config.top_p, 0.8)
|
||||||
|
self.assertAlmostEqual(ai_config.frequency_penalty, 0.7)
|
||||||
|
self.assertAlmostEqual(ai_config.presence_penalty, 0.2)
|
||||||
|
|
||||||
|
def test_ai_config_instance_with_invalid_name_should_raise_config_error(self) -> None:
|
||||||
|
with self.assertRaises(ConfigError):
|
||||||
|
ai_config_instance('invalid_name')
|
||||||
|
|
||||||
|
|
||||||
|
class TestConfig(unittest.TestCase):
|
||||||
|
def setUp(self) -> None:
|
||||||
|
self.test_file = NamedTemporaryFile(delete=False)
|
||||||
|
|
||||||
|
def tearDown(self) -> None:
|
||||||
|
os.remove(self.test_file.name)
|
||||||
|
|
||||||
|
def test_from_dict_should_create_config_from_dict(self) -> None:
|
||||||
|
source_dict = {
|
||||||
|
'cache': '.',
|
||||||
|
'db': './test_db/',
|
||||||
|
'ais': {
|
||||||
|
'myopenai': {
|
||||||
|
'name': 'openai',
|
||||||
|
'system': 'Custom system',
|
||||||
|
'api_key': '9876543210',
|
||||||
|
'model': 'custom_model',
|
||||||
|
'max_tokens': 5000,
|
||||||
|
'temperature': 0.5,
|
||||||
|
'top_p': 0.8,
|
||||||
|
'frequency_penalty': 0.7,
|
||||||
|
'presence_penalty': 0.2
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
config = Config.from_dict(source_dict)
|
||||||
|
self.assertEqual(config.cache, '.')
|
||||||
|
self.assertEqual(config.db, './test_db/')
|
||||||
|
self.assertEqual(len(config.ais), 1)
|
||||||
|
self.assertEqual(config.ais['myopenai'].name, 'openai')
|
||||||
|
self.assertEqual(cast(OpenAIConfig, config.ais['myopenai']).system, 'Custom system')
|
||||||
|
# check that 'ID' has been added
|
||||||
|
self.assertEqual(config.ais['myopenai'].ID, 'myopenai')
|
||||||
|
|
||||||
|
def test_create_default_should_create_default_config(self) -> None:
|
||||||
|
Config.create_default(Path(self.test_file.name))
|
||||||
|
with open(self.test_file.name, 'r') as f:
|
||||||
|
default_config = yaml.load(f, Loader=yaml.FullLoader)
|
||||||
|
config_reference = Config()
|
||||||
|
self.assertEqual(default_config['db'], config_reference.db)
|
||||||
|
|
||||||
|
def test_from_file_should_load_config_from_file(self) -> None:
|
||||||
|
source_dict = {
|
||||||
|
'cache': './test_cache/',
|
||||||
|
'db': './test_db/',
|
||||||
|
'ais': {
|
||||||
|
'default': {
|
||||||
|
'name': 'openai',
|
||||||
|
'system': 'Custom system',
|
||||||
|
'api_key': '9876543210',
|
||||||
|
'model': 'custom_model',
|
||||||
|
'max_tokens': 5000,
|
||||||
|
'temperature': 0.5,
|
||||||
|
'top_p': 0.8,
|
||||||
|
'frequency_penalty': 0.7,
|
||||||
|
'presence_penalty': 0.2
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
with open(self.test_file.name, 'w') as f:
|
||||||
|
yaml.dump(source_dict, f)
|
||||||
|
config = Config.from_file(self.test_file.name)
|
||||||
|
self.assertIsInstance(config, Config)
|
||||||
|
self.assertEqual(config.cache, './test_cache/')
|
||||||
|
self.assertEqual(config.db, './test_db/')
|
||||||
|
self.assertEqual(len(config.ais), 1)
|
||||||
|
self.assertIsInstance(config.ais['default'], AIConfig)
|
||||||
|
self.assertEqual(cast(OpenAIConfig, config.ais['default']).system, 'Custom system')
|
||||||
|
|
||||||
|
def test_to_file_should_save_config_to_file(self) -> None:
|
||||||
|
config = Config(
|
||||||
|
cache='./test_cache/',
|
||||||
|
db='./test_db/',
|
||||||
|
ais={
|
||||||
|
'myopenai': OpenAIConfig(
|
||||||
|
ID='myopenai',
|
||||||
|
system='Custom system',
|
||||||
|
api_key='9876543210',
|
||||||
|
model='custom_model',
|
||||||
|
max_tokens=5000,
|
||||||
|
temperature=0.5,
|
||||||
|
top_p=0.8,
|
||||||
|
frequency_penalty=0.7,
|
||||||
|
presence_penalty=0.2
|
||||||
|
)
|
||||||
|
}
|
||||||
|
)
|
||||||
|
config.to_file(Path(self.test_file.name))
|
||||||
|
with open(self.test_file.name, 'r') as f:
|
||||||
|
saved_config = yaml.load(f, Loader=yaml.FullLoader)
|
||||||
|
self.assertEqual(saved_config['cache'], './test_cache/')
|
||||||
|
self.assertEqual(saved_config['db'], './test_db/')
|
||||||
|
self.assertEqual(len(saved_config['ais']), 1)
|
||||||
|
self.assertEqual(saved_config['ais']['myopenai']['system'], 'Custom system')
|
||||||
|
|
||||||
|
def test_from_file_error_unknown_ai(self) -> None:
|
||||||
|
source_dict = {
|
||||||
|
'cache': './test_cache/',
|
||||||
|
'db': './test_db/',
|
||||||
|
'ais': {
|
||||||
|
'default': {
|
||||||
|
'name': 'foobla',
|
||||||
|
'system': 'Custom system',
|
||||||
|
'api_key': '9876543210',
|
||||||
|
'model': 'custom_model',
|
||||||
|
'max_tokens': 5000,
|
||||||
|
'temperature': 0.5,
|
||||||
|
'top_p': 0.8,
|
||||||
|
'frequency_penalty': 0.7,
|
||||||
|
'presence_penalty': 0.2
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
with open(self.test_file.name, 'w') as f:
|
||||||
|
yaml.dump(source_dict, f)
|
||||||
|
with self.assertRaises(ConfigError):
|
||||||
|
Config.from_file(self.test_file.name)
|
||||||
@@ -0,0 +1,62 @@
|
|||||||
|
import unittest
|
||||||
|
import argparse
|
||||||
|
import tempfile
|
||||||
|
import yaml
|
||||||
|
from pathlib import Path
|
||||||
|
from chatmastermind.message import Message, Question
|
||||||
|
from chatmastermind.chat import ChatDB, ChatError, msg_location
|
||||||
|
from chatmastermind.configuration import Config
|
||||||
|
from chatmastermind.commands.hist import convert_messages
|
||||||
|
|
||||||
|
|
||||||
|
msg_suffix = Message.file_suffix_write
|
||||||
|
|
||||||
|
|
||||||
|
class TestConvertMessages(unittest.TestCase):
|
||||||
|
def setUp(self) -> None:
|
||||||
|
self.db_dir = tempfile.TemporaryDirectory()
|
||||||
|
self.cache_dir = tempfile.TemporaryDirectory()
|
||||||
|
self.db_path = Path(self.db_dir.name)
|
||||||
|
self.cache_path = Path(self.cache_dir.name)
|
||||||
|
self.args = argparse.Namespace()
|
||||||
|
self.config = Config()
|
||||||
|
self.config.cache = self.cache_dir.name
|
||||||
|
self.config.db = self.db_dir.name
|
||||||
|
# Prepare some messages
|
||||||
|
self.chat = ChatDB.from_dir(Path(self.cache_path),
|
||||||
|
Path(self.db_path))
|
||||||
|
self.messages = [Message(Question(f'Question {i}')) for i in range(0, 6)]
|
||||||
|
self.chat.db_write(self.messages[0:2])
|
||||||
|
self.chat.cache_write(self.messages[2:])
|
||||||
|
# Change some of the suffixes
|
||||||
|
assert self.messages[0].file_path
|
||||||
|
assert self.messages[1].file_path
|
||||||
|
self.messages[0].file_path.rename(self.messages[0].file_path.with_suffix('.txt'))
|
||||||
|
self.messages[1].file_path.rename(self.messages[1].file_path.with_suffix('.yaml'))
|
||||||
|
|
||||||
|
def tearDown(self) -> None:
|
||||||
|
self.db_dir.cleanup()
|
||||||
|
self.cache_dir.cleanup()
|
||||||
|
|
||||||
|
def test_convert_messages(self) -> None:
|
||||||
|
self.args.convert = 'yaml'
|
||||||
|
convert_messages(self.args, self.config)
|
||||||
|
msgs = self.chat.msg_gather(loc=msg_location.DISK, glob='*.*')
|
||||||
|
# Check if the number of messages is the same as before
|
||||||
|
self.assertEqual(len(msgs), len(self.messages))
|
||||||
|
# Check if all messages have the requested suffix
|
||||||
|
for msg in msgs:
|
||||||
|
assert msg.file_path
|
||||||
|
self.assertEqual(msg.file_path.suffix, msg_suffix)
|
||||||
|
# Check if the message IDs are correctly maintained
|
||||||
|
for m_new, m_old in zip(msgs, self.messages):
|
||||||
|
self.assertEqual(m_new.msg_id(), m_old.msg_id())
|
||||||
|
# check if all messages have the new format
|
||||||
|
for m in msgs:
|
||||||
|
with open(str(m.file_path), "r") as fd:
|
||||||
|
yaml.load(fd, Loader=yaml.FullLoader)
|
||||||
|
|
||||||
|
def test_convert_messages_wrong_format(self) -> None:
|
||||||
|
self.args.convert = 'foo'
|
||||||
|
with self.assertRaises(ChatError):
|
||||||
|
convert_messages(self.args, self.config)
|
||||||
@@ -1,219 +0,0 @@
|
|||||||
import unittest
|
|
||||||
import io
|
|
||||||
import pathlib
|
|
||||||
import argparse
|
|
||||||
from chatmastermind.utils import terminal_width
|
|
||||||
from chatmastermind.main import create_parser, handle_question
|
|
||||||
from chatmastermind.api_client import ai
|
|
||||||
from chatmastermind.storage import create_chat, save_answers, dump_data
|
|
||||||
from unittest import mock
|
|
||||||
from unittest.mock import patch, MagicMock, Mock
|
|
||||||
|
|
||||||
|
|
||||||
class TestCreateChat(unittest.TestCase):
|
|
||||||
|
|
||||||
def setUp(self):
|
|
||||||
self.config = {
|
|
||||||
'system': 'System text',
|
|
||||||
'db': 'test_files'
|
|
||||||
}
|
|
||||||
self.question = "test question"
|
|
||||||
self.tags = ['test_tag']
|
|
||||||
|
|
||||||
@patch('os.listdir')
|
|
||||||
@patch('pathlib.Path.iterdir')
|
|
||||||
@patch('builtins.open')
|
|
||||||
def test_create_chat_with_tags(self, open_mock, iterdir_mock, listdir_mock):
|
|
||||||
listdir_mock.return_value = ['testfile.txt']
|
|
||||||
iterdir_mock.return_value = [pathlib.Path(x) for x in listdir_mock.return_value]
|
|
||||||
open_mock.return_value.__enter__.return_value = io.StringIO(dump_data(
|
|
||||||
{'question': 'test_content', 'answer': 'some answer',
|
|
||||||
'tags': ['test_tag']}))
|
|
||||||
|
|
||||||
test_chat = create_chat(self.question, self.tags, None, self.config)
|
|
||||||
|
|
||||||
self.assertEqual(len(test_chat), 4)
|
|
||||||
self.assertEqual(test_chat[0],
|
|
||||||
{'role': 'system', 'content': self.config['system']})
|
|
||||||
self.assertEqual(test_chat[1],
|
|
||||||
{'role': 'user', 'content': 'test_content'})
|
|
||||||
self.assertEqual(test_chat[2],
|
|
||||||
{'role': 'assistant', 'content': 'some answer'})
|
|
||||||
self.assertEqual(test_chat[3],
|
|
||||||
{'role': 'user', 'content': self.question})
|
|
||||||
|
|
||||||
@patch('os.listdir')
|
|
||||||
@patch('pathlib.Path.iterdir')
|
|
||||||
@patch('builtins.open')
|
|
||||||
def test_create_chat_with_other_tags(self, open_mock, iterdir_mock, listdir_mock):
|
|
||||||
listdir_mock.return_value = ['testfile.txt']
|
|
||||||
iterdir_mock.return_value = [pathlib.Path(x) for x in listdir_mock.return_value]
|
|
||||||
open_mock.return_value.__enter__.return_value = io.StringIO(dump_data(
|
|
||||||
{'question': 'test_content', 'answer': 'some answer',
|
|
||||||
'tags': ['other_tag']}))
|
|
||||||
|
|
||||||
test_chat = create_chat(self.question, self.tags, None, self.config)
|
|
||||||
|
|
||||||
self.assertEqual(len(test_chat), 2)
|
|
||||||
self.assertEqual(test_chat[0],
|
|
||||||
{'role': 'system', 'content': self.config['system']})
|
|
||||||
self.assertEqual(test_chat[1],
|
|
||||||
{'role': 'user', 'content': self.question})
|
|
||||||
|
|
||||||
@patch('os.listdir')
|
|
||||||
@patch('pathlib.Path.iterdir')
|
|
||||||
@patch('builtins.open')
|
|
||||||
def test_create_chat_without_tags(self, open_mock, iterdir_mock, listdir_mock):
|
|
||||||
listdir_mock.return_value = ['testfile.txt', 'testfile2.txt']
|
|
||||||
iterdir_mock.return_value = [pathlib.Path(x) for x in listdir_mock.return_value]
|
|
||||||
open_mock.side_effect = (
|
|
||||||
io.StringIO(dump_data({'question': 'test_content',
|
|
||||||
'answer': 'some answer',
|
|
||||||
'tags': ['test_tag']})),
|
|
||||||
io.StringIO(dump_data({'question': 'test_content2',
|
|
||||||
'answer': 'some answer2',
|
|
||||||
'tags': ['test_tag2']})),
|
|
||||||
)
|
|
||||||
|
|
||||||
test_chat = create_chat(self.question, [], None, self.config)
|
|
||||||
|
|
||||||
self.assertEqual(len(test_chat), 6)
|
|
||||||
self.assertEqual(test_chat[0],
|
|
||||||
{'role': 'system', 'content': self.config['system']})
|
|
||||||
self.assertEqual(test_chat[1],
|
|
||||||
{'role': 'user', 'content': 'test_content'})
|
|
||||||
self.assertEqual(test_chat[2],
|
|
||||||
{'role': 'assistant', 'content': 'some answer'})
|
|
||||||
self.assertEqual(test_chat[3],
|
|
||||||
{'role': 'user', 'content': 'test_content2'})
|
|
||||||
self.assertEqual(test_chat[4],
|
|
||||||
{'role': 'assistant', 'content': 'some answer2'})
|
|
||||||
|
|
||||||
|
|
||||||
class TestHandleQuestion(unittest.TestCase):
|
|
||||||
|
|
||||||
def setUp(self):
|
|
||||||
self.question = "test question"
|
|
||||||
self.args = argparse.Namespace(
|
|
||||||
tags=['tag1'],
|
|
||||||
extags=['extag1'],
|
|
||||||
output_tags=None,
|
|
||||||
question=[self.question],
|
|
||||||
source=None,
|
|
||||||
only_source_code=False,
|
|
||||||
number=3,
|
|
||||||
match_all_tags=False,
|
|
||||||
with_tags=False,
|
|
||||||
with_file=False,
|
|
||||||
)
|
|
||||||
self.config = {
|
|
||||||
'db': 'test_files',
|
|
||||||
'setting1': 'value1',
|
|
||||||
'setting2': 'value2'
|
|
||||||
}
|
|
||||||
|
|
||||||
@patch("chatmastermind.main.create_chat", return_value="test_chat")
|
|
||||||
@patch("chatmastermind.main.process_tags")
|
|
||||||
@patch("chatmastermind.main.ai", return_value=(["answer1", "answer2", "answer3"], "test_usage"))
|
|
||||||
@patch("chatmastermind.utils.pp")
|
|
||||||
@patch("builtins.print")
|
|
||||||
def test_handle_question(self, mock_print, mock_pp, mock_ai,
|
|
||||||
mock_process_tags, mock_create_chat):
|
|
||||||
open_mock = MagicMock()
|
|
||||||
with patch("chatmastermind.storage.open", open_mock):
|
|
||||||
handle_question(self.args, self.config, True)
|
|
||||||
mock_process_tags.assert_called_once_with(self.args.tags,
|
|
||||||
self.args.extags,
|
|
||||||
[])
|
|
||||||
mock_create_chat.assert_called_once_with(self.question,
|
|
||||||
self.args.tags,
|
|
||||||
self.args.extags,
|
|
||||||
self.config,
|
|
||||||
False, False, False)
|
|
||||||
mock_pp.assert_called_once_with("test_chat")
|
|
||||||
mock_ai.assert_called_with("test_chat",
|
|
||||||
self.config,
|
|
||||||
self.args.number)
|
|
||||||
expected_calls = []
|
|
||||||
for num, answer in enumerate(mock_ai.return_value[0], start=1):
|
|
||||||
title = f'-- ANSWER {num} '
|
|
||||||
title_end = '-' * (terminal_width() - len(title))
|
|
||||||
expected_calls.append(((f'{title}{title_end}',),))
|
|
||||||
expected_calls.append(((answer,),))
|
|
||||||
expected_calls.append((("-" * terminal_width(),),))
|
|
||||||
expected_calls.append(((f"Usage: {mock_ai.return_value[1]}",),))
|
|
||||||
self.assertEqual(mock_print.call_args_list, expected_calls)
|
|
||||||
open_expected_calls = list([mock.call(f"{num:04d}.txt", "w") for num in range(2, 5)])
|
|
||||||
open_mock.assert_has_calls(open_expected_calls, any_order=True)
|
|
||||||
|
|
||||||
|
|
||||||
class TestSaveAnswers(unittest.TestCase):
|
|
||||||
@mock.patch('builtins.open')
|
|
||||||
@mock.patch('chatmastermind.storage.print')
|
|
||||||
def test_save_answers(self, print_mock, open_mock):
|
|
||||||
question = "Test question?"
|
|
||||||
answers = ["Answer 1", "Answer 2"]
|
|
||||||
tags = ["tag1", "tag2"]
|
|
||||||
otags = ["otag1", "otag2"]
|
|
||||||
config = {'db': 'test_db'}
|
|
||||||
|
|
||||||
with mock.patch('chatmastermind.storage.pathlib.Path.exists', return_value=True), \
|
|
||||||
mock.patch('chatmastermind.storage.yaml.dump'), \
|
|
||||||
mock.patch('io.StringIO') as stringio_mock:
|
|
||||||
stringio_instance = stringio_mock.return_value
|
|
||||||
stringio_instance.getvalue.side_effect = ["question", "answer1", "answer2"]
|
|
||||||
save_answers(question, answers, tags, otags, config)
|
|
||||||
|
|
||||||
open_calls = [
|
|
||||||
mock.call(pathlib.Path('test_db/.next'), 'r'),
|
|
||||||
mock.call(pathlib.Path('test_db/.next'), 'w'),
|
|
||||||
]
|
|
||||||
open_mock.assert_has_calls(open_calls, any_order=True)
|
|
||||||
|
|
||||||
|
|
||||||
class TestAI(unittest.TestCase):
|
|
||||||
|
|
||||||
@patch("openai.ChatCompletion.create")
|
|
||||||
def test_ai(self, mock_create: MagicMock):
|
|
||||||
mock_create.return_value = {
|
|
||||||
'choices': [
|
|
||||||
{'message': {'content': 'response_text_1'}},
|
|
||||||
{'message': {'content': 'response_text_2'}}
|
|
||||||
],
|
|
||||||
'usage': {'tokens': 10}
|
|
||||||
}
|
|
||||||
|
|
||||||
number = 2
|
|
||||||
chat = [{"role": "system", "content": "hello ai"}]
|
|
||||||
config = {
|
|
||||||
"openai": {
|
|
||||||
"model": "text-davinci-002",
|
|
||||||
"temperature": 0.5,
|
|
||||||
"max_tokens": 150,
|
|
||||||
"top_p": 1,
|
|
||||||
"n": number,
|
|
||||||
"frequency_penalty": 0,
|
|
||||||
"presence_penalty": 0
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
result = ai(chat, config, number)
|
|
||||||
expected_result = (['response_text_1', 'response_text_2'],
|
|
||||||
{'tokens': 10})
|
|
||||||
self.assertEqual(result, expected_result)
|
|
||||||
|
|
||||||
|
|
||||||
class TestCreateParser(unittest.TestCase):
|
|
||||||
def test_create_parser(self):
|
|
||||||
with patch('argparse.ArgumentParser.add_mutually_exclusive_group') as mock_add_mutually_exclusive_group:
|
|
||||||
mock_group = Mock()
|
|
||||||
mock_add_mutually_exclusive_group.return_value = mock_group
|
|
||||||
parser = create_parser()
|
|
||||||
self.assertIsInstance(parser, argparse.ArgumentParser)
|
|
||||||
mock_add_mutually_exclusive_group.assert_called_once_with(required=True)
|
|
||||||
mock_group.add_argument.assert_any_call('-p', '--print', help='File to print')
|
|
||||||
mock_group.add_argument.assert_any_call('-q', '--question', nargs='*', help='Question to ask')
|
|
||||||
mock_group.add_argument.assert_any_call('-D', '--chat-dump', help="Print chat history as Python structure", action='store_true')
|
|
||||||
mock_group.add_argument.assert_any_call('-d', '--chat', help="Print chat history as readable text", action='store_true')
|
|
||||||
self.assertTrue('.config.yaml' in parser.get_default('config'))
|
|
||||||
self.assertEqual(parser.get_default('number'), 1)
|
|
||||||
@@ -0,0 +1,899 @@
|
|||||||
|
import unittest
|
||||||
|
import pathlib
|
||||||
|
import tempfile
|
||||||
|
import itertools
|
||||||
|
from typing import cast
|
||||||
|
from chatmastermind.message import source_code, Message, MessageError, Question, Answer, AILine, ModelLine,\
|
||||||
|
MessageFilter, message_in, message_valid_formats
|
||||||
|
from chatmastermind.tags import Tag, TagLine
|
||||||
|
|
||||||
|
|
||||||
|
msg_suffix: str = Message.file_suffix_write
|
||||||
|
|
||||||
|
|
||||||
|
class SourceCodeTestCase(unittest.TestCase):
|
||||||
|
def test_source_code_with_include_delims(self) -> None:
|
||||||
|
text = """
|
||||||
|
Some text before the code block
|
||||||
|
```python
|
||||||
|
print("Hello, World!")
|
||||||
|
```
|
||||||
|
Some text after the code block
|
||||||
|
```python
|
||||||
|
x = 10
|
||||||
|
y = 20
|
||||||
|
print(x + y)
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
expected_result = [
|
||||||
|
" ```python\n print(\"Hello, World!\")\n ```\n",
|
||||||
|
" ```python\n x = 10\n y = 20\n print(x + y)\n ```\n"
|
||||||
|
]
|
||||||
|
result = source_code(text, include_delims=True)
|
||||||
|
self.assertEqual(result, expected_result)
|
||||||
|
|
||||||
|
def test_source_code_without_include_delims(self) -> None:
|
||||||
|
text = """
|
||||||
|
Some text before the code block
|
||||||
|
```python
|
||||||
|
print("Hello, World!")
|
||||||
|
```
|
||||||
|
Some text after the code block
|
||||||
|
```python
|
||||||
|
x = 10
|
||||||
|
y = 20
|
||||||
|
print(x + y)
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
expected_result = [
|
||||||
|
" print(\"Hello, World!\")\n",
|
||||||
|
" x = 10\n y = 20\n print(x + y)\n"
|
||||||
|
]
|
||||||
|
result = source_code(text, include_delims=False)
|
||||||
|
self.assertEqual(result, expected_result)
|
||||||
|
|
||||||
|
def test_source_code_with_single_code_block(self) -> None:
|
||||||
|
text = "```python\nprint(\"Hello, World!\")\n```"
|
||||||
|
expected_result = ["```python\nprint(\"Hello, World!\")\n```\n"]
|
||||||
|
result = source_code(text, include_delims=True)
|
||||||
|
self.assertEqual(result, expected_result)
|
||||||
|
|
||||||
|
def test_source_code_with_no_code_blocks(self) -> None:
|
||||||
|
text = "Some text without any code blocks"
|
||||||
|
expected_result: list[str] = []
|
||||||
|
result = source_code(text, include_delims=True)
|
||||||
|
self.assertEqual(result, expected_result)
|
||||||
|
|
||||||
|
|
||||||
|
class QuestionTestCase(unittest.TestCase):
|
||||||
|
def test_question_with_header(self) -> None:
|
||||||
|
with self.assertRaises(MessageError):
|
||||||
|
Question(f"{Question.txt_header}\nWhat is your name?")
|
||||||
|
|
||||||
|
def test_question_with_answer_header(self) -> None:
|
||||||
|
with self.assertRaises(MessageError):
|
||||||
|
Question(f"{Answer.txt_header}\nBob")
|
||||||
|
|
||||||
|
def test_question_with_legal_header(self) -> None:
|
||||||
|
"""
|
||||||
|
If the header is just a part of a line, it's fine.
|
||||||
|
"""
|
||||||
|
question = Question(f"This is a line contaning '{Question.txt_header}'\nWhat does that mean?")
|
||||||
|
self.assertIsInstance(question, Question)
|
||||||
|
self.assertEqual(question, f"This is a line contaning '{Question.txt_header}'\nWhat does that mean?")
|
||||||
|
|
||||||
|
def test_question_without_header(self) -> None:
|
||||||
|
question = Question("What is your favorite color?")
|
||||||
|
self.assertIsInstance(question, Question)
|
||||||
|
self.assertEqual(question, "What is your favorite color?")
|
||||||
|
|
||||||
|
|
||||||
|
class AnswerTestCase(unittest.TestCase):
|
||||||
|
def test_answer_with_header(self) -> None:
|
||||||
|
with self.assertRaises(MessageError):
|
||||||
|
str(Answer(f"{Answer.txt_header}\nno"))
|
||||||
|
|
||||||
|
def test_answer_with_legal_header(self) -> None:
|
||||||
|
answer = Answer(f"This is a line contaning '{Answer.txt_header}'\nIt is what it is.")
|
||||||
|
self.assertIsInstance(answer, Answer)
|
||||||
|
self.assertEqual(answer, f"This is a line contaning '{Answer.txt_header}'\nIt is what it is.")
|
||||||
|
|
||||||
|
def test_answer_without_header(self) -> None:
|
||||||
|
answer = Answer("No")
|
||||||
|
self.assertIsInstance(answer, Answer)
|
||||||
|
self.assertEqual(answer, "No")
|
||||||
|
|
||||||
|
|
||||||
|
class MessageToFileTxtTestCase(unittest.TestCase):
|
||||||
|
def setUp(self) -> None:
|
||||||
|
self.file = tempfile.NamedTemporaryFile(delete=False, suffix=msg_suffix)
|
||||||
|
self.file_path = pathlib.Path(self.file.name)
|
||||||
|
self.message_complete = Message(Question('This is a question.'),
|
||||||
|
Answer('This is an answer.'),
|
||||||
|
{Tag('tag1'), Tag('tag2')},
|
||||||
|
ai='ChatGPT',
|
||||||
|
model='gpt-3.5-turbo',
|
||||||
|
file_path=self.file_path)
|
||||||
|
self.message_min = Message(Question('This is a question.'),
|
||||||
|
file_path=self.file_path)
|
||||||
|
|
||||||
|
def tearDown(self) -> None:
|
||||||
|
self.file.close()
|
||||||
|
self.file_path.unlink()
|
||||||
|
|
||||||
|
def test_to_file_txt_complete(self) -> None:
|
||||||
|
self.message_complete.to_file(self.file_path, mformat='txt')
|
||||||
|
|
||||||
|
with open(self.file_path, "r") as fd:
|
||||||
|
content = fd.read()
|
||||||
|
expected_content = f"""{TagLine.prefix} tag1 tag2
|
||||||
|
{AILine.prefix} ChatGPT
|
||||||
|
{ModelLine.prefix} gpt-3.5-turbo
|
||||||
|
{Question.txt_header}
|
||||||
|
This is a question.
|
||||||
|
{Answer.txt_header}
|
||||||
|
This is an answer.
|
||||||
|
"""
|
||||||
|
self.assertEqual(content, expected_content)
|
||||||
|
|
||||||
|
def test_to_file_txt_min(self) -> None:
|
||||||
|
self.message_min.to_file(self.file_path, mformat='txt')
|
||||||
|
|
||||||
|
with open(self.file_path, "r") as fd:
|
||||||
|
content = fd.read()
|
||||||
|
expected_content = f"""{Question.txt_header}
|
||||||
|
This is a question.
|
||||||
|
"""
|
||||||
|
self.assertEqual(content, expected_content)
|
||||||
|
|
||||||
|
def test_to_file_unsupported_file_suffix(self) -> None:
|
||||||
|
unsupported_file_path = pathlib.Path("example.doc")
|
||||||
|
with self.assertRaises(MessageError) as cm:
|
||||||
|
self.message_complete.to_file(unsupported_file_path)
|
||||||
|
self.assertEqual(str(cm.exception), "File suffix '.doc' is not supported")
|
||||||
|
|
||||||
|
def test_to_file_unsupported_file_format(self) -> None:
|
||||||
|
unsupported_file_format = pathlib.Path(f"example{msg_suffix}")
|
||||||
|
with self.assertRaises(MessageError) as cm:
|
||||||
|
self.message_complete.to_file(unsupported_file_format, mformat='doc') # type: ignore [arg-type]
|
||||||
|
self.assertEqual(str(cm.exception), "File format 'doc' is not supported")
|
||||||
|
|
||||||
|
def test_to_file_no_file_path(self) -> None:
|
||||||
|
"""
|
||||||
|
Provoke an exception using an empty path.
|
||||||
|
"""
|
||||||
|
with self.assertRaises(MessageError) as cm:
|
||||||
|
# clear the internal file_path
|
||||||
|
self.message_complete.file_path = None
|
||||||
|
self.message_complete.to_file(None)
|
||||||
|
self.assertEqual(str(cm.exception), "Got no valid path to write message")
|
||||||
|
# reset the internal file_path
|
||||||
|
self.message_complete.file_path = self.file_path
|
||||||
|
|
||||||
|
def test_to_file_txt_auto_suffix(self) -> None:
|
||||||
|
"""
|
||||||
|
Test if suffix is auto-generated if omitted.
|
||||||
|
"""
|
||||||
|
file_path_no_suffix = self.file_path.with_suffix('')
|
||||||
|
# test with file_path member
|
||||||
|
self.message_min.file_path = file_path_no_suffix
|
||||||
|
self.message_min.to_file(mformat='txt')
|
||||||
|
self.assertEqual(self.message_min.file_path.suffix, msg_suffix)
|
||||||
|
# test with explicit file_path
|
||||||
|
self.message_min.file_path = file_path_no_suffix
|
||||||
|
self.message_min.to_file(file_path=file_path_no_suffix, mformat='txt')
|
||||||
|
self.assertEqual(self.message_min.file_path.suffix, msg_suffix)
|
||||||
|
|
||||||
|
|
||||||
|
class MessageToFileYamlTestCase(unittest.TestCase):
|
||||||
|
def setUp(self) -> None:
|
||||||
|
self.file = tempfile.NamedTemporaryFile(delete=False, suffix=msg_suffix)
|
||||||
|
self.file_path = pathlib.Path(self.file.name)
|
||||||
|
self.message_complete = Message(Question('This is a question.'),
|
||||||
|
Answer('This is an answer.'),
|
||||||
|
{Tag('tag1'), Tag('tag2')},
|
||||||
|
ai='ChatGPT',
|
||||||
|
model='gpt-3.5-turbo',
|
||||||
|
file_path=self.file_path)
|
||||||
|
self.message_multiline = Message(Question('This is a\nmultiline question.'),
|
||||||
|
Answer('This is a\nmultiline answer.'),
|
||||||
|
{Tag('tag1'), Tag('tag2')},
|
||||||
|
ai='ChatGPT',
|
||||||
|
model='gpt-3.5-turbo',
|
||||||
|
file_path=self.file_path)
|
||||||
|
self.message_min = Message(Question('This is a question.'),
|
||||||
|
file_path=self.file_path)
|
||||||
|
|
||||||
|
def tearDown(self) -> None:
|
||||||
|
self.file.close()
|
||||||
|
self.file_path.unlink()
|
||||||
|
|
||||||
|
def test_to_file_yaml_complete(self) -> None:
|
||||||
|
self.message_complete.to_file(self.file_path, mformat='yaml')
|
||||||
|
|
||||||
|
with open(self.file_path, "r") as fd:
|
||||||
|
content = fd.read()
|
||||||
|
expected_content = f"""{Question.yaml_key}: This is a question.
|
||||||
|
{Answer.yaml_key}: This is an answer.
|
||||||
|
{Message.ai_yaml_key}: ChatGPT
|
||||||
|
{Message.model_yaml_key}: gpt-3.5-turbo
|
||||||
|
{Message.tags_yaml_key}:
|
||||||
|
- tag1
|
||||||
|
- tag2
|
||||||
|
"""
|
||||||
|
self.assertEqual(content, expected_content)
|
||||||
|
|
||||||
|
def test_to_file_yaml_multiline(self) -> None:
|
||||||
|
self.message_multiline.to_file(self.file_path, mformat='yaml')
|
||||||
|
|
||||||
|
with open(self.file_path, "r") as fd:
|
||||||
|
content = fd.read()
|
||||||
|
expected_content = f"""{Question.yaml_key}: |-
|
||||||
|
This is a
|
||||||
|
multiline question.
|
||||||
|
{Answer.yaml_key}: |-
|
||||||
|
This is a
|
||||||
|
multiline answer.
|
||||||
|
{Message.ai_yaml_key}: ChatGPT
|
||||||
|
{Message.model_yaml_key}: gpt-3.5-turbo
|
||||||
|
{Message.tags_yaml_key}:
|
||||||
|
- tag1
|
||||||
|
- tag2
|
||||||
|
"""
|
||||||
|
self.assertEqual(content, expected_content)
|
||||||
|
|
||||||
|
def test_to_file_yaml_min(self) -> None:
|
||||||
|
self.message_min.to_file(self.file_path, mformat='yaml')
|
||||||
|
|
||||||
|
with open(self.file_path, "r") as fd:
|
||||||
|
content = fd.read()
|
||||||
|
expected_content = f"{Question.yaml_key}: This is a question.\n"
|
||||||
|
self.assertEqual(content, expected_content)
|
||||||
|
|
||||||
|
def test_to_file_yaml_auto_suffix(self) -> None:
|
||||||
|
"""
|
||||||
|
Test if suffix is auto-generated if omitted.
|
||||||
|
"""
|
||||||
|
file_path_no_suffix = self.file_path.with_suffix('')
|
||||||
|
# test with file_path member
|
||||||
|
self.message_min.file_path = file_path_no_suffix
|
||||||
|
self.message_min.to_file(mformat='yaml')
|
||||||
|
self.assertEqual(self.message_min.file_path.suffix, msg_suffix)
|
||||||
|
# test with explicit file_path
|
||||||
|
self.message_min.file_path = file_path_no_suffix
|
||||||
|
self.message_min.to_file(file_path=file_path_no_suffix, mformat='yaml')
|
||||||
|
self.assertEqual(self.message_min.file_path.suffix, msg_suffix)
|
||||||
|
|
||||||
|
|
||||||
|
class MessageFromFileTxtTestCase(unittest.TestCase):
|
||||||
|
def setUp(self) -> None:
|
||||||
|
self.file = tempfile.NamedTemporaryFile(delete=False, suffix=msg_suffix)
|
||||||
|
self.file_path = pathlib.Path(self.file.name)
|
||||||
|
with open(self.file_path, "w") as fd:
|
||||||
|
fd.write(f"""{TagLine.prefix} tag1 tag2
|
||||||
|
{AILine.prefix} ChatGPT
|
||||||
|
{ModelLine.prefix} gpt-3.5-turbo
|
||||||
|
{Question.txt_header}
|
||||||
|
This is a question.
|
||||||
|
{Answer.txt_header}
|
||||||
|
This is an answer.
|
||||||
|
""")
|
||||||
|
self.file_min = tempfile.NamedTemporaryFile(delete=False, suffix=msg_suffix)
|
||||||
|
self.file_path_min = pathlib.Path(self.file_min.name)
|
||||||
|
with open(self.file_path_min, "w") as fd:
|
||||||
|
fd.write(f"""{Question.txt_header}
|
||||||
|
This is a question.
|
||||||
|
""")
|
||||||
|
|
||||||
|
def tearDown(self) -> None:
|
||||||
|
self.file.close()
|
||||||
|
self.file_min.close()
|
||||||
|
self.file_path.unlink()
|
||||||
|
self.file_path_min.unlink()
|
||||||
|
|
||||||
|
def test_from_file_txt_complete(self) -> None:
|
||||||
|
"""
|
||||||
|
Read a complete message (with all optional values).
|
||||||
|
"""
|
||||||
|
message = Message.from_file(self.file_path)
|
||||||
|
self.assertIsNotNone(message)
|
||||||
|
self.assertIsInstance(message, Message)
|
||||||
|
assert message
|
||||||
|
self.assertEqual(message.question, 'This is a question.')
|
||||||
|
self.assertEqual(message.answer, 'This is an answer.')
|
||||||
|
self.assertSetEqual(cast(set[Tag], message.tags), {Tag('tag1'), Tag('tag2')})
|
||||||
|
self.assertEqual(message.ai, 'ChatGPT')
|
||||||
|
self.assertEqual(message.model, 'gpt-3.5-turbo')
|
||||||
|
self.assertEqual(message.file_path, self.file_path)
|
||||||
|
|
||||||
|
def test_from_file_txt_min(self) -> None:
|
||||||
|
"""
|
||||||
|
Read a message with only required values.
|
||||||
|
"""
|
||||||
|
message = Message.from_file(self.file_path_min)
|
||||||
|
self.assertIsNotNone(message)
|
||||||
|
self.assertIsInstance(message, Message)
|
||||||
|
assert message
|
||||||
|
self.assertEqual(message.question, 'This is a question.')
|
||||||
|
self.assertEqual(message.file_path, self.file_path_min)
|
||||||
|
self.assertIsNone(message.answer)
|
||||||
|
|
||||||
|
def test_from_file_txt_tags_match(self) -> None:
|
||||||
|
message = Message.from_file(self.file_path,
|
||||||
|
MessageFilter(tags_or={Tag('tag1')}))
|
||||||
|
self.assertIsNotNone(message)
|
||||||
|
self.assertIsInstance(message, Message)
|
||||||
|
assert message
|
||||||
|
self.assertEqual(message.question, 'This is a question.')
|
||||||
|
self.assertEqual(message.answer, 'This is an answer.')
|
||||||
|
self.assertSetEqual(cast(set[Tag], message.tags), {Tag('tag1'), Tag('tag2')})
|
||||||
|
self.assertEqual(message.file_path, self.file_path)
|
||||||
|
|
||||||
|
def test_from_file_txt_tags_dont_match(self) -> None:
|
||||||
|
message = Message.from_file(self.file_path,
|
||||||
|
MessageFilter(tags_or={Tag('tag3')}))
|
||||||
|
self.assertIsNone(message)
|
||||||
|
|
||||||
|
def test_from_file_txt_no_tags_dont_match(self) -> None:
|
||||||
|
message = Message.from_file(self.file_path_min,
|
||||||
|
MessageFilter(tags_or={Tag('tag1')}))
|
||||||
|
self.assertIsNone(message)
|
||||||
|
|
||||||
|
def test_from_file_txt_empty_tags_dont_match(self) -> None:
|
||||||
|
message = Message.from_file(self.file_path_min,
|
||||||
|
MessageFilter(tags_or=set(),
|
||||||
|
tags_and=set()))
|
||||||
|
self.assertIsNone(message)
|
||||||
|
|
||||||
|
def test_from_file_txt_no_tags_match_tags_not(self) -> None:
|
||||||
|
message = Message.from_file(self.file_path_min,
|
||||||
|
MessageFilter(tags_not={Tag('tag1')}))
|
||||||
|
self.assertIsNotNone(message)
|
||||||
|
self.assertIsInstance(message, Message)
|
||||||
|
assert message
|
||||||
|
self.assertEqual(message.question, 'This is a question.')
|
||||||
|
self.assertSetEqual(cast(set[Tag], message.tags), set())
|
||||||
|
self.assertEqual(message.file_path, self.file_path_min)
|
||||||
|
|
||||||
|
def test_from_file_not_exists(self) -> None:
|
||||||
|
file_not_exists = pathlib.Path(f"example{msg_suffix}")
|
||||||
|
with self.assertRaises(MessageError) as cm:
|
||||||
|
Message.from_file(file_not_exists)
|
||||||
|
self.assertEqual(str(cm.exception), f"Message file '{file_not_exists}' does not exist")
|
||||||
|
|
||||||
|
def test_from_file_txt_question_match(self) -> None:
|
||||||
|
message = Message.from_file(self.file_path,
|
||||||
|
MessageFilter(question_contains='question'))
|
||||||
|
self.assertIsNotNone(message)
|
||||||
|
self.assertIsInstance(message, Message)
|
||||||
|
|
||||||
|
def test_from_file_txt_answer_match(self) -> None:
|
||||||
|
message = Message.from_file(self.file_path,
|
||||||
|
MessageFilter(answer_contains='answer'))
|
||||||
|
self.assertIsNotNone(message)
|
||||||
|
self.assertIsInstance(message, Message)
|
||||||
|
|
||||||
|
def test_from_file_txt_answer_available(self) -> None:
|
||||||
|
message = Message.from_file(self.file_path,
|
||||||
|
MessageFilter(answer_state='available'))
|
||||||
|
self.assertIsNotNone(message)
|
||||||
|
self.assertIsInstance(message, Message)
|
||||||
|
|
||||||
|
def test_from_file_txt_answer_missing(self) -> None:
|
||||||
|
message = Message.from_file(self.file_path_min,
|
||||||
|
MessageFilter(answer_state='missing'))
|
||||||
|
self.assertIsNotNone(message)
|
||||||
|
self.assertIsInstance(message, Message)
|
||||||
|
|
||||||
|
def test_from_file_txt_question_doesnt_match(self) -> None:
|
||||||
|
message = Message.from_file(self.file_path,
|
||||||
|
MessageFilter(question_contains='answer'))
|
||||||
|
self.assertIsNone(message)
|
||||||
|
|
||||||
|
def test_from_file_txt_answer_doesnt_match(self) -> None:
|
||||||
|
message = Message.from_file(self.file_path,
|
||||||
|
MessageFilter(answer_contains='question'))
|
||||||
|
self.assertIsNone(message)
|
||||||
|
|
||||||
|
def test_from_file_txt_answer_not_exists(self) -> None:
|
||||||
|
message = Message.from_file(self.file_path_min,
|
||||||
|
MessageFilter(answer_contains='answer'))
|
||||||
|
self.assertIsNone(message)
|
||||||
|
|
||||||
|
def test_from_file_txt_answer_not_available(self) -> None:
|
||||||
|
message = Message.from_file(self.file_path_min,
|
||||||
|
MessageFilter(answer_state='available'))
|
||||||
|
self.assertIsNone(message)
|
||||||
|
|
||||||
|
def test_from_file_txt_answer_not_missing(self) -> None:
|
||||||
|
message = Message.from_file(self.file_path,
|
||||||
|
MessageFilter(answer_state='missing'))
|
||||||
|
self.assertIsNone(message)
|
||||||
|
|
||||||
|
def test_from_file_txt_ai_match(self) -> None:
|
||||||
|
message = Message.from_file(self.file_path,
|
||||||
|
MessageFilter(ai='ChatGPT'))
|
||||||
|
self.assertIsNotNone(message)
|
||||||
|
self.assertIsInstance(message, Message)
|
||||||
|
|
||||||
|
def test_from_file_txt_ai_doesnt_match(self) -> None:
|
||||||
|
message = Message.from_file(self.file_path,
|
||||||
|
MessageFilter(ai='Foo'))
|
||||||
|
self.assertIsNone(message)
|
||||||
|
|
||||||
|
def test_from_file_txt_model_match(self) -> None:
|
||||||
|
message = Message.from_file(self.file_path,
|
||||||
|
MessageFilter(model='gpt-3.5-turbo'))
|
||||||
|
self.assertIsNotNone(message)
|
||||||
|
self.assertIsInstance(message, Message)
|
||||||
|
|
||||||
|
def test_from_file_txt_model_doesnt_match(self) -> None:
|
||||||
|
message = Message.from_file(self.file_path,
|
||||||
|
MessageFilter(model='Bar'))
|
||||||
|
self.assertIsNone(message)
|
||||||
|
|
||||||
|
|
||||||
|
class MessageFromFileYamlTestCase(unittest.TestCase):
|
||||||
|
def setUp(self) -> None:
|
||||||
|
self.file = tempfile.NamedTemporaryFile(delete=False, suffix=msg_suffix)
|
||||||
|
self.file_path = pathlib.Path(self.file.name)
|
||||||
|
with open(self.file_path, "w") as fd:
|
||||||
|
fd.write(f"""
|
||||||
|
{Question.yaml_key}: |-
|
||||||
|
This is a question.
|
||||||
|
{Answer.yaml_key}: |-
|
||||||
|
This is an answer.
|
||||||
|
{Message.ai_yaml_key}: ChatGPT
|
||||||
|
{Message.model_yaml_key}: gpt-3.5-turbo
|
||||||
|
{Message.tags_yaml_key}:
|
||||||
|
- tag1
|
||||||
|
- tag2
|
||||||
|
""")
|
||||||
|
self.file_min = tempfile.NamedTemporaryFile(delete=False, suffix=msg_suffix)
|
||||||
|
self.file_path_min = pathlib.Path(self.file_min.name)
|
||||||
|
with open(self.file_path_min, "w") as fd:
|
||||||
|
fd.write(f"""
|
||||||
|
{Question.yaml_key}: |-
|
||||||
|
This is a question.
|
||||||
|
""")
|
||||||
|
|
||||||
|
def tearDown(self) -> None:
|
||||||
|
self.file.close()
|
||||||
|
self.file_path.unlink()
|
||||||
|
self.file_min.close()
|
||||||
|
self.file_path_min.unlink()
|
||||||
|
|
||||||
|
def test_from_file_yaml_complete(self) -> None:
|
||||||
|
"""
|
||||||
|
Read a complete message (with all optional values).
|
||||||
|
"""
|
||||||
|
message = Message.from_file(self.file_path)
|
||||||
|
self.assertIsInstance(message, Message)
|
||||||
|
self.assertIsNotNone(message)
|
||||||
|
assert message
|
||||||
|
self.assertEqual(message.question, 'This is a question.')
|
||||||
|
self.assertEqual(message.answer, 'This is an answer.')
|
||||||
|
self.assertSetEqual(cast(set[Tag], message.tags), {Tag('tag1'), Tag('tag2')})
|
||||||
|
self.assertEqual(message.ai, 'ChatGPT')
|
||||||
|
self.assertEqual(message.model, 'gpt-3.5-turbo')
|
||||||
|
self.assertEqual(message.file_path, self.file_path)
|
||||||
|
|
||||||
|
def test_from_file_yaml_min(self) -> None:
|
||||||
|
"""
|
||||||
|
Read a message with only the required values.
|
||||||
|
"""
|
||||||
|
message = Message.from_file(self.file_path_min)
|
||||||
|
self.assertIsInstance(message, Message)
|
||||||
|
self.assertIsNotNone(message)
|
||||||
|
assert message
|
||||||
|
self.assertEqual(message.question, 'This is a question.')
|
||||||
|
self.assertSetEqual(cast(set[Tag], message.tags), set())
|
||||||
|
self.assertEqual(message.file_path, self.file_path_min)
|
||||||
|
self.assertIsNone(message.answer)
|
||||||
|
|
||||||
|
def test_from_file_not_exists(self) -> None:
|
||||||
|
file_not_exists = pathlib.Path(f"example{msg_suffix}")
|
||||||
|
with self.assertRaises(MessageError) as cm:
|
||||||
|
Message.from_file(file_not_exists)
|
||||||
|
self.assertEqual(str(cm.exception), f"Message file '{file_not_exists}' does not exist")
|
||||||
|
|
||||||
|
def test_from_file_yaml_tags_match(self) -> None:
|
||||||
|
message = Message.from_file(self.file_path,
|
||||||
|
MessageFilter(tags_or={Tag('tag1')}))
|
||||||
|
self.assertIsNotNone(message)
|
||||||
|
self.assertIsInstance(message, Message)
|
||||||
|
assert message
|
||||||
|
self.assertEqual(message.question, 'This is a question.')
|
||||||
|
self.assertEqual(message.answer, 'This is an answer.')
|
||||||
|
self.assertSetEqual(cast(set[Tag], message.tags), {Tag('tag1'), Tag('tag2')})
|
||||||
|
self.assertEqual(message.file_path, self.file_path)
|
||||||
|
|
||||||
|
def test_from_file_yaml_tags_dont_match(self) -> None:
|
||||||
|
message = Message.from_file(self.file_path,
|
||||||
|
MessageFilter(tags_or={Tag('tag3')}))
|
||||||
|
self.assertIsNone(message)
|
||||||
|
|
||||||
|
def test_from_file_yaml_no_tags_dont_match(self) -> None:
|
||||||
|
message = Message.from_file(self.file_path_min,
|
||||||
|
MessageFilter(tags_or={Tag('tag1')}))
|
||||||
|
self.assertIsNone(message)
|
||||||
|
|
||||||
|
def test_from_file_yaml_no_tags_match_tags_not(self) -> None:
|
||||||
|
message = Message.from_file(self.file_path_min,
|
||||||
|
MessageFilter(tags_not={Tag('tag1')}))
|
||||||
|
self.assertIsNotNone(message)
|
||||||
|
self.assertIsInstance(message, Message)
|
||||||
|
assert message
|
||||||
|
self.assertEqual(message.question, 'This is a question.')
|
||||||
|
self.assertSetEqual(cast(set[Tag], message.tags), set())
|
||||||
|
self.assertEqual(message.file_path, self.file_path_min)
|
||||||
|
|
||||||
|
def test_from_file_yaml_question_match(self) -> None:
|
||||||
|
message = Message.from_file(self.file_path,
|
||||||
|
MessageFilter(question_contains='question'))
|
||||||
|
self.assertIsNotNone(message)
|
||||||
|
self.assertIsInstance(message, Message)
|
||||||
|
|
||||||
|
def test_from_file_yaml_answer_match(self) -> None:
|
||||||
|
message = Message.from_file(self.file_path,
|
||||||
|
MessageFilter(answer_contains='answer'))
|
||||||
|
self.assertIsNotNone(message)
|
||||||
|
self.assertIsInstance(message, Message)
|
||||||
|
|
||||||
|
def test_from_file_yaml_answer_available(self) -> None:
|
||||||
|
message = Message.from_file(self.file_path,
|
||||||
|
MessageFilter(answer_state='available'))
|
||||||
|
self.assertIsNotNone(message)
|
||||||
|
self.assertIsInstance(message, Message)
|
||||||
|
|
||||||
|
def test_from_file_yaml_answer_missing(self) -> None:
|
||||||
|
message = Message.from_file(self.file_path_min,
|
||||||
|
MessageFilter(answer_state='missing'))
|
||||||
|
self.assertIsNotNone(message)
|
||||||
|
self.assertIsInstance(message, Message)
|
||||||
|
|
||||||
|
def test_from_file_yaml_question_doesnt_match(self) -> None:
|
||||||
|
message = Message.from_file(self.file_path,
|
||||||
|
MessageFilter(question_contains='answer'))
|
||||||
|
self.assertIsNone(message)
|
||||||
|
|
||||||
|
def test_from_file_yaml_answer_doesnt_match(self) -> None:
|
||||||
|
message = Message.from_file(self.file_path,
|
||||||
|
MessageFilter(answer_contains='question'))
|
||||||
|
self.assertIsNone(message)
|
||||||
|
|
||||||
|
def test_from_file_yaml_answer_not_exists(self) -> None:
|
||||||
|
message = Message.from_file(self.file_path_min,
|
||||||
|
MessageFilter(answer_contains='answer'))
|
||||||
|
self.assertIsNone(message)
|
||||||
|
|
||||||
|
def test_from_file_yaml_answer_not_available(self) -> None:
|
||||||
|
message = Message.from_file(self.file_path_min,
|
||||||
|
MessageFilter(answer_state='available'))
|
||||||
|
self.assertIsNone(message)
|
||||||
|
|
||||||
|
def test_from_file_yaml_answer_not_missing(self) -> None:
|
||||||
|
message = Message.from_file(self.file_path,
|
||||||
|
MessageFilter(answer_state='missing'))
|
||||||
|
self.assertIsNone(message)
|
||||||
|
|
||||||
|
def test_from_file_yaml_ai_match(self) -> None:
|
||||||
|
message = Message.from_file(self.file_path,
|
||||||
|
MessageFilter(ai='ChatGPT'))
|
||||||
|
self.assertIsNotNone(message)
|
||||||
|
self.assertIsInstance(message, Message)
|
||||||
|
|
||||||
|
def test_from_file_yaml_ai_doesnt_match(self) -> None:
|
||||||
|
message = Message.from_file(self.file_path,
|
||||||
|
MessageFilter(ai='Foo'))
|
||||||
|
self.assertIsNone(message)
|
||||||
|
|
||||||
|
def test_from_file_yaml_model_match(self) -> None:
|
||||||
|
message = Message.from_file(self.file_path,
|
||||||
|
MessageFilter(model='gpt-3.5-turbo'))
|
||||||
|
self.assertIsNotNone(message)
|
||||||
|
self.assertIsInstance(message, Message)
|
||||||
|
|
||||||
|
def test_from_file_yaml_model_doesnt_match(self) -> None:
|
||||||
|
message = Message.from_file(self.file_path,
|
||||||
|
MessageFilter(model='Bar'))
|
||||||
|
self.assertIsNone(message)
|
||||||
|
|
||||||
|
|
||||||
|
class TagsFromFileTestCase(unittest.TestCase):
|
||||||
|
def setUp(self) -> None:
|
||||||
|
self.file_txt = tempfile.NamedTemporaryFile(delete=False, suffix=msg_suffix)
|
||||||
|
self.file_path_txt = pathlib.Path(self.file_txt.name)
|
||||||
|
with open(self.file_path_txt, "w") as fd:
|
||||||
|
fd.write(f"""{TagLine.prefix} tag1 tag2 ptag3
|
||||||
|
{Question.txt_header}
|
||||||
|
This is a question.
|
||||||
|
{Answer.txt_header}
|
||||||
|
This is an answer.
|
||||||
|
""")
|
||||||
|
self.file_txt_no_tags = tempfile.NamedTemporaryFile(delete=False, suffix=msg_suffix)
|
||||||
|
self.file_path_txt_no_tags = pathlib.Path(self.file_txt_no_tags.name)
|
||||||
|
with open(self.file_path_txt_no_tags, "w") as fd:
|
||||||
|
fd.write(f"""{Question.txt_header}
|
||||||
|
This is a question.
|
||||||
|
{Answer.txt_header}
|
||||||
|
This is an answer.
|
||||||
|
""")
|
||||||
|
self.file_txt_tags_empty = tempfile.NamedTemporaryFile(delete=False, suffix=msg_suffix)
|
||||||
|
self.file_path_txt_tags_empty = pathlib.Path(self.file_txt_tags_empty.name)
|
||||||
|
with open(self.file_path_txt_tags_empty, "w") as fd:
|
||||||
|
fd.write(f"""TAGS:
|
||||||
|
{Question.txt_header}
|
||||||
|
This is a question.
|
||||||
|
{Answer.txt_header}
|
||||||
|
This is an answer.
|
||||||
|
""")
|
||||||
|
self.file_yaml = tempfile.NamedTemporaryFile(delete=False, suffix=msg_suffix)
|
||||||
|
self.file_path_yaml = pathlib.Path(self.file_yaml.name)
|
||||||
|
with open(self.file_path_yaml, "w") as fd:
|
||||||
|
fd.write(f"""
|
||||||
|
{Question.yaml_key}: |-
|
||||||
|
This is a question.
|
||||||
|
{Answer.yaml_key}: |-
|
||||||
|
This is an answer.
|
||||||
|
{Message.tags_yaml_key}:
|
||||||
|
- tag1
|
||||||
|
- tag2
|
||||||
|
- ptag3
|
||||||
|
""")
|
||||||
|
self.file_yaml_no_tags = tempfile.NamedTemporaryFile(delete=False, suffix=msg_suffix)
|
||||||
|
self.file_path_yaml_no_tags = pathlib.Path(self.file_yaml_no_tags.name)
|
||||||
|
with open(self.file_path_yaml_no_tags, "w") as fd:
|
||||||
|
fd.write(f"""
|
||||||
|
{Question.yaml_key}: |-
|
||||||
|
This is a question.
|
||||||
|
{Answer.yaml_key}: |-
|
||||||
|
This is an answer.
|
||||||
|
""")
|
||||||
|
|
||||||
|
def tearDown(self) -> None:
|
||||||
|
self.file_txt.close()
|
||||||
|
self.file_path_txt.unlink()
|
||||||
|
self.file_yaml.close()
|
||||||
|
self.file_path_yaml.unlink()
|
||||||
|
self.file_txt_no_tags.close
|
||||||
|
self.file_path_txt_no_tags.unlink()
|
||||||
|
self.file_txt_tags_empty.close
|
||||||
|
self.file_path_txt_tags_empty.unlink()
|
||||||
|
self.file_yaml_no_tags.close()
|
||||||
|
self.file_path_yaml_no_tags.unlink()
|
||||||
|
|
||||||
|
def test_tags_from_file_txt(self) -> None:
|
||||||
|
tags = Message.tags_from_file(self.file_path_txt)
|
||||||
|
self.assertSetEqual(tags, {Tag('tag1'), Tag('tag2'), Tag('ptag3')})
|
||||||
|
|
||||||
|
def test_tags_from_file_txt_no_tags(self) -> None:
|
||||||
|
tags = Message.tags_from_file(self.file_path_txt_no_tags)
|
||||||
|
self.assertSetEqual(tags, set())
|
||||||
|
|
||||||
|
def test_tags_from_file_txt_tags_empty(self) -> None:
|
||||||
|
tags = Message.tags_from_file(self.file_path_txt_tags_empty)
|
||||||
|
self.assertSetEqual(tags, set())
|
||||||
|
|
||||||
|
def test_tags_from_file_yaml(self) -> None:
|
||||||
|
tags = Message.tags_from_file(self.file_path_yaml)
|
||||||
|
self.assertSetEqual(tags, {Tag('tag1'), Tag('tag2'), Tag('ptag3')})
|
||||||
|
|
||||||
|
def test_tags_from_file_yaml_no_tags(self) -> None:
|
||||||
|
tags = Message.tags_from_file(self.file_path_yaml_no_tags)
|
||||||
|
self.assertSetEqual(tags, set())
|
||||||
|
|
||||||
|
def test_tags_from_file_txt_prefix(self) -> None:
|
||||||
|
tags = Message.tags_from_file(self.file_path_txt, prefix='p')
|
||||||
|
self.assertSetEqual(tags, {Tag('ptag3')})
|
||||||
|
tags = Message.tags_from_file(self.file_path_txt, prefix='R')
|
||||||
|
self.assertSetEqual(tags, set())
|
||||||
|
|
||||||
|
def test_tags_from_file_yaml_prefix(self) -> None:
|
||||||
|
tags = Message.tags_from_file(self.file_path_yaml, prefix='p')
|
||||||
|
self.assertSetEqual(tags, {Tag('ptag3')})
|
||||||
|
tags = Message.tags_from_file(self.file_path_yaml, prefix='R')
|
||||||
|
self.assertSetEqual(tags, set())
|
||||||
|
|
||||||
|
def test_tags_from_file_txt_contain(self) -> None:
|
||||||
|
tags = Message.tags_from_file(self.file_path_txt, contain='3')
|
||||||
|
self.assertSetEqual(tags, {Tag('ptag3')})
|
||||||
|
tags = Message.tags_from_file(self.file_path_txt, contain='R')
|
||||||
|
self.assertSetEqual(tags, set())
|
||||||
|
|
||||||
|
def test_tags_from_file_yaml_contain(self) -> None:
|
||||||
|
tags = Message.tags_from_file(self.file_path_yaml, contain='3')
|
||||||
|
self.assertSetEqual(tags, {Tag('ptag3')})
|
||||||
|
tags = Message.tags_from_file(self.file_path_yaml, contain='R')
|
||||||
|
self.assertSetEqual(tags, set())
|
||||||
|
|
||||||
|
|
||||||
|
class TagsFromDirTestCase(unittest.TestCase):
|
||||||
|
def setUp(self) -> None:
|
||||||
|
self.temp_dir = tempfile.TemporaryDirectory()
|
||||||
|
self.temp_dir_no_tags = tempfile.TemporaryDirectory()
|
||||||
|
self.tag_sets = [
|
||||||
|
{Tag('atag1'), Tag('atag2')},
|
||||||
|
{Tag('btag3'), Tag('btag4')},
|
||||||
|
{Tag('ctag5'), Tag('ctag6')}
|
||||||
|
]
|
||||||
|
self.files = [
|
||||||
|
pathlib.Path(self.temp_dir.name, f'file1{msg_suffix}'),
|
||||||
|
pathlib.Path(self.temp_dir.name, f'file2{msg_suffix}'),
|
||||||
|
pathlib.Path(self.temp_dir.name, f'file3{msg_suffix}')
|
||||||
|
]
|
||||||
|
self.files_no_tags = [
|
||||||
|
pathlib.Path(self.temp_dir_no_tags.name, f'file4{msg_suffix}'),
|
||||||
|
pathlib.Path(self.temp_dir_no_tags.name, f'file5{msg_suffix}'),
|
||||||
|
pathlib.Path(self.temp_dir_no_tags.name, f'file6{msg_suffix}')
|
||||||
|
]
|
||||||
|
mformats = itertools.cycle(message_valid_formats)
|
||||||
|
for file, tags in zip(self.files, self.tag_sets):
|
||||||
|
message = Message(Question('This is a question.'),
|
||||||
|
Answer('This is an answer.'),
|
||||||
|
tags)
|
||||||
|
message.to_file(file, next(mformats))
|
||||||
|
for file in self.files_no_tags:
|
||||||
|
message = Message(Question('This is a question.'),
|
||||||
|
Answer('This is an answer.'))
|
||||||
|
message.to_file(file, next(mformats))
|
||||||
|
|
||||||
|
def tearDown(self) -> None:
|
||||||
|
self.temp_dir.cleanup()
|
||||||
|
self.temp_dir_no_tags.cleanup()
|
||||||
|
|
||||||
|
def test_tags_from_dir(self) -> None:
|
||||||
|
all_tags = Message.tags_from_dir(pathlib.Path(self.temp_dir.name))
|
||||||
|
expected_tags = self.tag_sets[0] | self.tag_sets[1] | self.tag_sets[2]
|
||||||
|
self.assertEqual(all_tags, expected_tags)
|
||||||
|
|
||||||
|
def test_tags_from_dir_prefix(self) -> None:
|
||||||
|
atags = Message.tags_from_dir(pathlib.Path(self.temp_dir.name), prefix='a')
|
||||||
|
expected_tags = self.tag_sets[0]
|
||||||
|
self.assertEqual(atags, expected_tags)
|
||||||
|
|
||||||
|
def test_tags_from_dir_no_tags(self) -> None:
|
||||||
|
all_tags = Message.tags_from_dir(pathlib.Path(self.temp_dir_no_tags.name))
|
||||||
|
self.assertSetEqual(all_tags, set())
|
||||||
|
|
||||||
|
|
||||||
|
class MessageIDTestCase(unittest.TestCase):
|
||||||
|
def setUp(self) -> None:
|
||||||
|
self.file = tempfile.NamedTemporaryFile(delete=False, suffix=msg_suffix)
|
||||||
|
self.file_path = pathlib.Path(self.file.name)
|
||||||
|
self.message = Message(Question('This is a question.'),
|
||||||
|
file_path=self.file_path)
|
||||||
|
self.message_no_file_path = Message(Question('This is a question.'))
|
||||||
|
|
||||||
|
def tearDown(self) -> None:
|
||||||
|
self.file.close()
|
||||||
|
self.file_path.unlink()
|
||||||
|
|
||||||
|
def test_msg_id_txt(self) -> None:
|
||||||
|
self.assertEqual(self.message.msg_id(), self.file_path.stem)
|
||||||
|
|
||||||
|
def test_msg_id_txt_exception(self) -> None:
|
||||||
|
with self.assertRaises(MessageError):
|
||||||
|
self.message_no_file_path.msg_id()
|
||||||
|
|
||||||
|
|
||||||
|
class MessageHashTestCase(unittest.TestCase):
|
||||||
|
def setUp(self) -> None:
|
||||||
|
self.message1 = Message(Question('This is a question.'),
|
||||||
|
tags={Tag('tag1')},
|
||||||
|
file_path=pathlib.Path('/tmp/foo/bla'))
|
||||||
|
self.message2 = Message(Question('This is a new question.'),
|
||||||
|
file_path=pathlib.Path('/tmp/foo/bla'))
|
||||||
|
self.message3 = Message(Question('This is a question.'),
|
||||||
|
Answer('This is an answer.'),
|
||||||
|
file_path=pathlib.Path('/tmp/foo/bla'))
|
||||||
|
# message4 is a copy of message1, because only question and
|
||||||
|
# answer are used for hashing and comparison
|
||||||
|
self.message4 = Message(Question('This is a question.'),
|
||||||
|
tags={Tag('tag1'), Tag('tag2')},
|
||||||
|
ai='Blabla',
|
||||||
|
file_path=pathlib.Path('foobla'))
|
||||||
|
|
||||||
|
def test_set_hashing(self) -> None:
|
||||||
|
msgs: set[Message] = {self.message1, self.message2, self.message3, self.message4}
|
||||||
|
self.assertEqual(len(msgs), 3)
|
||||||
|
for msg in [self.message1, self.message2, self.message3]:
|
||||||
|
self.assertIn(msg, msgs)
|
||||||
|
|
||||||
|
|
||||||
|
class MessageTagsStrTestCase(unittest.TestCase):
|
||||||
|
def setUp(self) -> None:
|
||||||
|
self.message = Message(Question('This is a question.'),
|
||||||
|
tags={Tag('tag1')},
|
||||||
|
file_path=pathlib.Path('/tmp/foo/bla'))
|
||||||
|
|
||||||
|
def test_tags_str(self) -> None:
|
||||||
|
self.assertEqual(self.message.tags_str(), f'{TagLine.prefix} tag1')
|
||||||
|
|
||||||
|
|
||||||
|
class MessageFilterTagsTestCase(unittest.TestCase):
|
||||||
|
def setUp(self) -> None:
|
||||||
|
self.message = Message(Question('This is a question.'),
|
||||||
|
tags={Tag('atag1'), Tag('btag2')},
|
||||||
|
file_path=pathlib.Path('/tmp/foo/bla'))
|
||||||
|
|
||||||
|
def test_filter_tags(self) -> None:
|
||||||
|
tags_all = self.message.filter_tags()
|
||||||
|
self.assertSetEqual(tags_all, {Tag('atag1'), Tag('btag2')})
|
||||||
|
tags_pref = self.message.filter_tags(prefix='a')
|
||||||
|
self.assertSetEqual(tags_pref, {Tag('atag1')})
|
||||||
|
tags_cont = self.message.filter_tags(contain='2')
|
||||||
|
self.assertSetEqual(tags_cont, {Tag('btag2')})
|
||||||
|
|
||||||
|
|
||||||
|
class MessageInTestCase(unittest.TestCase):
|
||||||
|
def setUp(self) -> None:
|
||||||
|
self.message1 = Message(Question('This is a question.'),
|
||||||
|
tags={Tag('atag1'), Tag('btag2')},
|
||||||
|
file_path=pathlib.Path('/tmp/foo/bla'))
|
||||||
|
self.message2 = Message(Question('This is a question.'),
|
||||||
|
tags={Tag('atag1'), Tag('btag2')},
|
||||||
|
file_path=pathlib.Path('/tmp/bla/foo'))
|
||||||
|
|
||||||
|
def test_message_in(self) -> None:
|
||||||
|
self.assertTrue(message_in(self.message1, [self.message1]))
|
||||||
|
self.assertFalse(message_in(self.message1, [self.message2]))
|
||||||
|
|
||||||
|
|
||||||
|
class MessageRenameTagsTestCase(unittest.TestCase):
|
||||||
|
def setUp(self) -> None:
|
||||||
|
self.message = Message(Question('This is a question.'),
|
||||||
|
tags={Tag('atag1'), Tag('btag2')},
|
||||||
|
file_path=pathlib.Path('/tmp/foo/bla'))
|
||||||
|
|
||||||
|
def test_rename_tags(self) -> None:
|
||||||
|
self.message.rename_tags({(Tag('atag1'), Tag('atag2')), (Tag('btag2'), Tag('btag3'))})
|
||||||
|
self.assertIsNotNone(self.message.tags)
|
||||||
|
self.assertSetEqual(self.message.tags, {Tag('atag2'), Tag('btag3')}) # type: ignore [arg-type]
|
||||||
|
|
||||||
|
|
||||||
|
class MessageToStrTestCase(unittest.TestCase):
|
||||||
|
def setUp(self) -> None:
|
||||||
|
self.message = Message(Question('This is a question.'),
|
||||||
|
Answer('This is an answer.'),
|
||||||
|
ai=('FakeAI'),
|
||||||
|
model=('FakeModel'),
|
||||||
|
tags={Tag('atag1'), Tag('btag2')},
|
||||||
|
file_path=pathlib.Path('/tmp/foo/bla'))
|
||||||
|
|
||||||
|
def test_to_str(self) -> None:
|
||||||
|
expected_output = f"""{Question.txt_header}
|
||||||
|
This is a question.
|
||||||
|
{Answer.txt_header}
|
||||||
|
This is an answer."""
|
||||||
|
self.assertEqual(self.message.to_str(), expected_output)
|
||||||
|
|
||||||
|
def test_to_str_with_tags_and_file(self) -> None:
|
||||||
|
expected_output = f"""{TagLine.prefix} atag1 btag2
|
||||||
|
FILE: /tmp/foo/bla
|
||||||
|
AI: FakeAI
|
||||||
|
MODEL: FakeModel
|
||||||
|
{Question.txt_header}
|
||||||
|
This is a question.
|
||||||
|
{Answer.txt_header}
|
||||||
|
This is an answer."""
|
||||||
|
self.assertEqual(self.message.to_str(with_metadata=True), expected_output)
|
||||||
|
|
||||||
|
|
||||||
|
class MessageRmFileTestCase(unittest.TestCase):
|
||||||
|
def setUp(self) -> None:
|
||||||
|
self.file = tempfile.NamedTemporaryFile(delete=False, suffix=msg_suffix)
|
||||||
|
self.file_path = pathlib.Path(self.file.name)
|
||||||
|
self.message = Message(Question('This is a question.'),
|
||||||
|
file_path=self.file_path)
|
||||||
|
self.message.to_file()
|
||||||
|
|
||||||
|
def tearDown(self) -> None:
|
||||||
|
self.file.close()
|
||||||
|
self.file_path.unlink(missing_ok=True)
|
||||||
|
|
||||||
|
def test_rm_file(self) -> None:
|
||||||
|
assert self.message.file_path
|
||||||
|
self.assertTrue(self.message.file_path.exists())
|
||||||
|
self.message.rm_file()
|
||||||
|
self.assertFalse(self.message.file_path.exists())
|
||||||
@@ -0,0 +1,595 @@
|
|||||||
|
import os
|
||||||
|
import argparse
|
||||||
|
import tempfile
|
||||||
|
from copy import copy
|
||||||
|
from pathlib import Path
|
||||||
|
from unittest import mock
|
||||||
|
from unittest.mock import MagicMock, call
|
||||||
|
from chatmastermind.configuration import Config
|
||||||
|
from chatmastermind.commands.question import create_message, question_cmd
|
||||||
|
from chatmastermind.tags import Tag
|
||||||
|
from chatmastermind.message import Message, Question, Answer
|
||||||
|
from chatmastermind.chat import Chat, ChatDB, msg_location
|
||||||
|
from chatmastermind.ai import AIError
|
||||||
|
from .test_common import TestWithFakeAI
|
||||||
|
|
||||||
|
|
||||||
|
msg_suffix = Message.file_suffix_write
|
||||||
|
|
||||||
|
|
||||||
|
class TestMessageCreate(TestWithFakeAI):
|
||||||
|
"""
|
||||||
|
Test if messages created by the 'question' command have
|
||||||
|
the correct format.
|
||||||
|
"""
|
||||||
|
def setUp(self) -> None:
|
||||||
|
# create ChatDB structure
|
||||||
|
self.db_dir = tempfile.TemporaryDirectory()
|
||||||
|
self.cache_dir = tempfile.TemporaryDirectory()
|
||||||
|
self.chat = ChatDB.from_dir(cache_path=Path(self.cache_dir.name),
|
||||||
|
db_path=Path(self.db_dir.name))
|
||||||
|
# create some messages
|
||||||
|
self.message_text = Message(Question("What is this?"),
|
||||||
|
Answer("It is pure text"))
|
||||||
|
self.message_code = Message(Question("What is this?"),
|
||||||
|
Answer("Text\n```\nIt is embedded code\n```\ntext"))
|
||||||
|
self.chat.db_add([self.message_text, self.message_code])
|
||||||
|
# create arguments mock
|
||||||
|
self.args = MagicMock(spec=argparse.Namespace)
|
||||||
|
self.args.source_text = None
|
||||||
|
self.args.source_code = None
|
||||||
|
self.args.AI = None
|
||||||
|
self.args.model = None
|
||||||
|
self.args.output_tags = None
|
||||||
|
self.args.ask = None
|
||||||
|
self.args.create = None
|
||||||
|
# File 1 : no source code block, only text
|
||||||
|
self.source_file1 = tempfile.NamedTemporaryFile(delete=False)
|
||||||
|
self.source_file1_content = """This is just text.
|
||||||
|
No source code.
|
||||||
|
Nope. Go look elsewhere!"""
|
||||||
|
with open(self.source_file1.name, 'w') as f:
|
||||||
|
f.write(self.source_file1_content)
|
||||||
|
# File 2 : one embedded source code block
|
||||||
|
self.source_file2 = tempfile.NamedTemporaryFile(delete=False)
|
||||||
|
self.source_file2_content = """This is just text.
|
||||||
|
```
|
||||||
|
This is embedded source code.
|
||||||
|
```
|
||||||
|
And some text again."""
|
||||||
|
with open(self.source_file2.name, 'w') as f:
|
||||||
|
f.write(self.source_file2_content)
|
||||||
|
# File 3 : all source code
|
||||||
|
self.source_file3 = tempfile.NamedTemporaryFile(delete=False)
|
||||||
|
self.source_file3_content = """This is all source code.
|
||||||
|
Yes, really.
|
||||||
|
Language is called 'brainfart'."""
|
||||||
|
with open(self.source_file3.name, 'w') as f:
|
||||||
|
f.write(self.source_file3_content)
|
||||||
|
# File 4 : two source code blocks
|
||||||
|
self.source_file4 = tempfile.NamedTemporaryFile(delete=False)
|
||||||
|
self.source_file4_content = """This is just text.
|
||||||
|
```
|
||||||
|
This is embedded source code.
|
||||||
|
```
|
||||||
|
And some text again.
|
||||||
|
```
|
||||||
|
This is embedded source code.
|
||||||
|
```
|
||||||
|
Aaaand again some text."""
|
||||||
|
with open(self.source_file4.name, 'w') as f:
|
||||||
|
f.write(self.source_file4_content)
|
||||||
|
|
||||||
|
def tearDown(self) -> None:
|
||||||
|
os.remove(self.source_file1.name)
|
||||||
|
os.remove(self.source_file2.name)
|
||||||
|
os.remove(self.source_file3.name)
|
||||||
|
os.remove(self.source_file4.name)
|
||||||
|
|
||||||
|
def message_list(self, tmp_dir: tempfile.TemporaryDirectory) -> list[Path]:
|
||||||
|
# exclude '.next'
|
||||||
|
return list(Path(tmp_dir.name).glob(f'*{msg_suffix}'))
|
||||||
|
|
||||||
|
def test_message_file_created(self) -> None:
|
||||||
|
self.args.ask = ["What is this?"]
|
||||||
|
cache_dir_files = self.message_list(self.cache_dir)
|
||||||
|
self.assertEqual(len(cache_dir_files), 0)
|
||||||
|
create_message(self.chat, self.args)
|
||||||
|
cache_dir_files = self.message_list(self.cache_dir)
|
||||||
|
self.assertEqual(len(cache_dir_files), 1)
|
||||||
|
message = Message.from_file(cache_dir_files[0])
|
||||||
|
self.assertIsInstance(message, Message)
|
||||||
|
self.assertEqual(message.question, Question("What is this?")) # type: ignore [union-attr]
|
||||||
|
|
||||||
|
def test_single_question(self) -> None:
|
||||||
|
self.args.ask = ["What is this?"]
|
||||||
|
message = create_message(self.chat, self.args)
|
||||||
|
self.assertIsInstance(message, Message)
|
||||||
|
self.assertEqual(message.question, Question("What is this?"))
|
||||||
|
self.assertEqual(len(message.question.source_code()), 0)
|
||||||
|
|
||||||
|
def test_multipart_question(self) -> None:
|
||||||
|
self.args.ask = ["What is this", "'bard' thing?", "Is it good?"]
|
||||||
|
message = create_message(self.chat, self.args)
|
||||||
|
self.assertIsInstance(message, Message)
|
||||||
|
self.assertEqual(message.question, Question("""What is this
|
||||||
|
|
||||||
|
'bard' thing?
|
||||||
|
|
||||||
|
Is it good?"""))
|
||||||
|
|
||||||
|
def test_single_question_with_text_only_file(self) -> None:
|
||||||
|
self.args.ask = ["What is this?"]
|
||||||
|
self.args.source_text = [f"{self.source_file1.name}"]
|
||||||
|
message = create_message(self.chat, self.args)
|
||||||
|
self.assertIsInstance(message, Message)
|
||||||
|
# file contains no source code (only text)
|
||||||
|
# -> don't expect any in the question
|
||||||
|
self.assertEqual(len(message.question.source_code()), 0)
|
||||||
|
self.assertEqual(message.question, Question(f"""What is this?
|
||||||
|
|
||||||
|
{self.source_file1_content}"""))
|
||||||
|
|
||||||
|
def test_single_question_with_text_file_and_embedded_code(self) -> None:
|
||||||
|
self.args.ask = ["What is this?"]
|
||||||
|
self.args.source_code = [f"{self.source_file2.name}"]
|
||||||
|
message = create_message(self.chat, self.args)
|
||||||
|
self.assertIsInstance(message, Message)
|
||||||
|
# file contains 1 source code block
|
||||||
|
# -> expect it in the question
|
||||||
|
self.assertEqual(len(message.question.source_code()), 1)
|
||||||
|
self.assertEqual(message.question, Question("""What is this?
|
||||||
|
|
||||||
|
```
|
||||||
|
This is embedded source code.
|
||||||
|
```
|
||||||
|
"""))
|
||||||
|
|
||||||
|
def test_single_question_with_code_only_file(self) -> None:
|
||||||
|
self.args.ask = ["What is this?"]
|
||||||
|
self.args.source_code = [f"{self.source_file3.name}"]
|
||||||
|
message = create_message(self.chat, self.args)
|
||||||
|
self.assertIsInstance(message, Message)
|
||||||
|
# file is complete source code
|
||||||
|
self.assertEqual(len(message.question.source_code()), 1)
|
||||||
|
self.assertEqual(message.question, Question(f"""What is this?
|
||||||
|
|
||||||
|
```
|
||||||
|
{self.source_file3_content}
|
||||||
|
```"""))
|
||||||
|
|
||||||
|
def test_single_question_with_text_file_and_multi_embedded_code(self) -> None:
|
||||||
|
self.args.ask = ["What is this?"]
|
||||||
|
self.args.source_code = [f"{self.source_file4.name}"]
|
||||||
|
message = create_message(self.chat, self.args)
|
||||||
|
self.assertIsInstance(message, Message)
|
||||||
|
# file contains 2 source code blocks
|
||||||
|
# -> expect them in the question
|
||||||
|
self.assertEqual(len(message.question.source_code()), 2)
|
||||||
|
self.assertEqual(message.question, Question("""What is this?
|
||||||
|
|
||||||
|
```
|
||||||
|
This is embedded source code.
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
```
|
||||||
|
This is embedded source code.
|
||||||
|
```
|
||||||
|
"""))
|
||||||
|
|
||||||
|
def test_single_question_with_text_only_message(self) -> None:
|
||||||
|
self.args.ask = ["What is this?"]
|
||||||
|
self.args.source_text = [f"{self.chat.messages[0].file_path}"]
|
||||||
|
message = create_message(self.chat, self.args)
|
||||||
|
self.assertIsInstance(message, Message)
|
||||||
|
# file contains no source code (only text)
|
||||||
|
# -> don't expect any in the question
|
||||||
|
self.assertEqual(len(message.question.source_code()), 0)
|
||||||
|
self.assertEqual(message.question, Question(f"""What is this?
|
||||||
|
|
||||||
|
{self.message_text.answer}"""))
|
||||||
|
|
||||||
|
def test_single_question_with_message_and_embedded_code(self) -> None:
|
||||||
|
self.args.ask = ["What is this?"]
|
||||||
|
self.args.source_code = [f"{self.chat.messages[1].file_path}"]
|
||||||
|
message = create_message(self.chat, self.args)
|
||||||
|
self.assertIsInstance(message, Message)
|
||||||
|
# answer contains 1 source code block
|
||||||
|
# -> expect it in the question
|
||||||
|
self.assertEqual(len(message.question.source_code()), 1)
|
||||||
|
self.assertEqual(message.question, Question("""What is this?
|
||||||
|
|
||||||
|
```
|
||||||
|
It is embedded code
|
||||||
|
```
|
||||||
|
"""))
|
||||||
|
|
||||||
|
|
||||||
|
class TestCreateOption(TestMessageCreate):
|
||||||
|
|
||||||
|
def test_message_file_created(self) -> None:
|
||||||
|
self.args.create = ["How does question --create work?"]
|
||||||
|
self.args.ask = None
|
||||||
|
cache_dir_files = self.message_list(self.cache_dir)
|
||||||
|
self.assertEqual(len(cache_dir_files), 0)
|
||||||
|
create_message(self.chat, self.args)
|
||||||
|
cache_dir_files = self.message_list(self.cache_dir)
|
||||||
|
self.assertEqual(len(cache_dir_files), 1)
|
||||||
|
message = Message.from_file(cache_dir_files[0])
|
||||||
|
self.assertIsInstance(message, Message)
|
||||||
|
self.assertEqual(message.question, Question("How does question --create work?")) # type: ignore [union-attr]
|
||||||
|
|
||||||
|
|
||||||
|
class TestQuestionCmd(TestWithFakeAI):
|
||||||
|
|
||||||
|
def setUp(self) -> None:
|
||||||
|
# create DB and cache
|
||||||
|
self.db_dir = tempfile.TemporaryDirectory()
|
||||||
|
self.cache_dir = tempfile.TemporaryDirectory()
|
||||||
|
# create configuration
|
||||||
|
self.config = Config()
|
||||||
|
self.config.cache = self.cache_dir.name
|
||||||
|
self.config.db = self.db_dir.name
|
||||||
|
# create a mock argparse.Namespace
|
||||||
|
self.args = argparse.Namespace(
|
||||||
|
ask=['What is the meaning of life?'],
|
||||||
|
glob=None,
|
||||||
|
location='db',
|
||||||
|
num_answers=1,
|
||||||
|
output_tags=['science'],
|
||||||
|
AI='FakeAI',
|
||||||
|
model='FakeModel',
|
||||||
|
or_tags=None,
|
||||||
|
and_tags=None,
|
||||||
|
exclude_tags=None,
|
||||||
|
source_text=None,
|
||||||
|
source_code=None,
|
||||||
|
create=None,
|
||||||
|
repeat=None,
|
||||||
|
process=None,
|
||||||
|
overwrite=None
|
||||||
|
)
|
||||||
|
|
||||||
|
def message_list(self, tmp_dir: tempfile.TemporaryDirectory) -> list[Path]:
|
||||||
|
# exclude '.next'
|
||||||
|
return sorted([f for f in Path(tmp_dir.name).glob(f'*{msg_suffix}')])
|
||||||
|
|
||||||
|
|
||||||
|
class TestQuestionCmdAsk(TestQuestionCmd):
|
||||||
|
|
||||||
|
@mock.patch('chatmastermind.commands.question.create_ai')
|
||||||
|
def test_ask_single_answer(self, mock_create_ai: MagicMock) -> None:
|
||||||
|
"""
|
||||||
|
Test single answer with no errors.
|
||||||
|
"""
|
||||||
|
mock_create_ai.side_effect = self.mock_create_ai
|
||||||
|
expected_question = Message(Question(self.args.ask[0]),
|
||||||
|
tags=set(self.args.output_tags),
|
||||||
|
ai=self.args.AI,
|
||||||
|
model=self.args.model,
|
||||||
|
file_path=Path('<NOT COMPARED>'))
|
||||||
|
fake_ai = self.mock_create_ai(self.args, self.config)
|
||||||
|
expected_responses = fake_ai.request(expected_question,
|
||||||
|
Chat([]),
|
||||||
|
self.args.num_answers,
|
||||||
|
self.args.output_tags).messages
|
||||||
|
|
||||||
|
# execute the command
|
||||||
|
question_cmd(self.args, self.config)
|
||||||
|
|
||||||
|
# check for the expected message files
|
||||||
|
chat = ChatDB.from_dir(Path(self.cache_dir.name),
|
||||||
|
Path(self.db_dir.name))
|
||||||
|
cached_msg = chat.msg_gather(loc=msg_location.CACHE)
|
||||||
|
self.assertEqual(len(self.message_list(self.cache_dir)), 1)
|
||||||
|
self.assert_msgs_equal_except_file_path(cached_msg, expected_responses)
|
||||||
|
|
||||||
|
@mock.patch('chatmastermind.commands.question.ChatDB.from_dir')
|
||||||
|
@mock.patch('chatmastermind.commands.question.create_ai')
|
||||||
|
def test_ask_single_answer_mocked(self, mock_create_ai: MagicMock, mock_from_dir: MagicMock) -> None:
|
||||||
|
"""
|
||||||
|
Test single answer with no errors (mocked ChatDB version).
|
||||||
|
"""
|
||||||
|
chat = MagicMock(spec=ChatDB)
|
||||||
|
mock_from_dir.return_value = chat
|
||||||
|
|
||||||
|
mock_create_ai.side_effect = self.mock_create_ai
|
||||||
|
expected_question = Message(Question(self.args.ask[0]),
|
||||||
|
tags=set(self.args.output_tags),
|
||||||
|
ai=self.args.AI,
|
||||||
|
model=self.args.model,
|
||||||
|
file_path=Path('<NOT COMPARED>'))
|
||||||
|
fake_ai = self.mock_create_ai(self.args, self.config)
|
||||||
|
expected_responses = fake_ai.request(expected_question,
|
||||||
|
Chat([]),
|
||||||
|
self.args.num_answers,
|
||||||
|
self.args.output_tags).messages
|
||||||
|
|
||||||
|
# execute the command
|
||||||
|
question_cmd(self.args, self.config)
|
||||||
|
|
||||||
|
# check for the correct ChatDB calls:
|
||||||
|
# - initial question has been written (prior to the actual request)
|
||||||
|
# - responses have been written (after the request)
|
||||||
|
chat.cache_write.assert_has_calls([call([expected_question]),
|
||||||
|
call(expected_responses)],
|
||||||
|
any_order=False)
|
||||||
|
|
||||||
|
# check that the messages have not been added to the internal message list
|
||||||
|
chat.cache_add.assert_not_called()
|
||||||
|
|
||||||
|
@mock.patch('chatmastermind.commands.question.create_ai')
|
||||||
|
def test_ask_with_error(self, mock_create_ai: MagicMock) -> None:
|
||||||
|
"""
|
||||||
|
Provoke an error during the AI request and verify that the question
|
||||||
|
has been correctly stored in the cache.
|
||||||
|
"""
|
||||||
|
mock_create_ai.side_effect = self.mock_create_ai_with_error
|
||||||
|
expected_question = Message(Question(self.args.ask[0]),
|
||||||
|
tags=set(self.args.output_tags),
|
||||||
|
ai=self.args.AI,
|
||||||
|
model=self.args.model,
|
||||||
|
file_path=Path('<NOT COMPARED>'))
|
||||||
|
|
||||||
|
# execute the command
|
||||||
|
with self.assertRaises(AIError):
|
||||||
|
question_cmd(self.args, self.config)
|
||||||
|
|
||||||
|
# check for the expected message files
|
||||||
|
chat = ChatDB.from_dir(Path(self.cache_dir.name),
|
||||||
|
Path(self.db_dir.name))
|
||||||
|
cached_msg = chat.msg_gather(loc=msg_location.CACHE)
|
||||||
|
self.assertEqual(len(self.message_list(self.cache_dir)), 1)
|
||||||
|
self.assert_msgs_equal_except_file_path(cached_msg, [expected_question])
|
||||||
|
|
||||||
|
|
||||||
|
class TestQuestionCmdRepeat(TestQuestionCmd):
|
||||||
|
|
||||||
|
@mock.patch('chatmastermind.commands.question.create_ai')
|
||||||
|
def test_repeat_single_question(self, mock_create_ai: MagicMock) -> None:
|
||||||
|
"""
|
||||||
|
Repeat a single question.
|
||||||
|
"""
|
||||||
|
mock_create_ai.side_effect = self.mock_create_ai
|
||||||
|
chat = ChatDB.from_dir(Path(self.cache_dir.name),
|
||||||
|
Path(self.db_dir.name))
|
||||||
|
# create a message
|
||||||
|
message = Message(Question(self.args.ask[0]),
|
||||||
|
Answer('Old Answer'),
|
||||||
|
tags=set(self.args.output_tags),
|
||||||
|
ai=self.args.AI,
|
||||||
|
model=self.args.model,
|
||||||
|
file_path=Path(self.cache_dir.name) / f'0001{msg_suffix}')
|
||||||
|
chat.msg_write([message])
|
||||||
|
|
||||||
|
# repeat the last question (without overwriting)
|
||||||
|
# -> expect two identical messages (except for the file_path)
|
||||||
|
self.args.ask = None
|
||||||
|
self.args.repeat = []
|
||||||
|
self.args.overwrite = False
|
||||||
|
expected_response = Message(Question(message.question),
|
||||||
|
Answer('Answer 0'),
|
||||||
|
ai=message.ai,
|
||||||
|
model=message.model,
|
||||||
|
tags=message.tags,
|
||||||
|
file_path=Path('<NOT COMPARED>'))
|
||||||
|
# we expect the original message + the one with the new response
|
||||||
|
expected_responses = [message] + [expected_response]
|
||||||
|
question_cmd(self.args, self.config)
|
||||||
|
cached_msg = chat.msg_gather(loc=msg_location.CACHE)
|
||||||
|
print(self.message_list(self.cache_dir))
|
||||||
|
self.assertEqual(len(self.message_list(self.cache_dir)), 2)
|
||||||
|
self.assert_msgs_equal_except_file_path(cached_msg, expected_responses)
|
||||||
|
|
||||||
|
@mock.patch('chatmastermind.commands.question.create_ai')
|
||||||
|
def test_repeat_single_question_overwrite(self, mock_create_ai: MagicMock) -> None:
|
||||||
|
"""
|
||||||
|
Repeat a single question and overwrite the old one.
|
||||||
|
"""
|
||||||
|
mock_create_ai.side_effect = self.mock_create_ai
|
||||||
|
chat = ChatDB.from_dir(Path(self.cache_dir.name),
|
||||||
|
Path(self.db_dir.name))
|
||||||
|
# create a message
|
||||||
|
message = Message(Question(self.args.ask[0]),
|
||||||
|
Answer('Old Answer'),
|
||||||
|
tags=set(self.args.output_tags),
|
||||||
|
ai=self.args.AI,
|
||||||
|
model=self.args.model,
|
||||||
|
file_path=Path(self.cache_dir.name) / f'0001{msg_suffix}')
|
||||||
|
chat.msg_write([message])
|
||||||
|
cached_msg = chat.msg_gather(loc=msg_location.CACHE)
|
||||||
|
assert cached_msg[0].file_path
|
||||||
|
cached_msg_file_id = cached_msg[0].file_path.stem
|
||||||
|
|
||||||
|
# repeat the last question (WITH overwriting)
|
||||||
|
# -> expect a single message afterwards (with a new answer)
|
||||||
|
self.args.ask = None
|
||||||
|
self.args.repeat = []
|
||||||
|
self.args.overwrite = True
|
||||||
|
expected_response = Message(Question(message.question),
|
||||||
|
Answer('Answer 0'),
|
||||||
|
ai=message.ai,
|
||||||
|
model=message.model,
|
||||||
|
tags=message.tags,
|
||||||
|
file_path=Path('<NOT COMPARED>'))
|
||||||
|
question_cmd(self.args, self.config)
|
||||||
|
cached_msg = chat.msg_gather(loc=msg_location.CACHE)
|
||||||
|
self.assertEqual(len(self.message_list(self.cache_dir)), 1)
|
||||||
|
self.assert_msgs_equal_except_file_path(cached_msg, [expected_response])
|
||||||
|
# also check that the file ID has not been changed
|
||||||
|
assert cached_msg[0].file_path
|
||||||
|
self.assertEqual(cached_msg_file_id, cached_msg[0].file_path.stem)
|
||||||
|
|
||||||
|
@mock.patch('chatmastermind.commands.question.create_ai')
|
||||||
|
def test_repeat_single_question_after_error(self, mock_create_ai: MagicMock) -> None:
|
||||||
|
"""
|
||||||
|
Repeat a single question after an error.
|
||||||
|
"""
|
||||||
|
mock_create_ai.side_effect = self.mock_create_ai
|
||||||
|
chat = ChatDB.from_dir(Path(self.cache_dir.name),
|
||||||
|
Path(self.db_dir.name))
|
||||||
|
# create a question WITHOUT an answer
|
||||||
|
# -> just like after an error, which is tested above
|
||||||
|
message = Message(Question(self.args.ask[0]),
|
||||||
|
tags=set(self.args.output_tags),
|
||||||
|
ai=self.args.AI,
|
||||||
|
model=self.args.model,
|
||||||
|
file_path=Path(self.cache_dir.name) / f'0001{msg_suffix}')
|
||||||
|
chat.msg_write([message])
|
||||||
|
cached_msg = chat.msg_gather(loc=msg_location.CACHE)
|
||||||
|
assert cached_msg[0].file_path
|
||||||
|
cached_msg_file_id = cached_msg[0].file_path.stem
|
||||||
|
|
||||||
|
# repeat the last question (without overwriting)
|
||||||
|
# -> expect a single message because if the original has
|
||||||
|
# no answer, it should be overwritten by default
|
||||||
|
self.args.ask = None
|
||||||
|
self.args.repeat = []
|
||||||
|
self.args.overwrite = False
|
||||||
|
expected_response = Message(Question(message.question),
|
||||||
|
Answer('Answer 0'),
|
||||||
|
ai=message.ai,
|
||||||
|
model=message.model,
|
||||||
|
tags=message.tags,
|
||||||
|
file_path=Path('<NOT COMPARED>'))
|
||||||
|
question_cmd(self.args, self.config)
|
||||||
|
cached_msg = chat.msg_gather(loc=msg_location.CACHE)
|
||||||
|
self.assertEqual(len(self.message_list(self.cache_dir)), 1)
|
||||||
|
self.assert_msgs_equal_except_file_path(cached_msg, [expected_response])
|
||||||
|
# also check that the file ID has not been changed
|
||||||
|
assert cached_msg[0].file_path
|
||||||
|
self.assertEqual(cached_msg_file_id, cached_msg[0].file_path.stem)
|
||||||
|
|
||||||
|
@mock.patch('chatmastermind.commands.question.create_ai')
|
||||||
|
def test_repeat_single_question_new_args(self, mock_create_ai: MagicMock) -> None:
|
||||||
|
"""
|
||||||
|
Repeat a single question with new arguments.
|
||||||
|
"""
|
||||||
|
mock_create_ai.side_effect = self.mock_create_ai
|
||||||
|
chat = ChatDB.from_dir(Path(self.cache_dir.name),
|
||||||
|
Path(self.db_dir.name))
|
||||||
|
# create a message
|
||||||
|
message = Message(Question(self.args.ask[0]),
|
||||||
|
Answer('Old Answer'),
|
||||||
|
tags=set(self.args.output_tags),
|
||||||
|
ai=self.args.AI,
|
||||||
|
model=self.args.model,
|
||||||
|
file_path=Path(self.cache_dir.name) / f'0001{msg_suffix}')
|
||||||
|
chat.msg_write([message])
|
||||||
|
cached_msg = chat.msg_gather(loc=msg_location.CACHE)
|
||||||
|
assert cached_msg[0].file_path
|
||||||
|
|
||||||
|
# repeat the last question with new arguments (without overwriting)
|
||||||
|
# -> expect two messages with identical question but different metadata and new answer
|
||||||
|
self.args.ask = None
|
||||||
|
self.args.repeat = []
|
||||||
|
self.args.overwrite = False
|
||||||
|
self.args.output_tags = ['newtag']
|
||||||
|
self.args.AI = 'newai'
|
||||||
|
self.args.model = 'newmodel'
|
||||||
|
new_expected_response = Message(Question(message.question),
|
||||||
|
Answer('Answer 0'),
|
||||||
|
ai='newai',
|
||||||
|
model='newmodel',
|
||||||
|
tags={Tag('newtag')},
|
||||||
|
file_path=Path('<NOT COMPARED>'))
|
||||||
|
question_cmd(self.args, self.config)
|
||||||
|
cached_msg = chat.msg_gather(loc=msg_location.CACHE)
|
||||||
|
self.assertEqual(len(self.message_list(self.cache_dir)), 2)
|
||||||
|
self.assert_msgs_equal_except_file_path(cached_msg, [message] + [new_expected_response])
|
||||||
|
|
||||||
|
@mock.patch('chatmastermind.commands.question.create_ai')
|
||||||
|
def test_repeat_single_question_new_args_overwrite(self, mock_create_ai: MagicMock) -> None:
|
||||||
|
"""
|
||||||
|
Repeat a single question with new arguments, overwriting the old one.
|
||||||
|
"""
|
||||||
|
mock_create_ai.side_effect = self.mock_create_ai
|
||||||
|
chat = ChatDB.from_dir(Path(self.cache_dir.name),
|
||||||
|
Path(self.db_dir.name))
|
||||||
|
# create a message
|
||||||
|
message = Message(Question(self.args.ask[0]),
|
||||||
|
Answer('Old Answer'),
|
||||||
|
tags=set(self.args.output_tags),
|
||||||
|
ai=self.args.AI,
|
||||||
|
model=self.args.model,
|
||||||
|
file_path=Path(self.cache_dir.name) / f'0001{msg_suffix}')
|
||||||
|
chat.msg_write([message])
|
||||||
|
cached_msg = chat.msg_gather(loc=msg_location.CACHE)
|
||||||
|
assert cached_msg[0].file_path
|
||||||
|
|
||||||
|
# repeat the last question with new arguments
|
||||||
|
self.args.ask = None
|
||||||
|
self.args.repeat = []
|
||||||
|
self.args.overwrite = True
|
||||||
|
self.args.output_tags = ['newtag']
|
||||||
|
self.args.AI = 'newai'
|
||||||
|
self.args.model = 'newmodel'
|
||||||
|
new_expected_response = Message(Question(message.question),
|
||||||
|
Answer('Answer 0'),
|
||||||
|
ai='newai',
|
||||||
|
model='newmodel',
|
||||||
|
tags={Tag('newtag')},
|
||||||
|
file_path=Path('<NOT COMPARED>'))
|
||||||
|
question_cmd(self.args, self.config)
|
||||||
|
cached_msg = chat.msg_gather(loc=msg_location.CACHE)
|
||||||
|
self.assertEqual(len(self.message_list(self.cache_dir)), 1)
|
||||||
|
self.assert_msgs_equal_except_file_path(cached_msg, [new_expected_response])
|
||||||
|
|
||||||
|
@mock.patch('chatmastermind.commands.question.create_ai')
|
||||||
|
def test_repeat_multiple_questions(self, mock_create_ai: MagicMock) -> None:
|
||||||
|
"""
|
||||||
|
Repeat multiple questions.
|
||||||
|
"""
|
||||||
|
mock_create_ai.side_effect = self.mock_create_ai
|
||||||
|
chat = ChatDB.from_dir(Path(self.cache_dir.name),
|
||||||
|
Path(self.db_dir.name))
|
||||||
|
# 1. === create three questions ===
|
||||||
|
# cached message without an answer
|
||||||
|
message1 = Message(Question(self.args.ask[0]),
|
||||||
|
tags=self.args.output_tags,
|
||||||
|
ai=self.args.AI,
|
||||||
|
model=self.args.model,
|
||||||
|
file_path=Path(self.cache_dir.name) / f'0001{msg_suffix}')
|
||||||
|
# cached message with an answer
|
||||||
|
message2 = Message(Question(self.args.ask[0]),
|
||||||
|
Answer('Old Answer'),
|
||||||
|
tags=self.args.output_tags,
|
||||||
|
ai=self.args.AI,
|
||||||
|
model=self.args.model,
|
||||||
|
file_path=Path(self.cache_dir.name) / f'0002{msg_suffix}')
|
||||||
|
# DB message without an answer
|
||||||
|
message3 = Message(Question(self.args.ask[0]),
|
||||||
|
tags=self.args.output_tags,
|
||||||
|
ai=self.args.AI,
|
||||||
|
model=self.args.model,
|
||||||
|
file_path=Path(self.db_dir.name) / f'0003{msg_suffix}')
|
||||||
|
chat.msg_write([message1, message2, message3])
|
||||||
|
questions = [message1, message2, message3]
|
||||||
|
expected_responses: list[Message] = []
|
||||||
|
fake_ai = self.mock_create_ai(self.args, self.config)
|
||||||
|
for question in questions:
|
||||||
|
# since the message's answer is modified, we use a copy
|
||||||
|
# -> the original is used for comparison below
|
||||||
|
expected_responses += fake_ai.request(copy(question),
|
||||||
|
Chat([]),
|
||||||
|
self.args.num_answers,
|
||||||
|
set(self.args.output_tags)).messages
|
||||||
|
|
||||||
|
# 2. === repeat all three questions (without overwriting) ===
|
||||||
|
self.args.ask = None
|
||||||
|
self.args.repeat = ['0001', '0002', '0003']
|
||||||
|
self.args.overwrite = False
|
||||||
|
question_cmd(self.args, self.config)
|
||||||
|
# two new files should be in the cache directory
|
||||||
|
# * the repeated cached message with answer
|
||||||
|
# * the repeated DB message
|
||||||
|
# -> the cached message without answer should be overwritten
|
||||||
|
self.assertEqual(len(self.message_list(self.cache_dir)), 4)
|
||||||
|
self.assertEqual(len(self.message_list(self.db_dir)), 1)
|
||||||
|
expected_cache_messages = [expected_responses[0], message2, expected_responses[1], expected_responses[2]]
|
||||||
|
cached_msg = chat.msg_gather(loc=msg_location.CACHE)
|
||||||
|
self.assert_msgs_equal_except_file_path(cached_msg, expected_cache_messages)
|
||||||
|
# check that the DB message has not been modified at all
|
||||||
|
db_msg = chat.msg_gather(loc=msg_location.DB)
|
||||||
|
self.assert_msgs_all_equal(db_msg, [message3])
|
||||||
@@ -0,0 +1,163 @@
|
|||||||
|
import unittest
|
||||||
|
from chatmastermind.tags import Tag, TagLine, TagError
|
||||||
|
|
||||||
|
|
||||||
|
class TestTag(unittest.TestCase):
|
||||||
|
def test_valid_tag(self) -> None:
|
||||||
|
tag = Tag('mytag')
|
||||||
|
self.assertEqual(tag, 'mytag')
|
||||||
|
|
||||||
|
def test_invalid_tag(self) -> None:
|
||||||
|
with self.assertRaises(TagError):
|
||||||
|
Tag('tag with space')
|
||||||
|
|
||||||
|
def test_default_separator(self) -> None:
|
||||||
|
self.assertEqual(Tag.default_separator, ' ')
|
||||||
|
|
||||||
|
def test_alternative_separators(self) -> None:
|
||||||
|
self.assertEqual(Tag.alternative_separators, [','])
|
||||||
|
|
||||||
|
|
||||||
|
class TestTagLine(unittest.TestCase):
|
||||||
|
def test_valid_tagline(self) -> None:
|
||||||
|
tagline = TagLine('TAGS: tag1 tag2')
|
||||||
|
self.assertEqual(tagline, 'TAGS: tag1 tag2')
|
||||||
|
|
||||||
|
def test_valid_tagline_with_newline(self) -> None:
|
||||||
|
tagline = TagLine('TAGS: tag1\n tag2')
|
||||||
|
self.assertEqual(tagline, 'TAGS: tag1 tag2')
|
||||||
|
|
||||||
|
def test_invalid_tagline(self) -> None:
|
||||||
|
with self.assertRaises(TagError):
|
||||||
|
TagLine('tag1 tag2')
|
||||||
|
|
||||||
|
def test_prefix(self) -> None:
|
||||||
|
self.assertEqual(TagLine.prefix, 'TAGS:')
|
||||||
|
|
||||||
|
def test_from_set(self) -> None:
|
||||||
|
tags = {Tag('tag1'), Tag('tag2')}
|
||||||
|
tagline = TagLine.from_set(tags)
|
||||||
|
self.assertEqual(tagline, 'TAGS: tag1 tag2')
|
||||||
|
|
||||||
|
def test_tags(self) -> None:
|
||||||
|
tagline = TagLine('TAGS: atag1 btag2')
|
||||||
|
tags = tagline.tags()
|
||||||
|
self.assertEqual(tags, {Tag('atag1'), Tag('btag2')})
|
||||||
|
|
||||||
|
def test_tags_empty(self) -> None:
|
||||||
|
tagline = TagLine('TAGS:')
|
||||||
|
self.assertSetEqual(tagline.tags(), set())
|
||||||
|
|
||||||
|
def test_tags_with_newline(self) -> None:
|
||||||
|
tagline = TagLine('TAGS: tag1\n tag2')
|
||||||
|
tags = tagline.tags()
|
||||||
|
self.assertEqual(tags, {Tag('tag1'), Tag('tag2')})
|
||||||
|
|
||||||
|
def test_tags_prefix(self) -> None:
|
||||||
|
tagline = TagLine('TAGS: atag1 stag2 stag3')
|
||||||
|
tags = tagline.tags(prefix='a')
|
||||||
|
self.assertSetEqual(tags, {Tag('atag1')})
|
||||||
|
tags = tagline.tags(prefix='s')
|
||||||
|
self.assertSetEqual(tags, {Tag('stag2'), Tag('stag3')})
|
||||||
|
tags = tagline.tags(prefix='R')
|
||||||
|
self.assertSetEqual(tags, set())
|
||||||
|
|
||||||
|
def test_tags_contain(self) -> None:
|
||||||
|
tagline = TagLine('TAGS: atag1 stag2 stag3')
|
||||||
|
tags = tagline.tags(contain='t')
|
||||||
|
self.assertSetEqual(tags, {Tag('atag1'), Tag('stag2'), Tag('stag3')})
|
||||||
|
tags = tagline.tags(contain='1')
|
||||||
|
self.assertSetEqual(tags, {Tag('atag1')})
|
||||||
|
tags = tagline.tags(contain='R')
|
||||||
|
self.assertSetEqual(tags, set())
|
||||||
|
|
||||||
|
def test_merge(self) -> None:
|
||||||
|
tagline1 = TagLine('TAGS: tag1 tag2')
|
||||||
|
tagline2 = TagLine('TAGS: tag2 tag3')
|
||||||
|
merged_tagline = tagline1.merge({tagline2})
|
||||||
|
self.assertEqual(merged_tagline, 'TAGS: tag1 tag2 tag3')
|
||||||
|
|
||||||
|
def test_delete_tags(self) -> None:
|
||||||
|
tagline = TagLine('TAGS: tag1 tag2 tag3')
|
||||||
|
new_tagline = tagline.delete_tags({Tag('tag1'), Tag('tag3')})
|
||||||
|
self.assertEqual(new_tagline, 'TAGS: tag2')
|
||||||
|
|
||||||
|
def test_add_tags(self) -> None:
|
||||||
|
tagline = TagLine('TAGS: tag1')
|
||||||
|
new_tagline = tagline.add_tags({Tag('tag2'), Tag('tag3')})
|
||||||
|
self.assertEqual(new_tagline, 'TAGS: tag1 tag2 tag3')
|
||||||
|
|
||||||
|
def test_rename_tags(self) -> None:
|
||||||
|
tagline = TagLine('TAGS: old1 old2')
|
||||||
|
new_tagline = tagline.rename_tags({(Tag('old1'), Tag('new1')), (Tag('old2'), Tag('new2'))})
|
||||||
|
self.assertEqual(new_tagline, 'TAGS: new1 new2')
|
||||||
|
|
||||||
|
def test_match_tags(self) -> None:
|
||||||
|
tagline = TagLine('TAGS: tag1 tag2 tag3')
|
||||||
|
|
||||||
|
# Test case 1: Match any tag in 'tags_or'
|
||||||
|
tags_or = {Tag('tag1'), Tag('tag4')}
|
||||||
|
tags_and: set[Tag] = set()
|
||||||
|
tags_not: set[Tag] = set()
|
||||||
|
self.assertTrue(tagline.match_tags(tags_or, tags_and, tags_not))
|
||||||
|
|
||||||
|
# Test case 2: Match all tags in 'tags_and'
|
||||||
|
tags_or = set()
|
||||||
|
tags_and = {Tag('tag1'), Tag('tag2'), Tag('tag3')}
|
||||||
|
tags_not = set()
|
||||||
|
self.assertTrue(tagline.match_tags(tags_or, tags_and, tags_not))
|
||||||
|
|
||||||
|
# Test case 3: Match any tag in 'tags_or' and match all tags in 'tags_and'
|
||||||
|
tags_or = {Tag('tag1'), Tag('tag4')}
|
||||||
|
tags_and = {Tag('tag1'), Tag('tag2')}
|
||||||
|
tags_not = set()
|
||||||
|
self.assertTrue(tagline.match_tags(tags_or, tags_and, tags_not))
|
||||||
|
|
||||||
|
# Test case 4: Match any tag in 'tags_or', match all tags in 'tags_and', and exclude tags in 'tags_not'
|
||||||
|
tags_or = {Tag('tag1'), Tag('tag4')}
|
||||||
|
tags_and = {Tag('tag1'), Tag('tag2')}
|
||||||
|
tags_not = {Tag('tag5')}
|
||||||
|
self.assertTrue(tagline.match_tags(tags_or, tags_and, tags_not))
|
||||||
|
|
||||||
|
# Test case 5: No matching tags in 'tags_or'
|
||||||
|
tags_or = {Tag('tag4'), Tag('tag5')}
|
||||||
|
tags_and = set()
|
||||||
|
tags_not = set()
|
||||||
|
self.assertFalse(tagline.match_tags(tags_or, tags_and, tags_not))
|
||||||
|
|
||||||
|
# Test case 6: Not all tags in 'tags_and' are present
|
||||||
|
tags_or = set()
|
||||||
|
tags_and = {Tag('tag1'), Tag('tag2'), Tag('tag3'), Tag('tag4')}
|
||||||
|
tags_not = set()
|
||||||
|
self.assertFalse(tagline.match_tags(tags_or, tags_and, tags_not))
|
||||||
|
|
||||||
|
# Test case 7: Some tags in 'tags_not' are present
|
||||||
|
tags_or = {Tag('tag1')}
|
||||||
|
tags_and = set()
|
||||||
|
tags_not = {Tag('tag2')}
|
||||||
|
self.assertFalse(tagline.match_tags(tags_or, tags_and, tags_not))
|
||||||
|
|
||||||
|
# Test case 8: 'tags_or' and 'tags_and' are None, match all tags
|
||||||
|
tags_not = set()
|
||||||
|
self.assertTrue(tagline.match_tags(None, None, tags_not))
|
||||||
|
|
||||||
|
# Test case 9: 'tags_or' and 'tags_and' are None, match all tags except excluded tags
|
||||||
|
tags_not = {Tag('tag2')}
|
||||||
|
self.assertFalse(tagline.match_tags(None, None, tags_not))
|
||||||
|
|
||||||
|
# Test case 10: 'tags_or' and 'tags_and' are empty, match no tags
|
||||||
|
self.assertFalse(tagline.match_tags(set(), set(), None))
|
||||||
|
|
||||||
|
# Test case 11: 'tags_or' is empty, match no tags
|
||||||
|
self.assertFalse(tagline.match_tags(set(), None, None))
|
||||||
|
|
||||||
|
# Test case 12: 'tags_and' is empty, match no tags
|
||||||
|
self.assertFalse(tagline.match_tags(None, set(), None))
|
||||||
|
|
||||||
|
# Test case 13: 'tags_or' is empty, match 'tags_and'
|
||||||
|
tags_and = {Tag('tag1'), Tag('tag2')}
|
||||||
|
self.assertTrue(tagline.match_tags(None, tags_and, None))
|
||||||
|
|
||||||
|
# Test case 14: 'tags_and' is empty, match 'tags_or'
|
||||||
|
tags_or = {Tag('tag1'), Tag('tag2')}
|
||||||
|
self.assertTrue(tagline.match_tags(tags_or, None, None))
|
||||||
Reference in New Issue
Block a user