88 Commits

Author SHA1 Message Date
juk0de 25fffb6fea chat: db_read() and cache_read() now also support globbing and filtering 2023-09-17 10:59:29 +02:00
juk0de cf572e1882 chat: added functions db_move() and chat_move() (and tests) 2023-09-17 10:59:29 +02:00
juk0de 2fb7410b43 chat: added functions msg_in_cache() and msg_in_db(), also tests 2023-09-17 10:59:29 +02:00
juk0de 33ae27f00e chat: msg_remove() now supports multiple locations 2023-09-17 10:59:29 +02:00
juk0de f6a6e6036b chat: added validation during initialization 2023-09-17 10:59:29 +02:00
juk0de 525cdb92a1 message / chat: 'msg_id()' now returns 'file_path.stem' (removed suffix) 2023-09-17 10:59:29 +02:00
juk0de fc82f85b7c chat: added new functions: msg_unique_id(), msg_unique_content() and tests 2023-09-17 10:59:24 +02:00
juk0de d90845b58b chat: added new functions to ChatDB: msg_gather(), msg_find(), msg_remove() 2023-09-17 10:58:26 +02:00
juk0de 98777295d6 refactor: renamed (almost) all Chat/ChatDB functions 2023-09-17 10:58:26 +02:00
juk0de f6109949c8 chat: ChatDB now correctly ignores files that contain no valid messages 2023-09-17 10:58:10 +02:00
juk0de 071871f929 chat et al: '.next' and '.config.yaml' are now ignored by ChatDB 2023-09-14 16:06:00 +02:00
juk0de 5cb88dad1b chat: implemented special version of 'latest_message()' for the ChatDB class 2023-09-14 16:05:49 +02:00
juk0de 17a0264025 question_cmd: now also accepts Messages as source files 2023-09-13 17:44:39 +02:00
Oleksandr Kozachuk 7f4a16894e Add pre-commit checks into push webhook. 2023-09-13 11:08:02 +02:00
Oleksandr Kozachuk 26e3d38afb Add the Gitea web hooks. 2023-09-13 10:53:12 +02:00
juk0de b5af751193 openai: added test module 2023-09-13 09:01:00 +02:00
juk0de a7345cbc41 ai_factory: fixed argument parsing bug 2023-09-13 07:52:05 +02:00
juk0de 310cb9421e Merge pull request 'Cleanup after merge of restructurings #8' (#10) from cleanup into main
Reviewed-on: #10
2023-09-12 20:23:08 +02:00
Oleksandr Kozachuk 1ec3d6fcda Make it possible to specify the AI in config command. 2023-09-12 16:37:50 +02:00
Oleksandr Kozachuk 544bf0bf06 Improve README.md 2023-09-12 16:34:39 +02:00
Oleksandr Kozachuk f96e82bdd7 Implement the config -l and config -m commands. 2023-09-12 16:34:17 +02:00
Oleksandr Kozachuk 2b62cb8c4b Remove the -*terminal_width() to save space on screen. 2023-09-12 13:48:28 +02:00
juk0de a895c1fc6a Merge pull request 'ChatMasterMind Application Refactor and Enhancement' (#8) from restructurings into main
Reviewed-on: #8
2023-09-12 07:36:04 +02:00
Oleksandr Kozachuk ddfcc71510 Merge branch 'restructurings.main' into restructurings 2023-09-11 13:28:56 +02:00
Oleksandr Kozachuk 17de0b9967 Remove old code. 2023-09-11 13:17:59 +02:00
juk0de 33023d29f9 configuration: made 'default' AI ID optional 2023-09-11 13:09:45 +02:00
juk0de 481f9ecf7c configuration: improved config file format 2023-09-11 13:09:45 +02:00
juk0de 22fa187e5f question_cmd: when no tags are specified, no tags are selected 2023-09-11 13:09:45 +02:00
juk0de b840ebd792 message: to_file() now uses intermediate temporary file 2023-09-11 13:09:45 +02:00
juk0de 66908f5fed message: fixed matching with empty tag sets 2023-09-11 13:09:45 +02:00
juk0de 2e08ccf606 openai: stores AI.ID instead of AI.name in message 2023-09-11 13:09:44 +02:00
juk0de 595ff8e294 question_cmd: added message filtering by tags 2023-09-11 13:09:44 +02:00
juk0de faac42d3c2 question_cmd: fixed '--ask' command 2023-09-11 13:09:44 +02:00
juk0de 864ab7aeb1 chat: added check for existing files when creating new filenames 2023-09-11 13:09:44 +02:00
juk0de cc76da2ab3 chat: added 'update_messages()' function and test 2023-09-11 13:09:44 +02:00
juk0de f99cd3ed41 question_cmd: fixed source code extraction and added a testcase 2023-09-11 13:09:44 +02:00
Oleksandr Kozachuk 6f3ea98425 Small fixes. 2023-09-11 13:09:44 +02:00
Oleksandr Kozachuk 54ece6efeb Port print arguments -q/-a/-S from main to restructuring. 2023-09-11 13:09:44 +02:00
Oleksandr Kozachuk 86eebc39ea Allow in question -s for just sourcing file and -S to source file with ``` encapsulation. 2023-09-11 13:09:44 +02:00
juk0de 3eca53998b question cmd: added tests 2023-09-11 13:09:44 +02:00
juk0de c4f7bcc94e question_cmd: fixes 2023-09-11 13:09:44 +02:00
juk0de c52713c833 configuration: added tests 2023-09-11 13:09:44 +02:00
juk0de ecb6994783 configuration et al: implemented new Config format 2023-09-11 13:09:44 +02:00
juk0de 61e710a4b1 cmm: splitted commands into separate modules (and more cleanup) 2023-09-11 13:09:41 +02:00
juk0de 21d39c6c66 cmm: removed all the old code and modules 2023-09-11 13:08:45 +02:00
juk0de 6a4cc7a65d setup: added 'ais' subfolder 2023-09-11 13:07:46 +02:00
juk0de d6bb5800b1 test_main: temporarily disabled all testcases 2023-09-11 13:07:46 +02:00
juk0de 034e4093f1 cmm: added 'question' command 2023-09-11 13:07:46 +02:00
juk0de 7d15452242 added new module 'ai_factory' 2023-09-11 13:07:46 +02:00
juk0de 823d3bf7dc added new module 'openai.py' 2023-09-11 13:07:46 +02:00
juk0de 4bd144c4d7 added new module 'ai.py' 2023-09-11 13:07:46 +02:00
juk0de e186afbef0 cmm: the 'print' command now uses 'Message.from_file()' 2023-09-11 13:07:43 +02:00
juk0de 5e4ec70072 cmm: tags completion now uses 'Message.tags_from_dir' (fixes tag completion for me) 2023-09-11 13:06:22 +02:00
juk0de 4c378dde85 cmm: the 'hist' command now uses the new 'ChatDB' 2023-09-11 13:05:33 +02:00
juk0de 8923a13352 cmm: the 'tags' command now uses the new 'ChatDB' 2023-09-11 13:04:08 +02:00
juk0de e1414835c8 chat: added functions for finding and deleting messages 2023-09-11 13:04:08 +02:00
juk0de abb7fdacb6 message / chat: output improvements 2023-09-11 13:04:08 +02:00
juk0de 2e2228bd60 chat: new possibilites for adding messages and better tests 2023-09-11 13:04:08 +02:00
juk0de 713b55482a message: added rename_tags() function and test 2023-09-11 13:04:08 +02:00
juk0de d35de86c67 message: fixed Answer header for TXT format 2023-09-11 13:04:08 +02:00
juk0de aba3eb783d message: improved robustness of Question and Answer content checks and tests 2023-09-11 13:04:08 +02:00
juk0de 8e63831701 chat: added clear_cache() function and test 2023-09-11 13:04:08 +02:00
juk0de c318b99671 chat: improved history printing 2023-09-11 13:04:08 +02:00
juk0de 48c8e951e1 chat: fixed handling of unsupported files in DB and chache dir 2023-09-11 13:04:08 +02:00
juk0de b22a4b07ed chat: added tags_frequency() function and test 2023-09-11 13:04:08 +02:00
juk0de 33565d351d configuration: added AIConfig class 2023-09-11 13:04:08 +02:00
juk0de 6737fa98c7 added tokens() function to Message and Chat 2023-09-11 13:04:08 +02:00
juk0de 815a21893c added tests for 'chat.py' 2023-09-11 13:04:08 +02:00
juk0de 64893949a4 added new module 'chat.py' 2023-09-11 13:04:08 +02:00
juk0de a093f9b867 tags: some clarification and new tests 2023-09-11 13:04:08 +02:00
juk0de dc3f3dc168 added 'message_in()' function and test 2023-09-11 13:04:08 +02:00
juk0de 74c39070d6 fixed Message.filter_tags 2023-09-11 13:04:08 +02:00
juk0de fde0ae4652 fixed test case file cleanup 2023-09-11 13:04:08 +02:00
juk0de 238dbbee60 fixed handling empty tags in TXT file 2023-09-11 13:04:08 +02:00
juk0de 17f7b2fb45 Added tags filtering (prefix and contained string) to TagLine and Message 2023-09-11 13:04:08 +02:00
juk0de 9c2598a4b8 tests: added testcases for Message.from/to_file() and others 2023-09-11 13:04:08 +02:00
juk0de acec5f1d55 tests: splitted 'test_main.py' into 3 modules 2023-09-11 13:04:08 +02:00
juk0de c0f50bace5 gitignore: added vim session file 2023-09-11 13:04:08 +02:00
juk0de 30ccec2462 tags: TagLine constructor now supports multiline taglines and multiple spaces 2023-09-11 13:04:08 +02:00
juk0de 09da312657 configuration: added 'as_dict()' as an instance function 2023-09-11 13:04:08 +02:00
juk0de 33567df15f added testcases for messages.py 2023-09-11 13:04:08 +02:00
juk0de 264979a60d added new module 'message.py' 2023-09-11 13:04:08 +02:00
juk0de 061e5f8682 tags.py: converted most TagLine functions to module functions 2023-09-11 13:04:08 +02:00
juk0de 2d456e68f1 added testcases for Tag and TagLine classes 2023-09-11 13:04:08 +02:00
juk0de 8bd659e888 added new module 'tags.py' with classes 'Tag' and 'TagLine' 2023-09-11 13:04:08 +02:00
Oleksandr Kozachuk 3ef1339cc0 Fix extracting source file with type specification. 2023-09-09 11:53:32 +02:00
Oleksandr Kozachuk ed567afbea Make it possible to print just question or answer on printing files. 2023-09-08 15:54:29 +02:00
Oleksandr Kozachuk 6e447018d5 Fix tags_completter. 2023-09-07 18:11:32 +02:00
17 changed files with 899 additions and 265 deletions
+74 -42
View File
@@ -37,63 +37,95 @@ cmm [global options] command [command options]
### Global Options
- `-c`, `--config`: Config file name (defaults to `.config.yaml`).
### Commands
- `ask`: Ask a question.
- `hist`: Print chat history.
- `tag`: Manage tags.
- `config`: Manage configuration.
- `print`: Print files.
- `-C`, `--config`: Config file name (defaults to `.config.yaml`).
### Command Options
#### `ask` Command Options
#### Question
- `-q`, `--question`: Question to ask (required).
- `-m`, `--max-tokens`: Max tokens to use.
- `-T`, `--temperature`: Temperature to use.
- `-M`, `--model`: Model to use.
- `-n`, `--number`: Number of answers to produce (default is 3).
- `-s`, `--source`: Add content of a file to the query.
- `-S`, `--only-source-code`: Add pure source code to the chat history.
- `-t`, `--tags`: List of tag names.
- `-e`, `--extags`: List of tag names to exclude.
- `-o`, `--output-tags`: List of output tag names (default is the input tags).
- `-a`, `--match-all-tags`: All given tags must match when selecting chat history entries.
The `question` command is used to ask, create, and process questions.
#### `hist` Command Options
```bash
cmm question [-t OTAGS]... [-k ATAGS]... [-x XTAGS]... [-o OUTTAGS]... [-A AI] [-M MODEL] [-n NUM] [-m MAX] [-T TEMP] (-a ASK | -c CREATE | -r REPEAT | -p PROCESS) [-O] [-s SOURCE]... [-S SOURCE]...
```
- `-d`, `--dump`: Print chat history as Python structure.
- `-w`, `--with-tags`: Print chat history with tags.
- `-W`, `--with-files`: Print chat history with filenames.
- `-S`, `--only-source-code`: Print only source code.
- `-t`, `--tags`: List of tag names.
- `-e`, `--extags`: List of tag names to exclude.
- `-a`, `--match-all-tags`: All given tags must match when selecting chat history entries.
* `-t, --or-tags OTAGS` : List of tags (one must match)
* `-k, --and-tags ATAGS` : List of tags (all must match)
* `-x, --exclude-tags XTAGS` : List of tags to exclude
* `-o, --output-tags OUTTAGS` : List of output tags (default: use input tags)
* `-A, --AI AI` : AI ID to use
* `-M, --model MODEL` : Model to use
* `-n, --num-answers NUM` : Number of answers to request
* `-m, --max-tokens MAX` : Max. number of tokens
* `-T, --temperature TEMP` : Temperature value
* `-a, --ask ASK` : Ask a question
* `-c, --create CREATE` : Create a question
* `-r, --repeat REPEAT` : Repeat a question
* `-p, --process PROCESS` : Process existing questions
* `-O, --overwrite` : Overwrite existing messages when repeating them
* `-s, --source-text SOURCE` : Add content of a file to the query
* `-S, --source-code SOURCE` : Add source code file content to the chat history
#### `tag` Command Options
#### Hist
- `-l`, `--list`: List all tags and their frequency.
The `hist` command is used to print the chat history.
#### `config` Command Options
```bash
cmm hist [-t OTAGS]... [-k ATAGS]... [-x XTAGS]... [-w] [-W] [-S] [-A ANSWER] [-Q QUESTION]
```
- `-l`, `--list-models`: List all available models.
- `-m`, `--print-model`: Print the currently configured model.
- `-M`, `--model`: Set model in the config file.
* `-t, --or-tags OTAGS` : List of tags (one must match)
* `-k, --and-tags ATAGS` : List of tags (all must match)
* `-x, --exclude-tags XTAGS` : List of tags to exclude
* `-w, --with-tags` : Print chat history with tags
* `-W, --with-files` : Print chat history with filenames
* `-S, --source-code-only` : Print only source code
* `-A, --answer ANSWER` : Search for answer substring
* `-Q, --question QUESTION` : Search for question substring
#### `print` Command Options
#### Tags
- `-f`, `--file`: File to print (required).
- `-S`, `--only-source-code`: Print only source code.
The `tags` command is used to manage tags.
```bash
cmm tags (-l | -p PREFIX | -c CONTENT)
```
* `-l, --list` : List all tags and their frequency
* `-p, --prefix PREFIX` : Filter tags by prefix
* `-c, --contain CONTENT` : Filter tags by contained substring
#### Config
The `config` command is used to manage the configuration.
```bash
cmm config (-l | -m | -c CREATE)
```
* `-l, --list-models` : List all available models
* `-m, --print-model` : Print the currently configured model
* `-c, --create CREATE` : Create config with default settings in the given file
#### Print
The `print` command is used to print message files.
```bash
cmm print -f FILE [-q | -a | -S]
```
* `-f, --file FILE` : File to print
* `-q, --question` : Print only question
* `-a, --answer` : Print only answer
* `-S, --only-source-code` : Print only source code
### Examples
1. Ask a question:
```bash
cmm ask -q "What is the meaning of life?" -t philosophy -e religion
cmm question -a "What is the meaning of life?" -t philosophy -x religion
```
2. Display the chat history:
@@ -105,19 +137,19 @@ cmm hist
3. Filter chat history by tags:
```bash
cmm hist -t tag1 tag2
cmm hist --or-tags tag1 tag2
```
4. Exclude chat history by tags:
```bash
cmm hist -e tag3 tag4
cmm hist --exclude-tags tag3 tag4
```
5. List all tags and their frequency:
```bash
cmm tag -l
cmm tags -l
```
6. Print the contents of a file:
+6
View File
@@ -59,6 +59,12 @@ class AI(Protocol):
"""
raise NotImplementedError
def print_models(self) -> None:
"""
Print all models supported by this AI.
"""
raise NotImplementedError
def tokens(self, data: Union[Message, Chat]) -> int:
"""
Computes the nr. of AI language tokens for the given message
+4 -4
View File
@@ -17,7 +17,7 @@ def create_ai(args: argparse.Namespace, config: Config) -> AI: # noqa: 11
is not found, it uses the first AI in the list.
"""
ai_conf: AIConfig
if args.AI:
if hasattr(args, 'AI') and args.AI:
try:
ai_conf = config.ais[args.AI]
except KeyError:
@@ -32,11 +32,11 @@ def create_ai(args: argparse.Namespace, config: Config) -> AI: # noqa: 11
if ai_conf.name == 'openai':
ai = OpenAI(cast(OpenAIConfig, ai_conf))
if args.model:
if hasattr(args, 'model') and args.model:
ai.config.model = args.model
if args.max_tokens:
if hasattr(args, 'max_tokens') and args.max_tokens:
ai.config.max_tokens = args.max_tokens
if args.temperature:
if hasattr(args, 'temperature') and args.temperature:
ai.config.temperature = args.temperature
return ai
else:
+6 -1
View File
@@ -62,7 +62,12 @@ class OpenAI(AI):
"""
Return all models supported by this AI.
"""
raise NotImplementedError
ret = []
for engine in sorted(openai.Engine.list()['data'], key=lambda x: x['id']):
if engine['ready']:
ret.append(engine['id'])
ret.sort()
return ret
def print_models(self) -> None:
"""
+327 -111
View File
@@ -6,13 +6,18 @@ from pathlib import Path
from pprint import PrettyPrinter
from pydoc import pager
from dataclasses import dataclass
from typing import TypeVar, Type, Optional, ClassVar, Any, Callable
from typing import TypeVar, Type, Optional, ClassVar, Any, Callable, Literal, Union
from .configuration import default_config_file
from .message import Message, MessageFilter, MessageError, message_in
from .tags import Tag
ChatInst = TypeVar('ChatInst', bound='Chat')
ChatDBInst = TypeVar('ChatDBInst', bound='ChatDB')
db_next_file = '.next'
ignored_files = [db_next_file, default_config_file]
msg_location = Literal['mem', 'disk', 'cache', 'db', 'all']
class ChatError(Exception):
pass
@@ -45,13 +50,15 @@ def read_dir(dir_path: Path,
messages: list[Message] = []
file_iter = dir_path.glob(glob) if glob else dir_path.iterdir()
for file_path in sorted(file_iter):
if file_path.is_file() and file_path.suffix in Message.file_suffixes:
if (file_path.is_file()
and file_path.name not in ignored_files # noqa: W503
and file_path.suffix in Message.file_suffixes): # noqa: W503
try:
message = Message.from_file(file_path, mfilter)
if message:
messages.append(message)
except MessageError as e:
print(f"Error processing message in '{file_path}': {str(e)}")
print(f"WARNING: Skipping message in '{file_path}': {str(e)}")
return messages
@@ -100,7 +107,9 @@ def clear_dir(dir_path: Path,
"""
file_iter = dir_path.glob(glob) if glob else dir_path.iterdir()
for file_path in file_iter:
if file_path.is_file() and file_path.suffix in Message.file_suffixes:
if (file_path.is_file()
and file_path.name not in ignored_files # noqa: W503
and file_path.suffix in Message.file_suffixes): # noqa: W503
file_path.unlink(missing_ok=True)
@@ -112,14 +121,43 @@ class Chat:
messages: list[Message]
def filter(self, mfilter: MessageFilter) -> None:
def __post_init__(self) -> None:
self.validate()
def validate(self) -> None:
"""
Validate this Chat instance.
"""
def msg_paths(stem: str) -> list[str]:
return [str(fp) for fp in file_paths if fp.stem == stem]
file_paths: set[Path] = {m.file_path for m in self.messages if m.file_path is not None}
file_stems = [m.file_path.stem for m in self.messages if m.file_path is not None]
error = False
for fp in file_paths:
if file_stems.count(fp.stem) > 1:
print(f"ERROR: Found multiple copies of message '{fp.stem}': {msg_paths(fp.stem)}")
error = True
if error:
raise ChatError("Validation failed")
def msg_name_matches(self, file_path: Path, name: str) -> bool:
"""
Return True if the given name matches the given file_path.
Matching is True if:
* 'name' matches the full 'file_path'
* 'name' matches 'file_path.name' (i. e. including the suffix)
* 'name' matches 'file_path.stem' (i. e. without a suffix)
"""
return Path(name) == file_path or name == file_path.name or name == file_path.stem
def msg_filter(self, mfilter: MessageFilter) -> None:
"""
Use 'Message.match(mfilter) to remove all messages that
don't fulfill the filter requirements.
"""
self.messages = [m for m in self.messages if m.match(mfilter)]
def sort(self, reverse: bool = False) -> None:
def msg_sort(self, reverse: bool = False) -> None:
"""
Sort the messages according to 'Message.msg_id()'.
"""
@@ -129,48 +167,71 @@ class Chat:
except MessageError:
pass
def clear(self) -> None:
def msg_unique_id(self) -> None:
"""
Remove duplicates from the internal messages, based on the msg_id (i. e. file_path).
Messages without a file_path are kept.
"""
old_msgs = self.messages.copy()
self.messages = []
for m in old_msgs:
if not message_in(m, self.messages):
self.messages.append(m)
self.msg_sort()
def msg_unique_content(self) -> None:
"""
Remove duplicates from the internal messages, based on the content (i. e. question + answer).
"""
self.messages = list(set(self.messages))
self.msg_sort()
def msg_clear(self) -> None:
"""
Delete all messages.
"""
self.messages = []
def add_messages(self, messages: list[Message]) -> None:
def msg_add(self, messages: list[Message]) -> None:
"""
Add new messages and sort them if possible.
"""
self.messages += messages
self.sort()
self.msg_sort()
def latest_message(self) -> Optional[Message]:
def msg_latest(self, mfilter: Optional[MessageFilter] = None) -> Optional[Message]:
"""
Returns the last added message (according to the file ID).
Return the last added message (according to the file ID) that matches the given filter.
When containing messages without a valid file_path, it returns the latest message in
the internal list.
"""
if len(self.messages) > 0:
self.sort()
return self.messages[-1]
else:
return None
self.msg_sort()
for m in reversed(self.messages):
if mfilter is None or m.match(mfilter):
return m
return None
def find_messages(self, msg_names: list[str]) -> list[Message]:
def msg_find(self, msg_names: list[str]) -> list[Message]:
"""
Search and return the messages with the given names. Names can either be filenames
(incl. suffixes) or full paths. Messages that can't be found are ignored (i. e. the
caller should check the result if he requires all messages).
(with or without suffix), full paths or Message.msg_id(). Messages that can't be
found are ignored (i. e. the caller should check the result if they require all
messages).
"""
return [m for m in self.messages
if any((m.file_path and (m.file_path == Path(mn) or m.file_path.name == mn)) for mn in msg_names)]
if any((m.file_path and self.msg_name_matches(m.file_path, mn)) for mn in msg_names)]
def remove_messages(self, msg_names: list[str]) -> None:
def msg_remove(self, msg_names: list[str]) -> None:
"""
Remove the messages with the given names. Names can either be filenames
(incl. the suffix) or full paths.
(with or without suffix), full paths or Message.msg_id().
"""
self.messages = [m for m in self.messages
if not any((m.file_path and (m.file_path == Path(mn) or m.file_path.name == mn)) for mn in msg_names)]
self.sort()
if not any((m.file_path and self.msg_name_matches(m.file_path, mn)) for mn in msg_names)]
self.msg_sort()
def tags(self, prefix: Optional[str] = None, contain: Optional[str] = None) -> set[Tag]:
def msg_tags(self, prefix: Optional[str] = None, contain: Optional[str] = None) -> set[Tag]:
"""
Get the tags of all messages, optionally filtered by prefix or substring.
"""
@@ -179,7 +240,7 @@ class Chat:
tags |= m.filter_tags(prefix, contain)
return set(sorted(tags))
def tags_frequency(self, prefix: Optional[str] = None, contain: Optional[str] = None) -> dict[Tag, int]:
def msg_tags_frequency(self, prefix: Optional[str] = None, contain: Optional[str] = None) -> dict[Tag, int]:
"""
Get the frequency of all tags of all messages, optionally filtered by prefix or substring.
"""
@@ -204,7 +265,6 @@ class Chat:
output.append(message.to_str(source_code_only=True))
continue
output.append(message.to_str(with_tags, with_files))
output.append('\n' + ('-' * terminal_width()) + '\n')
if paged:
print_paged('\n'.join(output))
else:
@@ -233,10 +293,11 @@ class ChatDB(Chat):
def __post_init__(self) -> None:
# contains the latest message ID
self.next_fname = self.db_path / '.next'
self.next_path = self.db_path / db_next_file
# make all paths absolute
self.cache_path = self.cache_path.absolute()
self.db_path = self.db_path.absolute()
self.validate()
@classmethod
def from_dir(cls: Type[ChatDBInst],
@@ -272,7 +333,7 @@ class ChatDB(Chat):
def get_next_fid(self) -> int:
try:
with open(self.next_fname, 'r') as f:
with open(self.next_path, 'r') as f:
next_fid = int(f.read()) + 1
self.set_next_fid(next_fid)
return next_fid
@@ -281,87 +342,188 @@ class ChatDB(Chat):
return 1
def set_next_fid(self, fid: int) -> None:
with open(self.next_fname, 'w') as f:
with open(self.next_path, 'w') as f:
f.write(f'{fid}')
def read_db(self) -> None:
def msg_write(self, messages: Optional[list[Message]] = None) -> None:
"""
Reads new messages from the DB directory. New ones are added to the internal list,
existing ones are replaced. A message is determined as 'existing' if a message with
Write either the given messages or the internal ones to their CURRENT file_path.
If messages are given, they all must have a valid file_path. When writing the
internal messages, the ones with a valid file_path are written, the others
are ignored.
"""
if messages and any(m.file_path is None for m in messages):
raise ChatError("Can't write files without a valid file_path")
msgs = iter(messages if messages else self.messages)
while (m := next(msgs, None)):
m.to_file()
def msg_update(self, messages: list[Message], write: bool = True) -> None:
"""
Update EXISTING messages. A message is determined as 'existing' if a message with
the same base filename (i. e. 'file_path.name') is already in the list.
Only accepts existing messages.
"""
new_messages = read_dir(self.db_path, self.glob, self.mfilter)
if any(not message_in(m, self.messages) for m in messages):
raise ChatError("Can't update messages that are not in the internal list")
# remove old versions and add new ones
self.messages = [m for m in self.messages if not message_in(m, messages)]
self.messages += messages
self.msg_sort()
# write the UPDATED messages if requested
if write:
self.msg_write(messages)
def msg_gather(self,
loc: msg_location,
require_file_path: bool = False,
mfilter: Optional[MessageFilter] = None) -> list[Message]:
"""
Gather and return messages from the given locations:
* 'mem' : messages currently in memory
* 'disk' : messages on disk (cache + DB directory), but not in memory
* 'cache': messages in the cache directory
* 'db' : messages in the DB directory
* 'all' : all messages ('mem' + 'disk')
If 'require_file_path' is True, return only files with a valid file_path.
"""
loc_messages: list[Message] = []
if loc in ['mem', 'all']:
if require_file_path:
loc_messages += [m for m in self.messages if (m.file_path is not None and (mfilter is None or m.match(mfilter)))]
else:
loc_messages += [m for m in self.messages if (mfilter is None or m.match(mfilter))]
if loc in ['cache', 'disk', 'all']:
loc_messages += read_dir(self.cache_path, mfilter=mfilter)
if loc in ['db', 'disk', 'all']:
loc_messages += read_dir(self.db_path, mfilter=mfilter)
# remove_duplicates and sort the list
unique_messages: list[Message] = []
for m in loc_messages:
if not message_in(m, unique_messages):
unique_messages.append(m)
try:
unique_messages.sort(key=lambda m: m.msg_id())
# messages in 'mem' can have an empty file_path
except MessageError:
pass
return unique_messages
def msg_find(self,
msg_names: list[str],
loc: msg_location = 'mem',
) -> list[Message]:
"""
Search and return the messages with the given names. Names can either be filenames
(with or without suffix), full paths or Message.msg_id(). Messages that can't be
found are ignored (i. e. the caller should check the result if they require all
messages).
Searches one of the following locations:
* 'mem' : messages currently in memory
* 'disk' : messages on disk (cache + DB directory), but not in memory
* 'cache': messages in the cache directory
* 'db' : messages in the DB directory
* 'all' : all messages ('mem' + 'disk')
"""
loc_messages = self.msg_gather(loc, require_file_path=True)
return [m for m in loc_messages
if any((m.file_path and self.msg_name_matches(m.file_path, mn)) for mn in msg_names)]
def msg_remove(self, msg_names: list[str], loc: msg_location = 'mem') -> None:
"""
Remove the messages with the given names. Names can either be filenames
(with or without suffix), full paths or Message.msg_id(). Also deletes the
files of all given messages with a valid file_path.
Delete files from one of the following locations:
* 'mem' : messages currently in memory
* 'disk' : messages on disk (cache + DB directory), but not in memory
* 'cache': messages in the cache directory
* 'db' : messages in the DB directory
* 'all' : all messages ('mem' + 'disk')
"""
if loc != 'mem':
# delete the message files first
rm_messages = self.msg_find(msg_names, loc=loc)
for m in rm_messages:
if (m.file_path):
m.file_path.unlink()
# then remove them from the internal list
super().msg_remove(msg_names)
def msg_latest(self,
mfilter: Optional[MessageFilter] = None,
loc: msg_location = 'mem') -> Optional[Message]:
"""
Return the last added message (according to the file ID) that matches the given filter.
Only consider messages with a valid file_path (except if loc is 'mem').
Searches one of the following locations:
* 'mem' : messages currently in memory
* 'disk' : messages on disk (cache + DB directory), but not in memory
* 'cache': messages in the cache directory
* 'db' : messages in the DB directory
* 'all' : all messages ('mem' + 'disk')
"""
# only consider messages with a valid file_path so they can be sorted
loc_messages = self.msg_gather(loc, require_file_path=True)
loc_messages.sort(key=lambda m: m.msg_id(), reverse=True)
for m in loc_messages:
if mfilter is None or m.match(mfilter):
return m
return None
def msg_in_cache(self, message: Union[Message, str]) -> bool:
"""
Return true if the given Message (or filename or Message.msg_id())
is located in the cache directory. False otherwise.
"""
if isinstance(message, Message):
return (message.file_path is not None
and message.file_path.parent.samefile(self.cache_path) # noqa: W503
and message.file_path.exists()) # noqa: W503
else:
return len(self.msg_find([message], loc='cache')) > 0
def msg_in_db(self, message: Union[Message, str]) -> bool:
"""
Return true if the given Message (or filename or Message.msg_id())
is located in the DB directory. False otherwise.
"""
if isinstance(message, Message):
return (message.file_path is not None
and message.file_path.parent.samefile(self.db_path) # noqa: W503
and message.file_path.exists()) # noqa: W503
else:
return len(self.msg_find([message], loc='db')) > 0
def cache_read(self, glob: Optional[str] = None, mfilter: Optional[MessageFilter] = None) -> None:
"""
Read messages from the cache directory. New ones are added to the internal list,
existing ones are replaced. A message is determined as 'existing' if a message
with the same base filename (i. e. 'file_path.name') is already in the list.
"""
new_messages = read_dir(self.cache_path, glob, mfilter)
# remove all messages from self.messages that are in the new list
self.messages = [m for m in self.messages if not message_in(m, new_messages)]
# copy the messages from the temporary list to self.messages and sort them
self.messages += new_messages
self.sort()
self.msg_sort()
def read_cache(self) -> None:
"""
Reads new messages from the cache directory. New ones are added to the internal list,
existing ones are replaced. A message is determined as 'existing' if a message with
the same base filename (i. e. 'file_path.name') is already in the list.
"""
new_messages = read_dir(self.cache_path, self.glob, self.mfilter)
# remove all messages from self.messages that are in the new list
self.messages = [m for m in self.messages if not message_in(m, new_messages)]
# copy the messages from the temporary list to self.messages and sort them
self.messages += new_messages
self.sort()
def write_db(self, messages: Optional[list[Message]] = None) -> None:
"""
Write messages to the DB directory. If a message has no file_path, a new one
will be created. If message.file_path exists, it will be modified to point
to the DB directory.
"""
write_dir(self.db_path,
messages if messages else self.messages,
self.file_suffix,
self.get_next_fid)
def write_cache(self, messages: Optional[list[Message]] = None) -> None:
def cache_write(self, messages: Optional[list[Message]] = None) -> None:
"""
Write messages to the cache directory. If a message has no file_path, a new one
will be created. If message.file_path exists, it will be modified to point to
the cache directory.
Does NOT add the messages to the internal list (use 'cache_add()' for that)!
"""
write_dir(self.cache_path,
messages if messages else self.messages,
self.file_suffix,
self.get_next_fid)
def clear_cache(self) -> None:
def cache_add(self, messages: list[Message], write: bool = True) -> None:
"""
Deletes all Message files from the cache dir and removes those messages from
the internal list.
"""
clear_dir(self.cache_path, self.glob)
# only keep messages from DB dir (or those that have not yet been written)
self.messages = [m for m in self.messages if not m.file_path or m.file_path.parent.samefile(self.db_path)]
def add_to_db(self, messages: list[Message], write: bool = True) -> None:
"""
Add the given new messages and set the file_path to the DB directory.
Only accepts messages without a file_path.
"""
if any(m.file_path is not None for m in messages):
raise ChatError("Can't add new messages with existing file_path")
if write:
write_dir(self.db_path,
messages,
self.file_suffix,
self.get_next_fid)
else:
for m in messages:
m.file_path = make_file_path(self.db_path, self.default_file_suffix, self.get_next_fid)
self.messages += messages
self.sort()
def add_to_cache(self, messages: list[Message], write: bool = True) -> None:
"""
Add the given new messages and set the file_path to the cache directory.
Add NEW messages and set the file_path to the cache directory.
Only accepts messages without a file_path.
"""
if any(m.file_path is not None for m in messages):
@@ -375,33 +537,87 @@ class ChatDB(Chat):
for m in messages:
m.file_path = make_file_path(self.cache_path, self.default_file_suffix, self.get_next_fid)
self.messages += messages
self.sort()
self.msg_sort()
def write_messages(self, messages: Optional[list[Message]] = None) -> None:
def cache_clear(self, glob: Optional[str] = None) -> None:
"""
Write either the given messages or the internal ones to their current file_path.
If messages are given, they all must have a valid file_path. When writing the
internal messages, the ones with a valid file_path are written, the others
are ignored.
Delete all message files from the cache dir and remove them from the internal list.
"""
if messages and any(m.file_path is None for m in messages):
raise ChatError("Can't write files without a valid file_path")
msgs = iter(messages if messages else self.messages)
while (m := next(msgs, None)):
m.to_file()
clear_dir(self.cache_path, glob)
# only keep messages from DB dir (or those that have not yet been written)
self.messages = [m for m in self.messages if not m.file_path or m.file_path.parent.samefile(self.db_path)]
def update_messages(self, messages: list[Message], write: bool = True) -> None:
def cache_move(self, message: Message) -> None:
"""
Update existing messages. A message is determined as 'existing' if a message with
the same base filename (i. e. 'file_path.name') is already in the list. Only accepts
existing messages.
Moves the given messages to the cache directory.
"""
if any(not message_in(m, self.messages) for m in messages):
raise ChatError("Can't update messages that are not in the internal list")
# remove old versions and add new ones
self.messages = [m for m in self.messages if not message_in(m, messages)]
self.messages += messages
self.sort()
# write the UPDATED messages if requested
# remember the old path (if any)
old_path: Optional[Path] = None
if message.file_path:
old_path = message.file_path
# write message to the new destination
self.cache_write([message])
# remove the old one (if any)
if old_path:
self.msg_remove([str(old_path)], loc='db')
# (re)add it to the internal list
self.msg_add([message])
def db_read(self, glob: Optional[str] = None, mfilter: Optional[MessageFilter] = None) -> None:
"""
Read messages from the DB directory. New ones are added to the internal list,
existing ones are replaced. A message is determined as 'existing' if a message
with the same base filename (i. e. 'file_path.name') is already in the list.
"""
new_messages = read_dir(self.db_path, self.glob, self.mfilter)
# remove all messages from self.messages that are in the new list
self.messages = [m for m in self.messages if not message_in(m, new_messages)]
# copy the messages from the temporary list to self.messages and sort them
self.messages += new_messages
self.msg_sort()
def db_write(self, messages: Optional[list[Message]] = None) -> None:
"""
Write messages to the DB directory. If a message has no file_path, a new one
will be created. If message.file_path exists, it will be modified to point
to the DB directory.
Does NOT add the messages to the internal list (use 'db_add()' for that)!
"""
write_dir(self.db_path,
messages if messages else self.messages,
self.file_suffix,
self.get_next_fid)
def db_add(self, messages: list[Message], write: bool = True) -> None:
"""
Add NEW messages and set the file_path to the DB directory.
Only accepts messages without a file_path.
"""
if any(m.file_path is not None for m in messages):
raise ChatError("Can't add new messages with existing file_path")
if write:
self.write_messages(messages)
write_dir(self.db_path,
messages,
self.file_suffix,
self.get_next_fid)
else:
for m in messages:
m.file_path = make_file_path(self.db_path, self.default_file_suffix, self.get_next_fid)
self.messages += messages
self.msg_sort()
def db_move(self, message: Message) -> None:
"""
Moves the given messages to the db directory.
"""
# remember the old path (if any)
old_path: Optional[Path] = None
if message.file_path:
old_path = message.file_path
# write message to the new destination
self.db_write([message])
# remove the old one (if any)
if old_path:
self.msg_remove([str(old_path)], loc='cache')
# (re)add it to the internal list
self.msg_add([message])
+9
View File
@@ -1,6 +1,8 @@
import argparse
from pathlib import Path
from ..configuration import Config
from ..ai import AI
from ..ai_factory import create_ai
def config_cmd(args: argparse.Namespace) -> None:
@@ -9,3 +11,10 @@ def config_cmd(args: argparse.Namespace) -> None:
"""
if args.create:
Config.create_default(Path(args.create))
elif args.list_models or args.print_model:
config: Config = Config.from_file(args.config)
ai: AI = create_ai(args, config)
if args.list_models:
ai.print_models()
else:
print(ai.config.model)
+53 -24
View File
@@ -3,40 +3,69 @@ from pathlib import Path
from itertools import zip_longest
from ..configuration import Config
from ..chat import ChatDB
from ..message import Message, MessageFilter, Question, source_code
from ..message import Message, MessageFilter, MessageError, Question, source_code
from ..ai_factory import create_ai
from ..ai import AI, AIResponse
def add_file_as_text(question_parts: list[str], file: str) -> None:
"""
Add the given file as plain text to the question part list.
If the file is a Message, add the answer.
"""
file_path = Path(file)
content: str
try:
message = Message.from_file(file_path)
if message and message.answer:
content = message.answer
except MessageError:
with open(file) as r:
content = r.read().strip()
if len(content) > 0:
question_parts.append(content)
def add_file_as_code(question_parts: list[str], file: str) -> None:
"""
Add all source code from the given file. If no code segments can be extracted,
the whole content is added as source code segment. If the file is a Message,
extract the source code from the answer.
"""
file_path = Path(file)
content: str
try:
message = Message.from_file(file_path)
if message and message.answer:
content = message.answer
except MessageError:
with open(file) as r:
content = r.read().strip()
# extract and add source code
code_parts = source_code(content, include_delims=True)
if len(code_parts) > 0:
question_parts += code_parts
else:
question_parts.append(f"```\n{content}\n```")
def create_message(chat: ChatDB, args: argparse.Namespace) -> Message:
"""
Creates (and writes) a new message from the given arguments.
Creates a new message from the given arguments and writes it
to the cache directory.
"""
question_parts = []
question_list = args.ask if args.ask is not None else []
text_files = args.source_text if args.source_text is not None else []
code_files = args.source_code if args.source_code is not None else []
for question, source, code in zip_longest(question_list, text_files, code_files, fillvalue=None):
for question, text_file, code_file in zip_longest(question_list, text_files, code_files, fillvalue=None):
if question is not None and len(question.strip()) > 0:
question_parts.append(question)
if source is not None and len(source) > 0:
with open(source) as r:
content = r.read().strip()
if len(content) > 0:
question_parts.append(content)
if code is not None and len(code) > 0:
with open(code) as r:
content = r.read().strip()
if len(content) == 0:
continue
# try to extract and add source code
code_parts = source_code(content, include_delims=True)
if len(code_parts) > 0:
question_parts += code_parts
# if there's none, add the whole file
else:
question_parts.append(f"```\n{content}\n```")
if text_file is not None and len(text_file) > 0:
add_file_as_text(question_parts, text_file)
if code_file is not None and len(code_file) > 0:
add_file_as_code(question_parts, code_file)
full_question = '\n\n'.join(question_parts)
@@ -44,7 +73,7 @@ def create_message(chat: ChatDB, args: argparse.Namespace) -> Message:
tags=args.output_tags, # FIXME
ai=args.AI,
model=args.model)
chat.add_to_cache([message])
chat.cache_add([message])
return message
@@ -73,8 +102,8 @@ def question_cmd(args: argparse.Namespace, config: Config) -> None:
chat,
args.num_answers, # FIXME
args.output_tags) # FIXME
chat.update_messages([response.messages[0]])
chat.add_to_cache(response.messages[1:])
chat.msg_update([response.messages[0]])
chat.cache_add(response.messages[1:])
for idx, msg in enumerate(response.messages):
print(f"=== ANSWER {idx+1} ===")
print(msg.answer)
@@ -82,7 +111,7 @@ def question_cmd(args: argparse.Namespace, config: Config) -> None:
print("===============")
print(response.tokens)
elif args.repeat is not None:
lmessage = chat.latest_message()
lmessage = chat.msg_latest()
assert lmessage
# TODO: repeat either the last question or the
# one(s) given in 'args.repeat' (overwrite
+1 -1
View File
@@ -11,7 +11,7 @@ def tags_cmd(args: argparse.Namespace, config: Config) -> None:
chat = ChatDB.from_dir(cache_path=Path('.'),
db_path=Path(config.db))
if args.list:
tags_freq = chat.tags_frequency(args.prefix, args.contain)
tags_freq = chat.msg_tags_frequency(args.prefix, args.contain)
for tag, freq in tags_freq.items():
print(f"- {tag}: {freq}")
# TODO: add renaming
+2 -1
View File
@@ -9,7 +9,7 @@ OpenAIConfigInst = TypeVar('OpenAIConfigInst', bound='OpenAIConfig')
supported_ais: list[str] = ['openai']
default_config_path = '.config.yaml'
default_config_file = '.config.yaml'
class ConfigError(Exception):
@@ -39,6 +39,7 @@ class AIConfig:
name: ClassVar[str]
# a user-defined ID for an AI configuration entry
ID: str
model: str = 'n/a'
# the name must not be changed
def __setattr__(self, name: str, value: Any) -> None:
+3 -2
View File
@@ -7,7 +7,7 @@ import argcomplete
import argparse
from pathlib import Path
from typing import Any
from .configuration import Config, default_config_path
from .configuration import Config, default_config_file
from .message import Message
from .commands.question import question_cmd
from .commands.tags import tags_cmd
@@ -24,7 +24,7 @@ def tags_completer(prefix: str, parsed_args: Any, **kwargs: Any) -> list[str]:
def create_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser(
description="ChatMastermind is a Python application that automates conversation with AI")
parser.add_argument('-C', '--config', help='Config file name.', default=default_config_path)
parser.add_argument('-C', '--config', help='Config file name.', default=default_config_file)
# subcommand-parser
cmdparser = parser.add_subparsers(dest='command',
@@ -100,6 +100,7 @@ def create_parser() -> argparse.ArgumentParser:
help="Manage configuration",
aliases=['c'])
config_cmd_parser.set_defaults(func=config_cmd)
config_cmd_parser.add_argument('-A', '--AI', help='AI ID to use')
config_group = config_cmd_parser.add_mutually_exclusive_group(required=True)
config_group.add_argument('-l', '--list-models', help="List all available models",
action='store_true')
+10 -6
View File
@@ -370,7 +370,7 @@ class Message():
try:
question_idx = text.index(Question.txt_header) + 1
except ValueError:
raise MessageError(f"Question header '{Question.txt_header}' not found in '{file_path}'")
raise MessageError(f"'{file_path}' does not contain a valid message")
try:
answer_idx = text.index(Answer.txt_header)
question = Question.from_list(text[question_idx:answer_idx])
@@ -390,9 +390,12 @@ class Message():
* Message.model_yaml_key: str [Optional]
"""
with open(file_path, "r") as fd:
data = yaml.load(fd, Loader=yaml.FullLoader)
data[cls.file_yaml_key] = file_path
return cls.from_dict(data)
try:
data = yaml.load(fd, Loader=yaml.FullLoader)
data[cls.file_yaml_key] = file_path
return cls.from_dict(data)
except Exception:
raise MessageError(f"'{file_path}' does not contain a valid message")
def to_str(self, with_tags: bool = False, with_file: bool = False, source_code_only: bool = False) -> str:
"""
@@ -540,10 +543,11 @@ class Message():
def msg_id(self) -> str:
"""
Returns an ID that is unique throughout all messages in the same (DB) directory.
Currently this is the file name. The ID is also used for sorting messages.
Currently this is the file name without suffix. The ID is also used for sorting
messages.
"""
if self.file_path:
return self.file_path.name
return self.file_path.stem
else:
raise MessageError("Can't create file ID without a file path")
+56
View File
@@ -0,0 +1,56 @@
<?php
$secret_key = '123';
// check for POST request
if ($_SERVER['REQUEST_METHOD'] != 'POST') {
error_log('FAILED - not POST - '. $_SERVER['REQUEST_METHOD']);
exit();
}
// get content type
$content_type = isset($_SERVER['CONTENT_TYPE']) ? strtolower(trim($_SERVER['CONTENT_TYPE'])) : '';
if ($content_type != 'application/json') {
error_log('FAILED - not application/json - '. $content_type);
exit();
}
// get payload
$payload = trim(file_get_contents("php://input"));
if (empty($payload)) {
error_log('FAILED - no payload');
exit();
}
// get header signature
$header_signature = isset($_SERVER['HTTP_X_GITEA_SIGNATURE']) ? $_SERVER['HTTP_X_GITEA_SIGNATURE'] : '';
if (empty($header_signature)) {
error_log('FAILED - header signature missing');
exit();
}
// calculate payload signature
$payload_signature = hash_hmac('sha256', $payload, $secret_key, false);
// check payload signature against header signature
if ($header_signature !== $payload_signature) {
error_log('FAILED - payload signature');
exit();
}
// convert json to array
$decoded = json_decode($payload, true);
// check for json decode errors
if (json_last_error() !== JSON_ERROR_NONE) {
error_log('FAILED - json decode - '. json_last_error());
exit();
}
// success, do something
$output = shell_exec('/home/kaizen/repos/ChatMastermind/hooks/push_hook.sh');
echo "$output";
?>
+8
View File
@@ -0,0 +1,8 @@
#!/usr/bin/bash
. /home/kaizen/.bashrc
set -e
cd /home/kaizen/repos/ChatMastermind
git pull
pre-commit run -a
pytest
+81
View File
@@ -0,0 +1,81 @@
import unittest
from unittest import mock
from chatmastermind.ais.openai import OpenAI
from chatmastermind.message import Message, Question, Answer
from chatmastermind.chat import Chat
from chatmastermind.ai import AIResponse, Tokens
from chatmastermind.configuration import OpenAIConfig
class OpenAITest(unittest.TestCase):
@mock.patch('openai.ChatCompletion.create')
def test_request(self, mock_create: mock.MagicMock) -> None:
# Create a test instance of OpenAI
config = OpenAIConfig()
openai = OpenAI(config)
# Set up the mock response from openai.ChatCompletion.create
mock_response = {
'choices': [
{
'message': {
'content': 'Answer 1'
}
},
{
'message': {
'content': 'Answer 2'
}
}
],
'usage': {
'prompt_tokens': 10,
'completion_tokens': 20,
'total_tokens': 30
}
}
mock_create.return_value = mock_response
# Create test data
question = Message(Question('Question'))
chat = Chat([
Message(Question('Question 1'), answer=Answer('Answer 1')),
Message(Question('Question 2'), answer=Answer('Answer 2')),
# add message without an answer -> expect to be skipped
Message(Question('Question 3'))
])
# Make the request
response = openai.request(question, chat, num_answers=2)
# Assert the AIResponse
self.assertIsInstance(response, AIResponse)
self.assertEqual(len(response.messages), 2)
self.assertEqual(response.messages[0].answer, 'Answer 1')
self.assertEqual(response.messages[1].answer, 'Answer 2')
self.assertIsNotNone(response.tokens)
self.assertIsInstance(response.tokens, Tokens)
assert response.tokens
self.assertEqual(response.tokens.prompt, 10)
self.assertEqual(response.tokens.completion, 20)
self.assertEqual(response.tokens.total, 30)
# Assert the mock call to openai.ChatCompletion.create
mock_create.assert_called_once_with(
model=f'{config.model}',
messages=[
{'role': 'system', 'content': f'{config.system}'},
{'role': 'user', 'content': 'Question 1'},
{'role': 'assistant', 'content': 'Answer 1'},
{'role': 'user', 'content': 'Question 2'},
{'role': 'assistant', 'content': 'Answer 2'},
{'role': 'user', 'content': 'Question'}
],
temperature=config.temperature,
max_tokens=config.max_tokens,
top_p=config.top_p,
n=2,
frequency_penalty=config.frequency_penalty,
presence_penalty=config.presence_penalty
)
+224 -71
View File
@@ -2,11 +2,12 @@ import unittest
import pathlib
import tempfile
import time
import yaml
from io import StringIO
from unittest.mock import patch
from chatmastermind.tags import TagLine
from chatmastermind.message import Message, Question, Answer, Tag, MessageFilter
from chatmastermind.chat import Chat, ChatDB, terminal_width, ChatError
from chatmastermind.chat import Chat, ChatDB, ChatError
class TestChat(unittest.TestCase):
@@ -20,94 +21,118 @@ class TestChat(unittest.TestCase):
Answer('Answer 2'),
{Tag('btag2')},
file_path=pathlib.Path('0002.txt'))
self.maxDiff = None
def test_unique_id(self) -> None:
# test with two identical messages
self.chat.msg_add([self.message1, self.message1])
self.assertSequenceEqual(self.chat.messages, [self.message1, self.message1])
self.chat.msg_unique_id()
self.assertSequenceEqual(self.chat.messages, [self.message1])
# test with two different messages
self.chat.msg_add([self.message2])
self.chat.msg_unique_id()
self.assertSequenceEqual(self.chat.messages, [self.message1, self.message2])
def test_unique_content(self) -> None:
# test with two identical messages
self.chat.msg_add([self.message1, self.message1])
self.assertSequenceEqual(self.chat.messages, [self.message1, self.message1])
self.chat.msg_unique_content()
self.assertSequenceEqual(self.chat.messages, [self.message1])
# test with two different messages
self.chat.msg_add([self.message2])
self.chat.msg_unique_content()
self.assertSequenceEqual(self.chat.messages, [self.message1, self.message2])
def test_filter(self) -> None:
self.chat.add_messages([self.message1, self.message2])
self.chat.filter(MessageFilter(answer_contains='Answer 1'))
self.chat.msg_add([self.message1, self.message2])
self.chat.msg_filter(MessageFilter(answer_contains='Answer 1'))
self.assertEqual(len(self.chat.messages), 1)
self.assertEqual(self.chat.messages[0].question, 'Question 1')
def test_sort(self) -> None:
self.chat.add_messages([self.message2, self.message1])
self.chat.sort()
self.chat.msg_add([self.message2, self.message1])
self.chat.msg_sort()
self.assertEqual(self.chat.messages[0].question, 'Question 1')
self.assertEqual(self.chat.messages[1].question, 'Question 2')
self.chat.sort(reverse=True)
self.chat.msg_sort(reverse=True)
self.assertEqual(self.chat.messages[0].question, 'Question 2')
self.assertEqual(self.chat.messages[1].question, 'Question 1')
def test_clear(self) -> None:
self.chat.add_messages([self.message1])
self.chat.clear()
self.chat.msg_add([self.message1])
self.chat.msg_clear()
self.assertEqual(len(self.chat.messages), 0)
def test_add_messages(self) -> None:
self.chat.add_messages([self.message1, self.message2])
self.chat.msg_add([self.message1, self.message2])
self.assertEqual(len(self.chat.messages), 2)
self.assertEqual(self.chat.messages[0].question, 'Question 1')
self.assertEqual(self.chat.messages[1].question, 'Question 2')
def test_tags(self) -> None:
self.chat.add_messages([self.message1, self.message2])
tags_all = self.chat.tags()
self.chat.msg_add([self.message1, self.message2])
tags_all = self.chat.msg_tags()
self.assertSetEqual(tags_all, {Tag('atag1'), Tag('btag2')})
tags_pref = self.chat.tags(prefix='a')
tags_pref = self.chat.msg_tags(prefix='a')
self.assertSetEqual(tags_pref, {Tag('atag1')})
tags_cont = self.chat.tags(contain='2')
tags_cont = self.chat.msg_tags(contain='2')
self.assertSetEqual(tags_cont, {Tag('btag2')})
def test_tags_frequency(self) -> None:
self.chat.add_messages([self.message1, self.message2])
tags_freq = self.chat.tags_frequency()
self.chat.msg_add([self.message1, self.message2])
tags_freq = self.chat.msg_tags_frequency()
self.assertDictEqual(tags_freq, {'atag1': 1, 'btag2': 2})
def test_find_remove_messages(self) -> None:
self.chat.add_messages([self.message1, self.message2])
msgs = self.chat.find_messages(['0001.txt'])
self.chat.msg_add([self.message1, self.message2])
msgs = self.chat.msg_find(['0001.txt'])
self.assertListEqual(msgs, [self.message1])
msgs = self.chat.find_messages(['0001.txt', '0002.txt'])
msgs = self.chat.msg_find(['0001.txt', '0002.txt'])
self.assertListEqual(msgs, [self.message1, self.message2])
# add new Message with full path
message3 = Message(Question('Question 2'),
Answer('Answer 2'),
{Tag('btag2')},
file_path=pathlib.Path('/foo/bla/0003.txt'))
self.chat.add_messages([message3])
self.chat.msg_add([message3])
# find new Message by full path
msgs = self.chat.find_messages(['/foo/bla/0003.txt'])
msgs = self.chat.msg_find(['/foo/bla/0003.txt'])
self.assertListEqual(msgs, [message3])
# find Message with full path only by filename
msgs = self.chat.find_messages(['0003.txt'])
msgs = self.chat.msg_find(['0003.txt'])
self.assertListEqual(msgs, [message3])
# remove last message
self.chat.remove_messages(['0003.txt'])
self.chat.msg_remove(['0003.txt'])
self.assertListEqual(self.chat.messages, [self.message1, self.message2])
def test_latest_message(self) -> None:
self.assertIsNone(self.chat.msg_latest())
self.chat.msg_add([self.message1])
self.assertEqual(self.chat.msg_latest(), self.message1)
self.chat.msg_add([self.message2])
self.assertEqual(self.chat.msg_latest(), self.message2)
@patch('sys.stdout', new_callable=StringIO)
def test_print(self, mock_stdout: StringIO) -> None:
self.chat.add_messages([self.message1, self.message2])
self.chat.msg_add([self.message1, self.message2])
self.chat.print(paged=False)
expected_output = f"""{Question.txt_header}
Question 1
{Answer.txt_header}
Answer 1
{'-'*terminal_width()}
{Question.txt_header}
Question 2
{Answer.txt_header}
Answer 2
{'-'*terminal_width()}
"""
self.assertEqual(mock_stdout.getvalue(), expected_output)
@patch('sys.stdout', new_callable=StringIO)
def test_print_with_tags_and_file(self, mock_stdout: StringIO) -> None:
self.chat.add_messages([self.message1, self.message2])
self.chat.msg_add([self.message1, self.message2])
self.chat.print(paged=False, with_tags=True, with_files=True)
expected_output = f"""{TagLine.prefix} atag1 btag2
FILE: 0001.txt
@@ -115,18 +140,12 @@ FILE: 0001.txt
Question 1
{Answer.txt_header}
Answer 1
{'-'*terminal_width()}
{TagLine.prefix} btag2
FILE: 0002.txt
{Question.txt_header}
Question 2
{Answer.txt_header}
Answer 2
{'-'*terminal_width()}
"""
self.assertEqual(mock_stdout.getvalue(), expected_output)
@@ -161,20 +180,41 @@ class TestChatDB(unittest.TestCase):
next_fname = pathlib.Path(self.db_path.name) / '.next'
with open(next_fname, 'w') as f:
f.write('4')
# add some "trash" in order to test if it's correctly handled / ignored
self.trash_files = ['.config.yaml', 'foo.yaml', 'bla.txt']
for file in self.trash_files:
with open(pathlib.Path(self.db_path.name) / file, 'w') as f:
f.write('test trash')
# also create a file with actual yaml content
with open(pathlib.Path(self.db_path.name) / 'content.yaml', 'w') as f:
yaml.dump({'key': 'value'}, f)
self.trash_files.append('content.yaml')
self.maxDiff = None
def message_list(self, tmp_dir: tempfile.TemporaryDirectory) -> list[pathlib.Path]:
"""
List all Message files in the given TemporaryDirectory.
"""
# exclude '.next'
return list(pathlib.Path(tmp_dir.name).glob('*.[ty]*'))
return [f for f in pathlib.Path(tmp_dir.name).glob('*.[ty]*') if f.name not in self.trash_files]
def tearDown(self) -> None:
self.db_path.cleanup()
self.cache_path.cleanup()
pass
def test_chat_db_from_dir(self) -> None:
def test_validate(self) -> None:
duplicate_message = Message(Question('Question 4'),
Answer('Answer 4'),
{Tag('tag4')},
file_path=pathlib.Path('0004.txt'))
duplicate_message.to_file(pathlib.Path(self.db_path.name, '0004.txt'))
with self.assertRaises(ChatError) as cm:
ChatDB.from_dir(pathlib.Path(self.cache_path.name),
pathlib.Path(self.db_path.name))
self.assertEqual(str(cm.exception), "Validation failed")
def test_from_dir(self) -> None:
chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name),
pathlib.Path(self.db_path.name))
self.assertEqual(len(chat_db.messages), 4)
@@ -190,7 +230,7 @@ class TestChatDB(unittest.TestCase):
self.assertEqual(chat_db.messages[3].file_path,
pathlib.Path(self.db_path.name, '0004.yaml'))
def test_chat_db_from_dir_glob(self) -> None:
def test_from_dir_glob(self) -> None:
chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name),
pathlib.Path(self.db_path.name),
glob='*.txt')
@@ -202,7 +242,7 @@ class TestChatDB(unittest.TestCase):
self.assertEqual(chat_db.messages[1].file_path,
pathlib.Path(self.db_path.name, '0003.txt'))
def test_chat_db_from_dir_filter_tags(self) -> None:
def test_from_dir_filter_tags(self) -> None:
chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name),
pathlib.Path(self.db_path.name),
mfilter=MessageFilter(tags_or={Tag('tag1')}))
@@ -212,7 +252,7 @@ class TestChatDB(unittest.TestCase):
self.assertEqual(chat_db.messages[0].file_path,
pathlib.Path(self.db_path.name, '0001.txt'))
def test_chat_db_from_dir_filter_tags_empty(self) -> None:
def test_from_dir_filter_tags_empty(self) -> None:
chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name),
pathlib.Path(self.db_path.name),
mfilter=MessageFilter(tags_or=set(),
@@ -220,7 +260,7 @@ class TestChatDB(unittest.TestCase):
tags_not=set()))
self.assertEqual(len(chat_db.messages), 0)
def test_chat_db_from_dir_filter_answer(self) -> None:
def test_from_dir_filter_answer(self) -> None:
chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name),
pathlib.Path(self.db_path.name),
mfilter=MessageFilter(answer_contains='Answer 2'))
@@ -231,7 +271,7 @@ class TestChatDB(unittest.TestCase):
pathlib.Path(self.db_path.name, '0002.yaml'))
self.assertEqual(chat_db.messages[0].answer, 'Answer 2')
def test_chat_db_from_messages(self) -> None:
def test_from_messages(self) -> None:
chat_db = ChatDB.from_messages(pathlib.Path(self.cache_path.name),
pathlib.Path(self.db_path.name),
messages=[self.message1, self.message2,
@@ -240,16 +280,35 @@ class TestChatDB(unittest.TestCase):
self.assertEqual(chat_db.cache_path, pathlib.Path(self.cache_path.name))
self.assertEqual(chat_db.db_path, pathlib.Path(self.db_path.name))
def test_chat_db_fids(self) -> None:
def test_fids(self) -> None:
chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name),
pathlib.Path(self.db_path.name))
self.assertEqual(chat_db.get_next_fid(), 5)
self.assertEqual(chat_db.get_next_fid(), 6)
self.assertEqual(chat_db.get_next_fid(), 7)
with open(chat_db.next_fname, 'r') as f:
with open(chat_db.next_path, 'r') as f:
self.assertEqual(f.read(), '7')
def test_chat_db_write(self) -> None:
def test_msg_in_db_or_cache(self) -> None:
# create a new ChatDB instance
chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name),
pathlib.Path(self.db_path.name))
self.assertTrue(chat_db.msg_in_db(self.message1))
self.assertTrue(chat_db.msg_in_db(str(self.message1.file_path)))
self.assertTrue(chat_db.msg_in_db(self.message1.msg_id()))
self.assertFalse(chat_db.msg_in_cache(self.message1))
self.assertFalse(chat_db.msg_in_cache(str(self.message1.file_path)))
self.assertFalse(chat_db.msg_in_cache(self.message1.msg_id()))
# add new message to the cache dir
cache_message = Message(question=Question("Question 1"),
answer=Answer("Answer 1"))
chat_db.cache_add([cache_message])
self.assertTrue(chat_db.msg_in_cache(cache_message))
self.assertTrue(chat_db.msg_in_cache(cache_message.msg_id()))
self.assertFalse(chat_db.msg_in_db(cache_message))
self.assertFalse(chat_db.msg_in_db(str(cache_message.file_path)))
def test_db_write(self) -> None:
# create a new ChatDB instance
chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name),
pathlib.Path(self.db_path.name))
@@ -260,7 +319,7 @@ class TestChatDB(unittest.TestCase):
self.assertEqual(chat_db.messages[3].file_path, pathlib.Path(self.db_path.name, '0004.yaml'))
# write the messages to the cache directory
chat_db.write_cache()
chat_db.cache_write()
# check if the written files are in the cache directory
cache_dir_files = self.message_list(self.cache_path)
self.assertEqual(len(cache_dir_files), 4)
@@ -280,7 +339,7 @@ class TestChatDB(unittest.TestCase):
old_timestamps = {file: file.stat().st_mtime for file in db_dir_files}
# overwrite the messages in the db directory
time.sleep(0.05)
chat_db.write_db()
chat_db.db_write()
# check if the written files are in the DB directory
db_dir_files = self.message_list(self.db_path)
self.assertEqual(len(db_dir_files), 4)
@@ -297,7 +356,7 @@ class TestChatDB(unittest.TestCase):
self.assertEqual(chat_db.messages[2].file_path, pathlib.Path(self.db_path.name, '0003.txt'))
self.assertEqual(chat_db.messages[3].file_path, pathlib.Path(self.db_path.name, '0004.yaml'))
def test_chat_db_read(self) -> None:
def test_db_read(self) -> None:
# create a new ChatDB instance
chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name),
pathlib.Path(self.db_path.name))
@@ -313,7 +372,7 @@ class TestChatDB(unittest.TestCase):
new_message1.to_file(pathlib.Path(self.db_path.name, '0005.txt'))
new_message2.to_file(pathlib.Path(self.db_path.name, '0006.yaml'))
# read and check them
chat_db.read_db()
chat_db.db_read()
self.assertEqual(len(chat_db.messages), 6)
self.assertEqual(chat_db.messages[4].file_path, pathlib.Path(self.db_path.name, '0005.txt'))
self.assertEqual(chat_db.messages[5].file_path, pathlib.Path(self.db_path.name, '0006.yaml'))
@@ -328,7 +387,7 @@ class TestChatDB(unittest.TestCase):
new_message3.to_file(pathlib.Path(self.cache_path.name, '0007.txt'))
new_message4.to_file(pathlib.Path(self.cache_path.name, '0008.yaml'))
# read and check them
chat_db.read_cache()
chat_db.cache_read()
self.assertEqual(len(chat_db.messages), 8)
# check that the new message have the cache dir path
self.assertEqual(chat_db.messages[6].file_path, pathlib.Path(self.cache_path.name, '0007.txt'))
@@ -343,7 +402,7 @@ class TestChatDB(unittest.TestCase):
new_message1.to_file(pathlib.Path(self.db_path.name, '0005.txt'))
new_message2.to_file(pathlib.Path(self.db_path.name, '0006.yaml'))
# read from the DB dir and check if the modified messages have been updated
chat_db.read_db()
chat_db.db_read()
self.assertEqual(len(chat_db.messages), 8)
self.assertEqual(chat_db.messages[4].question, 'New Question 1')
self.assertEqual(chat_db.messages[5].question, 'New Question 2')
@@ -354,13 +413,13 @@ class TestChatDB(unittest.TestCase):
new_message3.to_file(pathlib.Path(self.db_path.name, '0007.txt'))
new_message4.to_file(pathlib.Path(self.db_path.name, '0008.yaml'))
# read and check them
chat_db.read_db()
chat_db.db_read()
self.assertEqual(len(chat_db.messages), 8)
# check that they now have the DB path
self.assertEqual(chat_db.messages[6].file_path, pathlib.Path(self.db_path.name, '0007.txt'))
self.assertEqual(chat_db.messages[7].file_path, pathlib.Path(self.db_path.name, '0008.yaml'))
def test_chat_db_clear(self) -> None:
def test_cache_clear(self) -> None:
# create a new ChatDB instance
chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name),
pathlib.Path(self.db_path.name))
@@ -371,13 +430,13 @@ class TestChatDB(unittest.TestCase):
self.assertEqual(chat_db.messages[3].file_path, pathlib.Path(self.db_path.name, '0004.yaml'))
# write the messages to the cache directory
chat_db.write_cache()
chat_db.cache_write()
# check if the written files are in the cache directory
cache_dir_files = self.message_list(self.cache_path)
self.assertEqual(len(cache_dir_files), 4)
# now rewrite them to the DB dir and check for modified paths
chat_db.write_db()
chat_db.db_write()
db_dir_files = self.message_list(self.db_path)
self.assertEqual(len(db_dir_files), 4)
self.assertIn(pathlib.Path(self.db_path.name, '0001.txt'), db_dir_files)
@@ -392,10 +451,10 @@ class TestChatDB(unittest.TestCase):
message_cache = Message(question=Question("What the hell am I doing here?"),
answer=Answer("You're a creep!"),
file_path=pathlib.Path(self.cache_path.name, '0005.txt'))
chat_db.add_messages([message_empty, message_cache])
chat_db.msg_add([message_empty, message_cache])
# clear the cache and check the cache dir
chat_db.clear_cache()
chat_db.cache_clear()
cache_dir_files = self.message_list(self.cache_path)
self.assertEqual(len(cache_dir_files), 0)
# make sure that the DB messages (and the new message) are still there
@@ -405,7 +464,7 @@ class TestChatDB(unittest.TestCase):
# but not the message with the cache dir path
self.assertFalse(any(m.file_path == message_cache.file_path for m in chat_db.messages))
def test_chat_db_add(self) -> None:
def test_add(self) -> None:
# create a new ChatDB instance
chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name),
pathlib.Path(self.db_path.name))
@@ -416,7 +475,7 @@ class TestChatDB(unittest.TestCase):
# add new messages to the cache dir
message1 = Message(question=Question("Question 1"),
answer=Answer("Answer 1"))
chat_db.add_to_cache([message1])
chat_db.cache_add([message1])
# check if the file_path has been correctly set
self.assertIsNotNone(message1.file_path)
self.assertEqual(message1.file_path.parent, pathlib.Path(self.cache_path.name)) # type: ignore [union-attr]
@@ -426,7 +485,7 @@ class TestChatDB(unittest.TestCase):
# add new messages to the DB dir
message2 = Message(question=Question("Question 2"),
answer=Answer("Answer 2"))
chat_db.add_to_db([message2])
chat_db.db_add([message2])
# check if the file_path has been correctly set
self.assertIsNotNone(message2.file_path)
self.assertEqual(message2.file_path.parent, pathlib.Path(self.db_path.name)) # type: ignore [union-attr]
@@ -434,9 +493,9 @@ class TestChatDB(unittest.TestCase):
self.assertEqual(len(db_dir_files), 5)
with self.assertRaises(ChatError):
chat_db.add_to_cache([Message(Question("?"), file_path=pathlib.Path("foo"))])
chat_db.cache_add([Message(Question("?"), file_path=pathlib.Path("foo"))])
def test_chat_db_write_messages(self) -> None:
def test_msg_write(self) -> None:
# create a new ChatDB instance
chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name),
pathlib.Path(self.db_path.name))
@@ -450,16 +509,16 @@ class TestChatDB(unittest.TestCase):
message = Message(question=Question("Question 1"),
answer=Answer("Answer 1"))
with self.assertRaises(ChatError):
chat_db.write_messages([message])
chat_db.msg_write([message])
# write a message with a valid file_path
message.file_path = pathlib.Path(self.cache_path.name) / '123456.txt'
chat_db.write_messages([message])
chat_db.msg_write([message])
cache_dir_files = self.message_list(self.cache_path)
self.assertEqual(len(cache_dir_files), 1)
self.assertIn(pathlib.Path(self.cache_path.name, '123456.txt'), cache_dir_files)
def test_chat_db_update_messages(self) -> None:
def test_msg_update(self) -> None:
# create a new ChatDB instance
chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name),
pathlib.Path(self.db_path.name))
@@ -472,17 +531,111 @@ class TestChatDB(unittest.TestCase):
message = chat_db.messages[0]
message.answer = Answer("New answer")
# update message without writing
chat_db.update_messages([message], write=False)
chat_db.msg_update([message], write=False)
self.assertEqual(chat_db.messages[0].answer, Answer("New answer"))
# re-read the message and check for old content
chat_db.read_db()
chat_db.db_read()
self.assertEqual(chat_db.messages[0].answer, Answer("Answer 1"))
# now check with writing (message should be overwritten)
chat_db.update_messages([message], write=True)
chat_db.read_db()
chat_db.msg_update([message], write=True)
chat_db.db_read()
self.assertEqual(chat_db.messages[0].answer, Answer("New answer"))
# test without file_path -> expect error
message1 = Message(question=Question("Question 1"),
answer=Answer("Answer 1"))
with self.assertRaises(ChatError):
chat_db.update_messages([message1])
chat_db.msg_update([message1])
def test_msg_find(self) -> None:
chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name),
pathlib.Path(self.db_path.name))
# search for a DB file in memory
self.assertEqual(chat_db.msg_find([str(self.message1.file_path)], loc='mem'), [self.message1])
self.assertEqual(chat_db.msg_find([self.message1.msg_id()], loc='mem'), [self.message1])
self.assertEqual(chat_db.msg_find(['0001.txt'], loc='mem'), [self.message1])
self.assertEqual(chat_db.msg_find(['0001'], loc='mem'), [self.message1])
# and on disk
self.assertEqual(chat_db.msg_find([str(self.message2.file_path)], loc='db'), [self.message2])
self.assertEqual(chat_db.msg_find([self.message2.msg_id()], loc='db'), [self.message2])
self.assertEqual(chat_db.msg_find(['0002.yaml'], loc='db'), [self.message2])
self.assertEqual(chat_db.msg_find(['0002'], loc='db'), [self.message2])
# now search the cache -> expect empty result
self.assertEqual(chat_db.msg_find([str(self.message3.file_path)], loc='cache'), [])
self.assertEqual(chat_db.msg_find([self.message3.msg_id()], loc='cache'), [])
self.assertEqual(chat_db.msg_find(['0003.txt'], loc='cache'), [])
self.assertEqual(chat_db.msg_find(['0003'], loc='cache'), [])
# search for multiple messages
# -> search one twice, expect result to be unique
search_names = ['0001', '0002.yaml', self.message3.msg_id(), str(self.message3.file_path)]
expected_result = [self.message1, self.message2, self.message3]
result = chat_db.msg_find(search_names, loc='all')
self.assertSequenceEqual(result, expected_result)
def test_msg_latest(self) -> None:
chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name),
pathlib.Path(self.db_path.name))
self.assertEqual(chat_db.msg_latest(loc='mem'), self.message4)
self.assertEqual(chat_db.msg_latest(loc='db'), self.message4)
self.assertEqual(chat_db.msg_latest(loc='disk'), self.message4)
self.assertEqual(chat_db.msg_latest(loc='all'), self.message4)
# the cache is currently empty:
self.assertIsNone(chat_db.msg_latest(loc='cache'))
# add new messages to the cache dir
new_message = Message(question=Question("New Question"),
answer=Answer("New Answer"))
chat_db.cache_add([new_message])
self.assertEqual(chat_db.msg_latest(loc='cache'), new_message)
self.assertEqual(chat_db.msg_latest(loc='mem'), new_message)
self.assertEqual(chat_db.msg_latest(loc='disk'), new_message)
self.assertEqual(chat_db.msg_latest(loc='all'), new_message)
# the DB does not contain the new message
self.assertEqual(chat_db.msg_latest(loc='db'), self.message4)
def test_msg_gather(self) -> None:
chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name),
pathlib.Path(self.db_path.name))
all_messages = [self.message1, self.message2, self.message3, self.message4]
self.assertSequenceEqual(chat_db.msg_gather(loc='all'), all_messages)
self.assertSequenceEqual(chat_db.msg_gather(loc='db'), all_messages)
self.assertSequenceEqual(chat_db.msg_gather(loc='mem'), all_messages)
self.assertSequenceEqual(chat_db.msg_gather(loc='disk'), all_messages)
self.assertSequenceEqual(chat_db.msg_gather(loc='cache'), [])
# add a new message, but only to the internal list
new_message = Message(Question("What?"))
all_messages_mem = all_messages + [new_message]
chat_db.msg_add([new_message])
self.assertSequenceEqual(chat_db.msg_gather(loc='mem'), all_messages_mem)
self.assertSequenceEqual(chat_db.msg_gather(loc='all'), all_messages_mem)
# the nr. of messages on disk did not change -> expect old result
self.assertSequenceEqual(chat_db.msg_gather(loc='db'), all_messages)
self.assertSequenceEqual(chat_db.msg_gather(loc='disk'), all_messages)
self.assertSequenceEqual(chat_db.msg_gather(loc='cache'), [])
# test with MessageFilter
self.assertSequenceEqual(chat_db.msg_gather(loc='all', mfilter=MessageFilter(tags_or={Tag('tag1')})),
[self.message1])
self.assertSequenceEqual(chat_db.msg_gather(loc='disk', mfilter=MessageFilter(tags_or={Tag('tag2')})),
[self.message2])
self.assertSequenceEqual(chat_db.msg_gather(loc='cache', mfilter=MessageFilter(tags_or={Tag('tag3')})),
[])
self.assertSequenceEqual(chat_db.msg_gather(loc='mem', mfilter=MessageFilter(question_contains="What")),
[new_message])
def test_msg_move_and_gather(self) -> None:
chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name),
pathlib.Path(self.db_path.name))
all_messages = [self.message1, self.message2, self.message3, self.message4]
self.assertSequenceEqual(chat_db.msg_gather(loc='db'), all_messages)
self.assertSequenceEqual(chat_db.msg_gather(loc='cache'), [])
# move first message to the cache
chat_db.cache_move(self.message1)
self.assertSequenceEqual(chat_db.msg_gather(loc='cache'), [self.message1])
self.assertEqual(self.message1.file_path.parent, pathlib.Path(self.cache_path.name)) # type: ignore [union-attr]
self.assertSequenceEqual(chat_db.msg_gather(loc='db'), [self.message2, self.message3, self.message4])
self.assertSequenceEqual(chat_db.msg_gather(loc='all'), all_messages)
self.assertSequenceEqual(chat_db.msg_gather(loc='disk'), all_messages)
self.assertSequenceEqual(chat_db.msg_gather(loc='mem'), all_messages)
# now move first message back to the DB
chat_db.db_move(self.message1)
self.assertSequenceEqual(chat_db.msg_gather(loc='cache'), [])
self.assertEqual(self.message1.file_path.parent, pathlib.Path(self.db_path.name)) # type: ignore [union-attr]
self.assertSequenceEqual(chat_db.msg_gather(loc='db'), all_messages)
+1 -1
View File
@@ -730,7 +730,7 @@ class MessageIDTestCase(unittest.TestCase):
self.file_path.unlink()
def test_msg_id_txt(self) -> None:
self.assertEqual(self.message.msg_id(), self.file_path.name)
self.assertEqual(self.message.msg_id(), self.file_path.stem)
def test_msg_id_txt_exception(self) -> None:
with self.assertRaises(MessageError):
+34 -1
View File
@@ -5,7 +5,7 @@ import tempfile
from pathlib import Path
from unittest.mock import MagicMock
from chatmastermind.commands.question import create_message
from chatmastermind.message import Message, Question
from chatmastermind.message import Message, Question, Answer
from chatmastermind.chat import ChatDB
@@ -20,6 +20,12 @@ class TestMessageCreate(unittest.TestCase):
self.cache_path = tempfile.TemporaryDirectory()
self.chat = ChatDB.from_dir(cache_path=Path(self.cache_path.name),
db_path=Path(self.db_path.name))
# create some messages
self.message_text = Message(Question("What is this?"),
Answer("It is pure text"))
self.message_code = Message(Question("What is this?"),
Answer("Text\n```\nIt is embedded code\n```\ntext"))
self.chat.db_add([self.message_text, self.message_code])
# create arguments mock
self.args = MagicMock(spec=argparse.Namespace)
self.args.source_text = None
@@ -159,4 +165,31 @@ This is embedded source code.
```
This is embedded source code.
```
"""))
def test_single_question_with_text_only_message(self) -> None:
self.args.ask = ["What is this?"]
self.args.source_text = [f"{self.chat.messages[0].file_path}"]
message = create_message(self.chat, self.args)
self.assertIsInstance(message, Message)
# file contains no source code (only text)
# -> don't expect any in the question
self.assertEqual(len(message.question.source_code()), 0)
self.assertEqual(message.question, Question(f"""What is this?
{self.message_text.answer}"""))
def test_single_question_with_message_and_embedded_code(self) -> None:
self.args.ask = ["What is this?"]
self.args.source_code = [f"{self.chat.messages[1].file_path}"]
message = create_message(self.chat, self.args)
self.assertIsInstance(message, Message)
# answer contains 1 source code block
# -> expect it in the question
self.assertEqual(len(message.question.source_code()), 1)
self.assertEqual(message.question, Question("""What is this?
```
It is embedded code
```
"""))