9 Commits

25 changed files with 639 additions and 2134 deletions
-1
View File
@@ -106,7 +106,6 @@ celerybeat.pid
.venv
env/
venv/
.old/
ENV/
env.bak/
venv.bak/
+6 -11
View File
@@ -65,28 +65,23 @@ cmm question [-t OTAGS]... [-k ATAGS]... [-x XTAGS]... [-o OUTTAGS]... [-A AI_ID
* `-O, --overwrite`: Overwrite existing messages when repeating them
* `-s, --source-text FILE`: Add content of a file to the query
* `-S, --source-code FILE`: Add source code file content to the chat history
* `-l, --location {cache,db,all}`: Use given location when building the chat history (default: 'db')
* `-g, --glob GLOB`: Filter message files using the given glob pattern
#### Hist
The `hist` command is used to print and manage the chat history.
The `hist` command is used to print the chat history.
```bash
cmm hist [--print | --convert FORMAT] [-t OTAGS]... [-k ATAGS]... [-x XTAGS]... [-w] [-W] [-S] [-A SUBSTRING] [-Q SUBSTRING]
cmm hist [-t OTAGS]... [-k ATAGS]... [-x XTAGS]... [-w] [-W] [-S] [-A SUBSTRING] [-Q SUBSTRING]
```
* `-p, --print`: Print the DB chat history
* `-c, --convert FORMAT`: Convert all messages to the given format
* `-t, --or-tags OTAGS`: List of tags (one must match)
* `-k, --and-tags ATAGS`: List of tags (all must match)
* `-x, --exclude-tags XTAGS`: List of tags to exclude
* `-w, --with-metadata`: Print chat history with metadata (tags, filenames, AI, etc.)
* `-w, --with-tags`: Print chat history with tags
* `-W, --with-files`: Print chat history with filenames
* `-S, --source-code-only`: Only print embedded source code
* `-A, --answer SUBSTRING`: Filter for answer substring
* `-Q, --question SUBSTRING`: Filter for question substring
* `-l, --location {cache,db,all}`: Use given location when building the chat history (default: 'db')
* `-g, --glob GLOB`: Filter message files using the given glob pattern
* `-A, --answer SUBSTRING`: Search for answer substring
* `-Q, --question SUBSTRING`: Search for question substring
#### Tags
+6 -12
View File
@@ -3,20 +3,18 @@ Creates different AI instances, based on the given configuration.
"""
import argparse
from typing import cast, Optional
from typing import cast
from .configuration import Config, AIConfig, OpenAIConfig
from .ai import AI, AIError
from .ais.openai import OpenAI
def create_ai(args: argparse.Namespace, config: Config, # noqa: 11
def_ai: Optional[str] = None,
def_model: Optional[str] = None) -> AI:
def create_ai(args: argparse.Namespace, config: Config) -> AI: # noqa: 11
"""
Creates an AI subclass instance from the given arguments and configuration file.
If AI has not been set in the arguments, it searches for the ID 'default'. If
that is not found, it uses the first AI in the list. It's also possible to
specify a default AI and model using 'def_ai' and 'def_model'.
Creates an AI subclass instance from the given arguments
and configuration file. If AI has not been set in the
arguments, it searches for the ID 'default'. If that
is not found, it uses the first AI in the list.
"""
ai_conf: AIConfig
if hasattr(args, 'AI') and args.AI:
@@ -24,8 +22,6 @@ def create_ai(args: argparse.Namespace, config: Config, # noqa: 11
ai_conf = config.ais[args.AI]
except KeyError:
raise AIError(f"AI ID '{args.AI}' does not exist in this configuration")
elif def_ai:
ai_conf = config.ais[def_ai]
elif 'default' in config.ais:
ai_conf = config.ais['default']
else:
@@ -38,8 +34,6 @@ def create_ai(args: argparse.Namespace, config: Config, # noqa: 11
ai = OpenAI(cast(OpenAIConfig, ai_conf))
if hasattr(args, 'model') and args.model:
ai.config.model = args.model
elif def_model:
ai.config.model = def_model
if hasattr(args, 'max_tokens') and args.max_tokens:
ai.config.max_tokens = args.max_tokens
if hasattr(args, 'temperature') and args.temperature:
+17 -70
View File
@@ -2,8 +2,7 @@
Implements the OpenAI client classes and functions.
"""
import openai
import tiktoken
from typing import Optional, Union, Generator
from typing import Optional, Union
from ..tags import Tag
from ..message import Message, Answer
from ..chat import Chat
@@ -13,52 +12,6 @@ from ..configuration import OpenAIConfig
ChatType = list[dict[str, str]]
class OpenAIAnswer:
def __init__(self,
idx: int,
streams: dict[int, 'OpenAIAnswer'],
response: openai.ChatCompletion,
tokens: Tokens,
encoding: tiktoken.core.Encoding) -> None:
self.idx = idx
self.streams = streams
self.response = response
self.position: int = 0
self.encoding = encoding
self.data: list[str] = []
self.finished: bool = False
self.tokens = tokens
def stream(self) -> Generator[str, None, None]:
while True:
if not self.next():
continue
if len(self.data) <= self.position:
break
yield self.data[self.position]
self.position += 1
def next(self) -> bool:
if self.finished:
return True
try:
chunk = next(self.response)
except StopIteration:
self.finished = True
if not self.finished:
found_choice = False
for choice in chunk['choices']:
if not choice['finish_reason']:
self.streams[choice['index']].data.append(choice['delta']['content'])
self.tokens.completion += len(self.encoding.encode(choice['delta']['content']))
self.tokens.total = self.tokens.prompt + self.tokens.completion
if choice['index'] == self.idx:
found_choice = True
if not found_choice:
return False
return True
class OpenAI(AI):
"""
The OpenAI AI client.
@@ -68,7 +21,7 @@ class OpenAI(AI):
self.ID = config.ID
self.name = config.name
self.config = config
openai.api_key = self.config.api_key
openai.api_key = config.api_key
def request(self,
question: Message,
@@ -80,9 +33,7 @@ class OpenAI(AI):
chat history. The nr. of requested answers corresponds to the
nr. of messages in the 'AIResponse'.
"""
self.encoding = tiktoken.encoding_for_model(self.config.model)
oai_chat, prompt_tokens = self.openai_chat(chat, self.config.system, question)
tokens: Tokens = Tokens(prompt_tokens, 0, prompt_tokens)
oai_chat = self.openai_chat(chat, self.config.system, question)
response = openai.ChatCompletion.create(
model=self.config.model,
messages=oai_chat,
@@ -90,24 +41,22 @@ class OpenAI(AI):
max_tokens=self.config.max_tokens,
top_p=self.config.top_p,
n=num_answers,
stream=True,
frequency_penalty=self.config.frequency_penalty,
presence_penalty=self.config.presence_penalty)
streams: dict[int, OpenAIAnswer] = {}
for n in range(num_answers):
streams[n] = OpenAIAnswer(n, streams, response, tokens, self.encoding)
question.answer = Answer(streams[0].stream())
question.tags = set(otags) if otags is not None else None
question.answer = Answer(response['choices'][0]['message']['content'])
question.tags = otags
question.ai = self.ID
question.model = self.config.model
answers: list[Message] = [question]
for idx in range(1, num_answers):
for choice in response['choices'][1:]: # type: ignore
answers.append(Message(question=question.question,
answer=Answer(streams[idx].stream()),
answer=Answer(choice['message']['content']),
tags=otags,
ai=self.ID,
model=self.config.model))
return AIResponse(answers, tokens)
return AIResponse(answers, Tokens(response['usage']['prompt_tokens'],
response['usage']['completion_tokens'],
response['usage']['total_tokens']))
def models(self) -> list[str]:
"""
@@ -134,26 +83,24 @@ class OpenAI(AI):
print('\nNot ready: ' + ', '.join(not_ready))
def openai_chat(self, chat: Chat, system: str,
question: Optional[Message] = None) -> tuple[ChatType, int]:
question: Optional[Message] = None) -> ChatType:
"""
Create a chat history with system message in OpenAI format.
Optionally append a new question.
"""
oai_chat: ChatType = []
prompt_tokens: int = 0
def append(role: str, content: str) -> int:
def append(role: str, content: str) -> None:
oai_chat.append({'role': role, 'content': content.replace("''", "'")})
return len(self.encoding.encode(', '.join(['role:', oai_chat[-1]['role'], 'content:', oai_chat[-1]['content']])))
prompt_tokens += append('system', system)
append('system', system)
for message in chat.messages:
if message.answer:
prompt_tokens += append('user', message.question)
prompt_tokens += append('assistant', str(message.answer))
append('user', message.question)
append('assistant', message.answer)
if question:
prompt_tokens += append('user', question.question)
return oai_chat, prompt_tokens
append('user', question.question)
return oai_chat
def tokens(self, data: Union[Message, Chat]) -> int:
raise NotImplementedError
+60 -86
View File
@@ -6,10 +6,9 @@ from pathlib import Path
from pprint import PrettyPrinter
from pydoc import pager
from dataclasses import dataclass
from enum import Enum
from typing import TypeVar, Type, Optional, Any, Callable, Union
from typing import TypeVar, Type, Optional, ClassVar, Any, Callable, Literal, Union
from .configuration import default_config_file
from .message import Message, MessageFilter, MessageError, MessageFormat, message_in, message_valid_formats
from .message import Message, MessageFilter, MessageError, message_in
from .tags import Tag
ChatInst = TypeVar('ChatInst', bound='Chat')
@@ -17,15 +16,7 @@ ChatDBInst = TypeVar('ChatDBInst', bound='ChatDB')
db_next_file = '.next'
ignored_files = [db_next_file, default_config_file]
msg_suffix = Message.file_suffix_write
class msg_location(Enum):
MEM = 'mem'
DISK = 'disk'
CACHE = 'cache'
DB = 'db'
ALL = 'all'
msg_location = Literal['mem', 'disk', 'cache', 'db', 'all']
class ChatError(Exception):
@@ -52,16 +43,16 @@ def read_dir(dir_path: Path,
Parameters:
* 'dir_path': source directory
* 'glob': if specified, files will be filtered using 'path.glob()',
otherwise it reads all files with the default message suffix
otherwise it uses 'path.iterdir()'.
* 'mfilter': use with 'Message.from_file()' to filter messages
when reading them.
"""
messages: list[Message] = []
file_iter = dir_path.glob(glob) if glob else dir_path.glob(f'*{msg_suffix}')
file_iter = dir_path.glob(glob) if glob else dir_path.iterdir()
for file_path in sorted(file_iter):
if (file_path.is_file()
and file_path.name not in ignored_files # noqa: W503
and file_path.suffix in Message.file_suffixes_read): # noqa: W503
and file_path.suffix in Message.file_suffixes): # noqa: W503
try:
message = Message.from_file(file_path, mfilter)
if message:
@@ -72,20 +63,22 @@ def read_dir(dir_path: Path,
def make_file_path(dir_path: Path,
file_suffix: str,
next_fid: Callable[[], int]) -> Path:
"""
Create a file_path for the given directory using the given ID generator function.
Create a file_path for the given directory using the
given file_suffix and ID generator function.
"""
file_path = dir_path / f"{next_fid():04d}{msg_suffix}"
file_path = dir_path / f"{next_fid():04d}{file_suffix}"
while file_path.exists():
file_path = dir_path / f"{next_fid():04d}{msg_suffix}"
file_path = dir_path / f"{next_fid():04d}{file_suffix}"
return file_path
def write_dir(dir_path: Path,
messages: list[Message],
next_fid: Callable[[], int],
mformat: MessageFormat = Message.default_format) -> None:
file_suffix: str,
next_fid: Callable[[], int]) -> None:
"""
Write all messages to the given directory. If a message has no file_path,
a new one will be created. If message.file_path exists, it will be modified
@@ -93,17 +86,18 @@ def write_dir(dir_path: Path,
Parameters:
* 'dir_path': destination directory
* 'messages': list of messages to write
* 'file_suffix': suffix for the message files ['.txt'|'.yaml']
* 'next_fid': callable that returns the next file ID
"""
for message in messages:
file_path = message.file_path
# message has no file_path: create one
if not file_path:
file_path = make_file_path(dir_path, next_fid)
file_path = make_file_path(dir_path, file_suffix, next_fid)
# file_path does not point to given directory: modify it
elif not file_path.parent.samefile(dir_path):
file_path = dir_path / file_path.name
message.to_file(file_path, mformat=mformat)
message.to_file(file_path)
def clear_dir(dir_path: Path,
@@ -115,7 +109,7 @@ def clear_dir(dir_path: Path,
for file_path in file_iter:
if (file_path.is_file()
and file_path.name not in ignored_files # noqa: W503
and file_path.suffix in Message.file_suffixes_read): # noqa: W503
and file_path.suffix in Message.file_suffixes): # noqa: W503
file_path.unlink(missing_ok=True)
@@ -152,7 +146,7 @@ class Chat:
Matching is True if:
* 'name' matches the full 'file_path'
* 'name' matches 'file_path.name' (i. e. including the suffix)
* 'name' matches 'file_path.stem' (i. e. without the suffix)
* 'name' matches 'file_path.stem' (i. e. without a suffix)
"""
return Path(name) == file_path or name == file_path.name or name == file_path.stem
@@ -263,17 +257,14 @@ class Chat:
return sum(m.tokens() for m in self.messages)
def print(self, source_code_only: bool = False,
with_metadata: bool = False,
paged: bool = True,
tight: bool = False) -> None:
with_tags: bool = False, with_files: bool = False,
paged: bool = True) -> None:
output: list[str] = []
for message in self.messages:
if source_code_only:
output.append(message.to_str(source_code_only=True))
continue
output.append(message.to_str(with_metadata))
if not tight:
output.append('\n' + ('-' * terminal_width()) + '\n')
output.append(message.to_str(with_tags, with_files))
if paged:
print_paged('\n'.join(output))
else:
@@ -290,14 +281,15 @@ class ChatDB(Chat):
persistently.
"""
default_file_suffix: ClassVar[str] = '.txt'
cache_path: Path
db_path: Path
# a MessageFilter that all messages must match (if given)
mfilter: Optional[MessageFilter] = None
file_suffix: str = default_file_suffix
# the glob pattern for all messages
glob: str = f'*{msg_suffix}'
# message format (for writing)
mformat: MessageFormat = Message.default_format
glob: Optional[str] = None
def __post_init__(self) -> None:
# contains the latest message ID
@@ -311,29 +303,22 @@ class ChatDB(Chat):
def from_dir(cls: Type[ChatDBInst],
cache_path: Path,
db_path: Path,
glob: str = f'*{msg_suffix}',
mfilter: Optional[MessageFilter] = None,
loc: msg_location = msg_location.DB) -> ChatDBInst:
glob: Optional[str] = None,
mfilter: Optional[MessageFilter] = None) -> ChatDBInst:
"""
Create a 'ChatDB' instance from the given directory structure.
Reads all messages from 'db_path' into the local message list.
Parameters:
* 'cache_path': path to the directory for temporary messages
* 'db_path': path to the directory for persistent messages
* 'glob': if specified, files will be filtered using 'path.glob()'
* 'glob': if specified, files will be filtered using 'path.glob()',
otherwise it uses 'path.iterdir()'.
* 'mfilter': use with 'Message.from_file()' to filter messages
when reading them.
* 'loc': read messages from given location instead of 'db_path'
"""
if loc == msg_location.MEM:
raise ChatError(f"Can't build ChatDB from message location '{loc}'")
messages: list[Message] = []
if loc in [msg_location.DB, msg_location.DISK, msg_location.ALL]:
messages.extend(read_dir(db_path, glob, mfilter))
if loc in [msg_location.CACHE, msg_location.DISK, msg_location.ALL]:
messages.extend(read_dir(cache_path, glob, mfilter))
messages.sort(key=lambda x: x.msg_id())
return cls(messages, cache_path, db_path, mfilter, glob)
messages = read_dir(db_path, glob, mfilter)
return cls(messages, cache_path, db_path, mfilter,
cls.default_file_suffix, glob)
@classmethod
def from_messages(cls: Type[ChatDBInst],
@@ -360,17 +345,7 @@ class ChatDB(Chat):
with open(self.next_path, 'w') as f:
f.write(f'{fid}')
def set_msg_format(self, mformat: MessageFormat) -> None:
"""
Set message format for writing messages.
"""
if mformat not in message_valid_formats:
raise ChatError(f"Message format '{mformat}' is not supported")
self.mformat = mformat
def msg_write(self,
messages: Optional[list[Message]] = None,
mformat: Optional[MessageFormat] = None) -> None:
def msg_write(self, messages: Optional[list[Message]] = None) -> None:
"""
Write either the given messages or the internal ones to their CURRENT file_path.
If messages are given, they all must have a valid file_path. When writing the
@@ -381,7 +356,7 @@ class ChatDB(Chat):
raise ChatError("Can't write files without a valid file_path")
msgs = iter(messages if messages else self.messages)
while (m := next(msgs, None)):
m.to_file(mformat=mformat if mformat else self.mformat)
m.to_file()
def msg_update(self, messages: list[Message], write: bool = True) -> None:
"""
@@ -402,7 +377,6 @@ class ChatDB(Chat):
def msg_gather(self,
loc: msg_location,
require_file_path: bool = False,
glob: str = f'*{msg_suffix}',
mfilter: Optional[MessageFilter] = None) -> list[Message]:
"""
Gather and return messages from the given locations:
@@ -415,15 +389,15 @@ class ChatDB(Chat):
If 'require_file_path' is True, return only files with a valid file_path.
"""
loc_messages: list[Message] = []
if loc in [msg_location.MEM, msg_location.ALL]:
if loc in ['mem', 'all']:
if require_file_path:
loc_messages += [m for m in self.messages if (m.file_path is not None and (mfilter is None or m.match(mfilter)))]
else:
loc_messages += [m for m in self.messages if (mfilter is None or m.match(mfilter))]
if loc in [msg_location.CACHE, msg_location.DISK, msg_location.ALL]:
loc_messages += read_dir(self.cache_path, glob=glob, mfilter=mfilter)
if loc in [msg_location.DB, msg_location.DISK, msg_location.ALL]:
loc_messages += read_dir(self.db_path, glob=glob, mfilter=mfilter)
if loc in ['cache', 'disk', 'all']:
loc_messages += read_dir(self.cache_path, mfilter=mfilter)
if loc in ['db', 'disk', 'all']:
loc_messages += read_dir(self.db_path, mfilter=mfilter)
# remove_duplicates and sort the list
unique_messages: list[Message] = []
for m in loc_messages:
@@ -438,7 +412,7 @@ class ChatDB(Chat):
def msg_find(self,
msg_names: list[str],
loc: msg_location = msg_location.MEM,
loc: msg_location = 'mem',
) -> list[Message]:
"""
Search and return the messages with the given names. Names can either be filenames
@@ -456,7 +430,7 @@ class ChatDB(Chat):
return [m for m in loc_messages
if any((m.file_path and self.msg_name_matches(m.file_path, mn)) for mn in msg_names)]
def msg_remove(self, msg_names: list[str], loc: msg_location = msg_location.MEM) -> None:
def msg_remove(self, msg_names: list[str], loc: msg_location = 'mem') -> None:
"""
Remove the messages with the given names. Names can either be filenames
(with or without suffix), full paths or Message.msg_id(). Also deletes the
@@ -468,7 +442,7 @@ class ChatDB(Chat):
* 'db' : messages in the DB directory
* 'all' : all messages ('mem' + 'disk')
"""
if loc != msg_location.MEM:
if loc != 'mem':
# delete the message files first
rm_messages = self.msg_find(msg_names, loc=loc)
for m in rm_messages:
@@ -479,7 +453,7 @@ class ChatDB(Chat):
def msg_latest(self,
mfilter: Optional[MessageFilter] = None,
loc: msg_location = msg_location.MEM) -> Optional[Message]:
loc: msg_location = 'mem') -> Optional[Message]:
"""
Return the last added message (according to the file ID) that matches the given filter.
Only consider messages with a valid file_path (except if loc is 'mem').
@@ -508,7 +482,7 @@ class ChatDB(Chat):
and message.file_path.parent.samefile(self.cache_path) # noqa: W503
and message.file_path.exists()) # noqa: W503
else:
return len(self.msg_find([message], loc=msg_location.CACHE)) > 0
return len(self.msg_find([message], loc='cache')) > 0
def msg_in_db(self, message: Union[Message, str]) -> bool:
"""
@@ -520,9 +494,9 @@ class ChatDB(Chat):
and message.file_path.parent.samefile(self.db_path) # noqa: W503
and message.file_path.exists()) # noqa: W503
else:
return len(self.msg_find([message], loc=msg_location.DB)) > 0
return len(self.msg_find([message], loc='db')) > 0
def cache_read(self, glob: str = f'*{msg_suffix}', mfilter: Optional[MessageFilter] = None) -> None:
def cache_read(self, glob: Optional[str] = None, mfilter: Optional[MessageFilter] = None) -> None:
"""
Read messages from the cache directory. New ones are added to the internal list,
existing ones are replaced. A message is determined as 'existing' if a message
@@ -544,8 +518,8 @@ class ChatDB(Chat):
"""
write_dir(self.cache_path,
messages if messages else self.messages,
self.get_next_fid,
self.mformat)
self.file_suffix,
self.get_next_fid)
def cache_add(self, messages: list[Message], write: bool = True) -> None:
"""
@@ -557,15 +531,15 @@ class ChatDB(Chat):
if write:
write_dir(self.cache_path,
messages,
self.get_next_fid,
self.mformat)
self.file_suffix,
self.get_next_fid)
else:
for m in messages:
m.file_path = make_file_path(self.cache_path, self.get_next_fid)
m.file_path = make_file_path(self.cache_path, self.default_file_suffix, self.get_next_fid)
self.messages += messages
self.msg_sort()
def cache_clear(self, glob: str = f'*{msg_suffix}') -> None:
def cache_clear(self, glob: Optional[str] = None) -> None:
"""
Delete all message files from the cache dir and remove them from the internal list.
"""
@@ -585,11 +559,11 @@ class ChatDB(Chat):
self.cache_write([message])
# remove the old one (if any)
if old_path:
self.msg_remove([str(old_path)], loc=msg_location.DB)
self.msg_remove([str(old_path)], loc='db')
# (re)add it to the internal list
self.msg_add([message])
def db_read(self, glob: str = f'*{msg_suffix}', mfilter: Optional[MessageFilter] = None) -> None:
def db_read(self, glob: Optional[str] = None, mfilter: Optional[MessageFilter] = None) -> None:
"""
Read messages from the DB directory. New ones are added to the internal list,
existing ones are replaced. A message is determined as 'existing' if a message
@@ -611,8 +585,8 @@ class ChatDB(Chat):
"""
write_dir(self.db_path,
messages if messages else self.messages,
self.get_next_fid,
self.mformat)
self.file_suffix,
self.get_next_fid)
def db_add(self, messages: list[Message], write: bool = True) -> None:
"""
@@ -624,11 +598,11 @@ class ChatDB(Chat):
if write:
write_dir(self.db_path,
messages,
self.get_next_fid,
self.mformat)
self.file_suffix,
self.get_next_fid)
else:
for m in messages:
m.file_path = make_file_path(self.db_path, self.get_next_fid)
m.file_path = make_file_path(self.db_path, self.default_file_suffix, self.get_next_fid)
self.messages += messages
self.msg_sort()
@@ -644,6 +618,6 @@ class ChatDB(Chat):
self.db_write([message])
# remove the old one (if any)
if old_path:
self.msg_remove([str(old_path)], loc=msg_location.CACHE)
self.msg_remove([str(old_path)], loc='cache')
# (re)add it to the internal list
self.msg_add([message])
-69
View File
@@ -1,69 +0,0 @@
"""
Contains shared functions for the various CMM subcommands.
"""
import argparse
from pathlib import Path
from ..message import Message, MessageError, source_code
def read_text_file(file: Path) -> str:
with open(file) as r:
content = r.read().strip()
return content
def add_file_as_text(question_parts: list[str], file: str) -> None:
"""
Add the given file as plain text to the question part list.
If the file is a Message, add the answer.
"""
file_path = Path(file)
content: str
try:
message = Message.from_file(file_path)
if message and message.answer:
content = message.answer
except MessageError:
content = read_text_file(Path(file))
if len(content) > 0:
question_parts.append(content)
def add_file_as_code(question_parts: list[str], file: str) -> None:
"""
Add all source code from the given file. If no code segments can be extracted,
the whole content is added as source code segment. If the file is a Message,
extract the source code from the answer.
"""
file_path = Path(file)
content: str
try:
message = Message.from_file(file_path)
if message and message.answer:
content = message.answer
except MessageError:
with open(file) as r:
content = r.read().strip()
# extract and add source code
code_parts = source_code(content, include_delims=True)
if len(code_parts) > 0:
question_parts += code_parts
else:
question_parts.append(f"```\n{content}\n```")
def invert_input_tag_args(args: argparse.Namespace) -> None:
"""
Changes the semantics of the INPUT tags for this command:
* not tags specified on the CLI -> no tags are selected
* empty tags specified on the CLI -> all tags are selected
"""
if args.or_tags is None:
args.or_tags = set()
elif len(args.or_tags) == 0:
args.or_tags = None
if args.and_tags is None:
args.and_tags = set()
elif len(args.and_tags) == 0:
args.and_tags = None
-95
View File
@@ -1,95 +0,0 @@
import sys
import argparse
from pathlib import Path
from pydoc import pager
from ..configuration import Config
from ..glossary import Glossary
class GlossaryCmdError(Exception):
pass
def print_paged(text: str) -> None:
pager(text)
def get_glossary_file_path(name: str, config: Config) -> Path:
"""
Get the complete filename for a glossary with the given path.
"""
if not config.glossaries:
raise GlossaryCmdError("Can't create glossary name without a glossary directory")
return Path(config.glossaries, name).with_suffix(Glossary.file_suffix).absolute()
def list_glossaries(args: argparse.Namespace, config: Config) -> None:
"""
List existing glossaries in the 'glossaries' directory.
"""
if not config.glossaries:
raise GlossaryCmdError("Glossaries directory missing in the configuration file")
glossaries = Path(config.glossaries).glob(f'*{Glossary.file_suffix}')
for glo in sorted(glossaries):
print(Glossary.from_file(glo).to_str())
def print_glossary(args: argparse.Namespace, config: Config) -> None:
"""
Print an existing glossary.
"""
# sanity checks
if args.name is None:
raise GlossaryCmdError("Missing glossary name")
if config.glossaries is None and args.file is None:
raise GlossaryCmdError("Glossaries directory missing in the configuration file")
# create file path or use the given one
glo_file = Path(args.file) if args.file else get_glossary_file_path(args.name, config)
if not glo_file.exists():
raise GlossaryCmdError(f"Glossary '{glo_file}' does not exist")
# read glossary
glo = Glossary.from_file(glo_file)
print_paged(glo.to_str(with_entries=True))
def create_glossary(args: argparse.Namespace, config: Config) -> None:
"""
Create a new glossary and write it either to the glossaries directory
or the given file.
"""
# sanity checks
if args.name is None:
raise GlossaryCmdError("Missing glossary name")
if args.source_lang is None:
raise GlossaryCmdError("Missing source language")
if args.target_lang is None:
raise GlossaryCmdError("Missing target language")
if config.glossaries is None and args.file is None:
raise GlossaryCmdError("Glossaries directory missing in the configuration file")
# create file path or use the given one
glo_file = Path(args.file) if args.file else get_glossary_file_path(args.name, config)
if glo_file.exists():
raise GlossaryCmdError(f"Glossary '{glo_file}' already exists")
glo = Glossary(name=args.name,
source_lang=args.source_lang,
target_lang=args.target_lang,
desc=args.description,
file_path=glo_file)
glo.to_file()
print(f"Successfully created new glossary '{glo_file}'.")
def glossary_cmd(args: argparse.Namespace, config: Config) -> None:
"""
Handler for the 'glossary' command.
"""
try:
if args.create:
create_glossary(args, config)
elif args.list:
list_glossaries(args, config)
elif args.print:
print_glossary(args, config)
except GlossaryCmdError as err:
print(f"Error: {err}")
sys.exit(1)
+7 -59
View File
@@ -1,52 +1,13 @@
import sys
import argparse
from pathlib import Path
from ..configuration import Config
from ..chat import ChatDB, msg_location
from ..message import MessageFilter, Message
from ..chat import ChatDB
from ..message import MessageFilter
msg_suffix = Message.file_suffix_write # currently '.msg'
def convert_messages(args: argparse.Namespace, config: Config) -> None:
def hist_cmd(args: argparse.Namespace, config: Config) -> None:
"""
Convert messages to a new format. Also used to change old suffixes
('.txt', '.yaml') to the latest default message file suffix ('.msg').
"""
chat = ChatDB.from_dir(Path(config.cache),
Path(config.db),
glob='*')
# read all known message files
msgs = chat.msg_gather(loc=msg_location.DISK, glob='*.*')
# make a set of all message IDs
msg_ids = set([m.msg_id() for m in msgs])
# set requested format and write all messages
chat.set_msg_format(args.convert)
# delete the current suffix
# -> a new one will automatically be created
for m in msgs:
if m.file_path:
m.file_path = m.file_path.with_suffix('')
chat.msg_write(msgs)
# read all messages with the current default suffix
msgs = chat.msg_gather(loc=msg_location.DISK, glob=f'*{msg_suffix}')
# make sure we converted all of the original messages
for mid in msg_ids:
if not any(mid == m.msg_id() for m in msgs):
print(f"Message '{mid}' has not been found after conversion. Aborting.")
sys.exit(1)
# delete messages with old suffixes
msgs = chat.msg_gather(loc=msg_location.DISK, glob='*.*')
for m in msgs:
if m.file_path and m.file_path.suffix != msg_suffix:
m.rm_file()
print(f"Successfully converted {len(msg_ids)} messages.")
def print_chat(args: argparse.Namespace, config: Config) -> None:
"""
Print the DB chat history.
Handler for the 'hist' command.
"""
mfilter = MessageFilter(tags_or=args.or_tags,
@@ -56,20 +17,7 @@ def print_chat(args: argparse.Namespace, config: Config) -> None:
answer_contains=args.answer)
chat = ChatDB.from_dir(Path(config.cache),
Path(config.db),
mfilter=mfilter,
loc=msg_location(args.location),
glob=args.glob)
mfilter=mfilter)
chat.print(args.source_code_only,
args.with_metadata,
paged=not args.no_paging,
tight=args.tight)
def hist_cmd(args: argparse.Namespace, config: Config) -> None:
"""
Handler for the 'hist' command.
"""
if args.print:
print_chat(args, config)
elif args.convert:
convert_messages(args, config)
args.with_tags,
args.with_files)
+2 -2
View File
@@ -3,7 +3,7 @@ import argparse
from pathlib import Path
from ..configuration import Config
from ..message import Message, MessageError
from ..chat import ChatDB, msg_location
from ..chat import ChatDB
def print_message(message: Message, args: argparse.Namespace) -> None:
@@ -38,7 +38,7 @@ def print_cmd(args: argparse.Namespace, config: Config) -> None:
# print latest message
elif args.latest:
chat = ChatDB.from_dir(Path(config.cache), Path(config.db))
latest = chat.msg_latest(loc=msg_location.DISK)
latest = chat.msg_latest(loc='disk')
if not latest:
print("No message found!")
sys.exit(1)
+61 -76
View File
@@ -2,40 +2,52 @@ import sys
import argparse
from pathlib import Path
from itertools import zip_longest
from copy import deepcopy
from .common import invert_input_tag_args, add_file_as_code, add_file_as_text
from ..configuration import Config
from ..chat import ChatDB, msg_location
from ..message import Message, MessageFilter, Question
from ..chat import ChatDB
from ..message import Message, MessageFilter, MessageError, Question, source_code
from ..ai_factory import create_ai
from ..ai import AI, AIResponse
class QuestionCmdError(Exception):
pass
def add_file_as_text(question_parts: list[str], file: str) -> None:
"""
Add the given file as plain text to the question part list.
If the file is a Message, add the answer.
"""
file_path = Path(file)
content: str
try:
message = Message.from_file(file_path)
if message and message.answer:
content = message.answer
except MessageError:
with open(file) as r:
content = r.read().strip()
if len(content) > 0:
question_parts.append(content)
def create_msg_args(msg: Message, args: argparse.Namespace) -> argparse.Namespace:
def add_file_as_code(question_parts: list[str], file: str) -> None:
"""
Takes an existing message and CLI arguments, and returns modified args based
on the members of the given message. Used e.g. when repeating messages, where
it's necessary to determine the correct AI, module and output tags to use
(either from the existing message or the given args).
Add all source code from the given file. If no code segments can be extracted,
the whole content is added as source code segment. If the file is a Message,
extract the source code from the answer.
"""
msg_args = args
# if AI, model or output tags have not been specified,
# use those from the original message
if (args.AI is None
or args.model is None # noqa: W503
or args.output_tags is None): # noqa: W503
msg_args = deepcopy(args)
if args.AI is None and msg.ai is not None:
msg_args.AI = msg.ai
if args.model is None and msg.model is not None:
msg_args.model = msg.model
if args.output_tags is None and msg.tags is not None:
msg_args.output_tags = msg.tags
return msg_args
file_path = Path(file)
content: str
try:
message = Message.from_file(file_path)
if message and message.answer:
content = message.answer
except MessageError:
with open(file) as r:
content = r.read().strip()
# extract and add source code
code_parts = source_code(content, include_delims=True)
if len(code_parts) > 0:
question_parts += code_parts
else:
question_parts.append(f"```\n{content}\n```")
def create_message(chat: ChatDB, args: argparse.Namespace) -> Message:
@@ -44,12 +56,7 @@ def create_message(chat: ChatDB, args: argparse.Namespace) -> Message:
to the cache directory.
"""
question_parts = []
if args.create is not None:
question_list = args.create
elif args.ask is not None:
question_list = args.ask
else:
raise QuestionCmdError("No question found")
question_list = args.ask if args.ask is not None else []
text_files = args.source_text if args.source_text is not None else []
code_files = args.source_code if args.source_code is not None else []
@@ -61,7 +68,7 @@ def create_message(chat: ChatDB, args: argparse.Namespace) -> Message:
if code_file is not None and len(code_file) > 0:
add_file_as_code(question_parts, code_file)
full_question = '\n\n'.join([str(s) for s in question_parts])
full_question = '\n\n'.join(question_parts)
message = Message(question=Question(full_question),
tags=args.output_tags,
@@ -89,77 +96,55 @@ def make_request(ai: AI, chat: ChatDB, message: Message, args: argparse.Namespac
args.output_tags)
# only write the response messages to the cache,
# don't add them to the internal list
chat.cache_write(response.messages)
for idx, msg in enumerate(response.messages):
print(f"=== ANSWER {idx+1} ===", flush=True)
if msg.answer:
for piece in msg.answer:
print(piece, end='', flush=True)
print()
print(f"=== ANSWER {idx+1} ===")
print(msg.answer)
if response.tokens:
print("===============")
print(response.tokens)
chat.cache_write(response.messages)
def repeat_messages(messages: list[Message], chat: ChatDB, args: argparse.Namespace, config: Config) -> None:
"""
Repeat the given messages using the given arguments.
"""
ai: AI
for msg in messages:
msg_args = create_msg_args(msg, args)
ai = create_ai(msg_args, config)
print(f"--------- Repeating message '{msg.msg_id()}': ---------")
# overwrite the latest message if requested or empty
# -> but not if it's in the DB!
if ((msg.answer is None or msg_args.overwrite is True)
and (not chat.msg_in_db(msg))): # noqa: W503
msg.clear_answer()
make_request(ai, chat, msg, msg_args)
# otherwise create a new one
else:
msg_args.ask = [msg.question]
message = create_message(chat, msg_args)
make_request(ai, chat, message, msg_args)
def question_cmd(args: argparse.Namespace, config: Config) -> None:
"""
Handler for the 'question' command.
"""
invert_input_tag_args(args)
mfilter = MessageFilter(tags_or=args.or_tags,
tags_and=args.and_tags,
tags_not=args.exclude_tags)
mfilter = MessageFilter(tags_or=args.or_tags if args.or_tags is not None else set(),
tags_and=args.and_tags if args.and_tags is not None else set(),
tags_not=args.exclude_tags if args.exclude_tags is not None else set())
chat = ChatDB.from_dir(cache_path=Path(config.cache),
db_path=Path(config.db),
mfilter=mfilter,
glob=args.glob,
loc=msg_location(args.location))
mfilter=mfilter)
# if it's a new question, create and store it immediately
if args.ask or args.create:
message = create_message(chat, args)
if args.create:
return
# create the correct AI instance
ai: AI = create_ai(args, config)
# === ASK ===
if args.ask:
ai: AI = create_ai(args, config)
make_request(ai, chat, message, args)
# === REPEAT ===
elif args.repeat is not None:
repeat_msgs: list[Message] = []
# repeat latest message
if len(args.repeat) == 0:
lmessage = chat.msg_latest(loc=msg_location.CACHE)
lmessage = chat.msg_latest(loc='cache')
if lmessage is None:
print("No message found to repeat!")
sys.exit(1)
repeat_msgs.append(lmessage)
# repeat given message(s)
else:
repeat_msgs = chat.msg_find(args.repeat, loc=msg_location.DISK)
repeat_messages(repeat_msgs, chat, args, config)
print(f"Repeating message '{lmessage.msg_id()}':")
# overwrite the latest message if requested or empty
if lmessage.answer is None or args.overwrite is True:
lmessage.clear_answer()
make_request(ai, chat, lmessage, args)
# otherwise create a new one
else:
args.ask = [lmessage.question]
message = create_message(chat, args)
make_request(ai, chat, message, args)
# === PROCESS ===
elif args.process is not None:
# TODO: process either all questions without an
-105
View File
@@ -1,105 +0,0 @@
import argparse
import mimetypes
from pathlib import Path
from .common import invert_input_tag_args, read_text_file
from ..configuration import Config
from ..message import MessageFilter, Message, Question
from ..chat import ChatDB, msg_location
class TranslationCmdError(Exception):
pass
text_separator: str = 'TEXT:'
def assert_document_type_supported_openai(document_file: Path) -> None:
doctype = mimetypes.guess_type(document_file)
if doctype != 'text/plain':
raise TranslationCmdError("AI 'OpenAI' only supports document type 'text/plain''")
def translation_prompt_openai(source_lang: str, target_lang: str) -> str:
"""
Return the prompt for GPT that tells it to do the translation.
"""
return f"Translate the text below the line {text_separator} from {source_lang} to {target_lang}."
def create_message_openai(chat: ChatDB, args: argparse.Namespace) -> Message:
"""
Create a new message from the given arguments and write it to the cache directory.
Message format
1. Translation prompt (tells GPT to do a translation)
2. Glossary (if specified as an argument)
3. User provided prompt enhancements
4. Translation separator
5. User provided text to be translated
The text to be translated is determined as a follows:
- if a document is provided in the arguments, translate its content
- if no document is provided, translate the last text argument
The other text arguments will be put into the "header" and can be used
to improve the translation prompt.
"""
text_args: list[str] = []
if args.create is not None:
text_args = args.create
elif args.ask is not None:
text_args = args.ask
else:
raise TranslationCmdError("No input text found")
# extract user prompt and user text to be translated
user_text: str
user_prompt: str
if args.input_document is not None:
assert_document_type_supported_openai(Path(args.input_document))
user_text = read_text_file(Path(args.input_document))
user_prompt = '\n\n'.join([str(s) for s in text_args])
else:
user_text = text_args[-1]
user_prompt = '\n\n'.join([str(s) for s in text_args[:-1]])
# build full question string
# FIXME: add glossaries if given
question_text: str = '\n\n'.join([translation_prompt_openai(args.source_lang, args.target_lang),
user_prompt,
text_separator,
user_text])
# create and write the message
message = Message(question=Question(question_text),
tags=args.output_tags,
ai=args.AI,
model=args.model)
# only write the new message to the cache,
# don't add it to the internal list
chat.cache_write([message])
return message
def translation_cmd(args: argparse.Namespace, config: Config) -> None:
"""
Handler for the 'translation' command. Creates and executes translation
requests based on the input and selected AI. Depending on the AI, the
whole process may be significantly different (e.g. DeepL vs OpenAI).
"""
invert_input_tag_args(args)
mfilter = MessageFilter(tags_or=args.or_tags,
tags_and=args.and_tags,
tags_not=args.exclude_tags)
chat = ChatDB.from_dir(cache_path=Path(config.cache),
db_path=Path(config.db),
mfilter=mfilter,
glob=args.glob,
loc=msg_location(args.location))
# if it's a new translation, create and store it immediately
# FIXME: check AI type
if args.ask or args.create:
# message = create_message(chat, args)
create_message_openai(chat, args)
if args.create:
return
+1 -5
View File
@@ -118,7 +118,6 @@ class Config:
# a default configuration
cache: str = '.'
db: str = './db/'
glossaries: str | None = './glossaries/'
ais: dict[str, AIConfig] = field(default_factory=create_default_ai_configs)
@classmethod
@@ -136,8 +135,7 @@ class Config:
return cls(
cache=str(source['cache']) if 'cache' in source else '.',
db=str(source['db']),
ais=ais,
glossaries=str(source['glossaries']) if 'glossaries' in source else None
ais=ais
)
@classmethod
@@ -150,8 +148,6 @@ class Config:
@classmethod
def from_file(cls: Type[ConfigInst], path: str) -> ConfigInst:
if not Path(path).exists():
raise ConfigError(f"Configuration file '{path}' not found. Use 'cmm config --create' to create one.")
with open(path, 'r') as f:
source = yaml.load(f, Loader=yaml.FullLoader)
return cls.from_dict(source)
-165
View File
@@ -1,165 +0,0 @@
"""
Module implementing glossaries for translations.
"""
import yaml
import tempfile
import shutil
import csv
from pathlib import Path
from dataclasses import dataclass, field
from typing import Type, TypeVar, ClassVar
GlossaryInst = TypeVar('GlossaryInst', bound='Glossary')
class GlossaryError(Exception):
pass
def str_presenter(dumper: yaml.Dumper, data: str) -> yaml.ScalarNode:
"""
Changes the YAML dump style to multiline syntax for multiline strings.
"""
if len(data.splitlines()) > 1:
return dumper.represent_scalar('tag:yaml.org,2002:str', data, style='|')
return dumper.represent_scalar('tag:yaml.org,2002:str', data)
@dataclass
class Glossary:
"""
A glossary consists of the following parameters:
- Name (freely selectable)
- Path (full file path, suffix is automatically generated)
- Source language
- Target language
- Description (optional)
- Entries (pairs of source lang and target lang terms)
- ID (automatically generated / modified, required by DeepL)
"""
name: str
source_lang: str
target_lang: str
file_path: Path | None = None
desc: str | None = None
entries: dict[str, str] = field(default_factory=lambda: dict())
ID: str | None = None
file_suffix: ClassVar[str] = '.glo'
def __post_init__(self) -> None:
# FIXME: check for valid languages
pass
@classmethod
def from_file(cls: Type[GlossaryInst], file_path: Path) -> GlossaryInst:
"""
Create a glossary from the given file.
"""
if not file_path.exists():
raise GlossaryError(f"Glossary file '{file_path}' does not exist")
if file_path.suffix != cls.file_suffix:
raise GlossaryError(f"File type '{file_path.suffix}' is not supported")
with open(file_path, "r") as fd:
try:
# use BaseLoader so every entry is read as a string
# - disables automatic conversions
# - makes it possible to omit quoting for YAML keywords in entries (e. g. 'yes')
# - also correctly reads quoted entries
data = yaml.load(fd, Loader=yaml.BaseLoader)
clean_entries = data['Entries']
return cls(name=data['Name'],
source_lang=data['SourceLang'],
target_lang=data['TargetLang'],
file_path=file_path,
desc=data['Description'],
entries=clean_entries,
ID=data['ID'] if data['ID'] != 'None' else None)
except Exception:
raise GlossaryError(f"'{file_path}' does not contain a valid glossary")
def to_file(self, file_path: Path | None = None) -> None:
"""
Write glossary to given file.
"""
if file_path:
self.file_path = file_path
if not self.file_path:
raise GlossaryError("Got no valid path to write glossary")
# check / add valid suffix
if not self.file_path.suffix:
self.file_path = self.file_path.with_suffix(self.file_suffix)
elif self.file_path.suffix != self.file_suffix:
raise GlossaryError(f"File suffix '{self.file_path.suffix}' is not supported")
# write YAML
with tempfile.NamedTemporaryFile(dir=self.file_path.parent, prefix=self.file_path.name, mode="w", delete=False) as temp_fd:
temp_file_path = Path(temp_fd.name)
data = {'Name': self.name,
'Description': str(self.desc),
'ID': str(self.ID),
'SourceLang': self.source_lang,
'TargetLang': self.target_lang,
'Entries': self.entries}
yaml.dump(data, temp_fd, sort_keys=False)
shutil.move(temp_file_path, self.file_path)
def export_csv(self, dictionary: dict[str, str], file_path: Path) -> None:
"""
Export the 'entries' of this glossary to a file in CSV format (compatible with DeepL).
"""
with open(file_path, 'w', newline='', encoding='utf-8') as csvfile:
writer = csv.writer(csvfile, delimiter=',', quotechar='"', quoting=csv.QUOTE_ALL)
for source_entry, target_entry in self.entries.items():
writer.writerow([source_entry, target_entry])
def export_tsv(self, entries: dict[str, str], file_path: Path) -> None:
"""
Export the 'entries' of this glossary to a file in TSV format (compatible with DeepL).
"""
with open(file_path, 'w', encoding='utf-8') as file:
for source_entry, target_entry in self.entries.items():
file.write(f"{source_entry}\t{target_entry}\n")
def import_csv(self, file_path: Path) -> None:
"""
Import the entries from the given CSV file to those of the current glossary.
Existing entries are overwritten.
"""
try:
with open(file_path, mode='r', encoding='utf-8') as csvfile:
reader = csv.reader(csvfile, delimiter=',', quotechar='"')
self.entries = {rows[0]: rows[1] for rows in reader if len(rows) >= 2}
except Exception as e:
raise GlossaryError(f"Error importing CSV: {e}")
def import_tsv(self, file_path: Path) -> None:
"""
Import the entries from the given CSV file to those of the current glossary.
Existing entries are overwritten.
"""
try:
with open(file_path, mode='r', encoding='utf-8') as tsvfile:
self.entries = {}
for line in tsvfile:
parts = line.strip().split('\t')
if len(parts) == 2:
self.entries[parts[0]] = parts[1]
except Exception as e:
raise GlossaryError(f"Error importing TSV: {e}")
def to_str(self, with_entries: bool = False) -> str:
"""
Return the current glossary as a string.
"""
output: list[str] = []
output.append(f'{self.name} (ID: {self.ID}):')
if self.desc and self.desc != 'None':
output.append('- ' + self.desc)
output.append(f'- Languages: {self.source_lang} -> {self.target_lang}')
if with_entries:
output.append('- Entries:')
for source, target in self.entries.items():
output.append(f' {source} : {target}')
else:
output.append(f'- Entries: {len(self.entries)}')
return '\n'.join(output)
+12 -104
View File
@@ -3,21 +3,17 @@
# vim: set fileencoding=utf-8 :
import sys
import os
import argcomplete
import argparse
from pathlib import Path
from typing import Any
from .configuration import Config, default_config_file, ConfigError
from .configuration import Config, default_config_file
from .message import Message
from .commands.question import question_cmd
from .commands.tags import tags_cmd
from .commands.config import config_cmd
from .commands.hist import hist_cmd
from .commands.print import print_cmd
from .commands.translation import translation_cmd
from .commands.glossary import glossary_cmd
from .chat import msg_location
def tags_completer(prefix: str, parsed_args: Any, **kwargs: Any) -> list[str]:
@@ -38,13 +34,13 @@ def create_parser() -> argparse.ArgumentParser:
# a parent parser for all commands that support tag selection
tag_parser = argparse.ArgumentParser(add_help=False)
tag_arg = tag_parser.add_argument('-t', '--or-tags', nargs='*',
tag_arg = tag_parser.add_argument('-t', '--or-tags', nargs='+',
help='List of tags (one must match)', metavar='OTAGS')
tag_arg.completer = tags_completer # type: ignore
atag_arg = tag_parser.add_argument('-k', '--and-tags', nargs='*',
atag_arg = tag_parser.add_argument('-k', '--and-tags', nargs='+',
help='List of tags (all must match)', metavar='ATAGS')
atag_arg.completer = tags_completer # type: ignore
etag_arg = tag_parser.add_argument('-x', '--exclude-tags', nargs='*',
etag_arg = tag_parser.add_argument('-x', '--exclude-tags', nargs='+',
help='List of tags to exclude', metavar='XTAGS')
etag_arg.completer = tags_completer # type: ignore
otag_arg = tag_parser.add_argument('-o', '--output-tags', nargs='+',
@@ -55,7 +51,7 @@ def create_parser() -> argparse.ArgumentParser:
ai_parser = argparse.ArgumentParser(add_help=False)
ai_parser.add_argument('-A', '--AI', help='AI ID to use', metavar='AI_ID')
ai_parser.add_argument('-M', '--model', help='Model to use', metavar='MODEL')
ai_parser.add_argument('-N', '--num-answers', help='Number of answers to request', type=int, default=1)
ai_parser.add_argument('-n', '--num-answers', help='Number of answers to request', type=int, default=1)
ai_parser.add_argument('-m', '--max-tokens', help='Max. nr. of tokens', type=int)
ai_parser.add_argument('-T', '--temperature', help='Temperature value', type=float)
@@ -69,11 +65,6 @@ def create_parser() -> argparse.ArgumentParser:
question_group.add_argument('-c', '--create', nargs='+', help='Create a question', metavar='QUESTION')
question_group.add_argument('-r', '--repeat', nargs='*', help='Repeat a question', metavar='MESSAGE')
question_group.add_argument('-p', '--process', nargs='*', help='Process existing questions', metavar='MESSAGE')
question_cmd_parser.add_argument('-l', '--location',
choices=[x.value for x in msg_location if x not in [msg_location.MEM, msg_location.DISK]],
default='db',
help='Use given location when building the chat history (default: \'db\')')
question_cmd_parser.add_argument('-g', '--glob', help='Filter message files using the given glob pattern')
question_cmd_parser.add_argument('-O', '--overwrite', help='Overwrite existing messages when repeating them',
action='store_true')
question_cmd_parser.add_argument('-s', '--source-text', nargs='+', help='Add content of a file to the query', metavar='FILE')
@@ -82,30 +73,22 @@ def create_parser() -> argparse.ArgumentParser:
# 'hist' command parser
hist_cmd_parser = cmdparser.add_parser('hist', parents=[tag_parser],
help="Print and manage chat history.",
help="Print chat history.",
aliases=['h'])
hist_cmd_parser.set_defaults(func=hist_cmd)
hist_group = hist_cmd_parser.add_mutually_exclusive_group(required=True)
hist_group.add_argument('-p', '--print', help='Print the DB chat history', action='store_true')
hist_group.add_argument('-c', '--convert', help='Convert all message files to the given format [txt|yaml]', metavar='FORMAT')
hist_cmd_parser.add_argument('-w', '--with-metadata', help="Print chat history with metadata (tags, filename, AI, etc.).",
hist_cmd_parser.add_argument('-w', '--with-tags', help="Print chat history with tags.",
action='store_true')
hist_cmd_parser.add_argument('-W', '--with-files', help="Print chat history with filenames.",
action='store_true')
hist_cmd_parser.add_argument('-S', '--source-code-only', help='Only print embedded source code',
action='store_true')
hist_cmd_parser.add_argument('-A', '--answer', help='Print only answers with given substring', metavar='SUBSTRING')
hist_cmd_parser.add_argument('-Q', '--question', help='Print only questions with given substring', metavar='SUBSTRING')
hist_cmd_parser.add_argument('-d', '--tight', help='Print without message separators', action='store_true')
hist_cmd_parser.add_argument('-P', '--no-paging', help='Print without paging', action='store_true')
hist_cmd_parser.add_argument('-l', '--location',
choices=[x.value for x in msg_location if x not in [msg_location.MEM, msg_location.DISK]],
default='db',
help='Use given location when building the chat history (default: \'db\')')
hist_cmd_parser.add_argument('-g', '--glob', help='Filter message files using the given glob pattern')
hist_cmd_parser.add_argument('-A', '--answer', help='Search for answer substring', metavar='SUBSTRING')
hist_cmd_parser.add_argument('-Q', '--question', help='Search for question substring', metavar='SUBSTRING')
# 'tags' command parser
tags_cmd_parser = cmdparser.add_parser('tags',
help="Manage tags.",
aliases=['T'])
aliases=['t'])
tags_cmd_parser.set_defaults(func=tags_cmd)
tags_group = tags_cmd_parser.add_mutually_exclusive_group(required=True)
tags_group.add_argument('-l', '--list', help="List all tags and their frequency",
@@ -139,80 +122,10 @@ def create_parser() -> argparse.ArgumentParser:
print_cmd_modes.add_argument('-a', '--answer', help='Only print the answer', action='store_true')
print_cmd_modes.add_argument('-S', '--only-source-code', help='Only print embedded source code', action='store_true')
# 'translation' command parser
translation_cmd_parser = cmdparser.add_parser('translation', parents=[ai_parser, tag_parser],
help="Ask, create and repeat translations.",
aliases=['t'])
translation_cmd_parser.set_defaults(func=translation_cmd)
translation_group = translation_cmd_parser.add_mutually_exclusive_group(required=True)
translation_group.add_argument('-a', '--ask', nargs='+', help='Ask to translate the given text', metavar='TEXT')
translation_group.add_argument('-c', '--create', nargs='+', help='Create a translation', metavar='TEXT')
translation_group.add_argument('-r', '--repeat', nargs='*', help='Repeat a translation', metavar='MESSAGE')
translation_cmd_parser.add_argument('-l', '--source-lang', help="Source language", metavar="LANGUAGE", required=True)
translation_cmd_parser.add_argument('-L', '--target-lang', help="Target language", metavar="LANGUAGE", required=True)
translation_cmd_parser.add_argument('-G', '--glossaries', nargs='+', help="List of glossary names", metavar="GLOSSARY")
translation_cmd_parser.add_argument('-d', '--input-document', help="Document to translate", metavar="FILE")
translation_cmd_parser.add_argument('-D', '--output-document', help="Path for the translated document", metavar="FILE")
# 'glossary' command parser
glossary_cmd_parser = cmdparser.add_parser('glossary', parents=[ai_parser],
help="Manage glossaries.",
aliases=['g'])
glossary_cmd_parser.set_defaults(func=glossary_cmd)
glossary_group = glossary_cmd_parser.add_mutually_exclusive_group(required=True)
glossary_group.add_argument('-c', '--create', help='Create a glossary', action='store_true')
glossary_cmd_parser.add_argument('-n', '--name', help="Glossary name (not ID)", metavar="NAME")
glossary_cmd_parser.add_argument('-l', '--source-lang', help="Source language", metavar="LANGUAGE")
glossary_cmd_parser.add_argument('-L', '--target-lang', help="Target language", metavar="LANGUAGE")
glossary_cmd_parser.add_argument('-f', '--file', help='File path of the goven glossary', metavar='GLOSSARY_FILE')
glossary_cmd_parser.add_argument('-D', '--description', help="Glossary description", metavar="DESCRIPTION")
glossary_group.add_argument('-i', '--list', help='List existing glossaries', action='store_true')
glossary_group.add_argument('-p', '--print', help='Print an existing glossary', action='store_true')
argcomplete.autocomplete(parser)
return parser
def create_directories(config: Config) -> None: # noqa: 11
"""
Create the directories in the given configuration if they don't exist.
"""
def make_dir(path: Path) -> None:
try:
os.makedirs(path.absolute())
except Exception as e:
print(f"Creating directory '{path.absolute()}' failed with: {e}")
sys.exit(1)
# Cache
cache_path = Path(config.cache)
if not cache_path.exists():
answer = input(f"Cache directory '{cache_path}' does not exist. Create it? [y/n]")
if answer.lower() in ['y', 'yes']:
make_dir(cache_path.absolute())
else:
print("Can't continue without a valid cache directory!")
sys.exit(1)
# DB
db_path = Path(config.db)
if not db_path.exists():
answer = input(f"DB directory '{db_path}' does not exist. Create it? [y/n]")
if answer.lower() in ['y', 'yes']:
make_dir(db_path.absolute())
else:
print("Can't continue without a valid DB directory!")
sys.exit(1)
# Glossaries
if config.glossaries:
glossaries_path = Path(config.glossaries)
if not glossaries_path.exists():
answer = input(f"Glossaries directory '{glossaries_path}' does not exist. Create it? [y/n]")
if answer.lower() in ['y', 'yes']:
make_dir(glossaries_path.absolute())
else:
print("Can't continue without a valid glossaries directory. Create it or remove it from the configuration.")
sys.exit(1)
def main() -> int:
parser = create_parser()
args = parser.parse_args()
@@ -221,12 +134,7 @@ def main() -> int:
if command.func == config_cmd:
command.func(command)
else:
try:
config = Config.from_file(args.config)
except ConfigError as err:
print(f"{err}")
return 1
create_directories(config)
command.func(command, config)
return 0
+49 -138
View File
@@ -5,10 +5,7 @@ import pathlib
import yaml
import tempfile
import shutil
import io
from typing import Type, TypeVar, ClassVar, Optional, Any, Union, Final, Literal, Iterable, Tuple
from typing import Generator, Iterator
from typing import get_args as typing_get_args
from typing import Type, TypeVar, ClassVar, Optional, Any, Union, Final, Literal, Iterable
from dataclasses import dataclass, asdict, field
from .tags import Tag, TagLine, TagError, match_tags, rename_tags
@@ -18,9 +15,6 @@ MessageInst = TypeVar('MessageInst', bound='Message')
AILineInst = TypeVar('AILineInst', bound='AILine')
ModelLineInst = TypeVar('ModelLineInst', bound='ModelLine')
YamlDict = dict[str, Union[QuestionInst, AnswerInst, set[Tag]]]
MessageFormat = Literal['txt', 'yaml']
message_valid_formats: Final[Tuple[MessageFormat, ...]] = typing_get_args(MessageFormat)
message_default_format: Final[MessageFormat] = 'txt'
class MessageError(Exception):
@@ -51,7 +45,7 @@ def source_code(text: str, include_delims: bool = False) -> list[str]:
code_lines: list[str] = []
in_code_block = False
for line in str(text).split('\n'):
for line in text.split('\n'):
if line.strip().startswith('```'):
if include_delims:
code_lines.append(line)
@@ -98,7 +92,7 @@ class MessageFilter:
class AILine(str):
"""
A line that represents the AI name in the 'txt' format.
A line that represents the AI name in a '.txt' file..
"""
prefix: Final[str] = 'AI:'
@@ -118,7 +112,7 @@ class AILine(str):
class ModelLine(str):
"""
A line that represents the model name in the 'txt' format.
A line that represents the model name in a '.txt' file..
"""
prefix: Final[str] = 'MODEL:'
@@ -144,100 +138,30 @@ class Answer(str):
txt_header: ClassVar[str] = '==== ANSWER ===='
yaml_key: ClassVar[str] = 'answer'
def __init__(self, data: Union[str, Generator[str, None, None]]) -> None:
# Indicator of whether all of data has been processed
self.is_exhausted: bool = False
# Initialize data
self.iterator: Iterator[str] = self._init_data(data)
# Set up the buffer to hold the 'Answer' content
self.buffer: io.StringIO = io.StringIO()
def _init_data(self, data: Union[str, Generator[str, None, None]]) -> Iterator[str]:
def __new__(cls: Type[AnswerInst], string: str) -> AnswerInst:
"""
Process input data (either a string or a string generator)
Make sure the answer string does not contain the header as a whole line.
"""
if isinstance(data, str):
yield data
else:
yield from data
def __str__(self) -> str:
"""
Output all content when converted into a string
"""
# Ensure all data has been processed
for _ in self:
pass
# Return the 'Answer' content
return self.buffer.getvalue()
def __repr__(self) -> str:
return repr(str(self))
def __iter__(self) -> Generator[str, None, None]:
"""
Allows the object to be iterable
"""
# Generate content if not all data has been processed
if not self.is_exhausted:
yield from self.generator_iter()
else:
yield self.buffer.getvalue()
def generator_iter(self) -> Generator[str, None, None]:
"""
Main generator method to process data
"""
for piece in self.iterator:
# Write to buffer and yield piece for the iterator
self.buffer.write(piece)
yield piece
self.is_exhausted = True # Set the flag that all data has been processed
# If the header occurs in the 'Answer' content, raise an error
if f'\n{self.txt_header}' in self.buffer.getvalue() or self.buffer.getvalue().startswith(self.txt_header):
raise MessageError(f"Answer {repr(self.buffer.getvalue())} contains the header {repr(Answer.txt_header)}")
def __eq__(self, other: object) -> bool:
"""
Comparing the object to a string or another object
"""
if isinstance(other, str):
return str(self) == other # Compare the string value of this object to the other string
# Default behavior for comparing non-string objects
return super().__eq__(other)
def __hash__(self) -> int:
"""
Generate a hash for the object based on its string representation.
"""
return hash(str(self))
def __format__(self, format_spec: str) -> str:
"""
Return a formatted version of the string as per the format specification.
"""
return str(self).__format__(format_spec)
if cls.txt_header in string.split('\n'):
raise MessageError(f"Answer '{string}' contains the header '{cls.txt_header}'")
instance = super().__new__(cls, string)
return instance
@classmethod
def from_list(cls: Type[AnswerInst], strings: list[str]) -> AnswerInst:
"""
Build Answer from a list of strings. Make sure strings do not contain the header.
Build Question from a list of strings. Make sure strings do not contain the header.
"""
def _gen() -> Generator[str, None, None]:
if len(strings) > 0:
yield strings[0]
for s in strings[1:]:
yield '\n'
yield s
return cls(_gen())
if cls.txt_header in strings:
raise MessageError(f"Question contains the header '{cls.txt_header}'")
instance = super().__new__(cls, '\n'.join(strings).strip())
return instance
def source_code(self, include_delims: bool = False) -> list[str]:
"""
Extract and return all source code sections.
"""
return source_code(str(self), include_delims)
return source_code(self, include_delims)
class Question(str):
@@ -292,9 +216,7 @@ class Message():
model: Optional[str] = field(default=None, compare=False)
file_path: Optional[pathlib.Path] = field(default=None, compare=False)
# class variables
file_suffixes_read: ClassVar[list[str]] = ['.msg', '.txt', '.yaml']
file_suffix_write: ClassVar[str] = '.msg'
default_format: ClassVar[MessageFormat] = message_default_format
file_suffixes: ClassVar[list[str]] = ['.txt', '.yaml']
tags_yaml_key: ClassVar[str] = 'tags'
file_yaml_key: ClassVar[str] = 'file_path'
ai_yaml_key: ClassVar[str] = 'ai'
@@ -354,8 +276,16 @@ class Message():
tags: set[Tag] = set()
if not file_path.exists():
raise MessageError(f"Message file '{file_path}' does not exist")
if file_path.suffix not in cls.file_suffixes_read:
if file_path.suffix not in cls.file_suffixes:
raise MessageError(f"File type '{file_path.suffix}' is not supported")
# for TXT, it's enough to read the TagLine
if file_path.suffix == '.txt':
with open(file_path, "r") as fd:
try:
tags = TagLine(fd.readline()).tags(prefix, contain)
except TagError:
pass # message without tags
else: # '.yaml'
try:
message = cls.from_file(file_path)
if message:
@@ -398,16 +328,15 @@ class Message():
"""
if not file_path.exists():
raise MessageError(f"Message file '{file_path}' does not exist")
if file_path.suffix not in cls.file_suffixes_read:
if file_path.suffix not in cls.file_suffixes:
raise MessageError(f"File type '{file_path.suffix}' is not supported")
# try TXT first
try:
if file_path.suffix == '.txt':
message = cls.__from_file_txt(file_path,
mfilter.tags_or if mfilter else None,
mfilter.tags_and if mfilter else None,
mfilter.tags_not if mfilter else None)
# then YAML
except MessageError:
else:
message = cls.__from_file_yaml(file_path)
if message and (mfilter is None or message.match(mfilter)):
return message
@@ -444,6 +373,10 @@ class Message():
tags = TagLine(fd.readline()).tags()
except TagError:
fd.seek(pos)
if tags_or or tags_and or tags_not:
# match with an empty set if the file has no tags
if not match_tags(tags, tags_or, tags_and, tags_not):
return None
# AILine (Optional)
try:
pos = fd.tell()
@@ -468,12 +401,6 @@ class Message():
answer = Answer.from_list(text[answer_idx + 1:])
except ValueError:
question = Question.from_list(text[question_idx:])
# match tags AFTER reading the whole file
# -> make sure it's a valid 'txt' file format
if tags_or or tags_and or tags_not:
# match with an empty set if the file has no tags
if not match_tags(tags, tags_or, tags_and, tags_not):
return None
return cls(question, answer, tags, ai, model, file_path)
@classmethod
@@ -494,7 +421,7 @@ class Message():
except Exception:
raise MessageError(f"'{file_path}' does not contain a valid message")
def to_str(self, with_metadata: bool = False, source_code_only: bool = False) -> str:
def to_str(self, with_tags: bool = False, with_file: bool = False, source_code_only: bool = False) -> str:
"""
Return the current Message as a string.
"""
@@ -504,41 +431,32 @@ class Message():
if self.answer:
output.extend(self.answer.source_code(include_delims=True))
return '\n'.join(output) if len(output) > 0 else ''
if with_metadata:
if with_tags:
output.append(self.tags_str())
if with_file:
output.append('FILE: ' + str(self.file_path))
output.append('AI: ' + str(self.ai))
output.append('MODEL: ' + str(self.model))
output.append(Question.txt_header)
output.append(self.question)
if self.answer:
output.append(Answer.txt_header)
output.append(str(self.answer))
output.append(self.answer)
return '\n'.join(output)
def to_file(self, file_path: Optional[pathlib.Path]=None, mformat: MessageFormat = message_default_format) -> None: # noqa: 11
def to_file(self, file_path: Optional[pathlib.Path]=None) -> None: # noqa: 11
"""
Write a Message to the given file. Supported message file formats are 'txt' and 'yaml'.
Suffix is always '.msg'.
Write a Message to the given file. Type is determined based on the suffix.
Currently supported suffixes: ['.txt', '.yaml']
"""
if file_path:
self.file_path = file_path
if not self.file_path:
raise MessageError("Got no valid path to write message")
if mformat not in message_valid_formats:
raise MessageError(f"File format '{mformat}' is not supported")
# check for valid suffix
# -> add one if it's empty
# -> refuse old or otherwise unsupported suffixes
if not self.file_path.suffix:
self.file_path = self.file_path.with_suffix(self.file_suffix_write)
elif self.file_path.suffix != self.file_suffix_write:
raise MessageError(f"File suffix '{self.file_path.suffix}' is not supported")
if self.file_path.suffix not in self.file_suffixes:
raise MessageError(f"File type '{self.file_path.suffix}' is not supported")
# TXT
if mformat == 'txt':
if self.file_path.suffix == '.txt':
return self.__to_file_txt(self.file_path)
# YAML
elif mformat == 'yaml':
elif self.file_path.suffix == '.yaml':
return self.__to_file_yaml(self.file_path)
def __to_file_txt(self, file_path: pathlib.Path) -> None:
@@ -550,8 +468,8 @@ class Message():
* Model [Optional]
* Question.txt_header
* Question
* Answer.txt_header [Optional]
* Answer [Optional]
* Answer.txt_header
* Answer
"""
with tempfile.NamedTemporaryFile(dir=file_path.parent, prefix=file_path.name, mode="w", delete=False) as temp_fd:
temp_file_path = pathlib.Path(temp_fd.name)
@@ -563,7 +481,7 @@ class Message():
temp_fd.write(f'{ModelLine.from_model(self.model)}\n')
temp_fd.write(f'{Question.txt_header}\n{self.question}\n')
if self.answer:
temp_fd.write(f'{Answer.txt_header}\n{str(self.answer)}\n')
temp_fd.write(f'{Answer.txt_header}\n{self.answer}\n')
shutil.move(temp_file_path, file_path)
def __to_file_yaml(self, file_path: pathlib.Path) -> None:
@@ -590,13 +508,6 @@ class Message():
yaml.dump(data, temp_fd, sort_keys=False)
shutil.move(temp_file_path, file_path)
def rm_file(self) -> None:
"""
Delete the message file. Ignore empty file_path and not existing files.
"""
if self.file_path is not None:
self.file_path.unlink(missing_ok=True)
def filter_tags(self, prefix: Optional[str] = None, contain: Optional[str] = None) -> set[Tag]:
"""
Filter tags based on their prefix (i. e. the tag starts with a given string)
@@ -632,7 +543,7 @@ class Message():
or (mfilter.ai and (not self.ai or mfilter.ai != self.ai)) # noqa: W503
or (mfilter.model and (not self.model or mfilter.model != self.model)) # noqa: W503
or (mfilter.question_contains and mfilter.question_contains not in self.question) # noqa: W503
or (mfilter.answer_contains and (not self.answer or mfilter.answer_contains not in str(self.answer))) # noqa: W503
or (mfilter.answer_contains and (not self.answer or mfilter.answer_contains not in self.answer)) # noqa: W503
or (mfilter.answer_state == 'available' and not self.answer) # noqa: W503
or (mfilter.ai_state == 'available' and not self.ai) # noqa: W503
or (mfilter.model_state == 'available' and not self.model) # noqa: W503
-1
View File
@@ -2,4 +2,3 @@ openai
PyYAML
argcomplete
pytest
tiktoken
+13 -25
View File
@@ -16,37 +16,26 @@ class OpenAITest(unittest.TestCase):
openai = OpenAI(config)
# Set up the mock response from openai.ChatCompletion.create
mock_chunk1 = {
mock_response = {
'choices': [
{
'index': 0,
'delta': {
'message': {
'content': 'Answer 1'
},
'finish_reason': None
}
},
{
'index': 1,
'delta': {
'message': {
'content': 'Answer 2'
},
'finish_reason': None
}
}
],
'usage': {
'prompt_tokens': 10,
'completion_tokens': 20,
'total_tokens': 30
}
mock_chunk2 = {
'choices': [
{
'index': 0,
'finish_reason': 'stop'
},
{
'index': 1,
'finish_reason': 'stop'
}
],
}
mock_create.return_value = iter([mock_chunk1, mock_chunk2])
mock_create.return_value = mock_response
# Create test data
question = Message(Question('Question'))
@@ -68,9 +57,9 @@ class OpenAITest(unittest.TestCase):
self.assertIsNotNone(response.tokens)
self.assertIsInstance(response.tokens, Tokens)
assert response.tokens
self.assertEqual(response.tokens.prompt, 53)
self.assertEqual(response.tokens.completion, 6)
self.assertEqual(response.tokens.total, 59)
self.assertEqual(response.tokens.prompt, 10)
self.assertEqual(response.tokens.completion, 20)
self.assertEqual(response.tokens.total, 30)
# Assert the mock call to openai.ChatCompletion.create
mock_create.assert_called_once_with(
@@ -87,7 +76,6 @@ class OpenAITest(unittest.TestCase):
max_tokens=config.max_tokens,
top_p=config.top_p,
n=2,
stream=True,
frequency_penalty=config.frequency_penalty,
presence_penalty=config.presence_penalty
)
+140 -175
View File
@@ -7,21 +7,7 @@ from io import StringIO
from unittest.mock import patch
from chatmastermind.tags import TagLine
from chatmastermind.message import Message, Question, Answer, Tag, MessageFilter
from chatmastermind.chat import Chat, ChatDB, ChatError, msg_location
msg_suffix: str = Message.file_suffix_write
def msg_to_file_force_suffix(msg: Message) -> None:
"""
Force writing a message file with illegal suffixes.
"""
def_suffix = Message.file_suffix_write
assert msg.file_path
Message.file_suffix_write = msg.file_path.suffix
msg.to_file()
Message.file_suffix_write = def_suffix
from chatmastermind.chat import Chat, ChatDB, ChatError
class TestChatBase(unittest.TestCase):
@@ -41,15 +27,11 @@ class TestChat(TestChatBase):
self.message1 = Message(Question('Question 1'),
Answer('Answer 1'),
{Tag('atag1'), Tag('btag2')},
ai='FakeAI',
model='FakeModel',
file_path=pathlib.Path(f'0001{msg_suffix}'))
file_path=pathlib.Path('0001.txt'))
self.message2 = Message(Question('Question 2'),
Answer('Answer 2'),
{Tag('btag2')},
ai='FakeAI',
model='FakeModel',
file_path=pathlib.Path(f'0002{msg_suffix}'))
file_path=pathlib.Path('0002.txt'))
self.maxDiff = None
def test_unique_id(self) -> None:
@@ -117,24 +99,24 @@ class TestChat(TestChatBase):
def test_find_remove_messages(self) -> None:
self.chat.msg_add([self.message1, self.message2])
msgs = self.chat.msg_find(['0001'])
msgs = self.chat.msg_find(['0001.txt'])
self.assertListEqual(msgs, [self.message1])
msgs = self.chat.msg_find(['0001', '0002'])
msgs = self.chat.msg_find(['0001.txt', '0002.txt'])
self.assertListEqual(msgs, [self.message1, self.message2])
# add new Message with full path
message3 = Message(Question('Question 2'),
Answer('Answer 2'),
{Tag('btag2')},
file_path=pathlib.Path(f'/foo/bla/0003{msg_suffix}'))
file_path=pathlib.Path('/foo/bla/0003.txt'))
self.chat.msg_add([message3])
# find new Message by full path
msgs = self.chat.msg_find([f'/foo/bla/0003{msg_suffix}'])
msgs = self.chat.msg_find(['/foo/bla/0003.txt'])
self.assertListEqual(msgs, [message3])
# find Message with full path only by filename
msgs = self.chat.msg_find([f'0003{msg_suffix}'])
msgs = self.chat.msg_find(['0003.txt'])
self.assertListEqual(msgs, [message3])
# remove last message
self.chat.msg_remove(['0003'])
self.chat.msg_remove(['0003.txt'])
self.assertListEqual(self.chat.messages, [self.message1, self.message2])
def test_latest_message(self) -> None:
@@ -147,7 +129,7 @@ class TestChat(TestChatBase):
@patch('sys.stdout', new_callable=StringIO)
def test_print(self, mock_stdout: StringIO) -> None:
self.chat.msg_add([self.message1, self.message2])
self.chat.print(paged=False, tight=True)
self.chat.print(paged=False)
expected_output = f"""{Question.txt_header}
Question 1
{Answer.txt_header}
@@ -160,21 +142,17 @@ Answer 2
self.assertEqual(mock_stdout.getvalue(), expected_output)
@patch('sys.stdout', new_callable=StringIO)
def test_print_with_metadata(self, mock_stdout: StringIO) -> None:
def test_print_with_tags_and_file(self, mock_stdout: StringIO) -> None:
self.chat.msg_add([self.message1, self.message2])
self.chat.print(paged=False, with_metadata=True, tight=True)
self.chat.print(paged=False, with_tags=True, with_files=True)
expected_output = f"""{TagLine.prefix} atag1 btag2
FILE: 0001{msg_suffix}
AI: FakeAI
MODEL: FakeModel
FILE: 0001.txt
{Question.txt_header}
Question 1
{Answer.txt_header}
Answer 1
{TagLine.prefix} btag2
FILE: 0002{msg_suffix}
AI: FakeAI
MODEL: FakeModel
FILE: 0002.txt
{Question.txt_header}
Question 2
{Answer.txt_header}
@@ -190,27 +168,31 @@ class TestChatDB(TestChatBase):
self.message1 = Message(Question('Question 1'),
Answer('Answer 1'),
{Tag('tag1')})
{Tag('tag1')},
file_path=pathlib.Path('0001.txt'))
self.message2 = Message(Question('Question 2'),
Answer('Answer 2'),
{Tag('tag2')})
{Tag('tag2')},
file_path=pathlib.Path('0002.yaml'))
self.message3 = Message(Question('Question 3'),
Answer('Answer 3'),
{Tag('tag3')})
{Tag('tag3')},
file_path=pathlib.Path('0003.txt'))
self.message4 = Message(Question('Question 4'),
Answer('Answer 4'),
{Tag('tag4')})
{Tag('tag4')},
file_path=pathlib.Path('0004.yaml'))
self.message1.to_file(pathlib.Path(self.db_path.name, '0001'), mformat='txt')
self.message2.to_file(pathlib.Path(self.db_path.name, '0002'), mformat='yaml')
self.message3.to_file(pathlib.Path(self.db_path.name, '0003'), mformat='txt')
self.message4.to_file(pathlib.Path(self.db_path.name, '0004'), mformat='yaml')
self.message1.to_file(pathlib.Path(self.db_path.name, '0001.txt'))
self.message2.to_file(pathlib.Path(self.db_path.name, '0002.yaml'))
self.message3.to_file(pathlib.Path(self.db_path.name, '0003.txt'))
self.message4.to_file(pathlib.Path(self.db_path.name, '0004.yaml'))
# make the next FID match the current state
next_fname = pathlib.Path(self.db_path.name) / '.next'
with open(next_fname, 'w') as f:
f.write('4')
# add some "trash" in order to test if it's correctly handled / ignored
self.trash_files = ['.config.yaml', 'foo.yaml', 'bla.txt', 'fubar.msg']
self.trash_files = ['.config.yaml', 'foo.yaml', 'bla.txt']
for file in self.trash_files:
with open(pathlib.Path(self.db_path.name) / file, 'w') as f:
f.write('test trash')
@@ -225,7 +207,7 @@ class TestChatDB(TestChatBase):
List all Message files in the given TemporaryDirectory.
"""
# exclude '.next'
return [f for f in pathlib.Path(tmp_dir.name).glob('*.[tym]*') if f.name not in self.trash_files]
return [f for f in pathlib.Path(tmp_dir.name).glob('*.[ty]*') if f.name not in self.trash_files]
def tearDown(self) -> None:
self.db_path.cleanup()
@@ -236,32 +218,13 @@ class TestChatDB(TestChatBase):
duplicate_message = Message(Question('Question 4'),
Answer('Answer 4'),
{Tag('tag4')},
file_path=pathlib.Path(self.db_path.name, '0004.txt'))
msg_to_file_force_suffix(duplicate_message)
file_path=pathlib.Path('0004.txt'))
duplicate_message.to_file(pathlib.Path(self.db_path.name, '0004.txt'))
with self.assertRaises(ChatError) as cm:
ChatDB.from_dir(pathlib.Path(self.cache_path.name),
pathlib.Path(self.db_path.name),
glob='*')
pathlib.Path(self.db_path.name))
self.assertEqual(str(cm.exception), "Validation failed")
def test_file_path_ID_exists(self) -> None:
"""
Tests if the CacheDB chooses another ID if a file path with
the given one exists.
"""
# create a new and empty CacheDB
db_path = tempfile.TemporaryDirectory()
cache_path = tempfile.TemporaryDirectory()
chat_db = ChatDB.from_dir(pathlib.Path(cache_path.name),
pathlib.Path(db_path.name))
# add a message file
message = Message(Question('What?'),
file_path=pathlib.Path(cache_path.name) / f'0001{msg_suffix}')
message.to_file()
message1 = Message(Question('Where?'))
chat_db.cache_write([message1])
self.assertEqual(message1.msg_id(), '0002')
def test_from_dir(self) -> None:
chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name),
pathlib.Path(self.db_path.name))
@@ -270,23 +233,25 @@ class TestChatDB(TestChatBase):
self.assertEqual(chat_db.db_path, pathlib.Path(self.db_path.name))
# check that the files are sorted
self.assertEqual(chat_db.messages[0].file_path,
pathlib.Path(self.db_path.name, f'0001{msg_suffix}'))
pathlib.Path(self.db_path.name, '0001.txt'))
self.assertEqual(chat_db.messages[1].file_path,
pathlib.Path(self.db_path.name, f'0002{msg_suffix}'))
pathlib.Path(self.db_path.name, '0002.yaml'))
self.assertEqual(chat_db.messages[2].file_path,
pathlib.Path(self.db_path.name, f'0003{msg_suffix}'))
pathlib.Path(self.db_path.name, '0003.txt'))
self.assertEqual(chat_db.messages[3].file_path,
pathlib.Path(self.db_path.name, f'0004{msg_suffix}'))
pathlib.Path(self.db_path.name, '0004.yaml'))
def test_from_dir_glob(self) -> None:
chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name),
pathlib.Path(self.db_path.name),
glob='*1.*')
self.assertEqual(len(chat_db.messages), 1)
glob='*.txt')
self.assertEqual(len(chat_db.messages), 2)
self.assertEqual(chat_db.cache_path, pathlib.Path(self.cache_path.name))
self.assertEqual(chat_db.db_path, pathlib.Path(self.db_path.name))
self.assertEqual(chat_db.messages[0].file_path,
pathlib.Path(self.db_path.name, f'0001{msg_suffix}'))
pathlib.Path(self.db_path.name, '0001.txt'))
self.assertEqual(chat_db.messages[1].file_path,
pathlib.Path(self.db_path.name, '0003.txt'))
def test_from_dir_filter_tags(self) -> None:
chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name),
@@ -296,7 +261,7 @@ class TestChatDB(TestChatBase):
self.assertEqual(chat_db.cache_path, pathlib.Path(self.cache_path.name))
self.assertEqual(chat_db.db_path, pathlib.Path(self.db_path.name))
self.assertEqual(chat_db.messages[0].file_path,
pathlib.Path(self.db_path.name, f'0001{msg_suffix}'))
pathlib.Path(self.db_path.name, '0001.txt'))
def test_from_dir_filter_tags_empty(self) -> None:
chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name),
@@ -314,7 +279,7 @@ class TestChatDB(TestChatBase):
self.assertEqual(chat_db.cache_path, pathlib.Path(self.cache_path.name))
self.assertEqual(chat_db.db_path, pathlib.Path(self.db_path.name))
self.assertEqual(chat_db.messages[0].file_path,
pathlib.Path(self.db_path.name, f'0002{msg_suffix}'))
pathlib.Path(self.db_path.name, '0002.yaml'))
self.assertEqual(chat_db.messages[0].answer, 'Answer 2')
def test_from_messages(self) -> None:
@@ -359,25 +324,25 @@ class TestChatDB(TestChatBase):
chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name),
pathlib.Path(self.db_path.name))
# check that Message.file_path is correct
self.assertEqual(chat_db.messages[0].file_path, pathlib.Path(self.db_path.name, f'0001{msg_suffix}'))
self.assertEqual(chat_db.messages[1].file_path, pathlib.Path(self.db_path.name, f'0002{msg_suffix}'))
self.assertEqual(chat_db.messages[2].file_path, pathlib.Path(self.db_path.name, f'0003{msg_suffix}'))
self.assertEqual(chat_db.messages[3].file_path, pathlib.Path(self.db_path.name, f'0004{msg_suffix}'))
self.assertEqual(chat_db.messages[0].file_path, pathlib.Path(self.db_path.name, '0001.txt'))
self.assertEqual(chat_db.messages[1].file_path, pathlib.Path(self.db_path.name, '0002.yaml'))
self.assertEqual(chat_db.messages[2].file_path, pathlib.Path(self.db_path.name, '0003.txt'))
self.assertEqual(chat_db.messages[3].file_path, pathlib.Path(self.db_path.name, '0004.yaml'))
# write the messages to the cache directory
chat_db.cache_write()
# check if the written files are in the cache directory
cache_dir_files = self.message_list(self.cache_path)
self.assertEqual(len(cache_dir_files), 4)
self.assertIn(pathlib.Path(self.cache_path.name, f'0001{msg_suffix}'), cache_dir_files)
self.assertIn(pathlib.Path(self.cache_path.name, f'0002{msg_suffix}'), cache_dir_files)
self.assertIn(pathlib.Path(self.cache_path.name, f'0003{msg_suffix}'), cache_dir_files)
self.assertIn(pathlib.Path(self.cache_path.name, f'0004{msg_suffix}'), cache_dir_files)
self.assertIn(pathlib.Path(self.cache_path.name, '0001.txt'), cache_dir_files)
self.assertIn(pathlib.Path(self.cache_path.name, '0002.yaml'), cache_dir_files)
self.assertIn(pathlib.Path(self.cache_path.name, '0003.txt'), cache_dir_files)
self.assertIn(pathlib.Path(self.cache_path.name, '0004.yaml'), cache_dir_files)
# check that Message.file_path has been correctly updated
self.assertEqual(chat_db.messages[0].file_path, pathlib.Path(self.cache_path.name, f'0001{msg_suffix}'))
self.assertEqual(chat_db.messages[1].file_path, pathlib.Path(self.cache_path.name, f'0002{msg_suffix}'))
self.assertEqual(chat_db.messages[2].file_path, pathlib.Path(self.cache_path.name, f'0003{msg_suffix}'))
self.assertEqual(chat_db.messages[3].file_path, pathlib.Path(self.cache_path.name, f'0004{msg_suffix}'))
self.assertEqual(chat_db.messages[0].file_path, pathlib.Path(self.cache_path.name, '0001.txt'))
self.assertEqual(chat_db.messages[1].file_path, pathlib.Path(self.cache_path.name, '0002.yaml'))
self.assertEqual(chat_db.messages[2].file_path, pathlib.Path(self.cache_path.name, '0003.txt'))
self.assertEqual(chat_db.messages[3].file_path, pathlib.Path(self.cache_path.name, '0004.yaml'))
# check the timestamp of the files in the DB directory
db_dir_files = self.message_list(self.db_path)
@@ -389,18 +354,18 @@ class TestChatDB(TestChatBase):
# check if the written files are in the DB directory
db_dir_files = self.message_list(self.db_path)
self.assertEqual(len(db_dir_files), 4)
self.assertIn(pathlib.Path(self.db_path.name, f'0001{msg_suffix}'), db_dir_files)
self.assertIn(pathlib.Path(self.db_path.name, f'0002{msg_suffix}'), db_dir_files)
self.assertIn(pathlib.Path(self.db_path.name, f'0003{msg_suffix}'), db_dir_files)
self.assertIn(pathlib.Path(self.db_path.name, f'0004{msg_suffix}'), db_dir_files)
self.assertIn(pathlib.Path(self.db_path.name, '0001.txt'), db_dir_files)
self.assertIn(pathlib.Path(self.db_path.name, '0002.yaml'), db_dir_files)
self.assertIn(pathlib.Path(self.db_path.name, '0003.txt'), db_dir_files)
self.assertIn(pathlib.Path(self.db_path.name, '0004.yaml'), db_dir_files)
# check if all files in the DB dir have actually been overwritten
for file in db_dir_files:
self.assertGreater(file.stat().st_mtime, old_timestamps[file])
# check that Message.file_path has been correctly updated (again)
self.assertEqual(chat_db.messages[0].file_path, pathlib.Path(self.db_path.name, f'0001{msg_suffix}'))
self.assertEqual(chat_db.messages[1].file_path, pathlib.Path(self.db_path.name, f'0002{msg_suffix}'))
self.assertEqual(chat_db.messages[2].file_path, pathlib.Path(self.db_path.name, f'0003{msg_suffix}'))
self.assertEqual(chat_db.messages[3].file_path, pathlib.Path(self.db_path.name, f'0004{msg_suffix}'))
self.assertEqual(chat_db.messages[0].file_path, pathlib.Path(self.db_path.name, '0001.txt'))
self.assertEqual(chat_db.messages[1].file_path, pathlib.Path(self.db_path.name, '0002.yaml'))
self.assertEqual(chat_db.messages[2].file_path, pathlib.Path(self.db_path.name, '0003.txt'))
self.assertEqual(chat_db.messages[3].file_path, pathlib.Path(self.db_path.name, '0004.yaml'))
def test_db_read(self) -> None:
# create a new ChatDB instance
@@ -415,65 +380,65 @@ class TestChatDB(TestChatBase):
new_message2 = Message(Question('Question 6'),
Answer('Answer 6'),
{Tag('tag6')})
new_message1.to_file(pathlib.Path(self.db_path.name, f'0005{msg_suffix}'), mformat='txt')
new_message2.to_file(pathlib.Path(self.db_path.name, f'0006{msg_suffix}'), mformat='yaml')
new_message1.to_file(pathlib.Path(self.db_path.name, '0005.txt'))
new_message2.to_file(pathlib.Path(self.db_path.name, '0006.yaml'))
# read and check them
chat_db.db_read()
self.assertEqual(len(chat_db.messages), 6)
self.assertEqual(chat_db.messages[4].file_path, pathlib.Path(self.db_path.name, f'0005{msg_suffix}'))
self.assertEqual(chat_db.messages[5].file_path, pathlib.Path(self.db_path.name, f'0006{msg_suffix}'))
self.assertEqual(chat_db.messages[4].file_path, pathlib.Path(self.db_path.name, '0005.txt'))
self.assertEqual(chat_db.messages[5].file_path, pathlib.Path(self.db_path.name, '0006.yaml'))
# create 2 new files in the cache directory
new_message3 = Message(Question('Question 7'),
Answer('Answer 7'),
Answer('Answer 5'),
{Tag('tag7')})
new_message4 = Message(Question('Question 8'),
Answer('Answer 8'),
Answer('Answer 6'),
{Tag('tag8')})
new_message3.to_file(pathlib.Path(self.cache_path.name, f'0007{msg_suffix}'), mformat='txt')
new_message4.to_file(pathlib.Path(self.cache_path.name, f'0008{msg_suffix}'), mformat='yaml')
new_message3.to_file(pathlib.Path(self.cache_path.name, '0007.txt'))
new_message4.to_file(pathlib.Path(self.cache_path.name, '0008.yaml'))
# read and check them
chat_db.cache_read()
self.assertEqual(len(chat_db.messages), 8)
# check that the new message have the cache dir path
self.assertEqual(chat_db.messages[6].file_path, pathlib.Path(self.cache_path.name, f'0007{msg_suffix}'))
self.assertEqual(chat_db.messages[7].file_path, pathlib.Path(self.cache_path.name, f'0008{msg_suffix}'))
self.assertEqual(chat_db.messages[6].file_path, pathlib.Path(self.cache_path.name, '0007.txt'))
self.assertEqual(chat_db.messages[7].file_path, pathlib.Path(self.cache_path.name, '0008.yaml'))
# an the old ones keep their path (since they have not been replaced)
self.assertEqual(chat_db.messages[4].file_path, pathlib.Path(self.db_path.name, f'0005{msg_suffix}'))
self.assertEqual(chat_db.messages[5].file_path, pathlib.Path(self.db_path.name, f'0006{msg_suffix}'))
self.assertEqual(chat_db.messages[4].file_path, pathlib.Path(self.db_path.name, '0005.txt'))
self.assertEqual(chat_db.messages[5].file_path, pathlib.Path(self.db_path.name, '0006.yaml'))
# now overwrite two messages in the DB directory
new_message1.question = Question('New Question 1')
new_message2.question = Question('New Question 2')
new_message1.to_file(pathlib.Path(self.db_path.name, f'0005{msg_suffix}'), mformat='txt')
new_message2.to_file(pathlib.Path(self.db_path.name, f'0006{msg_suffix}'), mformat='yaml')
new_message1.to_file(pathlib.Path(self.db_path.name, '0005.txt'))
new_message2.to_file(pathlib.Path(self.db_path.name, '0006.yaml'))
# read from the DB dir and check if the modified messages have been updated
chat_db.db_read()
self.assertEqual(len(chat_db.messages), 8)
self.assertEqual(chat_db.messages[4].question, 'New Question 1')
self.assertEqual(chat_db.messages[5].question, 'New Question 2')
self.assertEqual(chat_db.messages[4].file_path, pathlib.Path(self.db_path.name, f'0005{msg_suffix}'))
self.assertEqual(chat_db.messages[5].file_path, pathlib.Path(self.db_path.name, f'0006{msg_suffix}'))
self.assertEqual(chat_db.messages[4].file_path, pathlib.Path(self.db_path.name, '0005.txt'))
self.assertEqual(chat_db.messages[5].file_path, pathlib.Path(self.db_path.name, '0006.yaml'))
# now write the messages from the cache to the DB directory
new_message3.to_file(pathlib.Path(self.db_path.name, f'0007{msg_suffix}'))
new_message4.to_file(pathlib.Path(self.db_path.name, f'0008{msg_suffix}'))
new_message3.to_file(pathlib.Path(self.db_path.name, '0007.txt'))
new_message4.to_file(pathlib.Path(self.db_path.name, '0008.yaml'))
# read and check them
chat_db.db_read()
self.assertEqual(len(chat_db.messages), 8)
# check that they now have the DB path
self.assertEqual(chat_db.messages[6].file_path, pathlib.Path(self.db_path.name, f'0007{msg_suffix}'))
self.assertEqual(chat_db.messages[7].file_path, pathlib.Path(self.db_path.name, f'0008{msg_suffix}'))
self.assertEqual(chat_db.messages[6].file_path, pathlib.Path(self.db_path.name, '0007.txt'))
self.assertEqual(chat_db.messages[7].file_path, pathlib.Path(self.db_path.name, '0008.yaml'))
def test_cache_clear(self) -> None:
# create a new ChatDB instance
chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name),
pathlib.Path(self.db_path.name))
# check that Message.file_path is correct
self.assertEqual(chat_db.messages[0].file_path, pathlib.Path(self.db_path.name, f'0001{msg_suffix}'))
self.assertEqual(chat_db.messages[1].file_path, pathlib.Path(self.db_path.name, f'0002{msg_suffix}'))
self.assertEqual(chat_db.messages[2].file_path, pathlib.Path(self.db_path.name, f'0003{msg_suffix}'))
self.assertEqual(chat_db.messages[3].file_path, pathlib.Path(self.db_path.name, f'0004{msg_suffix}'))
self.assertEqual(chat_db.messages[0].file_path, pathlib.Path(self.db_path.name, '0001.txt'))
self.assertEqual(chat_db.messages[1].file_path, pathlib.Path(self.db_path.name, '0002.yaml'))
self.assertEqual(chat_db.messages[2].file_path, pathlib.Path(self.db_path.name, '0003.txt'))
self.assertEqual(chat_db.messages[3].file_path, pathlib.Path(self.db_path.name, '0004.yaml'))
# write the messages to the cache directory
chat_db.cache_write()
@@ -485,10 +450,10 @@ class TestChatDB(TestChatBase):
chat_db.db_write()
db_dir_files = self.message_list(self.db_path)
self.assertEqual(len(db_dir_files), 4)
self.assertIn(pathlib.Path(self.db_path.name, f'0001{msg_suffix}'), db_dir_files)
self.assertIn(pathlib.Path(self.db_path.name, f'0002{msg_suffix}'), db_dir_files)
self.assertIn(pathlib.Path(self.db_path.name, f'0003{msg_suffix}'), db_dir_files)
self.assertIn(pathlib.Path(self.db_path.name, f'0004{msg_suffix}'), db_dir_files)
self.assertIn(pathlib.Path(self.db_path.name, '0001.txt'), db_dir_files)
self.assertIn(pathlib.Path(self.db_path.name, '0002.yaml'), db_dir_files)
self.assertIn(pathlib.Path(self.db_path.name, '0003.txt'), db_dir_files)
self.assertIn(pathlib.Path(self.db_path.name, '0004.yaml'), db_dir_files)
# add a new message with empty file_path
message_empty = Message(question=Question("What the hell am I doing here?"),
@@ -496,7 +461,7 @@ class TestChatDB(TestChatBase):
# and one for the cache dir
message_cache = Message(question=Question("What the hell am I doing here?"),
answer=Answer("You're a creep!"),
file_path=pathlib.Path(self.cache_path.name, '0005'))
file_path=pathlib.Path(self.cache_path.name, '0005.txt'))
chat_db.msg_add([message_empty, message_cache])
# clear the cache and check the cache dir
@@ -558,11 +523,11 @@ class TestChatDB(TestChatBase):
chat_db.msg_write([message])
# write a message with a valid file_path
message.file_path = pathlib.Path(self.cache_path.name) / '123456'
message.file_path = pathlib.Path(self.cache_path.name) / '123456.txt'
chat_db.msg_write([message])
cache_dir_files = self.message_list(self.cache_path)
self.assertEqual(len(cache_dir_files), 1)
self.assertIn(pathlib.Path(self.cache_path.name, f'123456{msg_suffix}'), cache_dir_files)
self.assertIn(pathlib.Path(self.cache_path.name, '123456.txt'), cache_dir_files)
def test_msg_update(self) -> None:
# create a new ChatDB instance
@@ -596,92 +561,92 @@ class TestChatDB(TestChatBase):
chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name),
pathlib.Path(self.db_path.name))
# search for a DB file in memory
self.assertEqual(chat_db.msg_find([str(self.message1.file_path)], loc=msg_location.MEM), [self.message1])
self.assertEqual(chat_db.msg_find([self.message1.msg_id()], loc=msg_location.MEM), [self.message1])
self.assertEqual(chat_db.msg_find(['0001.msg'], loc=msg_location.MEM), [self.message1])
self.assertEqual(chat_db.msg_find(['0001'], loc=msg_location.MEM), [self.message1])
self.assertEqual(chat_db.msg_find([str(self.message1.file_path)], loc='mem'), [self.message1])
self.assertEqual(chat_db.msg_find([self.message1.msg_id()], loc='mem'), [self.message1])
self.assertEqual(chat_db.msg_find(['0001.txt'], loc='mem'), [self.message1])
self.assertEqual(chat_db.msg_find(['0001'], loc='mem'), [self.message1])
# and on disk
self.assertEqual(chat_db.msg_find([str(self.message2.file_path)], loc=msg_location.DB), [self.message2])
self.assertEqual(chat_db.msg_find([self.message2.msg_id()], loc=msg_location.DB), [self.message2])
self.assertEqual(chat_db.msg_find(['0002.msg'], loc=msg_location.DB), [self.message2])
self.assertEqual(chat_db.msg_find(['0002'], loc=msg_location.DB), [self.message2])
self.assertEqual(chat_db.msg_find([str(self.message2.file_path)], loc='db'), [self.message2])
self.assertEqual(chat_db.msg_find([self.message2.msg_id()], loc='db'), [self.message2])
self.assertEqual(chat_db.msg_find(['0002.yaml'], loc='db'), [self.message2])
self.assertEqual(chat_db.msg_find(['0002'], loc='db'), [self.message2])
# now search the cache -> expect empty result
self.assertEqual(chat_db.msg_find([str(self.message3.file_path)], loc=msg_location.CACHE), [])
self.assertEqual(chat_db.msg_find([self.message3.msg_id()], loc=msg_location.CACHE), [])
self.assertEqual(chat_db.msg_find(['0003.msg'], loc=msg_location.CACHE), [])
self.assertEqual(chat_db.msg_find(['0003'], loc=msg_location.CACHE), [])
self.assertEqual(chat_db.msg_find([str(self.message3.file_path)], loc='cache'), [])
self.assertEqual(chat_db.msg_find([self.message3.msg_id()], loc='cache'), [])
self.assertEqual(chat_db.msg_find(['0003.txt'], loc='cache'), [])
self.assertEqual(chat_db.msg_find(['0003'], loc='cache'), [])
# search for multiple messages
# -> search one twice, expect result to be unique
search_names = ['0001', '0002.msg', self.message3.msg_id(), str(self.message3.file_path)]
search_names = ['0001', '0002.yaml', self.message3.msg_id(), str(self.message3.file_path)]
expected_result = [self.message1, self.message2, self.message3]
result = chat_db.msg_find(search_names, loc=msg_location.ALL)
result = chat_db.msg_find(search_names, loc='all')
self.assert_messages_equal(result, expected_result)
def test_msg_latest(self) -> None:
chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name),
pathlib.Path(self.db_path.name))
self.assertEqual(chat_db.msg_latest(loc=msg_location.MEM), self.message4)
self.assertEqual(chat_db.msg_latest(loc=msg_location.DB), self.message4)
self.assertEqual(chat_db.msg_latest(loc=msg_location.DISK), self.message4)
self.assertEqual(chat_db.msg_latest(loc=msg_location.ALL), self.message4)
self.assertEqual(chat_db.msg_latest(loc='mem'), self.message4)
self.assertEqual(chat_db.msg_latest(loc='db'), self.message4)
self.assertEqual(chat_db.msg_latest(loc='disk'), self.message4)
self.assertEqual(chat_db.msg_latest(loc='all'), self.message4)
# the cache is currently empty:
self.assertIsNone(chat_db.msg_latest(loc=msg_location.CACHE))
self.assertIsNone(chat_db.msg_latest(loc='cache'))
# add new messages to the cache dir
new_message = Message(question=Question("New Question"),
answer=Answer("New Answer"))
chat_db.cache_add([new_message])
self.assertEqual(chat_db.msg_latest(loc=msg_location.CACHE), new_message)
self.assertEqual(chat_db.msg_latest(loc=msg_location.MEM), new_message)
self.assertEqual(chat_db.msg_latest(loc=msg_location.DISK), new_message)
self.assertEqual(chat_db.msg_latest(loc=msg_location.ALL), new_message)
self.assertEqual(chat_db.msg_latest(loc='cache'), new_message)
self.assertEqual(chat_db.msg_latest(loc='mem'), new_message)
self.assertEqual(chat_db.msg_latest(loc='disk'), new_message)
self.assertEqual(chat_db.msg_latest(loc='all'), new_message)
# the DB does not contain the new message
self.assertEqual(chat_db.msg_latest(loc=msg_location.DB), self.message4)
self.assertEqual(chat_db.msg_latest(loc='db'), self.message4)
def test_msg_gather(self) -> None:
chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name),
pathlib.Path(self.db_path.name))
all_messages = [self.message1, self.message2, self.message3, self.message4]
self.assert_messages_equal(chat_db.msg_gather(loc=msg_location.ALL), all_messages)
self.assert_messages_equal(chat_db.msg_gather(loc=msg_location.DB), all_messages)
self.assert_messages_equal(chat_db.msg_gather(loc=msg_location.MEM), all_messages)
self.assert_messages_equal(chat_db.msg_gather(loc=msg_location.DISK), all_messages)
self.assert_messages_equal(chat_db.msg_gather(loc=msg_location.CACHE), [])
self.assert_messages_equal(chat_db.msg_gather(loc='all'), all_messages)
self.assert_messages_equal(chat_db.msg_gather(loc='db'), all_messages)
self.assert_messages_equal(chat_db.msg_gather(loc='mem'), all_messages)
self.assert_messages_equal(chat_db.msg_gather(loc='disk'), all_messages)
self.assert_messages_equal(chat_db.msg_gather(loc='cache'), [])
# add a new message, but only to the internal list
new_message = Message(Question("What?"))
all_messages_mem = all_messages + [new_message]
chat_db.msg_add([new_message])
self.assert_messages_equal(chat_db.msg_gather(loc=msg_location.MEM), all_messages_mem)
self.assert_messages_equal(chat_db.msg_gather(loc=msg_location.ALL), all_messages_mem)
self.assert_messages_equal(chat_db.msg_gather(loc='mem'), all_messages_mem)
self.assert_messages_equal(chat_db.msg_gather(loc='all'), all_messages_mem)
# the nr. of messages on disk did not change -> expect old result
self.assert_messages_equal(chat_db.msg_gather(loc=msg_location.DB), all_messages)
self.assert_messages_equal(chat_db.msg_gather(loc=msg_location.DISK), all_messages)
self.assert_messages_equal(chat_db.msg_gather(loc=msg_location.CACHE), [])
self.assert_messages_equal(chat_db.msg_gather(loc='db'), all_messages)
self.assert_messages_equal(chat_db.msg_gather(loc='disk'), all_messages)
self.assert_messages_equal(chat_db.msg_gather(loc='cache'), [])
# test with MessageFilter
self.assert_messages_equal(chat_db.msg_gather(loc=msg_location.ALL, mfilter=MessageFilter(tags_or={Tag('tag1')})),
self.assert_messages_equal(chat_db.msg_gather(loc='all', mfilter=MessageFilter(tags_or={Tag('tag1')})),
[self.message1])
self.assert_messages_equal(chat_db.msg_gather(loc=msg_location.DISK, mfilter=MessageFilter(tags_or={Tag('tag2')})),
self.assert_messages_equal(chat_db.msg_gather(loc='disk', mfilter=MessageFilter(tags_or={Tag('tag2')})),
[self.message2])
self.assert_messages_equal(chat_db.msg_gather(loc=msg_location.CACHE, mfilter=MessageFilter(tags_or={Tag('tag3')})),
self.assert_messages_equal(chat_db.msg_gather(loc='cache', mfilter=MessageFilter(tags_or={Tag('tag3')})),
[])
self.assert_messages_equal(chat_db.msg_gather(loc=msg_location.MEM, mfilter=MessageFilter(question_contains="What")),
self.assert_messages_equal(chat_db.msg_gather(loc='mem', mfilter=MessageFilter(question_contains="What")),
[new_message])
def test_msg_move_and_gather(self) -> None:
chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name),
pathlib.Path(self.db_path.name))
all_messages = [self.message1, self.message2, self.message3, self.message4]
self.assert_messages_equal(chat_db.msg_gather(loc=msg_location.DB), all_messages)
self.assert_messages_equal(chat_db.msg_gather(loc=msg_location.CACHE), [])
self.assert_messages_equal(chat_db.msg_gather(loc='db'), all_messages)
self.assert_messages_equal(chat_db.msg_gather(loc='cache'), [])
# move first message to the cache
chat_db.cache_move(self.message1)
self.assert_messages_equal(chat_db.msg_gather(loc=msg_location.CACHE), [self.message1])
self.assert_messages_equal(chat_db.msg_gather(loc='cache'), [self.message1])
self.assertEqual(self.message1.file_path.parent, pathlib.Path(self.cache_path.name)) # type: ignore [union-attr]
self.assert_messages_equal(chat_db.msg_gather(loc=msg_location.DB), [self.message2, self.message3, self.message4])
self.assert_messages_equal(chat_db.msg_gather(loc=msg_location.ALL), all_messages)
self.assert_messages_equal(chat_db.msg_gather(loc=msg_location.DISK), all_messages)
self.assert_messages_equal(chat_db.msg_gather(loc=msg_location.MEM), all_messages)
self.assert_messages_equal(chat_db.msg_gather(loc='db'), [self.message2, self.message3, self.message4])
self.assert_messages_equal(chat_db.msg_gather(loc='all'), all_messages)
self.assert_messages_equal(chat_db.msg_gather(loc='disk'), all_messages)
self.assert_messages_equal(chat_db.msg_gather(loc='mem'), all_messages)
# now move first message back to the DB
chat_db.db_move(self.message1)
self.assert_messages_equal(chat_db.msg_gather(loc=msg_location.CACHE), [])
self.assert_messages_equal(chat_db.msg_gather(loc='cache'), [])
self.assertEqual(self.message1.file_path.parent, pathlib.Path(self.db_path.name)) # type: ignore [union-attr]
self.assert_messages_equal(chat_db.msg_gather(loc=msg_location.DB), all_messages)
self.assert_messages_equal(chat_db.msg_gather(loc='db'), all_messages)
-100
View File
@@ -1,100 +0,0 @@
import unittest
import argparse
from typing import Union, Optional
from chatmastermind.configuration import Config, AIConfig
from chatmastermind.tags import Tag
from chatmastermind.message import Message, Answer
from chatmastermind.chat import Chat
from chatmastermind.ai import AI, AIResponse, Tokens, AIError
class FakeAI(AI):
"""
A mocked version of the 'AI' class.
"""
ID: str
name: str
config: AIConfig
def models(self) -> list[str]:
raise NotImplementedError
def tokens(self, data: Union[Message, Chat]) -> int:
return 123
def print(self) -> None:
pass
def print_models(self) -> None:
pass
def __init__(self, ID: str, model: str, error: bool = False):
self.ID = ID
self.model = model
self.error = error
def request(self,
question: Message,
chat: Chat,
num_answers: int = 1,
otags: Optional[set[Tag]] = None) -> AIResponse:
"""
Mock the 'ai.request()' function by either returning fake
answers or raising an exception.
"""
if self.error:
raise AIError
question.answer = Answer("Answer 0")
question.tags = set(otags) if otags is not None else None
question.ai = self.ID
question.model = self.model
answers: list[Message] = [question]
for n in range(1, num_answers):
answers.append(Message(question=question.question,
answer=Answer(f"Answer {n}"),
tags=otags,
ai=self.ID,
model=self.model))
return AIResponse(answers, Tokens(10, 10, 20))
class TestWithFakeAI(unittest.TestCase):
"""
Base class for all tests that need to use the FakeAI.
"""
def assert_msgs_equal_except_file_path(self, msg1: list[Message], msg2: list[Message]) -> None:
"""
Compare messages using Question, Answer and all metadata excecot for the file_path.
"""
self.assertEqual(len(msg1), len(msg2))
for m1, m2 in zip(msg1, msg2):
# exclude the file_path, compare only Q, A and metadata
self.assertTrue(m1.equals(m2, file_path=False, verbose=True))
def assert_msgs_all_equal(self, msg1: list[Message], msg2: list[Message]) -> None:
"""
Compare messages using Question, Answer and ALL metadata.
"""
self.assertEqual(len(msg1), len(msg2))
for m1, m2 in zip(msg1, msg2):
self.assertTrue(m1.equals(m2, verbose=True))
def assert_msgs_content_equal(self, msg1: list[Message], msg2: list[Message]) -> None:
"""
Compare messages using only Question and Answer.
"""
self.assertEqual(len(msg1), len(msg2))
for m1, m2 in zip(msg1, msg2):
self.assertEqual(m1, m2)
def mock_create_ai(self, args: argparse.Namespace, config: Config) -> AI:
"""
Mocked 'create_ai' that returns a 'FakeAI' instance.
"""
return FakeAI(args.AI, args.model)
def mock_create_ai_with_error(self, args: argparse.Namespace, config: Config) -> AI:
"""
Mocked 'create_ai' that returns a 'FakeAI' instance.
"""
return FakeAI(args.AI, args.model, error=True)
+1 -6
View File
@@ -71,13 +71,11 @@ class TestConfig(unittest.TestCase):
'frequency_penalty': 0.7,
'presence_penalty': 0.2
}
},
'glossaries': './glossaries/'
}
}
config = Config.from_dict(source_dict)
self.assertEqual(config.cache, '.')
self.assertEqual(config.db, './test_db/')
self.assertEqual(config.glossaries, './glossaries/')
self.assertEqual(len(config.ais), 1)
self.assertEqual(config.ais['myopenai'].name, 'openai')
self.assertEqual(cast(OpenAIConfig, config.ais['myopenai']).system, 'Custom system')
@@ -107,7 +105,6 @@ class TestConfig(unittest.TestCase):
'frequency_penalty': 0.7,
'presence_penalty': 0.2
}
# omit glossaries, since it's optional
}
}
with open(self.test_file.name, 'w') as f:
@@ -116,8 +113,6 @@ class TestConfig(unittest.TestCase):
self.assertIsInstance(config, Config)
self.assertEqual(config.cache, './test_cache/')
self.assertEqual(config.db, './test_db/')
# missing 'glossaries' should result in 'None'
self.assertEqual(config.glossaries, None)
self.assertEqual(len(config.ais), 1)
self.assertIsInstance(config.ais['default'], AIConfig)
self.assertEqual(cast(OpenAIConfig, config.ais['default']).system, 'Custom system')
-209
View File
@@ -1,209 +0,0 @@
import unittest
import tempfile
from pathlib import Path
from chatmastermind.glossary import Glossary, GlossaryError
glossary_suffix: str = Glossary.file_suffix
class TestGlossary(unittest.TestCase):
def test_from_file_yaml_unquoted(self) -> None:
"""
Test glossary creatiom from YAML with unquoted entries.
"""
with tempfile.NamedTemporaryFile('w', delete=False, suffix=glossary_suffix) as yaml_file:
yaml_file.write("Name: Sample\n"
"Description: A brief description\n"
"ID: '123'\n"
"SourceLang: en\n"
"TargetLang: es\n"
"Entries:\n"
" hello: hola\n"
" goodbye: adiós\n"
# 'yes' is a YAML keyword and would normally be quoted
" yes: sí\n"
" I'm going home: me voy a casa\n")
yaml_file_path = Path(yaml_file.name)
# create and check valid glossary
glossary = Glossary.from_file(yaml_file_path)
self.assertEqual(glossary.name, "Sample")
self.assertEqual(glossary.desc, "A brief description")
self.assertEqual(glossary.ID, "123")
self.assertEqual(glossary.source_lang, "en")
self.assertEqual(glossary.target_lang, "es")
self.assertEqual(glossary.entries, {"hello": "hola",
"goodbye": "adiós",
"yes": "",
"I'm going home": "me voy a casa"})
yaml_file_path.unlink() # Remove the temporary file
def test_from_file_yaml_quoted(self) -> None:
"""
Test glossary creatiom from YAML with quoted entries.
"""
with tempfile.NamedTemporaryFile('w', delete=False, suffix=glossary_suffix) as yaml_file:
yaml_file.write("Name: Sample\n"
"Description: A brief description\n"
"ID: '123'\n"
"SourceLang: en\n"
"TargetLang: es\n"
"Entries:\n"
" 'hello': 'hola'\n"
" 'goodbye': 'adiós'\n"
" 'yes': ''\n"
" \"I'm going home\": 'me voy a casa'\n")
yaml_file_path = Path(yaml_file.name)
# create and check valid glossary
glossary = Glossary.from_file(yaml_file_path)
self.assertEqual(glossary.name, "Sample")
self.assertEqual(glossary.desc, "A brief description")
self.assertEqual(glossary.ID, "123")
self.assertEqual(glossary.source_lang, "en")
self.assertEqual(glossary.target_lang, "es")
self.assertEqual(glossary.entries, {"hello": "hola",
"goodbye": "adiós",
"yes": "",
"I'm going home": "me voy a casa"})
yaml_file_path.unlink() # Remove the temporary file
def test_to_file_writes_yaml(self) -> None:
# Create glossary instance
glossary = Glossary(name="Test",
desc="Test description",
ID="666",
source_lang="en",
target_lang="fr",
entries={"yes": "oui"})
with tempfile.NamedTemporaryFile('w', suffix=glossary_suffix) as tmp_file:
file_path = Path(tmp_file.name)
glossary.to_file(file_path)
# read and check valid YAML
with open(file_path, 'r') as file:
content = file.read()
self.assertIn("Name: Test", content)
self.assertIn("Description: Test description", content)
self.assertIn("ID: '666'", content)
self.assertIn("SourceLang: en", content)
self.assertIn("TargetLang: fr", content)
self.assertIn("Entries", content)
# 'yes' is a YAML keyword and therefore quoted
self.assertIn("'yes': oui", content)
def test_write_read_glossary(self) -> None:
# Create glossary instance
# -> use 'yes' in order to test if the YAML quoting is correctly removed when reading the file
glossary_write = Glossary(name="Test", source_lang="en", target_lang="fr", entries={"yes": "oui"})
with tempfile.NamedTemporaryFile('w', suffix=glossary_suffix) as tmp_file:
file_path = Path(tmp_file.name)
glossary_write.to_file(file_path)
# create new instance from glossary file
glossary_read = Glossary.from_file(file_path)
self.assertEqual(glossary_write.name, glossary_read.name)
self.assertEqual(glossary_write.source_lang, glossary_read.source_lang)
self.assertEqual(glossary_write.target_lang, glossary_read.target_lang)
self.assertDictEqual(glossary_write.entries, glossary_read.entries)
def test_import_export_csv(self) -> None:
glossary = Glossary(name="Test", source_lang="en", target_lang="fr", entries={})
# First export to CSV
with tempfile.NamedTemporaryFile('w', suffix=glossary_suffix) as csvfile:
csv_file_path = Path(csvfile.name)
glossary.entries = {"hello": "salut", "goodbye": "au revoir"}
glossary.export_csv(glossary.entries, csv_file_path)
# Now import CSV
glossary.import_csv(csv_file_path)
self.assertEqual(glossary.entries, {"hello": "salut", "goodbye": "au revoir"})
def test_import_export_tsv(self) -> None:
glossary = Glossary(name="Test", source_lang="en", target_lang="fr", entries={})
# First export to TSV
with tempfile.NamedTemporaryFile('w', suffix=glossary_suffix) as tsvfile:
tsv_file_path = Path(tsvfile.name)
glossary.entries = {"hello": "salut", "goodbye": "au revoir"}
glossary.export_tsv(glossary.entries, tsv_file_path)
# Now import TSV
glossary.import_tsv(tsv_file_path)
self.assertEqual(glossary.entries, {"hello": "salut", "goodbye": "au revoir"})
def test_to_file_wrong_suffix(self) -> None:
"""
Test for exception if suffix is wrong.
"""
glossary = Glossary(name="Test", source_lang="en", target_lang="fr", entries={"yes": "oui"})
with tempfile.NamedTemporaryFile('w', suffix='.wrong') as tmp_file:
file_path = Path(tmp_file.name)
with self.assertRaises(GlossaryError) as err:
glossary.to_file(file_path)
self.assertEqual(str(err.exception), "File suffix '.wrong' is not supported")
def test_to_file_auto_suffix(self) -> None:
"""
Test if suffix is auto-generated if omitted.
"""
glossary = Glossary(name="Test", source_lang="en", target_lang="fr", entries={"yes": "oui"})
with tempfile.NamedTemporaryFile('w', suffix='') as tmp_file:
file_path = Path(tmp_file.name)
glossary.to_file(file_path)
assert glossary.file_path is not None
self.assertEqual(glossary.file_path.suffix, glossary_suffix)
# remove glossary file (differs from 'tmp_file' because of the added suffix
glossary.file_path.unlink()
def test_to_str_with_id(self) -> None:
# Create a Glossary instance with an ID
glossary_with_id = Glossary(name="TestGlossary", source_lang="en", target_lang="fr",
desc="A simple test glossary", ID="1001", entries={"one": "un"})
glossary_str = glossary_with_id.to_str()
self.assertIn("TestGlossary (ID: 1001):", glossary_str)
self.assertIn("- A simple test glossary", glossary_str)
self.assertIn("- Languages: en -> fr", glossary_str)
self.assertIn("- Entries: 1", glossary_str)
def test_to_str_with_id_and_entries(self) -> None:
# Create a Glossary instance with an ID and include entries
glossary_with_entries = Glossary(name="TestGlossaryWithEntries", source_lang="en", target_lang="fr",
desc="Another test glossary", ID="2002",
entries={"hello": "salut", "goodbye": "au revoir"})
glossary_str_with_entries = glossary_with_entries.to_str(with_entries=True)
self.assertIn("TestGlossaryWithEntries (ID: 2002):", glossary_str_with_entries)
self.assertIn("- Entries:", glossary_str_with_entries)
self.assertIn(" hello : salut", glossary_str_with_entries)
self.assertIn(" goodbye : au revoir", glossary_str_with_entries)
def test_to_str_without_id(self) -> None:
# Create a Glossary instance without an ID
glossary_without_id = Glossary(name="TestGlossaryNoID", source_lang="en", target_lang="fr",
desc="A test glossary without an ID", ID=None, entries={"yes": "oui"})
glossary_str_no_id = glossary_without_id.to_str()
self.assertIn("TestGlossaryNoID (ID: None):", glossary_str_no_id)
self.assertIn("- A test glossary without an ID", glossary_str_no_id)
self.assertIn("- Languages: en -> fr", glossary_str_no_id)
self.assertIn("- Entries: 1", glossary_str_no_id)
def test_to_str_without_id_and_no_entries(self) -> None:
# Create a Glossary instance without an ID and no entries
glossary_no_id_no_entries = Glossary(name="EmptyGlossary", source_lang="en", target_lang="fr",
desc="An empty test glossary", ID=None, entries={})
glossary_str_no_id_no_entries = glossary_no_id_no_entries.to_str()
self.assertIn("EmptyGlossary (ID: None):", glossary_str_no_id_no_entries)
self.assertIn("- An empty test glossary", glossary_str_no_id_no_entries)
self.assertIn("- Languages: en -> fr", glossary_str_no_id_no_entries)
self.assertIn("- Entries: 0", glossary_str_no_id_no_entries)
def test_to_str_no_description(self) -> None:
# Create a Glossary instance with an ID
glossary_with_id = Glossary(name="TestGlossary", source_lang="en", target_lang="fr",
ID="1001", entries={"one": "un"})
glossary_str = glossary_with_id.to_str()
expected_str = """TestGlossary (ID: 1001):
- Languages: en -> fr
- Entries: 1"""
self.assertEqual(expected_str, glossary_str)
-149
View File
@@ -1,149 +0,0 @@
import unittest
import argparse
import tempfile
import io
from contextlib import redirect_stdout
from chatmastermind.configuration import Config
from chatmastermind.commands.glossary import (
Glossary,
GlossaryCmdError,
glossary_cmd,
get_glossary_file_path,
create_glossary,
print_glossary,
list_glossaries
)
class TestGlossaryCmdNoGlossaries(unittest.TestCase):
def setUp(self) -> None:
# create DB and cache
self.db_dir = tempfile.TemporaryDirectory()
self.cache_dir = tempfile.TemporaryDirectory()
self.glossaries_dir = tempfile.TemporaryDirectory()
# create configuration
self.config = Config()
self.config.cache = self.cache_dir.name
self.config.db = self.db_dir.name
self.config.glossaries = self.glossaries_dir.name
# create a mock argparse.Namespace
self.args = argparse.Namespace(
create=True,
list=False,
print=False,
name='new_glossary',
file=None,
source_lang='en',
target_lang='de',
description=False,
)
def test_glossary_create_no_glossaries_err(self) -> None:
self.config.glossaries = None
with self.assertRaises(GlossaryCmdError) as err:
create_glossary(self.args, self.config)
self.assertIn(str(err.exception).lower(), "glossaries directory missing")
def test_glossary_create_no_name_err(self) -> None:
self.args.name = None
with self.assertRaises(GlossaryCmdError) as err:
create_glossary(self.args, self.config)
self.assertIn(str(err.exception).lower(), "missing glossary name")
def test_glossary_create_no_source_lang_err(self) -> None:
self.args.source_lang = None
with self.assertRaises(GlossaryCmdError) as err:
create_glossary(self.args, self.config)
self.assertIn(str(err.exception).lower(), "missing source language")
def test_glossary_create_no_target_lang_err(self) -> None:
self.args.target_lang = None
with self.assertRaises(GlossaryCmdError) as err:
create_glossary(self.args, self.config)
self.assertIn(str(err.exception).lower(), "missing target language")
def test_glossary_print_no_name_err(self) -> None:
self.args.name = None
with self.assertRaises(GlossaryCmdError) as err:
print_glossary(self.args, self.config)
self.assertIn(str(err.exception).lower(), "missing glossary name")
def test_glossary_list_no_glossaries_err(self) -> None:
self.config.glossaries = None
with self.assertRaises(GlossaryCmdError) as err:
list_glossaries(self.args, self.config)
self.assertIn(str(err.exception).lower(), "glossaries directory missing")
def test_glossary_create(self) -> None:
self.args.create = True
self.args.list = False
self.args.print = False
glossary_cmd(self.args, self.config)
expected_path = get_glossary_file_path(self.args.name, self.config)
glo = Glossary.from_file(expected_path)
self.assertEqual(glo.name, self.args.name)
expected_path.unlink()
def test_glossary_create_twice_err(self) -> None:
self.args.create = True
self.args.list = False
self.args.print = False
glossary_cmd(self.args, self.config)
expected_path = get_glossary_file_path(self.args.name, self.config)
glo = Glossary.from_file(expected_path)
self.assertEqual(glo.name, self.args.name)
# create glossary with the same name again
with self.assertRaises(GlossaryCmdError) as err:
create_glossary(self.args, self.config)
self.assertIn(str(err.exception).lower(), "already exists")
expected_path.unlink()
class TestGlossaryCmdWithGlossaries(unittest.TestCase):
def setUp(self) -> None:
# create DB and cache
self.db_dir = tempfile.TemporaryDirectory()
self.cache_dir = tempfile.TemporaryDirectory()
self.glossaries_dir = tempfile.TemporaryDirectory()
# create configuration
self.config = Config()
self.config.cache = self.cache_dir.name
self.config.db = self.db_dir.name
self.config.glossaries = self.glossaries_dir.name
# create a mock argparse.Namespace
self.args = argparse.Namespace(
create=True,
list=False,
print=False,
name='Glossary1',
file=None,
source_lang='en',
target_lang='de',
description=False,
)
# create Glossary1
glossary_cmd(self.args, self.config)
self.Glossary1_path = get_glossary_file_path('Glossary1', self.config)
# create Glossary2
self.args.name = 'Glossary2'
glossary_cmd(self.args, self.config)
self.Glossary2_path = get_glossary_file_path('Glossary2', self.config)
def test_glossaries_exist(self) -> None:
"""
Test if the default glossaries created in setUp exist.
"""
glo = Glossary.from_file(self.Glossary1_path)
self.assertEqual(glo.name, 'Glossary1')
glo = Glossary.from_file(self.Glossary2_path)
self.assertEqual(glo.name, 'Glossary2')
def test_glossaries_list(self) -> None:
self.args.create = False
self.args.list = True
with redirect_stdout(io.StringIO()) as list_output:
glossary_cmd(self.args, self.config)
self.assertIn('Glossary1', list_output.getvalue())
self.assertIn('Glossary2', list_output.getvalue())
-62
View File
@@ -1,62 +0,0 @@
import unittest
import argparse
import tempfile
import yaml
from pathlib import Path
from chatmastermind.message import Message, Question
from chatmastermind.chat import ChatDB, ChatError, msg_location
from chatmastermind.configuration import Config
from chatmastermind.commands.hist import convert_messages
msg_suffix = Message.file_suffix_write
class TestConvertMessages(unittest.TestCase):
def setUp(self) -> None:
self.db_dir = tempfile.TemporaryDirectory()
self.cache_dir = tempfile.TemporaryDirectory()
self.db_path = Path(self.db_dir.name)
self.cache_path = Path(self.cache_dir.name)
self.args = argparse.Namespace()
self.config = Config()
self.config.cache = self.cache_dir.name
self.config.db = self.db_dir.name
# Prepare some messages
self.chat = ChatDB.from_dir(Path(self.cache_path),
Path(self.db_path))
self.messages = [Message(Question(f'Question {i}')) for i in range(0, 6)]
self.chat.db_write(self.messages[0:2])
self.chat.cache_write(self.messages[2:])
# Change some of the suffixes
assert self.messages[0].file_path
assert self.messages[1].file_path
self.messages[0].file_path.rename(self.messages[0].file_path.with_suffix('.txt'))
self.messages[1].file_path.rename(self.messages[1].file_path.with_suffix('.yaml'))
def tearDown(self) -> None:
self.db_dir.cleanup()
self.cache_dir.cleanup()
def test_convert_messages(self) -> None:
self.args.convert = 'yaml'
convert_messages(self.args, self.config)
msgs = self.chat.msg_gather(loc=msg_location.DISK, glob='*.*')
# Check if the number of messages is the same as before
self.assertEqual(len(msgs), len(self.messages))
# Check if all messages have the requested suffix
for msg in msgs:
assert msg.file_path
self.assertEqual(msg.file_path.suffix, msg_suffix)
# Check if the message IDs are correctly maintained
for m_new, m_old in zip(msgs, self.messages):
self.assertEqual(m_new.msg_id(), m_old.msg_id())
# check if all messages have the new format
for m in msgs:
with open(str(m.file_path), "r") as fd:
yaml.load(fd, Loader=yaml.FullLoader)
def test_convert_messages_wrong_format(self) -> None:
self.args.convert = 'foo'
with self.assertRaises(ChatError):
convert_messages(self.args, self.config)
+40 -103
View File
@@ -1,16 +1,11 @@
import unittest
import pathlib
import tempfile
import itertools
from typing import cast
from chatmastermind.message import source_code, Message, MessageError, Question, Answer, AILine, ModelLine,\
MessageFilter, message_in, message_valid_formats
from chatmastermind.message import source_code, Message, MessageError, Question, Answer, AILine, ModelLine, MessageFilter, message_in
from chatmastermind.tags import Tag, TagLine
msg_suffix: str = Message.file_suffix_write
class SourceCodeTestCase(unittest.TestCase):
def test_source_code_with_include_delims(self) -> None:
text = """
@@ -91,7 +86,7 @@ class QuestionTestCase(unittest.TestCase):
class AnswerTestCase(unittest.TestCase):
def test_answer_with_header(self) -> None:
with self.assertRaises(MessageError):
str(Answer(f"{Answer.txt_header}\nno"))
Answer(f"{Answer.txt_header}\nno")
def test_answer_with_legal_header(self) -> None:
answer = Answer(f"This is a line contaning '{Answer.txt_header}'\nIt is what it is.")
@@ -106,7 +101,7 @@ class AnswerTestCase(unittest.TestCase):
class MessageToFileTxtTestCase(unittest.TestCase):
def setUp(self) -> None:
self.file = tempfile.NamedTemporaryFile(delete=False, suffix=msg_suffix)
self.file = tempfile.NamedTemporaryFile(delete=False, suffix='.txt')
self.file_path = pathlib.Path(self.file.name)
self.message_complete = Message(Question('This is a question.'),
Answer('This is an answer.'),
@@ -122,7 +117,7 @@ class MessageToFileTxtTestCase(unittest.TestCase):
self.file_path.unlink()
def test_to_file_txt_complete(self) -> None:
self.message_complete.to_file(self.file_path, mformat='txt')
self.message_complete.to_file(self.file_path)
with open(self.file_path, "r") as fd:
content = fd.read()
@@ -137,7 +132,7 @@ This is an answer.
self.assertEqual(content, expected_content)
def test_to_file_txt_min(self) -> None:
self.message_min.to_file(self.file_path, mformat='txt')
self.message_min.to_file(self.file_path)
with open(self.file_path, "r") as fd:
content = fd.read()
@@ -146,17 +141,11 @@ This is a question.
"""
self.assertEqual(content, expected_content)
def test_to_file_unsupported_file_suffix(self) -> None:
def test_to_file_unsupported_file_type(self) -> None:
unsupported_file_path = pathlib.Path("example.doc")
with self.assertRaises(MessageError) as cm:
self.message_complete.to_file(unsupported_file_path)
self.assertEqual(str(cm.exception), "File suffix '.doc' is not supported")
def test_to_file_unsupported_file_format(self) -> None:
unsupported_file_format = pathlib.Path(f"example{msg_suffix}")
with self.assertRaises(MessageError) as cm:
self.message_complete.to_file(unsupported_file_format, mformat='doc') # type: ignore [arg-type]
self.assertEqual(str(cm.exception), "File format 'doc' is not supported")
self.assertEqual(str(cm.exception), "File type '.doc' is not supported")
def test_to_file_no_file_path(self) -> None:
"""
@@ -170,24 +159,10 @@ This is a question.
# reset the internal file_path
self.message_complete.file_path = self.file_path
def test_to_file_txt_auto_suffix(self) -> None:
"""
Test if suffix is auto-generated if omitted.
"""
file_path_no_suffix = self.file_path.with_suffix('')
# test with file_path member
self.message_min.file_path = file_path_no_suffix
self.message_min.to_file(mformat='txt')
self.assertEqual(self.message_min.file_path.suffix, msg_suffix)
# test with explicit file_path
self.message_min.file_path = file_path_no_suffix
self.message_min.to_file(file_path=file_path_no_suffix, mformat='txt')
self.assertEqual(self.message_min.file_path.suffix, msg_suffix)
class MessageToFileYamlTestCase(unittest.TestCase):
def setUp(self) -> None:
self.file = tempfile.NamedTemporaryFile(delete=False, suffix=msg_suffix)
self.file = tempfile.NamedTemporaryFile(delete=False, suffix='.yaml')
self.file_path = pathlib.Path(self.file.name)
self.message_complete = Message(Question('This is a question.'),
Answer('This is an answer.'),
@@ -209,7 +184,7 @@ class MessageToFileYamlTestCase(unittest.TestCase):
self.file_path.unlink()
def test_to_file_yaml_complete(self) -> None:
self.message_complete.to_file(self.file_path, mformat='yaml')
self.message_complete.to_file(self.file_path)
with open(self.file_path, "r") as fd:
content = fd.read()
@@ -224,7 +199,7 @@ class MessageToFileYamlTestCase(unittest.TestCase):
self.assertEqual(content, expected_content)
def test_to_file_yaml_multiline(self) -> None:
self.message_multiline.to_file(self.file_path, mformat='yaml')
self.message_multiline.to_file(self.file_path)
with open(self.file_path, "r") as fd:
content = fd.read()
@@ -243,31 +218,17 @@ class MessageToFileYamlTestCase(unittest.TestCase):
self.assertEqual(content, expected_content)
def test_to_file_yaml_min(self) -> None:
self.message_min.to_file(self.file_path, mformat='yaml')
self.message_min.to_file(self.file_path)
with open(self.file_path, "r") as fd:
content = fd.read()
expected_content = f"{Question.yaml_key}: This is a question.\n"
self.assertEqual(content, expected_content)
def test_to_file_yaml_auto_suffix(self) -> None:
"""
Test if suffix is auto-generated if omitted.
"""
file_path_no_suffix = self.file_path.with_suffix('')
# test with file_path member
self.message_min.file_path = file_path_no_suffix
self.message_min.to_file(mformat='yaml')
self.assertEqual(self.message_min.file_path.suffix, msg_suffix)
# test with explicit file_path
self.message_min.file_path = file_path_no_suffix
self.message_min.to_file(file_path=file_path_no_suffix, mformat='yaml')
self.assertEqual(self.message_min.file_path.suffix, msg_suffix)
class MessageFromFileTxtTestCase(unittest.TestCase):
def setUp(self) -> None:
self.file = tempfile.NamedTemporaryFile(delete=False, suffix=msg_suffix)
self.file = tempfile.NamedTemporaryFile(delete=False, suffix='.txt')
self.file_path = pathlib.Path(self.file.name)
with open(self.file_path, "w") as fd:
fd.write(f"""{TagLine.prefix} tag1 tag2
@@ -278,7 +239,7 @@ This is a question.
{Answer.txt_header}
This is an answer.
""")
self.file_min = tempfile.NamedTemporaryFile(delete=False, suffix=msg_suffix)
self.file_min = tempfile.NamedTemporaryFile(delete=False, suffix='.txt')
self.file_path_min = pathlib.Path(self.file_min.name)
with open(self.file_path_min, "w") as fd:
fd.write(f"""{Question.txt_header}
@@ -298,7 +259,7 @@ This is a question.
message = Message.from_file(self.file_path)
self.assertIsNotNone(message)
self.assertIsInstance(message, Message)
assert message
if message: # mypy bug
self.assertEqual(message.question, 'This is a question.')
self.assertEqual(message.answer, 'This is an answer.')
self.assertSetEqual(cast(set[Tag], message.tags), {Tag('tag1'), Tag('tag2')})
@@ -313,7 +274,7 @@ This is a question.
message = Message.from_file(self.file_path_min)
self.assertIsNotNone(message)
self.assertIsInstance(message, Message)
assert message
if message: # mypy bug
self.assertEqual(message.question, 'This is a question.')
self.assertEqual(message.file_path, self.file_path_min)
self.assertIsNone(message.answer)
@@ -323,7 +284,7 @@ This is a question.
MessageFilter(tags_or={Tag('tag1')}))
self.assertIsNotNone(message)
self.assertIsInstance(message, Message)
assert message
if message: # mypy bug
self.assertEqual(message.question, 'This is a question.')
self.assertEqual(message.answer, 'This is an answer.')
self.assertSetEqual(cast(set[Tag], message.tags), {Tag('tag1'), Tag('tag2')})
@@ -350,13 +311,13 @@ This is a question.
MessageFilter(tags_not={Tag('tag1')}))
self.assertIsNotNone(message)
self.assertIsInstance(message, Message)
assert message
if message: # mypy bug
self.assertEqual(message.question, 'This is a question.')
self.assertSetEqual(cast(set[Tag], message.tags), set())
self.assertEqual(message.file_path, self.file_path_min)
def test_from_file_not_exists(self) -> None:
file_not_exists = pathlib.Path(f"example{msg_suffix}")
file_not_exists = pathlib.Path("example.txt")
with self.assertRaises(MessageError) as cm:
Message.from_file(file_not_exists)
self.assertEqual(str(cm.exception), f"Message file '{file_not_exists}' does not exist")
@@ -435,7 +396,7 @@ This is a question.
class MessageFromFileYamlTestCase(unittest.TestCase):
def setUp(self) -> None:
self.file = tempfile.NamedTemporaryFile(delete=False, suffix=msg_suffix)
self.file = tempfile.NamedTemporaryFile(delete=False, suffix='.yaml')
self.file_path = pathlib.Path(self.file.name)
with open(self.file_path, "w") as fd:
fd.write(f"""
@@ -449,7 +410,7 @@ class MessageFromFileYamlTestCase(unittest.TestCase):
- tag1
- tag2
""")
self.file_min = tempfile.NamedTemporaryFile(delete=False, suffix=msg_suffix)
self.file_min = tempfile.NamedTemporaryFile(delete=False, suffix='.yaml')
self.file_path_min = pathlib.Path(self.file_min.name)
with open(self.file_path_min, "w") as fd:
fd.write(f"""
@@ -470,7 +431,7 @@ class MessageFromFileYamlTestCase(unittest.TestCase):
message = Message.from_file(self.file_path)
self.assertIsInstance(message, Message)
self.assertIsNotNone(message)
assert message
if message: # mypy bug
self.assertEqual(message.question, 'This is a question.')
self.assertEqual(message.answer, 'This is an answer.')
self.assertSetEqual(cast(set[Tag], message.tags), {Tag('tag1'), Tag('tag2')})
@@ -485,14 +446,14 @@ class MessageFromFileYamlTestCase(unittest.TestCase):
message = Message.from_file(self.file_path_min)
self.assertIsInstance(message, Message)
self.assertIsNotNone(message)
assert message
if message: # mypy bug
self.assertEqual(message.question, 'This is a question.')
self.assertSetEqual(cast(set[Tag], message.tags), set())
self.assertEqual(message.file_path, self.file_path_min)
self.assertIsNone(message.answer)
def test_from_file_not_exists(self) -> None:
file_not_exists = pathlib.Path(f"example{msg_suffix}")
file_not_exists = pathlib.Path("example.yaml")
with self.assertRaises(MessageError) as cm:
Message.from_file(file_not_exists)
self.assertEqual(str(cm.exception), f"Message file '{file_not_exists}' does not exist")
@@ -502,7 +463,7 @@ class MessageFromFileYamlTestCase(unittest.TestCase):
MessageFilter(tags_or={Tag('tag1')}))
self.assertIsNotNone(message)
self.assertIsInstance(message, Message)
assert message
if message: # mypy bug
self.assertEqual(message.question, 'This is a question.')
self.assertEqual(message.answer, 'This is an answer.')
self.assertSetEqual(cast(set[Tag], message.tags), {Tag('tag1'), Tag('tag2')})
@@ -523,7 +484,7 @@ class MessageFromFileYamlTestCase(unittest.TestCase):
MessageFilter(tags_not={Tag('tag1')}))
self.assertIsNotNone(message)
self.assertIsInstance(message, Message)
assert message
if message: # mypy bug
self.assertEqual(message.question, 'This is a question.')
self.assertSetEqual(cast(set[Tag], message.tags), set())
self.assertEqual(message.file_path, self.file_path_min)
@@ -602,7 +563,7 @@ class MessageFromFileYamlTestCase(unittest.TestCase):
class TagsFromFileTestCase(unittest.TestCase):
def setUp(self) -> None:
self.file_txt = tempfile.NamedTemporaryFile(delete=False, suffix=msg_suffix)
self.file_txt = tempfile.NamedTemporaryFile(delete=False, suffix='.txt')
self.file_path_txt = pathlib.Path(self.file_txt.name)
with open(self.file_path_txt, "w") as fd:
fd.write(f"""{TagLine.prefix} tag1 tag2 ptag3
@@ -611,7 +572,7 @@ This is a question.
{Answer.txt_header}
This is an answer.
""")
self.file_txt_no_tags = tempfile.NamedTemporaryFile(delete=False, suffix=msg_suffix)
self.file_txt_no_tags = tempfile.NamedTemporaryFile(delete=False, suffix='.txt')
self.file_path_txt_no_tags = pathlib.Path(self.file_txt_no_tags.name)
with open(self.file_path_txt_no_tags, "w") as fd:
fd.write(f"""{Question.txt_header}
@@ -619,7 +580,7 @@ This is a question.
{Answer.txt_header}
This is an answer.
""")
self.file_txt_tags_empty = tempfile.NamedTemporaryFile(delete=False, suffix=msg_suffix)
self.file_txt_tags_empty = tempfile.NamedTemporaryFile(delete=False, suffix='.txt')
self.file_path_txt_tags_empty = pathlib.Path(self.file_txt_tags_empty.name)
with open(self.file_path_txt_tags_empty, "w") as fd:
fd.write(f"""TAGS:
@@ -628,7 +589,7 @@ This is a question.
{Answer.txt_header}
This is an answer.
""")
self.file_yaml = tempfile.NamedTemporaryFile(delete=False, suffix=msg_suffix)
self.file_yaml = tempfile.NamedTemporaryFile(delete=False, suffix='.yaml')
self.file_path_yaml = pathlib.Path(self.file_yaml.name)
with open(self.file_path_yaml, "w") as fd:
fd.write(f"""
@@ -641,7 +602,7 @@ This is an answer.
- tag2
- ptag3
""")
self.file_yaml_no_tags = tempfile.NamedTemporaryFile(delete=False, suffix=msg_suffix)
self.file_yaml_no_tags = tempfile.NamedTemporaryFile(delete=False, suffix='.yaml')
self.file_path_yaml_no_tags = pathlib.Path(self.file_yaml_no_tags.name)
with open(self.file_path_yaml_no_tags, "w") as fd:
fd.write(f"""
@@ -718,25 +679,24 @@ class TagsFromDirTestCase(unittest.TestCase):
{Tag('ctag5'), Tag('ctag6')}
]
self.files = [
pathlib.Path(self.temp_dir.name, f'file1{msg_suffix}'),
pathlib.Path(self.temp_dir.name, f'file2{msg_suffix}'),
pathlib.Path(self.temp_dir.name, f'file3{msg_suffix}')
pathlib.Path(self.temp_dir.name, 'file1.txt'),
pathlib.Path(self.temp_dir.name, 'file2.yaml'),
pathlib.Path(self.temp_dir.name, 'file3.txt')
]
self.files_no_tags = [
pathlib.Path(self.temp_dir_no_tags.name, f'file4{msg_suffix}'),
pathlib.Path(self.temp_dir_no_tags.name, f'file5{msg_suffix}'),
pathlib.Path(self.temp_dir_no_tags.name, f'file6{msg_suffix}')
pathlib.Path(self.temp_dir_no_tags.name, 'file4.txt'),
pathlib.Path(self.temp_dir_no_tags.name, 'file5.yaml'),
pathlib.Path(self.temp_dir_no_tags.name, 'file6.txt')
]
mformats = itertools.cycle(message_valid_formats)
for file, tags in zip(self.files, self.tag_sets):
message = Message(Question('This is a question.'),
Answer('This is an answer.'),
tags)
message.to_file(file, next(mformats))
message.to_file(file)
for file in self.files_no_tags:
message = Message(Question('This is a question.'),
Answer('This is an answer.'))
message.to_file(file, next(mformats))
message.to_file(file)
def tearDown(self) -> None:
self.temp_dir.cleanup()
@@ -759,7 +719,7 @@ class TagsFromDirTestCase(unittest.TestCase):
class MessageIDTestCase(unittest.TestCase):
def setUp(self) -> None:
self.file = tempfile.NamedTemporaryFile(delete=False, suffix=msg_suffix)
self.file = tempfile.NamedTemporaryFile(delete=False, suffix='.txt')
self.file_path = pathlib.Path(self.file.name)
self.message = Message(Question('This is a question.'),
file_path=self.file_path)
@@ -856,8 +816,6 @@ class MessageToStrTestCase(unittest.TestCase):
def setUp(self) -> None:
self.message = Message(Question('This is a question.'),
Answer('This is an answer.'),
ai=('FakeAI'),
model=('FakeModel'),
tags={Tag('atag1'), Tag('btag2')},
file_path=pathlib.Path('/tmp/foo/bla'))
@@ -871,29 +829,8 @@ This is an answer."""
def test_to_str_with_tags_and_file(self) -> None:
expected_output = f"""{TagLine.prefix} atag1 btag2
FILE: /tmp/foo/bla
AI: FakeAI
MODEL: FakeModel
{Question.txt_header}
This is a question.
{Answer.txt_header}
This is an answer."""
self.assertEqual(self.message.to_str(with_metadata=True), expected_output)
class MessageRmFileTestCase(unittest.TestCase):
def setUp(self) -> None:
self.file = tempfile.NamedTemporaryFile(delete=False, suffix=msg_suffix)
self.file_path = pathlib.Path(self.file.name)
self.message = Message(Question('This is a question.'),
file_path=self.file_path)
self.message.to_file()
def tearDown(self) -> None:
self.file.close()
self.file_path.unlink(missing_ok=True)
def test_rm_file(self) -> None:
assert self.message.file_path
self.assertTrue(self.message.file_path.exists())
self.message.rm_file()
self.assertFalse(self.message.file_path.exists())
self.assertEqual(self.message.to_str(with_tags=True, with_file=True), expected_output)
+153 -235
View File
@@ -1,23 +1,31 @@
import os
import unittest
import argparse
import tempfile
from copy import copy
from pathlib import Path
from unittest import mock
from unittest.mock import MagicMock, call
from unittest.mock import MagicMock, call, ANY
from typing import Optional
from chatmastermind.configuration import Config
from chatmastermind.commands.question import create_message, question_cmd
from chatmastermind.tags import Tag
from chatmastermind.message import Message, Question, Answer
from chatmastermind.chat import Chat, ChatDB, msg_location
from chatmastermind.ai import AIError
from .test_common import TestWithFakeAI
from chatmastermind.chat import Chat, ChatDB
from chatmastermind.ai import AI, AIResponse, Tokens, AIError
msg_suffix = Message.file_suffix_write
class TestQuestionCmdBase(unittest.TestCase):
def assert_messages_equal(self, msg1: list[Message], msg2: list[Message]) -> None:
"""
Compare messages using more than just Question and Answer.
"""
self.assertEqual(len(msg1), len(msg2))
for m1, m2 in zip(msg1, msg2):
# exclude the file_path, compare only Q, A and metadata
self.assertTrue(m1.equals(m2, file_path=False, verbose=True))
class TestMessageCreate(TestWithFakeAI):
class TestMessageCreate(TestQuestionCmdBase):
"""
Test if messages created by the 'question' command have
the correct format.
@@ -41,8 +49,6 @@ class TestMessageCreate(TestWithFakeAI):
self.args.AI = None
self.args.model = None
self.args.output_tags = None
self.args.ask = None
self.args.create = None
# File 1 : no source code block, only text
self.source_file1 = tempfile.NamedTemporaryFile(delete=False)
self.source_file1_content = """This is just text.
@@ -88,7 +94,7 @@ Aaaand again some text."""
def message_list(self, tmp_dir: tempfile.TemporaryDirectory) -> list[Path]:
# exclude '.next'
return list(Path(tmp_dir.name).glob(f'*{msg_suffix}'))
return list(Path(tmp_dir.name).glob('*.[ty]*'))
def test_message_file_created(self) -> None:
self.args.ask = ["What is this?"]
@@ -206,22 +212,7 @@ It is embedded code
"""))
class TestCreateOption(TestMessageCreate):
def test_message_file_created(self) -> None:
self.args.create = ["How does question --create work?"]
self.args.ask = None
cache_dir_files = self.message_list(self.cache_dir)
self.assertEqual(len(cache_dir_files), 0)
create_message(self.chat, self.args)
cache_dir_files = self.message_list(self.cache_dir)
self.assertEqual(len(cache_dir_files), 1)
message = Message.from_file(cache_dir_files[0])
self.assertIsInstance(message, Message)
self.assertEqual(message.question, Question("How does question --create work?")) # type: ignore [union-attr]
class TestQuestionCmd(TestWithFakeAI):
class TestQuestionCmd(TestQuestionCmdBase):
def setUp(self) -> None:
# create DB and cache
@@ -234,12 +225,10 @@ class TestQuestionCmd(TestWithFakeAI):
# create a mock argparse.Namespace
self.args = argparse.Namespace(
ask=['What is the meaning of life?'],
glob=None,
location='db',
num_answers=1,
output_tags=['science'],
AI='FakeAI',
model='FakeModel',
AI='openai',
model='gpt-3.5-turbo',
or_tags=None,
and_tags=None,
exclude_tags=None,
@@ -250,27 +239,57 @@ class TestQuestionCmd(TestWithFakeAI):
process=None,
overwrite=None
)
# create a mock AI instance
self.ai = MagicMock(spec=AI)
self.ai.request.side_effect = self.mock_request
def input_message(self, args: argparse.Namespace) -> Message:
"""
Create the expected input message for a question using the
given arguments.
"""
# NOTE: we only use the first question from the "ask" list
# -> message creation using "question.create_message()" is
# tested above
# the answer is always empty for the input message
return Message(Question(args.ask[0]),
tags=args.output_tags,
ai=args.AI,
model=args.model)
def mock_request(self,
question: Message,
chat: Chat,
num_answers: int = 1,
otags: Optional[set[Tag]] = None) -> AIResponse:
"""
Mock the 'ai.request()' function
"""
question.answer = Answer("Answer 0")
question.tags = set(otags) if otags else None
question.ai = 'FakeAI'
question.model = 'FakeModel'
answers: list[Message] = [question]
for n in range(1, num_answers):
answers.append(Message(question=question.question,
answer=Answer(f"Answer {n}"),
tags=otags,
ai='FakeAI',
model='FakeModel'))
return AIResponse(answers, Tokens(10, 10, 20))
def message_list(self, tmp_dir: tempfile.TemporaryDirectory) -> list[Path]:
# exclude '.next'
return sorted([f for f in Path(tmp_dir.name).glob(f'*{msg_suffix}')])
class TestQuestionCmdAsk(TestQuestionCmd):
return sorted([f for f in Path(tmp_dir.name).glob('*.[ty]*')])
@mock.patch('chatmastermind.commands.question.create_ai')
def test_ask_single_answer(self, mock_create_ai: MagicMock) -> None:
"""
Test single answer with no errors.
"""
mock_create_ai.side_effect = self.mock_create_ai
expected_question = Message(Question(self.args.ask[0]),
tags=set(self.args.output_tags),
ai=self.args.AI,
model=self.args.model,
file_path=Path('<NOT COMPARED>'))
fake_ai = self.mock_create_ai(self.args, self.config)
expected_responses = fake_ai.request(expected_question,
mock_create_ai.return_value = self.ai
expected_question = self.input_message(self.args)
expected_responses = self.mock_request(expected_question,
Chat([]),
self.args.num_answers,
self.args.output_tags).messages
@@ -278,12 +297,17 @@ class TestQuestionCmdAsk(TestQuestionCmd):
# execute the command
question_cmd(self.args, self.config)
# check for correct request call
self.ai.request.assert_called_once_with(expected_question,
ANY,
self.args.num_answers,
self.args.output_tags)
# check for the expected message files
chat = ChatDB.from_dir(Path(self.cache_dir.name),
Path(self.db_dir.name))
cached_msg = chat.msg_gather(loc=msg_location.CACHE)
cached_msg = chat.msg_gather(loc='cache')
self.assertEqual(len(self.message_list(self.cache_dir)), 1)
self.assert_msgs_equal_except_file_path(cached_msg, expected_responses)
self.assert_messages_equal(cached_msg, expected_responses)
@mock.patch('chatmastermind.commands.question.ChatDB.from_dir')
@mock.patch('chatmastermind.commands.question.create_ai')
@@ -294,14 +318,9 @@ class TestQuestionCmdAsk(TestQuestionCmd):
chat = MagicMock(spec=ChatDB)
mock_from_dir.return_value = chat
mock_create_ai.side_effect = self.mock_create_ai
expected_question = Message(Question(self.args.ask[0]),
tags=set(self.args.output_tags),
ai=self.args.AI,
model=self.args.model,
file_path=Path('<NOT COMPARED>'))
fake_ai = self.mock_create_ai(self.args, self.config)
expected_responses = fake_ai.request(expected_question,
mock_create_ai.return_value = self.ai
expected_question = self.input_message(self.args)
expected_responses = self.mock_request(expected_question,
Chat([]),
self.args.num_answers,
self.args.output_tags).messages
@@ -309,6 +328,12 @@ class TestQuestionCmdAsk(TestQuestionCmd):
# execute the command
question_cmd(self.args, self.config)
# check for correct request call
self.ai.request.assert_called_once_with(expected_question,
chat,
self.args.num_answers,
self.args.output_tags)
# check for the correct ChatDB calls:
# - initial question has been written (prior to the actual request)
# - responses have been written (after the request)
@@ -325,98 +350,86 @@ class TestQuestionCmdAsk(TestQuestionCmd):
Provoke an error during the AI request and verify that the question
has been correctly stored in the cache.
"""
mock_create_ai.side_effect = self.mock_create_ai_with_error
expected_question = Message(Question(self.args.ask[0]),
tags=set(self.args.output_tags),
ai=self.args.AI,
model=self.args.model,
file_path=Path('<NOT COMPARED>'))
mock_create_ai.return_value = self.ai
expected_question = self.input_message(self.args)
self.ai.request.side_effect = AIError
# execute the command
with self.assertRaises(AIError):
question_cmd(self.args, self.config)
# check for correct request call
self.ai.request.assert_called_once_with(expected_question,
ANY,
self.args.num_answers,
self.args.output_tags)
# check for the expected message files
chat = ChatDB.from_dir(Path(self.cache_dir.name),
Path(self.db_dir.name))
cached_msg = chat.msg_gather(loc=msg_location.CACHE)
cached_msg = chat.msg_gather(loc='cache')
self.assertEqual(len(self.message_list(self.cache_dir)), 1)
self.assert_msgs_equal_except_file_path(cached_msg, [expected_question])
class TestQuestionCmdRepeat(TestQuestionCmd):
self.assert_messages_equal(cached_msg, [expected_question])
@mock.patch('chatmastermind.commands.question.create_ai')
def test_repeat_single_question(self, mock_create_ai: MagicMock) -> None:
"""
Repeat a single question.
"""
mock_create_ai.side_effect = self.mock_create_ai
# 1. ask a question
mock_create_ai.return_value = self.ai
expected_question = self.input_message(self.args)
expected_responses = self.mock_request(expected_question,
Chat([]),
self.args.num_answers,
self.args.output_tags).messages
question_cmd(self.args, self.config)
chat = ChatDB.from_dir(Path(self.cache_dir.name),
Path(self.db_dir.name))
# create a message
message = Message(Question(self.args.ask[0]),
Answer('Old Answer'),
tags=set(self.args.output_tags),
ai=self.args.AI,
model=self.args.model,
file_path=Path(self.cache_dir.name) / f'0001{msg_suffix}')
chat.msg_write([message])
cached_msg = chat.msg_gather(loc='cache')
self.assertEqual(len(self.message_list(self.cache_dir)), 1)
self.assert_messages_equal(cached_msg, expected_responses)
# repeat the last question (without overwriting)
# 2. repeat the last question (without overwriting)
# -> expect two identical messages (except for the file_path)
self.args.ask = None
self.args.repeat = []
self.args.overwrite = False
expected_response = Message(Question(message.question),
Answer('Answer 0'),
ai=message.ai,
model=message.model,
tags=message.tags,
file_path=Path('<NOT COMPARED>'))
# we expect the original message + the one with the new response
expected_responses = [message] + [expected_response]
expected_responses += expected_responses
question_cmd(self.args, self.config)
cached_msg = chat.msg_gather(loc=msg_location.CACHE)
print(self.message_list(self.cache_dir))
cached_msg = chat.msg_gather(loc='cache')
self.assertEqual(len(self.message_list(self.cache_dir)), 2)
self.assert_msgs_equal_except_file_path(cached_msg, expected_responses)
self.assert_messages_equal(cached_msg, expected_responses)
@mock.patch('chatmastermind.commands.question.create_ai')
def test_repeat_single_question_overwrite(self, mock_create_ai: MagicMock) -> None:
"""
Repeat a single question and overwrite the old one.
"""
mock_create_ai.side_effect = self.mock_create_ai
# 1. ask a question
mock_create_ai.return_value = self.ai
expected_question = self.input_message(self.args)
expected_responses = self.mock_request(expected_question,
Chat([]),
self.args.num_answers,
self.args.output_tags).messages
question_cmd(self.args, self.config)
chat = ChatDB.from_dir(Path(self.cache_dir.name),
Path(self.db_dir.name))
# create a message
message = Message(Question(self.args.ask[0]),
Answer('Old Answer'),
tags=set(self.args.output_tags),
ai=self.args.AI,
model=self.args.model,
file_path=Path(self.cache_dir.name) / f'0001{msg_suffix}')
chat.msg_write([message])
cached_msg = chat.msg_gather(loc=msg_location.CACHE)
cached_msg = chat.msg_gather(loc='cache')
assert cached_msg[0].file_path
cached_msg_file_id = cached_msg[0].file_path.stem
self.assertEqual(len(self.message_list(self.cache_dir)), 1)
self.assert_messages_equal(cached_msg, expected_responses)
# repeat the last question (WITH overwriting)
# -> expect a single message afterwards (with a new answer)
# 2. repeat the last question (WITH overwriting)
# -> expect a single message afterwards
self.args.ask = None
self.args.repeat = []
self.args.overwrite = True
expected_response = Message(Question(message.question),
Answer('Answer 0'),
ai=message.ai,
model=message.model,
tags=message.tags,
file_path=Path('<NOT COMPARED>'))
question_cmd(self.args, self.config)
cached_msg = chat.msg_gather(loc=msg_location.CACHE)
cached_msg = chat.msg_gather(loc='cache')
self.assertEqual(len(self.message_list(self.cache_dir)), 1)
self.assert_msgs_equal_except_file_path(cached_msg, [expected_response])
self.assert_messages_equal(cached_msg, expected_responses)
# also check that the file ID has not been changed
assert cached_msg[0].file_path
self.assertEqual(cached_msg_file_id, cached_msg[0].file_path.stem)
@@ -426,37 +439,35 @@ class TestQuestionCmdRepeat(TestQuestionCmd):
"""
Repeat a single question after an error.
"""
mock_create_ai.side_effect = self.mock_create_ai
# 1. ask a question and provoke an error
mock_create_ai.return_value = self.ai
expected_question = self.input_message(self.args)
self.ai.request.side_effect = AIError
with self.assertRaises(AIError):
question_cmd(self.args, self.config)
chat = ChatDB.from_dir(Path(self.cache_dir.name),
Path(self.db_dir.name))
# create a question WITHOUT an answer
# -> just like after an error, which is tested above
message = Message(Question(self.args.ask[0]),
tags=set(self.args.output_tags),
ai=self.args.AI,
model=self.args.model,
file_path=Path(self.cache_dir.name) / f'0001{msg_suffix}')
chat.msg_write([message])
cached_msg = chat.msg_gather(loc=msg_location.CACHE)
cached_msg = chat.msg_gather(loc='cache')
assert cached_msg[0].file_path
cached_msg_file_id = cached_msg[0].file_path.stem
self.assertEqual(len(self.message_list(self.cache_dir)), 1)
self.assert_messages_equal(cached_msg, [expected_question])
# repeat the last question (without overwriting)
# 2. repeat the last question (without overwriting)
# -> expect a single message because if the original has
# no answer, it should be overwritten by default
self.args.ask = None
self.args.repeat = []
self.args.overwrite = False
expected_response = Message(Question(message.question),
Answer('Answer 0'),
ai=message.ai,
model=message.model,
tags=message.tags,
file_path=Path('<NOT COMPARED>'))
self.ai.request.side_effect = self.mock_request
expected_responses = self.mock_request(expected_question,
Chat([]),
self.args.num_answers,
self.args.output_tags).messages
question_cmd(self.args, self.config)
cached_msg = chat.msg_gather(loc=msg_location.CACHE)
cached_msg = chat.msg_gather(loc='cache')
self.assertEqual(len(self.message_list(self.cache_dir)), 1)
self.assert_msgs_equal_except_file_path(cached_msg, [expected_response])
self.assert_messages_equal(cached_msg, expected_responses)
# also check that the file ID has not been changed
assert cached_msg[0].file_path
self.assertEqual(cached_msg_file_id, cached_msg[0].file_path.stem)
@@ -466,130 +477,37 @@ class TestQuestionCmdRepeat(TestQuestionCmd):
"""
Repeat a single question with new arguments.
"""
mock_create_ai.side_effect = self.mock_create_ai
# 1. ask a question
mock_create_ai.return_value = self.ai
expected_question = self.input_message(self.args)
expected_responses = self.mock_request(expected_question,
Chat([]),
self.args.num_answers,
self.args.output_tags).messages
question_cmd(self.args, self.config)
chat = ChatDB.from_dir(Path(self.cache_dir.name),
Path(self.db_dir.name))
# create a message
message = Message(Question(self.args.ask[0]),
Answer('Old Answer'),
tags=set(self.args.output_tags),
ai=self.args.AI,
model=self.args.model,
file_path=Path(self.cache_dir.name) / f'0001{msg_suffix}')
chat.msg_write([message])
cached_msg = chat.msg_gather(loc=msg_location.CACHE)
assert cached_msg[0].file_path
cached_msg = chat.msg_gather(loc='cache')
self.assertEqual(len(self.message_list(self.cache_dir)), 1)
self.assert_messages_equal(cached_msg, expected_responses)
# repeat the last question with new arguments (without overwriting)
# -> expect two messages with identical question but different metadata and new answer
# 2. repeat the last question with new arguments (without overwriting)
# -> expect two messages with identical question and answer, but different metadata
self.args.ask = None
self.args.repeat = []
self.args.overwrite = False
self.args.output_tags = ['newtag']
self.args.AI = 'newai'
self.args.model = 'newmodel'
new_expected_response = Message(Question(message.question),
Answer('Answer 0'),
ai='newai',
model='newmodel',
tags={Tag('newtag')},
file_path=Path('<NOT COMPARED>'))
question_cmd(self.args, self.config)
cached_msg = chat.msg_gather(loc=msg_location.CACHE)
self.assertEqual(len(self.message_list(self.cache_dir)), 2)
self.assert_msgs_equal_except_file_path(cached_msg, [message] + [new_expected_response])
@mock.patch('chatmastermind.commands.question.create_ai')
def test_repeat_single_question_new_args_overwrite(self, mock_create_ai: MagicMock) -> None:
"""
Repeat a single question with new arguments, overwriting the old one.
"""
mock_create_ai.side_effect = self.mock_create_ai
chat = ChatDB.from_dir(Path(self.cache_dir.name),
Path(self.db_dir.name))
# create a message
message = Message(Question(self.args.ask[0]),
Answer('Old Answer'),
new_expected_question = Message(question=Question(expected_question.question),
tags=set(self.args.output_tags),
ai=self.args.AI,
model=self.args.model,
file_path=Path(self.cache_dir.name) / f'0001{msg_suffix}')
chat.msg_write([message])
cached_msg = chat.msg_gather(loc=msg_location.CACHE)
assert cached_msg[0].file_path
# repeat the last question with new arguments
self.args.ask = None
self.args.repeat = []
self.args.overwrite = True
self.args.output_tags = ['newtag']
self.args.AI = 'newai'
self.args.model = 'newmodel'
new_expected_response = Message(Question(message.question),
Answer('Answer 0'),
ai='newai',
model='newmodel',
tags={Tag('newtag')},
file_path=Path('<NOT COMPARED>'))
question_cmd(self.args, self.config)
cached_msg = chat.msg_gather(loc=msg_location.CACHE)
self.assertEqual(len(self.message_list(self.cache_dir)), 1)
self.assert_msgs_equal_except_file_path(cached_msg, [new_expected_response])
@mock.patch('chatmastermind.commands.question.create_ai')
def test_repeat_multiple_questions(self, mock_create_ai: MagicMock) -> None:
"""
Repeat multiple questions.
"""
mock_create_ai.side_effect = self.mock_create_ai
chat = ChatDB.from_dir(Path(self.cache_dir.name),
Path(self.db_dir.name))
# 1. === create three questions ===
# cached message without an answer
message1 = Message(Question(self.args.ask[0]),
tags=self.args.output_tags,
ai=self.args.AI,
model=self.args.model,
file_path=Path(self.cache_dir.name) / f'0001{msg_suffix}')
# cached message with an answer
message2 = Message(Question(self.args.ask[0]),
Answer('Old Answer'),
tags=self.args.output_tags,
ai=self.args.AI,
model=self.args.model,
file_path=Path(self.cache_dir.name) / f'0002{msg_suffix}')
# DB message without an answer
message3 = Message(Question(self.args.ask[0]),
tags=self.args.output_tags,
ai=self.args.AI,
model=self.args.model,
file_path=Path(self.db_dir.name) / f'0003{msg_suffix}')
chat.msg_write([message1, message2, message3])
questions = [message1, message2, message3]
expected_responses: list[Message] = []
fake_ai = self.mock_create_ai(self.args, self.config)
for question in questions:
# since the message's answer is modified, we use a copy
# -> the original is used for comparison below
expected_responses += fake_ai.request(copy(question),
model=self.args.model)
expected_responses += self.mock_request(new_expected_question,
Chat([]),
self.args.num_answers,
set(self.args.output_tags)).messages
# 2. === repeat all three questions (without overwriting) ===
self.args.ask = None
self.args.repeat = ['0001', '0002', '0003']
self.args.overwrite = False
question_cmd(self.args, self.config)
# two new files should be in the cache directory
# * the repeated cached message with answer
# * the repeated DB message
# -> the cached message without answer should be overwritten
self.assertEqual(len(self.message_list(self.cache_dir)), 4)
self.assertEqual(len(self.message_list(self.db_dir)), 1)
expected_cache_messages = [expected_responses[0], message2, expected_responses[1], expected_responses[2]]
cached_msg = chat.msg_gather(loc=msg_location.CACHE)
self.assert_msgs_equal_except_file_path(cached_msg, expected_cache_messages)
# check that the DB message has not been modified at all
db_msg = chat.msg_gather(loc=msg_location.DB)
self.assert_msgs_all_equal(db_msg, [message3])
cached_msg = chat.msg_gather(loc='cache')
self.assertEqual(len(self.message_list(self.cache_dir)), 2)
self.assert_messages_equal(cached_msg, expected_responses)