Compare commits
76 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 17a0264025 | |||
| 7f4a16894e | |||
| 26e3d38afb | |||
| b5af751193 | |||
| a7345cbc41 | |||
| 310cb9421e | |||
| 1ec3d6fcda | |||
| 544bf0bf06 | |||
| f96e82bdd7 | |||
| 2b62cb8c4b | |||
| a895c1fc6a | |||
| ddfcc71510 | |||
| 17de0b9967 | |||
| 33023d29f9 | |||
| 481f9ecf7c | |||
| 22fa187e5f | |||
| b840ebd792 | |||
| 66908f5fed | |||
| 2e08ccf606 | |||
| 595ff8e294 | |||
| faac42d3c2 | |||
| 864ab7aeb1 | |||
| cc76da2ab3 | |||
| f99cd3ed41 | |||
| 6f3ea98425 | |||
| 54ece6efeb | |||
| 86eebc39ea | |||
| 3eca53998b | |||
| c4f7bcc94e | |||
| c52713c833 | |||
| ecb6994783 | |||
| 61e710a4b1 | |||
| 21d39c6c66 | |||
| 6a4cc7a65d | |||
| d6bb5800b1 | |||
| 034e4093f1 | |||
| 7d15452242 | |||
| 823d3bf7dc | |||
| 4bd144c4d7 | |||
| e186afbef0 | |||
| 5e4ec70072 | |||
| 4c378dde85 | |||
| 8923a13352 | |||
| e1414835c8 | |||
| abb7fdacb6 | |||
| 2e2228bd60 | |||
| 713b55482a | |||
| d35de86c67 | |||
| aba3eb783d | |||
| 8e63831701 | |||
| c318b99671 | |||
| 48c8e951e1 | |||
| b22a4b07ed | |||
| 33565d351d | |||
| 6737fa98c7 | |||
| 815a21893c | |||
| 64893949a4 | |||
| a093f9b867 | |||
| dc3f3dc168 | |||
| 74c39070d6 | |||
| fde0ae4652 | |||
| 238dbbee60 | |||
| 17f7b2fb45 | |||
| 9c2598a4b8 | |||
| acec5f1d55 | |||
| c0f50bace5 | |||
| 30ccec2462 | |||
| 09da312657 | |||
| 33567df15f | |||
| 264979a60d | |||
| 061e5f8682 | |||
| 2d456e68f1 | |||
| 8bd659e888 | |||
| 3ef1339cc0 | |||
| ed567afbea | |||
| 6e447018d5 |
@@ -37,63 +37,95 @@ cmm [global options] command [command options]
|
|||||||
|
|
||||||
### Global Options
|
### Global Options
|
||||||
|
|
||||||
- `-c`, `--config`: Config file name (defaults to `.config.yaml`).
|
- `-C`, `--config`: Config file name (defaults to `.config.yaml`).
|
||||||
|
|
||||||
### Commands
|
|
||||||
|
|
||||||
- `ask`: Ask a question.
|
|
||||||
- `hist`: Print chat history.
|
|
||||||
- `tag`: Manage tags.
|
|
||||||
- `config`: Manage configuration.
|
|
||||||
- `print`: Print files.
|
|
||||||
|
|
||||||
### Command Options
|
### Command Options
|
||||||
|
|
||||||
#### `ask` Command Options
|
#### Question
|
||||||
|
|
||||||
- `-q`, `--question`: Question to ask (required).
|
The `question` command is used to ask, create, and process questions.
|
||||||
- `-m`, `--max-tokens`: Max tokens to use.
|
|
||||||
- `-T`, `--temperature`: Temperature to use.
|
|
||||||
- `-M`, `--model`: Model to use.
|
|
||||||
- `-n`, `--number`: Number of answers to produce (default is 3).
|
|
||||||
- `-s`, `--source`: Add content of a file to the query.
|
|
||||||
- `-S`, `--only-source-code`: Add pure source code to the chat history.
|
|
||||||
- `-t`, `--tags`: List of tag names.
|
|
||||||
- `-e`, `--extags`: List of tag names to exclude.
|
|
||||||
- `-o`, `--output-tags`: List of output tag names (default is the input tags).
|
|
||||||
- `-a`, `--match-all-tags`: All given tags must match when selecting chat history entries.
|
|
||||||
|
|
||||||
#### `hist` Command Options
|
```bash
|
||||||
|
cmm question [-t OTAGS]... [-k ATAGS]... [-x XTAGS]... [-o OUTTAGS]... [-A AI] [-M MODEL] [-n NUM] [-m MAX] [-T TEMP] (-a ASK | -c CREATE | -r REPEAT | -p PROCESS) [-O] [-s SOURCE]... [-S SOURCE]...
|
||||||
|
```
|
||||||
|
|
||||||
- `-d`, `--dump`: Print chat history as Python structure.
|
* `-t, --or-tags OTAGS` : List of tags (one must match)
|
||||||
- `-w`, `--with-tags`: Print chat history with tags.
|
* `-k, --and-tags ATAGS` : List of tags (all must match)
|
||||||
- `-W`, `--with-files`: Print chat history with filenames.
|
* `-x, --exclude-tags XTAGS` : List of tags to exclude
|
||||||
- `-S`, `--only-source-code`: Print only source code.
|
* `-o, --output-tags OUTTAGS` : List of output tags (default: use input tags)
|
||||||
- `-t`, `--tags`: List of tag names.
|
* `-A, --AI AI` : AI ID to use
|
||||||
- `-e`, `--extags`: List of tag names to exclude.
|
* `-M, --model MODEL` : Model to use
|
||||||
- `-a`, `--match-all-tags`: All given tags must match when selecting chat history entries.
|
* `-n, --num-answers NUM` : Number of answers to request
|
||||||
|
* `-m, --max-tokens MAX` : Max. number of tokens
|
||||||
|
* `-T, --temperature TEMP` : Temperature value
|
||||||
|
* `-a, --ask ASK` : Ask a question
|
||||||
|
* `-c, --create CREATE` : Create a question
|
||||||
|
* `-r, --repeat REPEAT` : Repeat a question
|
||||||
|
* `-p, --process PROCESS` : Process existing questions
|
||||||
|
* `-O, --overwrite` : Overwrite existing messages when repeating them
|
||||||
|
* `-s, --source-text SOURCE` : Add content of a file to the query
|
||||||
|
* `-S, --source-code SOURCE` : Add source code file content to the chat history
|
||||||
|
|
||||||
#### `tag` Command Options
|
#### Hist
|
||||||
|
|
||||||
- `-l`, `--list`: List all tags and their frequency.
|
The `hist` command is used to print the chat history.
|
||||||
|
|
||||||
#### `config` Command Options
|
```bash
|
||||||
|
cmm hist [-t OTAGS]... [-k ATAGS]... [-x XTAGS]... [-w] [-W] [-S] [-A ANSWER] [-Q QUESTION]
|
||||||
|
```
|
||||||
|
|
||||||
- `-l`, `--list-models`: List all available models.
|
* `-t, --or-tags OTAGS` : List of tags (one must match)
|
||||||
- `-m`, `--print-model`: Print the currently configured model.
|
* `-k, --and-tags ATAGS` : List of tags (all must match)
|
||||||
- `-M`, `--model`: Set model in the config file.
|
* `-x, --exclude-tags XTAGS` : List of tags to exclude
|
||||||
|
* `-w, --with-tags` : Print chat history with tags
|
||||||
|
* `-W, --with-files` : Print chat history with filenames
|
||||||
|
* `-S, --source-code-only` : Print only source code
|
||||||
|
* `-A, --answer ANSWER` : Search for answer substring
|
||||||
|
* `-Q, --question QUESTION` : Search for question substring
|
||||||
|
|
||||||
#### `print` Command Options
|
#### Tags
|
||||||
|
|
||||||
- `-f`, `--file`: File to print (required).
|
The `tags` command is used to manage tags.
|
||||||
- `-S`, `--only-source-code`: Print only source code.
|
|
||||||
|
```bash
|
||||||
|
cmm tags (-l | -p PREFIX | -c CONTENT)
|
||||||
|
```
|
||||||
|
|
||||||
|
* `-l, --list` : List all tags and their frequency
|
||||||
|
* `-p, --prefix PREFIX` : Filter tags by prefix
|
||||||
|
* `-c, --contain CONTENT` : Filter tags by contained substring
|
||||||
|
|
||||||
|
#### Config
|
||||||
|
|
||||||
|
The `config` command is used to manage the configuration.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
cmm config (-l | -m | -c CREATE)
|
||||||
|
```
|
||||||
|
|
||||||
|
* `-l, --list-models` : List all available models
|
||||||
|
* `-m, --print-model` : Print the currently configured model
|
||||||
|
* `-c, --create CREATE` : Create config with default settings in the given file
|
||||||
|
|
||||||
|
#### Print
|
||||||
|
|
||||||
|
The `print` command is used to print message files.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
cmm print -f FILE [-q | -a | -S]
|
||||||
|
```
|
||||||
|
|
||||||
|
* `-f, --file FILE` : File to print
|
||||||
|
* `-q, --question` : Print only question
|
||||||
|
* `-a, --answer` : Print only answer
|
||||||
|
* `-S, --only-source-code` : Print only source code
|
||||||
|
|
||||||
### Examples
|
### Examples
|
||||||
|
|
||||||
1. Ask a question:
|
1. Ask a question:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
cmm ask -q "What is the meaning of life?" -t philosophy -e religion
|
cmm question -a "What is the meaning of life?" -t philosophy -x religion
|
||||||
```
|
```
|
||||||
|
|
||||||
2. Display the chat history:
|
2. Display the chat history:
|
||||||
@@ -105,19 +137,19 @@ cmm hist
|
|||||||
3. Filter chat history by tags:
|
3. Filter chat history by tags:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
cmm hist -t tag1 tag2
|
cmm hist --or-tags tag1 tag2
|
||||||
```
|
```
|
||||||
|
|
||||||
4. Exclude chat history by tags:
|
4. Exclude chat history by tags:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
cmm hist -e tag3 tag4
|
cmm hist --exclude-tags tag3 tag4
|
||||||
```
|
```
|
||||||
|
|
||||||
5. List all tags and their frequency:
|
5. List all tags and their frequency:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
cmm tag -l
|
cmm tags -l
|
||||||
```
|
```
|
||||||
|
|
||||||
6. Print the contents of a file:
|
6. Print the contents of a file:
|
||||||
|
|||||||
@@ -59,6 +59,12 @@ class AI(Protocol):
|
|||||||
"""
|
"""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def print_models(self) -> None:
|
||||||
|
"""
|
||||||
|
Print all models supported by this AI.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
def tokens(self, data: Union[Message, Chat]) -> int:
|
def tokens(self, data: Union[Message, Chat]) -> int:
|
||||||
"""
|
"""
|
||||||
Computes the nr. of AI language tokens for the given message
|
Computes the nr. of AI language tokens for the given message
|
||||||
|
|||||||
@@ -17,7 +17,7 @@ def create_ai(args: argparse.Namespace, config: Config) -> AI: # noqa: 11
|
|||||||
is not found, it uses the first AI in the list.
|
is not found, it uses the first AI in the list.
|
||||||
"""
|
"""
|
||||||
ai_conf: AIConfig
|
ai_conf: AIConfig
|
||||||
if args.AI:
|
if hasattr(args, 'AI') and args.AI:
|
||||||
try:
|
try:
|
||||||
ai_conf = config.ais[args.AI]
|
ai_conf = config.ais[args.AI]
|
||||||
except KeyError:
|
except KeyError:
|
||||||
@@ -32,11 +32,11 @@ def create_ai(args: argparse.Namespace, config: Config) -> AI: # noqa: 11
|
|||||||
|
|
||||||
if ai_conf.name == 'openai':
|
if ai_conf.name == 'openai':
|
||||||
ai = OpenAI(cast(OpenAIConfig, ai_conf))
|
ai = OpenAI(cast(OpenAIConfig, ai_conf))
|
||||||
if args.model:
|
if hasattr(args, 'model') and args.model:
|
||||||
ai.config.model = args.model
|
ai.config.model = args.model
|
||||||
if args.max_tokens:
|
if hasattr(args, 'max_tokens') and args.max_tokens:
|
||||||
ai.config.max_tokens = args.max_tokens
|
ai.config.max_tokens = args.max_tokens
|
||||||
if args.temperature:
|
if hasattr(args, 'temperature') and args.temperature:
|
||||||
ai.config.temperature = args.temperature
|
ai.config.temperature = args.temperature
|
||||||
return ai
|
return ai
|
||||||
else:
|
else:
|
||||||
|
|||||||
@@ -62,7 +62,12 @@ class OpenAI(AI):
|
|||||||
"""
|
"""
|
||||||
Return all models supported by this AI.
|
Return all models supported by this AI.
|
||||||
"""
|
"""
|
||||||
raise NotImplementedError
|
ret = []
|
||||||
|
for engine in sorted(openai.Engine.list()['data'], key=lambda x: x['id']):
|
||||||
|
if engine['ready']:
|
||||||
|
ret.append(engine['id'])
|
||||||
|
ret.sort()
|
||||||
|
return ret
|
||||||
|
|
||||||
def print_models(self) -> None:
|
def print_models(self) -> None:
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -204,7 +204,6 @@ class Chat:
|
|||||||
output.append(message.to_str(source_code_only=True))
|
output.append(message.to_str(source_code_only=True))
|
||||||
continue
|
continue
|
||||||
output.append(message.to_str(with_tags, with_files))
|
output.append(message.to_str(with_tags, with_files))
|
||||||
output.append('\n' + ('-' * terminal_width()) + '\n')
|
|
||||||
if paged:
|
if paged:
|
||||||
print_paged('\n'.join(output))
|
print_paged('\n'.join(output))
|
||||||
else:
|
else:
|
||||||
|
|||||||
@@ -1,6 +1,8 @@
|
|||||||
import argparse
|
import argparse
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from ..configuration import Config
|
from ..configuration import Config
|
||||||
|
from ..ai import AI
|
||||||
|
from ..ai_factory import create_ai
|
||||||
|
|
||||||
|
|
||||||
def config_cmd(args: argparse.Namespace) -> None:
|
def config_cmd(args: argparse.Namespace) -> None:
|
||||||
@@ -9,3 +11,10 @@ def config_cmd(args: argparse.Namespace) -> None:
|
|||||||
"""
|
"""
|
||||||
if args.create:
|
if args.create:
|
||||||
Config.create_default(Path(args.create))
|
Config.create_default(Path(args.create))
|
||||||
|
elif args.list_models or args.print_model:
|
||||||
|
config: Config = Config.from_file(args.config)
|
||||||
|
ai: AI = create_ai(args, config)
|
||||||
|
if args.list_models:
|
||||||
|
ai.print_models()
|
||||||
|
else:
|
||||||
|
print(ai.config.model)
|
||||||
|
|||||||
@@ -3,11 +3,52 @@ from pathlib import Path
|
|||||||
from itertools import zip_longest
|
from itertools import zip_longest
|
||||||
from ..configuration import Config
|
from ..configuration import Config
|
||||||
from ..chat import ChatDB
|
from ..chat import ChatDB
|
||||||
from ..message import Message, MessageFilter, Question, source_code
|
from ..message import Message, MessageFilter, MessageError, Question, source_code
|
||||||
from ..ai_factory import create_ai
|
from ..ai_factory import create_ai
|
||||||
from ..ai import AI, AIResponse
|
from ..ai import AI, AIResponse
|
||||||
|
|
||||||
|
|
||||||
|
def add_file_as_text(question_parts: list[str], file: str) -> None:
|
||||||
|
"""
|
||||||
|
Add the given file as plain text to the question part list.
|
||||||
|
If the file is a Message, add the answer.
|
||||||
|
"""
|
||||||
|
file_path = Path(file)
|
||||||
|
content: str
|
||||||
|
try:
|
||||||
|
message = Message.from_file(file_path)
|
||||||
|
if message and message.answer:
|
||||||
|
content = message.answer
|
||||||
|
except MessageError:
|
||||||
|
with open(file) as r:
|
||||||
|
content = r.read().strip()
|
||||||
|
if len(content) > 0:
|
||||||
|
question_parts.append(content)
|
||||||
|
|
||||||
|
|
||||||
|
def add_file_as_code(question_parts: list[str], file: str) -> None:
|
||||||
|
"""
|
||||||
|
Add all source code from the given file. If no code segments can be extracted,
|
||||||
|
the whole content is added as source code segment. If the file is a Message,
|
||||||
|
extract the source code from the answer.
|
||||||
|
"""
|
||||||
|
file_path = Path(file)
|
||||||
|
content: str
|
||||||
|
try:
|
||||||
|
message = Message.from_file(file_path)
|
||||||
|
if message and message.answer:
|
||||||
|
content = message.answer
|
||||||
|
except MessageError:
|
||||||
|
with open(file) as r:
|
||||||
|
content = r.read().strip()
|
||||||
|
# extract and add source code
|
||||||
|
code_parts = source_code(content, include_delims=True)
|
||||||
|
if len(code_parts) > 0:
|
||||||
|
question_parts += code_parts
|
||||||
|
else:
|
||||||
|
question_parts.append(f"```\n{content}\n```")
|
||||||
|
|
||||||
|
|
||||||
def create_message(chat: ChatDB, args: argparse.Namespace) -> Message:
|
def create_message(chat: ChatDB, args: argparse.Namespace) -> Message:
|
||||||
"""
|
"""
|
||||||
Creates (and writes) a new message from the given arguments.
|
Creates (and writes) a new message from the given arguments.
|
||||||
@@ -17,26 +58,13 @@ def create_message(chat: ChatDB, args: argparse.Namespace) -> Message:
|
|||||||
text_files = args.source_text if args.source_text is not None else []
|
text_files = args.source_text if args.source_text is not None else []
|
||||||
code_files = args.source_code if args.source_code is not None else []
|
code_files = args.source_code if args.source_code is not None else []
|
||||||
|
|
||||||
for question, source, code in zip_longest(question_list, text_files, code_files, fillvalue=None):
|
for question, text_file, code_file in zip_longest(question_list, text_files, code_files, fillvalue=None):
|
||||||
if question is not None and len(question.strip()) > 0:
|
if question is not None and len(question.strip()) > 0:
|
||||||
question_parts.append(question)
|
question_parts.append(question)
|
||||||
if source is not None and len(source) > 0:
|
if text_file is not None and len(text_file) > 0:
|
||||||
with open(source) as r:
|
add_file_as_text(question_parts, text_file)
|
||||||
content = r.read().strip()
|
if code_file is not None and len(code_file) > 0:
|
||||||
if len(content) > 0:
|
add_file_as_code(question_parts, code_file)
|
||||||
question_parts.append(content)
|
|
||||||
if code is not None and len(code) > 0:
|
|
||||||
with open(code) as r:
|
|
||||||
content = r.read().strip()
|
|
||||||
if len(content) == 0:
|
|
||||||
continue
|
|
||||||
# try to extract and add source code
|
|
||||||
code_parts = source_code(content, include_delims=True)
|
|
||||||
if len(code_parts) > 0:
|
|
||||||
question_parts += code_parts
|
|
||||||
# if there's none, add the whole file
|
|
||||||
else:
|
|
||||||
question_parts.append(f"```\n{content}\n```")
|
|
||||||
|
|
||||||
full_question = '\n\n'.join(question_parts)
|
full_question = '\n\n'.join(question_parts)
|
||||||
|
|
||||||
|
|||||||
@@ -39,6 +39,7 @@ class AIConfig:
|
|||||||
name: ClassVar[str]
|
name: ClassVar[str]
|
||||||
# a user-defined ID for an AI configuration entry
|
# a user-defined ID for an AI configuration entry
|
||||||
ID: str
|
ID: str
|
||||||
|
model: str = 'n/a'
|
||||||
|
|
||||||
# the name must not be changed
|
# the name must not be changed
|
||||||
def __setattr__(self, name: str, value: Any) -> None:
|
def __setattr__(self, name: str, value: Any) -> None:
|
||||||
|
|||||||
@@ -100,6 +100,7 @@ def create_parser() -> argparse.ArgumentParser:
|
|||||||
help="Manage configuration",
|
help="Manage configuration",
|
||||||
aliases=['c'])
|
aliases=['c'])
|
||||||
config_cmd_parser.set_defaults(func=config_cmd)
|
config_cmd_parser.set_defaults(func=config_cmd)
|
||||||
|
config_cmd_parser.add_argument('-A', '--AI', help='AI ID to use')
|
||||||
config_group = config_cmd_parser.add_mutually_exclusive_group(required=True)
|
config_group = config_cmd_parser.add_mutually_exclusive_group(required=True)
|
||||||
config_group.add_argument('-l', '--list-models', help="List all available models",
|
config_group.add_argument('-l', '--list-models', help="List all available models",
|
||||||
action='store_true')
|
action='store_true')
|
||||||
|
|||||||
@@ -0,0 +1,56 @@
|
|||||||
|
<?php
|
||||||
|
|
||||||
|
$secret_key = '123';
|
||||||
|
|
||||||
|
// check for POST request
|
||||||
|
if ($_SERVER['REQUEST_METHOD'] != 'POST') {
|
||||||
|
error_log('FAILED - not POST - '. $_SERVER['REQUEST_METHOD']);
|
||||||
|
exit();
|
||||||
|
}
|
||||||
|
|
||||||
|
// get content type
|
||||||
|
$content_type = isset($_SERVER['CONTENT_TYPE']) ? strtolower(trim($_SERVER['CONTENT_TYPE'])) : '';
|
||||||
|
|
||||||
|
if ($content_type != 'application/json') {
|
||||||
|
error_log('FAILED - not application/json - '. $content_type);
|
||||||
|
exit();
|
||||||
|
}
|
||||||
|
|
||||||
|
// get payload
|
||||||
|
$payload = trim(file_get_contents("php://input"));
|
||||||
|
|
||||||
|
if (empty($payload)) {
|
||||||
|
error_log('FAILED - no payload');
|
||||||
|
exit();
|
||||||
|
}
|
||||||
|
|
||||||
|
// get header signature
|
||||||
|
$header_signature = isset($_SERVER['HTTP_X_GITEA_SIGNATURE']) ? $_SERVER['HTTP_X_GITEA_SIGNATURE'] : '';
|
||||||
|
|
||||||
|
if (empty($header_signature)) {
|
||||||
|
error_log('FAILED - header signature missing');
|
||||||
|
exit();
|
||||||
|
}
|
||||||
|
|
||||||
|
// calculate payload signature
|
||||||
|
$payload_signature = hash_hmac('sha256', $payload, $secret_key, false);
|
||||||
|
|
||||||
|
// check payload signature against header signature
|
||||||
|
if ($header_signature !== $payload_signature) {
|
||||||
|
error_log('FAILED - payload signature');
|
||||||
|
exit();
|
||||||
|
}
|
||||||
|
|
||||||
|
// convert json to array
|
||||||
|
$decoded = json_decode($payload, true);
|
||||||
|
|
||||||
|
// check for json decode errors
|
||||||
|
if (json_last_error() !== JSON_ERROR_NONE) {
|
||||||
|
error_log('FAILED - json decode - '. json_last_error());
|
||||||
|
exit();
|
||||||
|
}
|
||||||
|
|
||||||
|
// success, do something
|
||||||
|
$output = shell_exec('/home/kaizen/repos/ChatMastermind/hooks/push_hook.sh');
|
||||||
|
echo "$output";
|
||||||
|
?>
|
||||||
Executable
+8
@@ -0,0 +1,8 @@
|
|||||||
|
#!/usr/bin/bash
|
||||||
|
|
||||||
|
. /home/kaizen/.bashrc
|
||||||
|
set -e
|
||||||
|
cd /home/kaizen/repos/ChatMastermind
|
||||||
|
git pull
|
||||||
|
pre-commit run -a
|
||||||
|
pytest
|
||||||
@@ -0,0 +1,81 @@
|
|||||||
|
import unittest
|
||||||
|
from unittest import mock
|
||||||
|
from chatmastermind.ais.openai import OpenAI
|
||||||
|
from chatmastermind.message import Message, Question, Answer
|
||||||
|
from chatmastermind.chat import Chat
|
||||||
|
from chatmastermind.ai import AIResponse, Tokens
|
||||||
|
from chatmastermind.configuration import OpenAIConfig
|
||||||
|
|
||||||
|
|
||||||
|
class OpenAITest(unittest.TestCase):
|
||||||
|
|
||||||
|
@mock.patch('openai.ChatCompletion.create')
|
||||||
|
def test_request(self, mock_create: mock.MagicMock) -> None:
|
||||||
|
# Create a test instance of OpenAI
|
||||||
|
config = OpenAIConfig()
|
||||||
|
openai = OpenAI(config)
|
||||||
|
|
||||||
|
# Set up the mock response from openai.ChatCompletion.create
|
||||||
|
mock_response = {
|
||||||
|
'choices': [
|
||||||
|
{
|
||||||
|
'message': {
|
||||||
|
'content': 'Answer 1'
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
'message': {
|
||||||
|
'content': 'Answer 2'
|
||||||
|
}
|
||||||
|
}
|
||||||
|
],
|
||||||
|
'usage': {
|
||||||
|
'prompt_tokens': 10,
|
||||||
|
'completion_tokens': 20,
|
||||||
|
'total_tokens': 30
|
||||||
|
}
|
||||||
|
}
|
||||||
|
mock_create.return_value = mock_response
|
||||||
|
|
||||||
|
# Create test data
|
||||||
|
question = Message(Question('Question'))
|
||||||
|
chat = Chat([
|
||||||
|
Message(Question('Question 1'), answer=Answer('Answer 1')),
|
||||||
|
Message(Question('Question 2'), answer=Answer('Answer 2')),
|
||||||
|
# add message without an answer -> expect to be skipped
|
||||||
|
Message(Question('Question 3'))
|
||||||
|
])
|
||||||
|
|
||||||
|
# Make the request
|
||||||
|
response = openai.request(question, chat, num_answers=2)
|
||||||
|
|
||||||
|
# Assert the AIResponse
|
||||||
|
self.assertIsInstance(response, AIResponse)
|
||||||
|
self.assertEqual(len(response.messages), 2)
|
||||||
|
self.assertEqual(response.messages[0].answer, 'Answer 1')
|
||||||
|
self.assertEqual(response.messages[1].answer, 'Answer 2')
|
||||||
|
self.assertIsNotNone(response.tokens)
|
||||||
|
self.assertIsInstance(response.tokens, Tokens)
|
||||||
|
assert response.tokens
|
||||||
|
self.assertEqual(response.tokens.prompt, 10)
|
||||||
|
self.assertEqual(response.tokens.completion, 20)
|
||||||
|
self.assertEqual(response.tokens.total, 30)
|
||||||
|
|
||||||
|
# Assert the mock call to openai.ChatCompletion.create
|
||||||
|
mock_create.assert_called_once_with(
|
||||||
|
model=f'{config.model}',
|
||||||
|
messages=[
|
||||||
|
{'role': 'system', 'content': f'{config.system}'},
|
||||||
|
{'role': 'user', 'content': 'Question 1'},
|
||||||
|
{'role': 'assistant', 'content': 'Answer 1'},
|
||||||
|
{'role': 'user', 'content': 'Question 2'},
|
||||||
|
{'role': 'assistant', 'content': 'Answer 2'},
|
||||||
|
{'role': 'user', 'content': 'Question'}
|
||||||
|
],
|
||||||
|
temperature=config.temperature,
|
||||||
|
max_tokens=config.max_tokens,
|
||||||
|
top_p=config.top_p,
|
||||||
|
n=2,
|
||||||
|
frequency_penalty=config.frequency_penalty,
|
||||||
|
presence_penalty=config.presence_penalty
|
||||||
|
)
|
||||||
+1
-13
@@ -6,7 +6,7 @@ from io import StringIO
|
|||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
from chatmastermind.tags import TagLine
|
from chatmastermind.tags import TagLine
|
||||||
from chatmastermind.message import Message, Question, Answer, Tag, MessageFilter
|
from chatmastermind.message import Message, Question, Answer, Tag, MessageFilter
|
||||||
from chatmastermind.chat import Chat, ChatDB, terminal_width, ChatError
|
from chatmastermind.chat import Chat, ChatDB, ChatError
|
||||||
|
|
||||||
|
|
||||||
class TestChat(unittest.TestCase):
|
class TestChat(unittest.TestCase):
|
||||||
@@ -92,16 +92,10 @@ class TestChat(unittest.TestCase):
|
|||||||
Question 1
|
Question 1
|
||||||
{Answer.txt_header}
|
{Answer.txt_header}
|
||||||
Answer 1
|
Answer 1
|
||||||
|
|
||||||
{'-'*terminal_width()}
|
|
||||||
|
|
||||||
{Question.txt_header}
|
{Question.txt_header}
|
||||||
Question 2
|
Question 2
|
||||||
{Answer.txt_header}
|
{Answer.txt_header}
|
||||||
Answer 2
|
Answer 2
|
||||||
|
|
||||||
{'-'*terminal_width()}
|
|
||||||
|
|
||||||
"""
|
"""
|
||||||
self.assertEqual(mock_stdout.getvalue(), expected_output)
|
self.assertEqual(mock_stdout.getvalue(), expected_output)
|
||||||
|
|
||||||
@@ -115,18 +109,12 @@ FILE: 0001.txt
|
|||||||
Question 1
|
Question 1
|
||||||
{Answer.txt_header}
|
{Answer.txt_header}
|
||||||
Answer 1
|
Answer 1
|
||||||
|
|
||||||
{'-'*terminal_width()}
|
|
||||||
|
|
||||||
{TagLine.prefix} btag2
|
{TagLine.prefix} btag2
|
||||||
FILE: 0002.txt
|
FILE: 0002.txt
|
||||||
{Question.txt_header}
|
{Question.txt_header}
|
||||||
Question 2
|
Question 2
|
||||||
{Answer.txt_header}
|
{Answer.txt_header}
|
||||||
Answer 2
|
Answer 2
|
||||||
|
|
||||||
{'-'*terminal_width()}
|
|
||||||
|
|
||||||
"""
|
"""
|
||||||
self.assertEqual(mock_stdout.getvalue(), expected_output)
|
self.assertEqual(mock_stdout.getvalue(), expected_output)
|
||||||
|
|
||||||
|
|||||||
@@ -5,7 +5,7 @@ import tempfile
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from unittest.mock import MagicMock
|
from unittest.mock import MagicMock
|
||||||
from chatmastermind.commands.question import create_message
|
from chatmastermind.commands.question import create_message
|
||||||
from chatmastermind.message import Message, Question
|
from chatmastermind.message import Message, Question, Answer
|
||||||
from chatmastermind.chat import ChatDB
|
from chatmastermind.chat import ChatDB
|
||||||
|
|
||||||
|
|
||||||
@@ -20,6 +20,12 @@ class TestMessageCreate(unittest.TestCase):
|
|||||||
self.cache_path = tempfile.TemporaryDirectory()
|
self.cache_path = tempfile.TemporaryDirectory()
|
||||||
self.chat = ChatDB.from_dir(cache_path=Path(self.cache_path.name),
|
self.chat = ChatDB.from_dir(cache_path=Path(self.cache_path.name),
|
||||||
db_path=Path(self.db_path.name))
|
db_path=Path(self.db_path.name))
|
||||||
|
# create some messages
|
||||||
|
self.message_text = Message(Question("What is this?"),
|
||||||
|
Answer("It is pure text"))
|
||||||
|
self.message_code = Message(Question("What is this?"),
|
||||||
|
Answer("Text\n```\nIt is embedded code\n```\ntext"))
|
||||||
|
self.chat.add_to_db([self.message_text, self.message_code])
|
||||||
# create arguments mock
|
# create arguments mock
|
||||||
self.args = MagicMock(spec=argparse.Namespace)
|
self.args = MagicMock(spec=argparse.Namespace)
|
||||||
self.args.source_text = None
|
self.args.source_text = None
|
||||||
@@ -159,4 +165,31 @@ This is embedded source code.
|
|||||||
```
|
```
|
||||||
This is embedded source code.
|
This is embedded source code.
|
||||||
```
|
```
|
||||||
|
"""))
|
||||||
|
|
||||||
|
def test_single_question_with_text_only_message(self) -> None:
|
||||||
|
self.args.ask = ["What is this?"]
|
||||||
|
self.args.source_text = [f"{self.chat.messages[0].file_path}"]
|
||||||
|
message = create_message(self.chat, self.args)
|
||||||
|
self.assertIsInstance(message, Message)
|
||||||
|
# file contains no source code (only text)
|
||||||
|
# -> don't expect any in the question
|
||||||
|
self.assertEqual(len(message.question.source_code()), 0)
|
||||||
|
self.assertEqual(message.question, Question(f"""What is this?
|
||||||
|
|
||||||
|
{self.message_text.answer}"""))
|
||||||
|
|
||||||
|
def test_single_question_with_message_and_embedded_code(self) -> None:
|
||||||
|
self.args.ask = ["What is this?"]
|
||||||
|
self.args.source_code = [f"{self.chat.messages[1].file_path}"]
|
||||||
|
message = create_message(self.chat, self.args)
|
||||||
|
self.assertIsInstance(message, Message)
|
||||||
|
# answer contains 1 source code block
|
||||||
|
# -> expect it in the question
|
||||||
|
self.assertEqual(len(message.question.source_code()), 1)
|
||||||
|
self.assertEqual(message.question, Question("""What is this?
|
||||||
|
|
||||||
|
```
|
||||||
|
It is embedded code
|
||||||
|
```
|
||||||
"""))
|
"""))
|
||||||
|
|||||||
Reference in New Issue
Block a user