40 Commits

Author SHA1 Message Date
juk0de f17e76203a glossary cmd / main: added --print option 2024-02-26 16:34:30 +01:00
juk0de 6a77ec1d3b glossary test: added testcase for to_str() without description 2024-02-26 16:27:00 +01:00
juk0de f298a68140 glossary cmd test: added test for listing glossaries 2024-02-26 16:27:00 +01:00
juk0de 9bbf67af67 glossary: fixed printing of empty description 2024-02-26 16:27:00 +01:00
juk0de 284dd13201 glossary cmd: fixed globbing in 'list_glossaries' 2024-02-26 16:27:00 +01:00
juk0de 1932f8f6e9 configuration: improved error message when config file is missing 2024-02-26 16:27:00 +01:00
juk0de 15e8f8fd6b main: resolved conflicting short parameters 2024-02-26 16:27:00 +01:00
juk0de 9c683be994 glossary test: fixed cleanup of temporary files 2024-02-26 16:27:00 +01:00
juk0de 92fb2bbe15 glossary: added '__post_init__' 2024-02-26 16:27:00 +01:00
juk0de 2e0da31150 added test module for the 'glossary' command 2024-02-26 16:27:00 +01:00
juk0de ff6d4ded33 added 'glossary' command 2024-02-26 16:27:00 +01:00
juk0de 5377dc0784 main: missing directories are now created if user agrees 2024-02-26 16:27:00 +01:00
juk0de 3def4cb668 configuration: added 'glossaries' directory 2024-02-26 16:27:00 +01:00
juk0de 580c506483 glossary: now supports quoted and unquoted entries (incl. tests) 2024-02-26 16:27:00 +01:00
juk0de a1a090bcae glossary test: added testcases for 'to_str()' 2024-02-26 16:27:00 +01:00
juk0de 3cca32a40b glossary: added 'to_str()' function 2024-02-26 16:27:00 +01:00
juk0de 1b39fb1ac5 glossary test: added description test 2024-02-26 16:27:00 +01:00
juk0de b4ef2e43ca glossary: added description and removed useless input stripping 2024-02-26 16:27:00 +01:00
juk0de ff1e405991 glossary test: added suffix testcases 2024-02-26 16:27:00 +01:00
juk0de 4afd6d4e94 glossary: added suffix check 2024-02-26 16:27:00 +01:00
juk0de 94b812c31e glossary: added test module for glossaries 2024-02-26 16:27:00 +01:00
juk0de be873867ea added module 'glossary.py' 2024-02-26 16:27:00 +01:00
juk0de 82ad697b68 translation: added check for valid document format when using OpenAI 2024-02-26 16:27:00 +01:00
juk0de a185c0db7b translation: speficied / implemented the question format for OpenAI based translations 2024-02-26 16:27:00 +01:00
juk0de c1dc152f48 translation: some small required refactoring 2024-02-26 16:27:00 +01:00
juk0de f0129f7060 added new command 'translation' 2024-02-26 16:27:00 +01:00
Oleksandr Kozachuk 5d1bb1f9e4 Fix some of the commands. 2023-11-10 10:42:46 +01:00
Oleksandr Kozachuk 75a123eb72 Fix usage of the dynamic answer is some cases. 2023-10-24 12:59:13 +02:00
juk0de 7c1c67f8ff Merge pull request 'Dynamic Answer class and OpenAI streaming API' (#19) from dynamic_answer into main
Introduces several changes with the main objective of enabling OpenAI's streaming API in the chatmastermind application. This allows for the retrieval of AI responses gradually as a stream, which can significantly improve the user experience in interactions that involve large result sets.

* Added tiktoken import in 'openai.py' and modifications to the OpenAI class to support streaming. This includes the addition of a new class OpenAIAnswer to handle streaming API responses.
* Modified request function in the OpenAI class: the stream=True flag is added to the openai.ChatCompletion.create method to enable streaming API.
* Modified 'question.py' to print the answer parts as they are streamed.
* Replaced the Answer class's string data type with a generator which supports str and Generator[str, None, None] data types. Modifications are made to the Answer class methods to handle both data types accordingly.
* Updated the tests in 'test_ais_openai.py' and 'test_message.py' to reflect and validate these changes.
2023-10-21 15:50:45 +02:00
Oleksandr Kozachuk dbe72ff11c Activate and use OpenAI streaming API. 2023-10-21 14:21:48 +02:00
Oleksandr Kozachuk bbc1ab5a0a Fix source_code function with the dynamic answer class. 2023-10-20 14:02:09 +02:00
Oleksandr Kozachuk 2aee018708 Refactor message.Answer class in a way, that it can be constructed dynamically step by step, in preparation of using streaming API. 2023-10-20 13:43:31 +02:00
ok 17c6fa2453 Merge pull request 'Configurable glob and location on question and hist commands' (#18) from cust_loc_glob into main
Reviewed-on: #18
2023-10-20 09:47:03 +02:00
juk0de 5774278fb7 README: added new 'question' command parameters 2023-10-20 09:16:03 +02:00
juk0de 40d0de50de cmm: limited the message locations for the new cmm parameters to those that make sense 2023-10-20 09:16:03 +02:00
juk0de 72d31c26e9 main: improved parameter descriptions 2023-10-20 09:16:03 +02:00
juk0de 980e5ac51f chat: changed default glob to '*.msg' in all ChatDB functions 2023-10-20 09:00:58 +02:00
Oleksandr Kozachuk 114282dfd8 Add --glob and --location flags to hist and question commands, to be able to specify the location and files they should use. 2023-10-19 16:03:51 +02:00
Oleksandr Kozachuk 9a493b57da Per default use only files with .msg suffix ignoring other files. 2023-10-19 16:02:40 +02:00
Oleksandr Kozachuk 9b0951cb3f Change type msg_location to an Enum instead of Literal to be able to get all values easy and improve type checks. 2023-10-19 16:00:44 +02:00
22 changed files with 1219 additions and 216 deletions
+4
View File
@@ -65,6 +65,8 @@ cmm question [-t OTAGS]... [-k ATAGS]... [-x XTAGS]... [-o OUTTAGS]... [-A AI_ID
* `-O, --overwrite`: Overwrite existing messages when repeating them * `-O, --overwrite`: Overwrite existing messages when repeating them
* `-s, --source-text FILE`: Add content of a file to the query * `-s, --source-text FILE`: Add content of a file to the query
* `-S, --source-code FILE`: Add source code file content to the chat history * `-S, --source-code FILE`: Add source code file content to the chat history
* `-l, --location {cache,db,all}`: Use given location when building the chat history (default: 'db')
* `-g, --glob GLOB`: Filter message files using the given glob pattern
#### Hist #### Hist
@@ -83,6 +85,8 @@ cmm hist [--print | --convert FORMAT] [-t OTAGS]... [-k ATAGS]... [-x XTAGS]...
* `-S, --source-code-only`: Only print embedded source code * `-S, --source-code-only`: Only print embedded source code
* `-A, --answer SUBSTRING`: Filter for answer substring * `-A, --answer SUBSTRING`: Filter for answer substring
* `-Q, --question SUBSTRING`: Filter for question substring * `-Q, --question SUBSTRING`: Filter for question substring
* `-l, --location {cache,db,all}`: Use given location when building the chat history (default: 'db')
* `-g, --glob GLOB`: Filter message files using the given glob pattern
#### Tags #### Tags
+69 -16
View File
@@ -2,7 +2,8 @@
Implements the OpenAI client classes and functions. Implements the OpenAI client classes and functions.
""" """
import openai import openai
from typing import Optional, Union import tiktoken
from typing import Optional, Union, Generator
from ..tags import Tag from ..tags import Tag
from ..message import Message, Answer from ..message import Message, Answer
from ..chat import Chat from ..chat import Chat
@@ -12,6 +13,52 @@ from ..configuration import OpenAIConfig
ChatType = list[dict[str, str]] ChatType = list[dict[str, str]]
class OpenAIAnswer:
def __init__(self,
idx: int,
streams: dict[int, 'OpenAIAnswer'],
response: openai.ChatCompletion,
tokens: Tokens,
encoding: tiktoken.core.Encoding) -> None:
self.idx = idx
self.streams = streams
self.response = response
self.position: int = 0
self.encoding = encoding
self.data: list[str] = []
self.finished: bool = False
self.tokens = tokens
def stream(self) -> Generator[str, None, None]:
while True:
if not self.next():
continue
if len(self.data) <= self.position:
break
yield self.data[self.position]
self.position += 1
def next(self) -> bool:
if self.finished:
return True
try:
chunk = next(self.response)
except StopIteration:
self.finished = True
if not self.finished:
found_choice = False
for choice in chunk['choices']:
if not choice['finish_reason']:
self.streams[choice['index']].data.append(choice['delta']['content'])
self.tokens.completion += len(self.encoding.encode(choice['delta']['content']))
self.tokens.total = self.tokens.prompt + self.tokens.completion
if choice['index'] == self.idx:
found_choice = True
if not found_choice:
return False
return True
class OpenAI(AI): class OpenAI(AI):
""" """
The OpenAI AI client. The OpenAI AI client.
@@ -21,7 +68,7 @@ class OpenAI(AI):
self.ID = config.ID self.ID = config.ID
self.name = config.name self.name = config.name
self.config = config self.config = config
openai.api_key = config.api_key openai.api_key = self.config.api_key
def request(self, def request(self,
question: Message, question: Message,
@@ -33,7 +80,9 @@ class OpenAI(AI):
chat history. The nr. of requested answers corresponds to the chat history. The nr. of requested answers corresponds to the
nr. of messages in the 'AIResponse'. nr. of messages in the 'AIResponse'.
""" """
oai_chat = self.openai_chat(chat, self.config.system, question) self.encoding = tiktoken.encoding_for_model(self.config.model)
oai_chat, prompt_tokens = self.openai_chat(chat, self.config.system, question)
tokens: Tokens = Tokens(prompt_tokens, 0, prompt_tokens)
response = openai.ChatCompletion.create( response = openai.ChatCompletion.create(
model=self.config.model, model=self.config.model,
messages=oai_chat, messages=oai_chat,
@@ -41,22 +90,24 @@ class OpenAI(AI):
max_tokens=self.config.max_tokens, max_tokens=self.config.max_tokens,
top_p=self.config.top_p, top_p=self.config.top_p,
n=num_answers, n=num_answers,
stream=True,
frequency_penalty=self.config.frequency_penalty, frequency_penalty=self.config.frequency_penalty,
presence_penalty=self.config.presence_penalty) presence_penalty=self.config.presence_penalty)
question.answer = Answer(response['choices'][0]['message']['content']) streams: dict[int, OpenAIAnswer] = {}
for n in range(num_answers):
streams[n] = OpenAIAnswer(n, streams, response, tokens, self.encoding)
question.answer = Answer(streams[0].stream())
question.tags = set(otags) if otags is not None else None question.tags = set(otags) if otags is not None else None
question.ai = self.ID question.ai = self.ID
question.model = self.config.model question.model = self.config.model
answers: list[Message] = [question] answers: list[Message] = [question]
for choice in response['choices'][1:]: # type: ignore for idx in range(1, num_answers):
answers.append(Message(question=question.question, answers.append(Message(question=question.question,
answer=Answer(choice['message']['content']), answer=Answer(streams[idx].stream()),
tags=otags, tags=otags,
ai=self.ID, ai=self.ID,
model=self.config.model)) model=self.config.model))
return AIResponse(answers, Tokens(response['usage']['prompt_tokens'], return AIResponse(answers, tokens)
response['usage']['completion_tokens'],
response['usage']['total_tokens']))
def models(self) -> list[str]: def models(self) -> list[str]:
""" """
@@ -83,24 +134,26 @@ class OpenAI(AI):
print('\nNot ready: ' + ', '.join(not_ready)) print('\nNot ready: ' + ', '.join(not_ready))
def openai_chat(self, chat: Chat, system: str, def openai_chat(self, chat: Chat, system: str,
question: Optional[Message] = None) -> ChatType: question: Optional[Message] = None) -> tuple[ChatType, int]:
""" """
Create a chat history with system message in OpenAI format. Create a chat history with system message in OpenAI format.
Optionally append a new question. Optionally append a new question.
""" """
oai_chat: ChatType = [] oai_chat: ChatType = []
prompt_tokens: int = 0
def append(role: str, content: str) -> None: def append(role: str, content: str) -> int:
oai_chat.append({'role': role, 'content': content.replace("''", "'")}) oai_chat.append({'role': role, 'content': content.replace("''", "'")})
return len(self.encoding.encode(', '.join(['role:', oai_chat[-1]['role'], 'content:', oai_chat[-1]['content']])))
append('system', system) prompt_tokens += append('system', system)
for message in chat.messages: for message in chat.messages:
if message.answer: if message.answer:
append('user', message.question) prompt_tokens += append('user', message.question)
append('assistant', message.answer) prompt_tokens += append('assistant', str(message.answer))
if question: if question:
append('user', question.question) prompt_tokens += append('user', question.question)
return oai_chat return oai_chat, prompt_tokens
def tokens(self, data: Union[Message, Chat]) -> int: def tokens(self, data: Union[Message, Chat]) -> int:
raise NotImplementedError raise NotImplementedError
+41 -25
View File
@@ -6,7 +6,8 @@ from pathlib import Path
from pprint import PrettyPrinter from pprint import PrettyPrinter
from pydoc import pager from pydoc import pager
from dataclasses import dataclass from dataclasses import dataclass
from typing import TypeVar, Type, Optional, Any, Callable, Literal, Union from enum import Enum
from typing import TypeVar, Type, Optional, Any, Callable, Union
from .configuration import default_config_file from .configuration import default_config_file
from .message import Message, MessageFilter, MessageError, MessageFormat, message_in, message_valid_formats from .message import Message, MessageFilter, MessageError, MessageFormat, message_in, message_valid_formats
from .tags import Tag from .tags import Tag
@@ -16,10 +17,17 @@ ChatDBInst = TypeVar('ChatDBInst', bound='ChatDB')
db_next_file = '.next' db_next_file = '.next'
ignored_files = [db_next_file, default_config_file] ignored_files = [db_next_file, default_config_file]
msg_location = Literal['mem', 'disk', 'cache', 'db', 'all']
msg_suffix = Message.file_suffix_write msg_suffix = Message.file_suffix_write
class msg_location(Enum):
MEM = 'mem'
DISK = 'disk'
CACHE = 'cache'
DB = 'db'
ALL = 'all'
class ChatError(Exception): class ChatError(Exception):
pass pass
@@ -44,12 +52,12 @@ def read_dir(dir_path: Path,
Parameters: Parameters:
* 'dir_path': source directory * 'dir_path': source directory
* 'glob': if specified, files will be filtered using 'path.glob()', * 'glob': if specified, files will be filtered using 'path.glob()',
otherwise it uses 'path.iterdir()'. otherwise it reads all files with the default message suffix
* 'mfilter': use with 'Message.from_file()' to filter messages * 'mfilter': use with 'Message.from_file()' to filter messages
when reading them. when reading them.
""" """
messages: list[Message] = [] messages: list[Message] = []
file_iter = dir_path.glob(glob) if glob else dir_path.iterdir() file_iter = dir_path.glob(glob) if glob else dir_path.glob(f'*{msg_suffix}')
for file_path in sorted(file_iter): for file_path in sorted(file_iter):
if (file_path.is_file() if (file_path.is_file()
and file_path.name not in ignored_files # noqa: W503 and file_path.name not in ignored_files # noqa: W503
@@ -287,7 +295,7 @@ class ChatDB(Chat):
# a MessageFilter that all messages must match (if given) # a MessageFilter that all messages must match (if given)
mfilter: Optional[MessageFilter] = None mfilter: Optional[MessageFilter] = None
# the glob pattern for all messages # the glob pattern for all messages
glob: Optional[str] = None glob: str = f'*{msg_suffix}'
# message format (for writing) # message format (for writing)
mformat: MessageFormat = Message.default_format mformat: MessageFormat = Message.default_format
@@ -303,20 +311,28 @@ class ChatDB(Chat):
def from_dir(cls: Type[ChatDBInst], def from_dir(cls: Type[ChatDBInst],
cache_path: Path, cache_path: Path,
db_path: Path, db_path: Path,
glob: Optional[str] = None, glob: str = f'*{msg_suffix}',
mfilter: Optional[MessageFilter] = None) -> ChatDBInst: mfilter: Optional[MessageFilter] = None,
loc: msg_location = msg_location.DB) -> ChatDBInst:
""" """
Create a 'ChatDB' instance from the given directory structure. Create a 'ChatDB' instance from the given directory structure.
Reads all messages from 'db_path' into the local message list. Reads all messages from 'db_path' into the local message list.
Parameters: Parameters:
* 'cache_path': path to the directory for temporary messages * 'cache_path': path to the directory for temporary messages
* 'db_path': path to the directory for persistent messages * 'db_path': path to the directory for persistent messages
* 'glob': if specified, files will be filtered using 'path.glob()', * 'glob': if specified, files will be filtered using 'path.glob()'
otherwise it uses 'path.iterdir()'.
* 'mfilter': use with 'Message.from_file()' to filter messages * 'mfilter': use with 'Message.from_file()' to filter messages
when reading them. when reading them.
* 'loc': read messages from given location instead of 'db_path'
""" """
messages = read_dir(db_path, glob, mfilter) if loc == msg_location.MEM:
raise ChatError(f"Can't build ChatDB from message location '{loc}'")
messages: list[Message] = []
if loc in [msg_location.DB, msg_location.DISK, msg_location.ALL]:
messages.extend(read_dir(db_path, glob, mfilter))
if loc in [msg_location.CACHE, msg_location.DISK, msg_location.ALL]:
messages.extend(read_dir(cache_path, glob, mfilter))
messages.sort(key=lambda x: x.msg_id())
return cls(messages, cache_path, db_path, mfilter, glob) return cls(messages, cache_path, db_path, mfilter, glob)
@classmethod @classmethod
@@ -386,7 +402,7 @@ class ChatDB(Chat):
def msg_gather(self, def msg_gather(self,
loc: msg_location, loc: msg_location,
require_file_path: bool = False, require_file_path: bool = False,
glob: Optional[str] = None, glob: str = f'*{msg_suffix}',
mfilter: Optional[MessageFilter] = None) -> list[Message]: mfilter: Optional[MessageFilter] = None) -> list[Message]:
""" """
Gather and return messages from the given locations: Gather and return messages from the given locations:
@@ -399,14 +415,14 @@ class ChatDB(Chat):
If 'require_file_path' is True, return only files with a valid file_path. If 'require_file_path' is True, return only files with a valid file_path.
""" """
loc_messages: list[Message] = [] loc_messages: list[Message] = []
if loc in ['mem', 'all']: if loc in [msg_location.MEM, msg_location.ALL]:
if require_file_path: if require_file_path:
loc_messages += [m for m in self.messages if (m.file_path is not None and (mfilter is None or m.match(mfilter)))] loc_messages += [m for m in self.messages if (m.file_path is not None and (mfilter is None or m.match(mfilter)))]
else: else:
loc_messages += [m for m in self.messages if (mfilter is None or m.match(mfilter))] loc_messages += [m for m in self.messages if (mfilter is None or m.match(mfilter))]
if loc in ['cache', 'disk', 'all']: if loc in [msg_location.CACHE, msg_location.DISK, msg_location.ALL]:
loc_messages += read_dir(self.cache_path, glob=glob, mfilter=mfilter) loc_messages += read_dir(self.cache_path, glob=glob, mfilter=mfilter)
if loc in ['db', 'disk', 'all']: if loc in [msg_location.DB, msg_location.DISK, msg_location.ALL]:
loc_messages += read_dir(self.db_path, glob=glob, mfilter=mfilter) loc_messages += read_dir(self.db_path, glob=glob, mfilter=mfilter)
# remove_duplicates and sort the list # remove_duplicates and sort the list
unique_messages: list[Message] = [] unique_messages: list[Message] = []
@@ -422,7 +438,7 @@ class ChatDB(Chat):
def msg_find(self, def msg_find(self,
msg_names: list[str], msg_names: list[str],
loc: msg_location = 'mem', loc: msg_location = msg_location.MEM,
) -> list[Message]: ) -> list[Message]:
""" """
Search and return the messages with the given names. Names can either be filenames Search and return the messages with the given names. Names can either be filenames
@@ -440,7 +456,7 @@ class ChatDB(Chat):
return [m for m in loc_messages return [m for m in loc_messages
if any((m.file_path and self.msg_name_matches(m.file_path, mn)) for mn in msg_names)] if any((m.file_path and self.msg_name_matches(m.file_path, mn)) for mn in msg_names)]
def msg_remove(self, msg_names: list[str], loc: msg_location = 'mem') -> None: def msg_remove(self, msg_names: list[str], loc: msg_location = msg_location.MEM) -> None:
""" """
Remove the messages with the given names. Names can either be filenames Remove the messages with the given names. Names can either be filenames
(with or without suffix), full paths or Message.msg_id(). Also deletes the (with or without suffix), full paths or Message.msg_id(). Also deletes the
@@ -452,7 +468,7 @@ class ChatDB(Chat):
* 'db' : messages in the DB directory * 'db' : messages in the DB directory
* 'all' : all messages ('mem' + 'disk') * 'all' : all messages ('mem' + 'disk')
""" """
if loc != 'mem': if loc != msg_location.MEM:
# delete the message files first # delete the message files first
rm_messages = self.msg_find(msg_names, loc=loc) rm_messages = self.msg_find(msg_names, loc=loc)
for m in rm_messages: for m in rm_messages:
@@ -463,7 +479,7 @@ class ChatDB(Chat):
def msg_latest(self, def msg_latest(self,
mfilter: Optional[MessageFilter] = None, mfilter: Optional[MessageFilter] = None,
loc: msg_location = 'mem') -> Optional[Message]: loc: msg_location = msg_location.MEM) -> Optional[Message]:
""" """
Return the last added message (according to the file ID) that matches the given filter. Return the last added message (according to the file ID) that matches the given filter.
Only consider messages with a valid file_path (except if loc is 'mem'). Only consider messages with a valid file_path (except if loc is 'mem').
@@ -492,7 +508,7 @@ class ChatDB(Chat):
and message.file_path.parent.samefile(self.cache_path) # noqa: W503 and message.file_path.parent.samefile(self.cache_path) # noqa: W503
and message.file_path.exists()) # noqa: W503 and message.file_path.exists()) # noqa: W503
else: else:
return len(self.msg_find([message], loc='cache')) > 0 return len(self.msg_find([message], loc=msg_location.CACHE)) > 0
def msg_in_db(self, message: Union[Message, str]) -> bool: def msg_in_db(self, message: Union[Message, str]) -> bool:
""" """
@@ -504,9 +520,9 @@ class ChatDB(Chat):
and message.file_path.parent.samefile(self.db_path) # noqa: W503 and message.file_path.parent.samefile(self.db_path) # noqa: W503
and message.file_path.exists()) # noqa: W503 and message.file_path.exists()) # noqa: W503
else: else:
return len(self.msg_find([message], loc='db')) > 0 return len(self.msg_find([message], loc=msg_location.DB)) > 0
def cache_read(self, glob: Optional[str] = None, mfilter: Optional[MessageFilter] = None) -> None: def cache_read(self, glob: str = f'*{msg_suffix}', mfilter: Optional[MessageFilter] = None) -> None:
""" """
Read messages from the cache directory. New ones are added to the internal list, Read messages from the cache directory. New ones are added to the internal list,
existing ones are replaced. A message is determined as 'existing' if a message existing ones are replaced. A message is determined as 'existing' if a message
@@ -549,7 +565,7 @@ class ChatDB(Chat):
self.messages += messages self.messages += messages
self.msg_sort() self.msg_sort()
def cache_clear(self, glob: Optional[str] = None) -> None: def cache_clear(self, glob: str = f'*{msg_suffix}') -> None:
""" """
Delete all message files from the cache dir and remove them from the internal list. Delete all message files from the cache dir and remove them from the internal list.
""" """
@@ -569,11 +585,11 @@ class ChatDB(Chat):
self.cache_write([message]) self.cache_write([message])
# remove the old one (if any) # remove the old one (if any)
if old_path: if old_path:
self.msg_remove([str(old_path)], loc='db') self.msg_remove([str(old_path)], loc=msg_location.DB)
# (re)add it to the internal list # (re)add it to the internal list
self.msg_add([message]) self.msg_add([message])
def db_read(self, glob: Optional[str] = None, mfilter: Optional[MessageFilter] = None) -> None: def db_read(self, glob: str = f'*{msg_suffix}', mfilter: Optional[MessageFilter] = None) -> None:
""" """
Read messages from the DB directory. New ones are added to the internal list, Read messages from the DB directory. New ones are added to the internal list,
existing ones are replaced. A message is determined as 'existing' if a message existing ones are replaced. A message is determined as 'existing' if a message
@@ -628,6 +644,6 @@ class ChatDB(Chat):
self.db_write([message]) self.db_write([message])
# remove the old one (if any) # remove the old one (if any)
if old_path: if old_path:
self.msg_remove([str(old_path)], loc='cache') self.msg_remove([str(old_path)], loc=msg_location.CACHE)
# (re)add it to the internal list # (re)add it to the internal list
self.msg_add([message]) self.msg_add([message])
+69
View File
@@ -0,0 +1,69 @@
"""
Contains shared functions for the various CMM subcommands.
"""
import argparse
from pathlib import Path
from ..message import Message, MessageError, source_code
def read_text_file(file: Path) -> str:
with open(file) as r:
content = r.read().strip()
return content
def add_file_as_text(question_parts: list[str], file: str) -> None:
"""
Add the given file as plain text to the question part list.
If the file is a Message, add the answer.
"""
file_path = Path(file)
content: str
try:
message = Message.from_file(file_path)
if message and message.answer:
content = message.answer
except MessageError:
content = read_text_file(Path(file))
if len(content) > 0:
question_parts.append(content)
def add_file_as_code(question_parts: list[str], file: str) -> None:
"""
Add all source code from the given file. If no code segments can be extracted,
the whole content is added as source code segment. If the file is a Message,
extract the source code from the answer.
"""
file_path = Path(file)
content: str
try:
message = Message.from_file(file_path)
if message and message.answer:
content = message.answer
except MessageError:
with open(file) as r:
content = r.read().strip()
# extract and add source code
code_parts = source_code(content, include_delims=True)
if len(code_parts) > 0:
question_parts += code_parts
else:
question_parts.append(f"```\n{content}\n```")
def invert_input_tag_args(args: argparse.Namespace) -> None:
"""
Changes the semantics of the INPUT tags for this command:
* not tags specified on the CLI -> no tags are selected
* empty tags specified on the CLI -> all tags are selected
"""
if args.or_tags is None:
args.or_tags = set()
elif len(args.or_tags) == 0:
args.or_tags = None
if args.and_tags is None:
args.and_tags = set()
elif len(args.and_tags) == 0:
args.and_tags = None
+95
View File
@@ -0,0 +1,95 @@
import sys
import argparse
from pathlib import Path
from pydoc import pager
from ..configuration import Config
from ..glossary import Glossary
class GlossaryCmdError(Exception):
pass
def print_paged(text: str) -> None:
pager(text)
def get_glossary_file_path(name: str, config: Config) -> Path:
"""
Get the complete filename for a glossary with the given path.
"""
if not config.glossaries:
raise GlossaryCmdError("Can't create glossary name without a glossary directory")
return Path(config.glossaries, name).with_suffix(Glossary.file_suffix).absolute()
def list_glossaries(args: argparse.Namespace, config: Config) -> None:
"""
List existing glossaries in the 'glossaries' directory.
"""
if not config.glossaries:
raise GlossaryCmdError("Glossaries directory missing in the configuration file")
glossaries = Path(config.glossaries).glob(f'*{Glossary.file_suffix}')
for glo in sorted(glossaries):
print(Glossary.from_file(glo).to_str())
def print_glossary(args: argparse.Namespace, config: Config) -> None:
"""
Print an existing glossary.
"""
# sanity checks
if args.name is None:
raise GlossaryCmdError("Missing glossary name")
if config.glossaries is None and args.file is None:
raise GlossaryCmdError("Glossaries directory missing in the configuration file")
# create file path or use the given one
glo_file = Path(args.file) if args.file else get_glossary_file_path(args.name, config)
if not glo_file.exists():
raise GlossaryCmdError(f"Glossary '{glo_file}' does not exist")
# read glossary
glo = Glossary.from_file(glo_file)
print_paged(glo.to_str(with_entries=True))
def create_glossary(args: argparse.Namespace, config: Config) -> None:
"""
Create a new glossary and write it either to the glossaries directory
or the given file.
"""
# sanity checks
if args.name is None:
raise GlossaryCmdError("Missing glossary name")
if args.source_lang is None:
raise GlossaryCmdError("Missing source language")
if args.target_lang is None:
raise GlossaryCmdError("Missing target language")
if config.glossaries is None and args.file is None:
raise GlossaryCmdError("Glossaries directory missing in the configuration file")
# create file path or use the given one
glo_file = Path(args.file) if args.file else get_glossary_file_path(args.name, config)
if glo_file.exists():
raise GlossaryCmdError(f"Glossary '{glo_file}' already exists")
glo = Glossary(name=args.name,
source_lang=args.source_lang,
target_lang=args.target_lang,
desc=args.description,
file_path=glo_file)
glo.to_file()
print(f"Successfully created new glossary '{glo_file}'.")
def glossary_cmd(args: argparse.Namespace, config: Config) -> None:
"""
Handler for the 'glossary' command.
"""
try:
if args.create:
create_glossary(args, config)
elif args.list:
list_glossaries(args, config)
elif args.print:
print_glossary(args, config)
except GlossaryCmdError as err:
print(f"Error: {err}")
sys.exit(1)
+9 -6
View File
@@ -2,7 +2,7 @@ import sys
import argparse import argparse
from pathlib import Path from pathlib import Path
from ..configuration import Config from ..configuration import Config
from ..chat import ChatDB from ..chat import ChatDB, msg_location
from ..message import MessageFilter, Message from ..message import MessageFilter, Message
@@ -15,9 +15,10 @@ def convert_messages(args: argparse.Namespace, config: Config) -> None:
('.txt', '.yaml') to the latest default message file suffix ('.msg'). ('.txt', '.yaml') to the latest default message file suffix ('.msg').
""" """
chat = ChatDB.from_dir(Path(config.cache), chat = ChatDB.from_dir(Path(config.cache),
Path(config.db)) Path(config.db),
glob='*')
# read all known message files # read all known message files
msgs = chat.msg_gather(loc='disk', glob='*.*') msgs = chat.msg_gather(loc=msg_location.DISK, glob='*.*')
# make a set of all message IDs # make a set of all message IDs
msg_ids = set([m.msg_id() for m in msgs]) msg_ids = set([m.msg_id() for m in msgs])
# set requested format and write all messages # set requested format and write all messages
@@ -29,14 +30,14 @@ def convert_messages(args: argparse.Namespace, config: Config) -> None:
m.file_path = m.file_path.with_suffix('') m.file_path = m.file_path.with_suffix('')
chat.msg_write(msgs) chat.msg_write(msgs)
# read all messages with the current default suffix # read all messages with the current default suffix
msgs = chat.msg_gather(loc='disk', glob=f'*{msg_suffix}') msgs = chat.msg_gather(loc=msg_location.DISK, glob=f'*{msg_suffix}')
# make sure we converted all of the original messages # make sure we converted all of the original messages
for mid in msg_ids: for mid in msg_ids:
if not any(mid == m.msg_id() for m in msgs): if not any(mid == m.msg_id() for m in msgs):
print(f"Message '{mid}' has not been found after conversion. Aborting.") print(f"Message '{mid}' has not been found after conversion. Aborting.")
sys.exit(1) sys.exit(1)
# delete messages with old suffixes # delete messages with old suffixes
msgs = chat.msg_gather(loc='disk', glob='*.*') msgs = chat.msg_gather(loc=msg_location.DISK, glob='*.*')
for m in msgs: for m in msgs:
if m.file_path and m.file_path.suffix != msg_suffix: if m.file_path and m.file_path.suffix != msg_suffix:
m.rm_file() m.rm_file()
@@ -55,7 +56,9 @@ def print_chat(args: argparse.Namespace, config: Config) -> None:
answer_contains=args.answer) answer_contains=args.answer)
chat = ChatDB.from_dir(Path(config.cache), chat = ChatDB.from_dir(Path(config.cache),
Path(config.db), Path(config.db),
mfilter=mfilter) mfilter=mfilter,
loc=msg_location(args.location),
glob=args.glob)
chat.print(args.source_code_only, chat.print(args.source_code_only,
args.with_metadata, args.with_metadata,
paged=not args.no_paging, paged=not args.no_paging,
+2 -2
View File
@@ -3,7 +3,7 @@ import argparse
from pathlib import Path from pathlib import Path
from ..configuration import Config from ..configuration import Config
from ..message import Message, MessageError from ..message import Message, MessageError
from ..chat import ChatDB from ..chat import ChatDB, msg_location
def print_message(message: Message, args: argparse.Namespace) -> None: def print_message(message: Message, args: argparse.Namespace) -> None:
@@ -38,7 +38,7 @@ def print_cmd(args: argparse.Namespace, config: Config) -> None:
# print latest message # print latest message
elif args.latest: elif args.latest:
chat = ChatDB.from_dir(Path(config.cache), Path(config.db)) chat = ChatDB.from_dir(Path(config.cache), Path(config.db))
latest = chat.msg_latest(loc='disk') latest = chat.msg_latest(loc=msg_location.DISK)
if not latest: if not latest:
print("No message found!") print("No message found!")
sys.exit(1) sys.exit(1)
+15 -66
View File
@@ -3,9 +3,10 @@ import argparse
from pathlib import Path from pathlib import Path
from itertools import zip_longest from itertools import zip_longest
from copy import deepcopy from copy import deepcopy
from .common import invert_input_tag_args, add_file_as_code, add_file_as_text
from ..configuration import Config from ..configuration import Config
from ..chat import ChatDB from ..chat import ChatDB, msg_location
from ..message import Message, MessageFilter, MessageError, Question, source_code from ..message import Message, MessageFilter, Question
from ..ai_factory import create_ai from ..ai_factory import create_ai
from ..ai import AI, AIResponse from ..ai import AI, AIResponse
@@ -14,47 +15,6 @@ class QuestionCmdError(Exception):
pass pass
def add_file_as_text(question_parts: list[str], file: str) -> None:
"""
Add the given file as plain text to the question part list.
If the file is a Message, add the answer.
"""
file_path = Path(file)
content: str
try:
message = Message.from_file(file_path)
if message and message.answer:
content = message.answer
except MessageError:
with open(file) as r:
content = r.read().strip()
if len(content) > 0:
question_parts.append(content)
def add_file_as_code(question_parts: list[str], file: str) -> None:
"""
Add all source code from the given file. If no code segments can be extracted,
the whole content is added as source code segment. If the file is a Message,
extract the source code from the answer.
"""
file_path = Path(file)
content: str
try:
message = Message.from_file(file_path)
if message and message.answer:
content = message.answer
except MessageError:
with open(file) as r:
content = r.read().strip()
# extract and add source code
code_parts = source_code(content, include_delims=True)
if len(code_parts) > 0:
question_parts += code_parts
else:
question_parts.append(f"```\n{content}\n```")
def create_msg_args(msg: Message, args: argparse.Namespace) -> argparse.Namespace: def create_msg_args(msg: Message, args: argparse.Namespace) -> argparse.Namespace:
""" """
Takes an existing message and CLI arguments, and returns modified args based Takes an existing message and CLI arguments, and returns modified args based
@@ -101,7 +61,7 @@ def create_message(chat: ChatDB, args: argparse.Namespace) -> Message:
if code_file is not None and len(code_file) > 0: if code_file is not None and len(code_file) > 0:
add_file_as_code(question_parts, code_file) add_file_as_code(question_parts, code_file)
full_question = '\n\n'.join(question_parts) full_question = '\n\n'.join([str(s) for s in question_parts])
message = Message(question=Question(full_question), message = Message(question=Question(full_question),
tags=args.output_tags, tags=args.output_tags,
@@ -129,13 +89,16 @@ def make_request(ai: AI, chat: ChatDB, message: Message, args: argparse.Namespac
args.output_tags) args.output_tags)
# only write the response messages to the cache, # only write the response messages to the cache,
# don't add them to the internal list # don't add them to the internal list
chat.cache_write(response.messages)
for idx, msg in enumerate(response.messages): for idx, msg in enumerate(response.messages):
print(f"=== ANSWER {idx+1} ===") print(f"=== ANSWER {idx+1} ===", flush=True)
print(msg.answer) if msg.answer:
for piece in msg.answer:
print(piece, end='', flush=True)
print()
if response.tokens: if response.tokens:
print("===============") print("===============")
print(response.tokens) print(response.tokens)
chat.cache_write(response.messages)
def repeat_messages(messages: list[Message], chat: ChatDB, args: argparse.Namespace, config: Config) -> None: def repeat_messages(messages: list[Message], chat: ChatDB, args: argparse.Namespace, config: Config) -> None:
@@ -160,22 +123,6 @@ def repeat_messages(messages: list[Message], chat: ChatDB, args: argparse.Namesp
make_request(ai, chat, message, msg_args) make_request(ai, chat, message, msg_args)
def invert_input_tag_args(args: argparse.Namespace) -> None:
"""
Changes the semantics of the INPUT tags for this command:
* not tags specified on the CLI -> no tags are selected
* empty tags specified on the CLI -> all tags are selected
"""
if args.or_tags is None:
args.or_tags = set()
elif len(args.or_tags) == 0:
args.or_tags = None
if args.and_tags is None:
args.and_tags = set()
elif len(args.and_tags) == 0:
args.and_tags = None
def question_cmd(args: argparse.Namespace, config: Config) -> None: def question_cmd(args: argparse.Namespace, config: Config) -> None:
""" """
Handler for the 'question' command. Handler for the 'question' command.
@@ -186,7 +133,9 @@ def question_cmd(args: argparse.Namespace, config: Config) -> None:
tags_not=args.exclude_tags) tags_not=args.exclude_tags)
chat = ChatDB.from_dir(cache_path=Path(config.cache), chat = ChatDB.from_dir(cache_path=Path(config.cache),
db_path=Path(config.db), db_path=Path(config.db),
mfilter=mfilter) mfilter=mfilter,
glob=args.glob,
loc=msg_location(args.location))
# if it's a new question, create and store it immediately # if it's a new question, create and store it immediately
if args.ask or args.create: if args.ask or args.create:
message = create_message(chat, args) message = create_message(chat, args)
@@ -202,14 +151,14 @@ def question_cmd(args: argparse.Namespace, config: Config) -> None:
repeat_msgs: list[Message] = [] repeat_msgs: list[Message] = []
# repeat latest message # repeat latest message
if len(args.repeat) == 0: if len(args.repeat) == 0:
lmessage = chat.msg_latest(loc='cache') lmessage = chat.msg_latest(loc=msg_location.CACHE)
if lmessage is None: if lmessage is None:
print("No message found to repeat!") print("No message found to repeat!")
sys.exit(1) sys.exit(1)
repeat_msgs.append(lmessage) repeat_msgs.append(lmessage)
# repeat given message(s) # repeat given message(s)
else: else:
repeat_msgs = chat.msg_find(args.repeat, loc='disk') repeat_msgs = chat.msg_find(args.repeat, loc=msg_location.DISK)
repeat_messages(repeat_msgs, chat, args, config) repeat_messages(repeat_msgs, chat, args, config)
# === PROCESS === # === PROCESS ===
elif args.process is not None: elif args.process is not None:
+105
View File
@@ -0,0 +1,105 @@
import argparse
import mimetypes
from pathlib import Path
from .common import invert_input_tag_args, read_text_file
from ..configuration import Config
from ..message import MessageFilter, Message, Question
from ..chat import ChatDB, msg_location
class TranslationCmdError(Exception):
pass
text_separator: str = 'TEXT:'
def assert_document_type_supported_openai(document_file: Path) -> None:
doctype = mimetypes.guess_type(document_file)
if doctype != 'text/plain':
raise TranslationCmdError("AI 'OpenAI' only supports document type 'text/plain''")
def translation_prompt_openai(source_lang: str, target_lang: str) -> str:
"""
Return the prompt for GPT that tells it to do the translation.
"""
return f"Translate the text below the line {text_separator} from {source_lang} to {target_lang}."
def create_message_openai(chat: ChatDB, args: argparse.Namespace) -> Message:
"""
Create a new message from the given arguments and write it to the cache directory.
Message format
1. Translation prompt (tells GPT to do a translation)
2. Glossary (if specified as an argument)
3. User provided prompt enhancements
4. Translation separator
5. User provided text to be translated
The text to be translated is determined as a follows:
- if a document is provided in the arguments, translate its content
- if no document is provided, translate the last text argument
The other text arguments will be put into the "header" and can be used
to improve the translation prompt.
"""
text_args: list[str] = []
if args.create is not None:
text_args = args.create
elif args.ask is not None:
text_args = args.ask
else:
raise TranslationCmdError("No input text found")
# extract user prompt and user text to be translated
user_text: str
user_prompt: str
if args.input_document is not None:
assert_document_type_supported_openai(Path(args.input_document))
user_text = read_text_file(Path(args.input_document))
user_prompt = '\n\n'.join([str(s) for s in text_args])
else:
user_text = text_args[-1]
user_prompt = '\n\n'.join([str(s) for s in text_args[:-1]])
# build full question string
# FIXME: add glossaries if given
question_text: str = '\n\n'.join([translation_prompt_openai(args.source_lang, args.target_lang),
user_prompt,
text_separator,
user_text])
# create and write the message
message = Message(question=Question(question_text),
tags=args.output_tags,
ai=args.AI,
model=args.model)
# only write the new message to the cache,
# don't add it to the internal list
chat.cache_write([message])
return message
def translation_cmd(args: argparse.Namespace, config: Config) -> None:
"""
Handler for the 'translation' command. Creates and executes translation
requests based on the input and selected AI. Depending on the AI, the
whole process may be significantly different (e.g. DeepL vs OpenAI).
"""
invert_input_tag_args(args)
mfilter = MessageFilter(tags_or=args.or_tags,
tags_and=args.and_tags,
tags_not=args.exclude_tags)
chat = ChatDB.from_dir(cache_path=Path(config.cache),
db_path=Path(config.db),
mfilter=mfilter,
glob=args.glob,
loc=msg_location(args.location))
# if it's a new translation, create and store it immediately
# FIXME: check AI type
if args.ask or args.create:
# message = create_message(chat, args)
create_message_openai(chat, args)
if args.create:
return
+5 -1
View File
@@ -118,6 +118,7 @@ class Config:
# a default configuration # a default configuration
cache: str = '.' cache: str = '.'
db: str = './db/' db: str = './db/'
glossaries: str | None = './glossaries/'
ais: dict[str, AIConfig] = field(default_factory=create_default_ai_configs) ais: dict[str, AIConfig] = field(default_factory=create_default_ai_configs)
@classmethod @classmethod
@@ -135,7 +136,8 @@ class Config:
return cls( return cls(
cache=str(source['cache']) if 'cache' in source else '.', cache=str(source['cache']) if 'cache' in source else '.',
db=str(source['db']), db=str(source['db']),
ais=ais ais=ais,
glossaries=str(source['glossaries']) if 'glossaries' in source else None
) )
@classmethod @classmethod
@@ -148,6 +150,8 @@ class Config:
@classmethod @classmethod
def from_file(cls: Type[ConfigInst], path: str) -> ConfigInst: def from_file(cls: Type[ConfigInst], path: str) -> ConfigInst:
if not Path(path).exists():
raise ConfigError(f"Configuration file '{path}' not found. Use 'cmm config --create' to create one.")
with open(path, 'r') as f: with open(path, 'r') as f:
source = yaml.load(f, Loader=yaml.FullLoader) source = yaml.load(f, Loader=yaml.FullLoader)
return cls.from_dict(source) return cls.from_dict(source)
+165
View File
@@ -0,0 +1,165 @@
"""
Module implementing glossaries for translations.
"""
import yaml
import tempfile
import shutil
import csv
from pathlib import Path
from dataclasses import dataclass, field
from typing import Type, TypeVar, ClassVar
GlossaryInst = TypeVar('GlossaryInst', bound='Glossary')
class GlossaryError(Exception):
pass
def str_presenter(dumper: yaml.Dumper, data: str) -> yaml.ScalarNode:
"""
Changes the YAML dump style to multiline syntax for multiline strings.
"""
if len(data.splitlines()) > 1:
return dumper.represent_scalar('tag:yaml.org,2002:str', data, style='|')
return dumper.represent_scalar('tag:yaml.org,2002:str', data)
@dataclass
class Glossary:
"""
A glossary consists of the following parameters:
- Name (freely selectable)
- Path (full file path, suffix is automatically generated)
- Source language
- Target language
- Description (optional)
- Entries (pairs of source lang and target lang terms)
- ID (automatically generated / modified, required by DeepL)
"""
name: str
source_lang: str
target_lang: str
file_path: Path | None = None
desc: str | None = None
entries: dict[str, str] = field(default_factory=lambda: dict())
ID: str | None = None
file_suffix: ClassVar[str] = '.glo'
def __post_init__(self) -> None:
# FIXME: check for valid languages
pass
@classmethod
def from_file(cls: Type[GlossaryInst], file_path: Path) -> GlossaryInst:
"""
Create a glossary from the given file.
"""
if not file_path.exists():
raise GlossaryError(f"Glossary file '{file_path}' does not exist")
if file_path.suffix != cls.file_suffix:
raise GlossaryError(f"File type '{file_path.suffix}' is not supported")
with open(file_path, "r") as fd:
try:
# use BaseLoader so every entry is read as a string
# - disables automatic conversions
# - makes it possible to omit quoting for YAML keywords in entries (e. g. 'yes')
# - also correctly reads quoted entries
data = yaml.load(fd, Loader=yaml.BaseLoader)
clean_entries = data['Entries']
return cls(name=data['Name'],
source_lang=data['SourceLang'],
target_lang=data['TargetLang'],
file_path=file_path,
desc=data['Description'],
entries=clean_entries,
ID=data['ID'] if data['ID'] != 'None' else None)
except Exception:
raise GlossaryError(f"'{file_path}' does not contain a valid glossary")
def to_file(self, file_path: Path | None = None) -> None:
"""
Write glossary to given file.
"""
if file_path:
self.file_path = file_path
if not self.file_path:
raise GlossaryError("Got no valid path to write glossary")
# check / add valid suffix
if not self.file_path.suffix:
self.file_path = self.file_path.with_suffix(self.file_suffix)
elif self.file_path.suffix != self.file_suffix:
raise GlossaryError(f"File suffix '{self.file_path.suffix}' is not supported")
# write YAML
with tempfile.NamedTemporaryFile(dir=self.file_path.parent, prefix=self.file_path.name, mode="w", delete=False) as temp_fd:
temp_file_path = Path(temp_fd.name)
data = {'Name': self.name,
'Description': str(self.desc),
'ID': str(self.ID),
'SourceLang': self.source_lang,
'TargetLang': self.target_lang,
'Entries': self.entries}
yaml.dump(data, temp_fd, sort_keys=False)
shutil.move(temp_file_path, self.file_path)
def export_csv(self, dictionary: dict[str, str], file_path: Path) -> None:
"""
Export the 'entries' of this glossary to a file in CSV format (compatible with DeepL).
"""
with open(file_path, 'w', newline='', encoding='utf-8') as csvfile:
writer = csv.writer(csvfile, delimiter=',', quotechar='"', quoting=csv.QUOTE_ALL)
for source_entry, target_entry in self.entries.items():
writer.writerow([source_entry, target_entry])
def export_tsv(self, entries: dict[str, str], file_path: Path) -> None:
"""
Export the 'entries' of this glossary to a file in TSV format (compatible with DeepL).
"""
with open(file_path, 'w', encoding='utf-8') as file:
for source_entry, target_entry in self.entries.items():
file.write(f"{source_entry}\t{target_entry}\n")
def import_csv(self, file_path: Path) -> None:
"""
Import the entries from the given CSV file to those of the current glossary.
Existing entries are overwritten.
"""
try:
with open(file_path, mode='r', encoding='utf-8') as csvfile:
reader = csv.reader(csvfile, delimiter=',', quotechar='"')
self.entries = {rows[0]: rows[1] for rows in reader if len(rows) >= 2}
except Exception as e:
raise GlossaryError(f"Error importing CSV: {e}")
def import_tsv(self, file_path: Path) -> None:
"""
Import the entries from the given CSV file to those of the current glossary.
Existing entries are overwritten.
"""
try:
with open(file_path, mode='r', encoding='utf-8') as tsvfile:
self.entries = {}
for line in tsvfile:
parts = line.strip().split('\t')
if len(parts) == 2:
self.entries[parts[0]] = parts[1]
except Exception as e:
raise GlossaryError(f"Error importing TSV: {e}")
def to_str(self, with_entries: bool = False) -> str:
"""
Return the current glossary as a string.
"""
output: list[str] = []
output.append(f'{self.name} (ID: {self.ID}):')
if self.desc and self.desc != 'None':
output.append('- ' + self.desc)
output.append(f'- Languages: {self.source_lang} -> {self.target_lang}')
if with_entries:
output.append('- Entries:')
for source, target in self.entries.items():
output.append(f' {source} : {target}')
else:
output.append(f'- Entries: {len(self.entries)}')
return '\n'.join(output)
+92 -3
View File
@@ -3,17 +3,21 @@
# vim: set fileencoding=utf-8 : # vim: set fileencoding=utf-8 :
import sys import sys
import os
import argcomplete import argcomplete
import argparse import argparse
from pathlib import Path from pathlib import Path
from typing import Any from typing import Any
from .configuration import Config, default_config_file from .configuration import Config, default_config_file, ConfigError
from .message import Message from .message import Message
from .commands.question import question_cmd from .commands.question import question_cmd
from .commands.tags import tags_cmd from .commands.tags import tags_cmd
from .commands.config import config_cmd from .commands.config import config_cmd
from .commands.hist import hist_cmd from .commands.hist import hist_cmd
from .commands.print import print_cmd from .commands.print import print_cmd
from .commands.translation import translation_cmd
from .commands.glossary import glossary_cmd
from .chat import msg_location
def tags_completer(prefix: str, parsed_args: Any, **kwargs: Any) -> list[str]: def tags_completer(prefix: str, parsed_args: Any, **kwargs: Any) -> list[str]:
@@ -51,7 +55,7 @@ def create_parser() -> argparse.ArgumentParser:
ai_parser = argparse.ArgumentParser(add_help=False) ai_parser = argparse.ArgumentParser(add_help=False)
ai_parser.add_argument('-A', '--AI', help='AI ID to use', metavar='AI_ID') ai_parser.add_argument('-A', '--AI', help='AI ID to use', metavar='AI_ID')
ai_parser.add_argument('-M', '--model', help='Model to use', metavar='MODEL') ai_parser.add_argument('-M', '--model', help='Model to use', metavar='MODEL')
ai_parser.add_argument('-n', '--num-answers', help='Number of answers to request', type=int, default=1) ai_parser.add_argument('-N', '--num-answers', help='Number of answers to request', type=int, default=1)
ai_parser.add_argument('-m', '--max-tokens', help='Max. nr. of tokens', type=int) ai_parser.add_argument('-m', '--max-tokens', help='Max. nr. of tokens', type=int)
ai_parser.add_argument('-T', '--temperature', help='Temperature value', type=float) ai_parser.add_argument('-T', '--temperature', help='Temperature value', type=float)
@@ -65,6 +69,11 @@ def create_parser() -> argparse.ArgumentParser:
question_group.add_argument('-c', '--create', nargs='+', help='Create a question', metavar='QUESTION') question_group.add_argument('-c', '--create', nargs='+', help='Create a question', metavar='QUESTION')
question_group.add_argument('-r', '--repeat', nargs='*', help='Repeat a question', metavar='MESSAGE') question_group.add_argument('-r', '--repeat', nargs='*', help='Repeat a question', metavar='MESSAGE')
question_group.add_argument('-p', '--process', nargs='*', help='Process existing questions', metavar='MESSAGE') question_group.add_argument('-p', '--process', nargs='*', help='Process existing questions', metavar='MESSAGE')
question_cmd_parser.add_argument('-l', '--location',
choices=[x.value for x in msg_location if x not in [msg_location.MEM, msg_location.DISK]],
default='db',
help='Use given location when building the chat history (default: \'db\')')
question_cmd_parser.add_argument('-g', '--glob', help='Filter message files using the given glob pattern')
question_cmd_parser.add_argument('-O', '--overwrite', help='Overwrite existing messages when repeating them', question_cmd_parser.add_argument('-O', '--overwrite', help='Overwrite existing messages when repeating them',
action='store_true') action='store_true')
question_cmd_parser.add_argument('-s', '--source-text', nargs='+', help='Add content of a file to the query', metavar='FILE') question_cmd_parser.add_argument('-s', '--source-text', nargs='+', help='Add content of a file to the query', metavar='FILE')
@@ -87,11 +96,16 @@ def create_parser() -> argparse.ArgumentParser:
hist_cmd_parser.add_argument('-Q', '--question', help='Print only questions with given substring', metavar='SUBSTRING') hist_cmd_parser.add_argument('-Q', '--question', help='Print only questions with given substring', metavar='SUBSTRING')
hist_cmd_parser.add_argument('-d', '--tight', help='Print without message separators', action='store_true') hist_cmd_parser.add_argument('-d', '--tight', help='Print without message separators', action='store_true')
hist_cmd_parser.add_argument('-P', '--no-paging', help='Print without paging', action='store_true') hist_cmd_parser.add_argument('-P', '--no-paging', help='Print without paging', action='store_true')
hist_cmd_parser.add_argument('-l', '--location',
choices=[x.value for x in msg_location if x not in [msg_location.MEM, msg_location.DISK]],
default='db',
help='Use given location when building the chat history (default: \'db\')')
hist_cmd_parser.add_argument('-g', '--glob', help='Filter message files using the given glob pattern')
# 'tags' command parser # 'tags' command parser
tags_cmd_parser = cmdparser.add_parser('tags', tags_cmd_parser = cmdparser.add_parser('tags',
help="Manage tags.", help="Manage tags.",
aliases=['t']) aliases=['T'])
tags_cmd_parser.set_defaults(func=tags_cmd) tags_cmd_parser.set_defaults(func=tags_cmd)
tags_group = tags_cmd_parser.add_mutually_exclusive_group(required=True) tags_group = tags_cmd_parser.add_mutually_exclusive_group(required=True)
tags_group.add_argument('-l', '--list', help="List all tags and their frequency", tags_group.add_argument('-l', '--list', help="List all tags and their frequency",
@@ -125,10 +139,80 @@ def create_parser() -> argparse.ArgumentParser:
print_cmd_modes.add_argument('-a', '--answer', help='Only print the answer', action='store_true') print_cmd_modes.add_argument('-a', '--answer', help='Only print the answer', action='store_true')
print_cmd_modes.add_argument('-S', '--only-source-code', help='Only print embedded source code', action='store_true') print_cmd_modes.add_argument('-S', '--only-source-code', help='Only print embedded source code', action='store_true')
# 'translation' command parser
translation_cmd_parser = cmdparser.add_parser('translation', parents=[ai_parser, tag_parser],
help="Ask, create and repeat translations.",
aliases=['t'])
translation_cmd_parser.set_defaults(func=translation_cmd)
translation_group = translation_cmd_parser.add_mutually_exclusive_group(required=True)
translation_group.add_argument('-a', '--ask', nargs='+', help='Ask to translate the given text', metavar='TEXT')
translation_group.add_argument('-c', '--create', nargs='+', help='Create a translation', metavar='TEXT')
translation_group.add_argument('-r', '--repeat', nargs='*', help='Repeat a translation', metavar='MESSAGE')
translation_cmd_parser.add_argument('-l', '--source-lang', help="Source language", metavar="LANGUAGE", required=True)
translation_cmd_parser.add_argument('-L', '--target-lang', help="Target language", metavar="LANGUAGE", required=True)
translation_cmd_parser.add_argument('-G', '--glossaries', nargs='+', help="List of glossary names", metavar="GLOSSARY")
translation_cmd_parser.add_argument('-d', '--input-document', help="Document to translate", metavar="FILE")
translation_cmd_parser.add_argument('-D', '--output-document', help="Path for the translated document", metavar="FILE")
# 'glossary' command parser
glossary_cmd_parser = cmdparser.add_parser('glossary', parents=[ai_parser],
help="Manage glossaries.",
aliases=['g'])
glossary_cmd_parser.set_defaults(func=glossary_cmd)
glossary_group = glossary_cmd_parser.add_mutually_exclusive_group(required=True)
glossary_group.add_argument('-c', '--create', help='Create a glossary', action='store_true')
glossary_cmd_parser.add_argument('-n', '--name', help="Glossary name (not ID)", metavar="NAME")
glossary_cmd_parser.add_argument('-l', '--source-lang', help="Source language", metavar="LANGUAGE")
glossary_cmd_parser.add_argument('-L', '--target-lang', help="Target language", metavar="LANGUAGE")
glossary_cmd_parser.add_argument('-f', '--file', help='File path of the goven glossary', metavar='GLOSSARY_FILE')
glossary_cmd_parser.add_argument('-D', '--description', help="Glossary description", metavar="DESCRIPTION")
glossary_group.add_argument('-i', '--list', help='List existing glossaries', action='store_true')
glossary_group.add_argument('-p', '--print', help='Print an existing glossary', action='store_true')
argcomplete.autocomplete(parser) argcomplete.autocomplete(parser)
return parser return parser
def create_directories(config: Config) -> None: # noqa: 11
"""
Create the directories in the given configuration if they don't exist.
"""
def make_dir(path: Path) -> None:
try:
os.makedirs(path.absolute())
except Exception as e:
print(f"Creating directory '{path.absolute()}' failed with: {e}")
sys.exit(1)
# Cache
cache_path = Path(config.cache)
if not cache_path.exists():
answer = input(f"Cache directory '{cache_path}' does not exist. Create it? [y/n]")
if answer.lower() in ['y', 'yes']:
make_dir(cache_path.absolute())
else:
print("Can't continue without a valid cache directory!")
sys.exit(1)
# DB
db_path = Path(config.db)
if not db_path.exists():
answer = input(f"DB directory '{db_path}' does not exist. Create it? [y/n]")
if answer.lower() in ['y', 'yes']:
make_dir(db_path.absolute())
else:
print("Can't continue without a valid DB directory!")
sys.exit(1)
# Glossaries
if config.glossaries:
glossaries_path = Path(config.glossaries)
if not glossaries_path.exists():
answer = input(f"Glossaries directory '{glossaries_path}' does not exist. Create it? [y/n]")
if answer.lower() in ['y', 'yes']:
make_dir(glossaries_path.absolute())
else:
print("Can't continue without a valid glossaries directory. Create it or remove it from the configuration.")
sys.exit(1)
def main() -> int: def main() -> int:
parser = create_parser() parser = create_parser()
args = parser.parse_args() args = parser.parse_args()
@@ -137,7 +221,12 @@ def main() -> int:
if command.func == config_cmd: if command.func == config_cmd:
command.func(command) command.func(command)
else: else:
try:
config = Config.from_file(args.config) config = Config.from_file(args.config)
except ConfigError as err:
print(f"{err}")
return 1
create_directories(config)
command.func(command, config) command.func(command, config)
return 0 return 0
+88 -16
View File
@@ -5,7 +5,9 @@ import pathlib
import yaml import yaml
import tempfile import tempfile
import shutil import shutil
import io
from typing import Type, TypeVar, ClassVar, Optional, Any, Union, Final, Literal, Iterable, Tuple from typing import Type, TypeVar, ClassVar, Optional, Any, Union, Final, Literal, Iterable, Tuple
from typing import Generator, Iterator
from typing import get_args as typing_get_args from typing import get_args as typing_get_args
from dataclasses import dataclass, asdict, field from dataclasses import dataclass, asdict, field
from .tags import Tag, TagLine, TagError, match_tags, rename_tags from .tags import Tag, TagLine, TagError, match_tags, rename_tags
@@ -49,7 +51,7 @@ def source_code(text: str, include_delims: bool = False) -> list[str]:
code_lines: list[str] = [] code_lines: list[str] = []
in_code_block = False in_code_block = False
for line in text.split('\n'): for line in str(text).split('\n'):
if line.strip().startswith('```'): if line.strip().startswith('```'):
if include_delims: if include_delims:
code_lines.append(line) code_lines.append(line)
@@ -142,30 +144,100 @@ class Answer(str):
txt_header: ClassVar[str] = '==== ANSWER ====' txt_header: ClassVar[str] = '==== ANSWER ===='
yaml_key: ClassVar[str] = 'answer' yaml_key: ClassVar[str] = 'answer'
def __new__(cls: Type[AnswerInst], string: str) -> AnswerInst: def __init__(self, data: Union[str, Generator[str, None, None]]) -> None:
# Indicator of whether all of data has been processed
self.is_exhausted: bool = False
# Initialize data
self.iterator: Iterator[str] = self._init_data(data)
# Set up the buffer to hold the 'Answer' content
self.buffer: io.StringIO = io.StringIO()
def _init_data(self, data: Union[str, Generator[str, None, None]]) -> Iterator[str]:
""" """
Make sure the answer string does not contain the header as a whole line. Process input data (either a string or a string generator)
""" """
if cls.txt_header in string.split('\n'): if isinstance(data, str):
raise MessageError(f"Answer '{string}' contains the header '{cls.txt_header}'") yield data
instance = super().__new__(cls, string) else:
return instance yield from data
def __str__(self) -> str:
"""
Output all content when converted into a string
"""
# Ensure all data has been processed
for _ in self:
pass
# Return the 'Answer' content
return self.buffer.getvalue()
def __repr__(self) -> str:
return repr(str(self))
def __iter__(self) -> Generator[str, None, None]:
"""
Allows the object to be iterable
"""
# Generate content if not all data has been processed
if not self.is_exhausted:
yield from self.generator_iter()
else:
yield self.buffer.getvalue()
def generator_iter(self) -> Generator[str, None, None]:
"""
Main generator method to process data
"""
for piece in self.iterator:
# Write to buffer and yield piece for the iterator
self.buffer.write(piece)
yield piece
self.is_exhausted = True # Set the flag that all data has been processed
# If the header occurs in the 'Answer' content, raise an error
if f'\n{self.txt_header}' in self.buffer.getvalue() or self.buffer.getvalue().startswith(self.txt_header):
raise MessageError(f"Answer {repr(self.buffer.getvalue())} contains the header {repr(Answer.txt_header)}")
def __eq__(self, other: object) -> bool:
"""
Comparing the object to a string or another object
"""
if isinstance(other, str):
return str(self) == other # Compare the string value of this object to the other string
# Default behavior for comparing non-string objects
return super().__eq__(other)
def __hash__(self) -> int:
"""
Generate a hash for the object based on its string representation.
"""
return hash(str(self))
def __format__(self, format_spec: str) -> str:
"""
Return a formatted version of the string as per the format specification.
"""
return str(self).__format__(format_spec)
@classmethod @classmethod
def from_list(cls: Type[AnswerInst], strings: list[str]) -> AnswerInst: def from_list(cls: Type[AnswerInst], strings: list[str]) -> AnswerInst:
""" """
Build Question from a list of strings. Make sure strings do not contain the header. Build Answer from a list of strings. Make sure strings do not contain the header.
""" """
if cls.txt_header in strings: def _gen() -> Generator[str, None, None]:
raise MessageError(f"Question contains the header '{cls.txt_header}'") if len(strings) > 0:
instance = super().__new__(cls, '\n'.join(strings).strip()) yield strings[0]
return instance for s in strings[1:]:
yield '\n'
yield s
return cls(_gen())
def source_code(self, include_delims: bool = False) -> list[str]: def source_code(self, include_delims: bool = False) -> list[str]:
""" """
Extract and return all source code sections. Extract and return all source code sections.
""" """
return source_code(self, include_delims) return source_code(str(self), include_delims)
class Question(str): class Question(str):
@@ -441,7 +513,7 @@ class Message():
output.append(self.question) output.append(self.question)
if self.answer: if self.answer:
output.append(Answer.txt_header) output.append(Answer.txt_header)
output.append(self.answer) output.append(str(self.answer))
return '\n'.join(output) return '\n'.join(output)
def to_file(self, file_path: Optional[pathlib.Path]=None, mformat: MessageFormat = message_default_format) -> None: # noqa: 11 def to_file(self, file_path: Optional[pathlib.Path]=None, mformat: MessageFormat = message_default_format) -> None: # noqa: 11
@@ -491,7 +563,7 @@ class Message():
temp_fd.write(f'{ModelLine.from_model(self.model)}\n') temp_fd.write(f'{ModelLine.from_model(self.model)}\n')
temp_fd.write(f'{Question.txt_header}\n{self.question}\n') temp_fd.write(f'{Question.txt_header}\n{self.question}\n')
if self.answer: if self.answer:
temp_fd.write(f'{Answer.txt_header}\n{self.answer}\n') temp_fd.write(f'{Answer.txt_header}\n{str(self.answer)}\n')
shutil.move(temp_file_path, file_path) shutil.move(temp_file_path, file_path)
def __to_file_yaml(self, file_path: pathlib.Path) -> None: def __to_file_yaml(self, file_path: pathlib.Path) -> None:
@@ -560,7 +632,7 @@ class Message():
or (mfilter.ai and (not self.ai or mfilter.ai != self.ai)) # noqa: W503 or (mfilter.ai and (not self.ai or mfilter.ai != self.ai)) # noqa: W503
or (mfilter.model and (not self.model or mfilter.model != self.model)) # noqa: W503 or (mfilter.model and (not self.model or mfilter.model != self.model)) # noqa: W503
or (mfilter.question_contains and mfilter.question_contains not in self.question) # noqa: W503 or (mfilter.question_contains and mfilter.question_contains not in self.question) # noqa: W503
or (mfilter.answer_contains and (not self.answer or mfilter.answer_contains not in self.answer)) # noqa: W503 or (mfilter.answer_contains and (not self.answer or mfilter.answer_contains not in str(self.answer))) # noqa: W503
or (mfilter.answer_state == 'available' and not self.answer) # noqa: W503 or (mfilter.answer_state == 'available' and not self.answer) # noqa: W503
or (mfilter.ai_state == 'available' and not self.ai) # noqa: W503 or (mfilter.ai_state == 'available' and not self.ai) # noqa: W503
or (mfilter.model_state == 'available' and not self.model) # noqa: W503 or (mfilter.model_state == 'available' and not self.model) # noqa: W503
+1
View File
@@ -2,3 +2,4 @@ openai
PyYAML PyYAML
argcomplete argcomplete
pytest pytest
tiktoken
+25 -13
View File
@@ -16,26 +16,37 @@ class OpenAITest(unittest.TestCase):
openai = OpenAI(config) openai = OpenAI(config)
# Set up the mock response from openai.ChatCompletion.create # Set up the mock response from openai.ChatCompletion.create
mock_response = { mock_chunk1 = {
'choices': [ 'choices': [
{ {
'message': { 'index': 0,
'delta': {
'content': 'Answer 1' 'content': 'Answer 1'
} },
'finish_reason': None
}, },
{ {
'message': { 'index': 1,
'delta': {
'content': 'Answer 2' 'content': 'Answer 2'
} },
'finish_reason': None
} }
], ],
'usage': {
'prompt_tokens': 10,
'completion_tokens': 20,
'total_tokens': 30
} }
mock_chunk2 = {
'choices': [
{
'index': 0,
'finish_reason': 'stop'
},
{
'index': 1,
'finish_reason': 'stop'
} }
mock_create.return_value = mock_response ],
}
mock_create.return_value = iter([mock_chunk1, mock_chunk2])
# Create test data # Create test data
question = Message(Question('Question')) question = Message(Question('Question'))
@@ -57,9 +68,9 @@ class OpenAITest(unittest.TestCase):
self.assertIsNotNone(response.tokens) self.assertIsNotNone(response.tokens)
self.assertIsInstance(response.tokens, Tokens) self.assertIsInstance(response.tokens, Tokens)
assert response.tokens assert response.tokens
self.assertEqual(response.tokens.prompt, 10) self.assertEqual(response.tokens.prompt, 53)
self.assertEqual(response.tokens.completion, 20) self.assertEqual(response.tokens.completion, 6)
self.assertEqual(response.tokens.total, 30) self.assertEqual(response.tokens.total, 59)
# Assert the mock call to openai.ChatCompletion.create # Assert the mock call to openai.ChatCompletion.create
mock_create.assert_called_once_with( mock_create.assert_called_once_with(
@@ -76,6 +87,7 @@ class OpenAITest(unittest.TestCase):
max_tokens=config.max_tokens, max_tokens=config.max_tokens,
top_p=config.top_p, top_p=config.top_p,
n=2, n=2,
stream=True,
frequency_penalty=config.frequency_penalty, frequency_penalty=config.frequency_penalty,
presence_penalty=config.presence_penalty presence_penalty=config.presence_penalty
) )
+49 -48
View File
@@ -7,7 +7,7 @@ from io import StringIO
from unittest.mock import patch from unittest.mock import patch
from chatmastermind.tags import TagLine from chatmastermind.tags import TagLine
from chatmastermind.message import Message, Question, Answer, Tag, MessageFilter from chatmastermind.message import Message, Question, Answer, Tag, MessageFilter
from chatmastermind.chat import Chat, ChatDB, ChatError from chatmastermind.chat import Chat, ChatDB, ChatError, msg_location
msg_suffix: str = Message.file_suffix_write msg_suffix: str = Message.file_suffix_write
@@ -240,7 +240,8 @@ class TestChatDB(TestChatBase):
msg_to_file_force_suffix(duplicate_message) msg_to_file_force_suffix(duplicate_message)
with self.assertRaises(ChatError) as cm: with self.assertRaises(ChatError) as cm:
ChatDB.from_dir(pathlib.Path(self.cache_path.name), ChatDB.from_dir(pathlib.Path(self.cache_path.name),
pathlib.Path(self.db_path.name)) pathlib.Path(self.db_path.name),
glob='*')
self.assertEqual(str(cm.exception), "Validation failed") self.assertEqual(str(cm.exception), "Validation failed")
def test_file_path_ID_exists(self) -> None: def test_file_path_ID_exists(self) -> None:
@@ -595,92 +596,92 @@ class TestChatDB(TestChatBase):
chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name), chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name),
pathlib.Path(self.db_path.name)) pathlib.Path(self.db_path.name))
# search for a DB file in memory # search for a DB file in memory
self.assertEqual(chat_db.msg_find([str(self.message1.file_path)], loc='mem'), [self.message1]) self.assertEqual(chat_db.msg_find([str(self.message1.file_path)], loc=msg_location.MEM), [self.message1])
self.assertEqual(chat_db.msg_find([self.message1.msg_id()], loc='mem'), [self.message1]) self.assertEqual(chat_db.msg_find([self.message1.msg_id()], loc=msg_location.MEM), [self.message1])
self.assertEqual(chat_db.msg_find(['0001.msg'], loc='mem'), [self.message1]) self.assertEqual(chat_db.msg_find(['0001.msg'], loc=msg_location.MEM), [self.message1])
self.assertEqual(chat_db.msg_find(['0001'], loc='mem'), [self.message1]) self.assertEqual(chat_db.msg_find(['0001'], loc=msg_location.MEM), [self.message1])
# and on disk # and on disk
self.assertEqual(chat_db.msg_find([str(self.message2.file_path)], loc='db'), [self.message2]) self.assertEqual(chat_db.msg_find([str(self.message2.file_path)], loc=msg_location.DB), [self.message2])
self.assertEqual(chat_db.msg_find([self.message2.msg_id()], loc='db'), [self.message2]) self.assertEqual(chat_db.msg_find([self.message2.msg_id()], loc=msg_location.DB), [self.message2])
self.assertEqual(chat_db.msg_find(['0002.msg'], loc='db'), [self.message2]) self.assertEqual(chat_db.msg_find(['0002.msg'], loc=msg_location.DB), [self.message2])
self.assertEqual(chat_db.msg_find(['0002'], loc='db'), [self.message2]) self.assertEqual(chat_db.msg_find(['0002'], loc=msg_location.DB), [self.message2])
# now search the cache -> expect empty result # now search the cache -> expect empty result
self.assertEqual(chat_db.msg_find([str(self.message3.file_path)], loc='cache'), []) self.assertEqual(chat_db.msg_find([str(self.message3.file_path)], loc=msg_location.CACHE), [])
self.assertEqual(chat_db.msg_find([self.message3.msg_id()], loc='cache'), []) self.assertEqual(chat_db.msg_find([self.message3.msg_id()], loc=msg_location.CACHE), [])
self.assertEqual(chat_db.msg_find(['0003.msg'], loc='cache'), []) self.assertEqual(chat_db.msg_find(['0003.msg'], loc=msg_location.CACHE), [])
self.assertEqual(chat_db.msg_find(['0003'], loc='cache'), []) self.assertEqual(chat_db.msg_find(['0003'], loc=msg_location.CACHE), [])
# search for multiple messages # search for multiple messages
# -> search one twice, expect result to be unique # -> search one twice, expect result to be unique
search_names = ['0001', '0002.msg', self.message3.msg_id(), str(self.message3.file_path)] search_names = ['0001', '0002.msg', self.message3.msg_id(), str(self.message3.file_path)]
expected_result = [self.message1, self.message2, self.message3] expected_result = [self.message1, self.message2, self.message3]
result = chat_db.msg_find(search_names, loc='all') result = chat_db.msg_find(search_names, loc=msg_location.ALL)
self.assert_messages_equal(result, expected_result) self.assert_messages_equal(result, expected_result)
def test_msg_latest(self) -> None: def test_msg_latest(self) -> None:
chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name), chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name),
pathlib.Path(self.db_path.name)) pathlib.Path(self.db_path.name))
self.assertEqual(chat_db.msg_latest(loc='mem'), self.message4) self.assertEqual(chat_db.msg_latest(loc=msg_location.MEM), self.message4)
self.assertEqual(chat_db.msg_latest(loc='db'), self.message4) self.assertEqual(chat_db.msg_latest(loc=msg_location.DB), self.message4)
self.assertEqual(chat_db.msg_latest(loc='disk'), self.message4) self.assertEqual(chat_db.msg_latest(loc=msg_location.DISK), self.message4)
self.assertEqual(chat_db.msg_latest(loc='all'), self.message4) self.assertEqual(chat_db.msg_latest(loc=msg_location.ALL), self.message4)
# the cache is currently empty: # the cache is currently empty:
self.assertIsNone(chat_db.msg_latest(loc='cache')) self.assertIsNone(chat_db.msg_latest(loc=msg_location.CACHE))
# add new messages to the cache dir # add new messages to the cache dir
new_message = Message(question=Question("New Question"), new_message = Message(question=Question("New Question"),
answer=Answer("New Answer")) answer=Answer("New Answer"))
chat_db.cache_add([new_message]) chat_db.cache_add([new_message])
self.assertEqual(chat_db.msg_latest(loc='cache'), new_message) self.assertEqual(chat_db.msg_latest(loc=msg_location.CACHE), new_message)
self.assertEqual(chat_db.msg_latest(loc='mem'), new_message) self.assertEqual(chat_db.msg_latest(loc=msg_location.MEM), new_message)
self.assertEqual(chat_db.msg_latest(loc='disk'), new_message) self.assertEqual(chat_db.msg_latest(loc=msg_location.DISK), new_message)
self.assertEqual(chat_db.msg_latest(loc='all'), new_message) self.assertEqual(chat_db.msg_latest(loc=msg_location.ALL), new_message)
# the DB does not contain the new message # the DB does not contain the new message
self.assertEqual(chat_db.msg_latest(loc='db'), self.message4) self.assertEqual(chat_db.msg_latest(loc=msg_location.DB), self.message4)
def test_msg_gather(self) -> None: def test_msg_gather(self) -> None:
chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name), chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name),
pathlib.Path(self.db_path.name)) pathlib.Path(self.db_path.name))
all_messages = [self.message1, self.message2, self.message3, self.message4] all_messages = [self.message1, self.message2, self.message3, self.message4]
self.assert_messages_equal(chat_db.msg_gather(loc='all'), all_messages) self.assert_messages_equal(chat_db.msg_gather(loc=msg_location.ALL), all_messages)
self.assert_messages_equal(chat_db.msg_gather(loc='db'), all_messages) self.assert_messages_equal(chat_db.msg_gather(loc=msg_location.DB), all_messages)
self.assert_messages_equal(chat_db.msg_gather(loc='mem'), all_messages) self.assert_messages_equal(chat_db.msg_gather(loc=msg_location.MEM), all_messages)
self.assert_messages_equal(chat_db.msg_gather(loc='disk'), all_messages) self.assert_messages_equal(chat_db.msg_gather(loc=msg_location.DISK), all_messages)
self.assert_messages_equal(chat_db.msg_gather(loc='cache'), []) self.assert_messages_equal(chat_db.msg_gather(loc=msg_location.CACHE), [])
# add a new message, but only to the internal list # add a new message, but only to the internal list
new_message = Message(Question("What?")) new_message = Message(Question("What?"))
all_messages_mem = all_messages + [new_message] all_messages_mem = all_messages + [new_message]
chat_db.msg_add([new_message]) chat_db.msg_add([new_message])
self.assert_messages_equal(chat_db.msg_gather(loc='mem'), all_messages_mem) self.assert_messages_equal(chat_db.msg_gather(loc=msg_location.MEM), all_messages_mem)
self.assert_messages_equal(chat_db.msg_gather(loc='all'), all_messages_mem) self.assert_messages_equal(chat_db.msg_gather(loc=msg_location.ALL), all_messages_mem)
# the nr. of messages on disk did not change -> expect old result # the nr. of messages on disk did not change -> expect old result
self.assert_messages_equal(chat_db.msg_gather(loc='db'), all_messages) self.assert_messages_equal(chat_db.msg_gather(loc=msg_location.DB), all_messages)
self.assert_messages_equal(chat_db.msg_gather(loc='disk'), all_messages) self.assert_messages_equal(chat_db.msg_gather(loc=msg_location.DISK), all_messages)
self.assert_messages_equal(chat_db.msg_gather(loc='cache'), []) self.assert_messages_equal(chat_db.msg_gather(loc=msg_location.CACHE), [])
# test with MessageFilter # test with MessageFilter
self.assert_messages_equal(chat_db.msg_gather(loc='all', mfilter=MessageFilter(tags_or={Tag('tag1')})), self.assert_messages_equal(chat_db.msg_gather(loc=msg_location.ALL, mfilter=MessageFilter(tags_or={Tag('tag1')})),
[self.message1]) [self.message1])
self.assert_messages_equal(chat_db.msg_gather(loc='disk', mfilter=MessageFilter(tags_or={Tag('tag2')})), self.assert_messages_equal(chat_db.msg_gather(loc=msg_location.DISK, mfilter=MessageFilter(tags_or={Tag('tag2')})),
[self.message2]) [self.message2])
self.assert_messages_equal(chat_db.msg_gather(loc='cache', mfilter=MessageFilter(tags_or={Tag('tag3')})), self.assert_messages_equal(chat_db.msg_gather(loc=msg_location.CACHE, mfilter=MessageFilter(tags_or={Tag('tag3')})),
[]) [])
self.assert_messages_equal(chat_db.msg_gather(loc='mem', mfilter=MessageFilter(question_contains="What")), self.assert_messages_equal(chat_db.msg_gather(loc=msg_location.MEM, mfilter=MessageFilter(question_contains="What")),
[new_message]) [new_message])
def test_msg_move_and_gather(self) -> None: def test_msg_move_and_gather(self) -> None:
chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name), chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name),
pathlib.Path(self.db_path.name)) pathlib.Path(self.db_path.name))
all_messages = [self.message1, self.message2, self.message3, self.message4] all_messages = [self.message1, self.message2, self.message3, self.message4]
self.assert_messages_equal(chat_db.msg_gather(loc='db'), all_messages) self.assert_messages_equal(chat_db.msg_gather(loc=msg_location.DB), all_messages)
self.assert_messages_equal(chat_db.msg_gather(loc='cache'), []) self.assert_messages_equal(chat_db.msg_gather(loc=msg_location.CACHE), [])
# move first message to the cache # move first message to the cache
chat_db.cache_move(self.message1) chat_db.cache_move(self.message1)
self.assert_messages_equal(chat_db.msg_gather(loc='cache'), [self.message1]) self.assert_messages_equal(chat_db.msg_gather(loc=msg_location.CACHE), [self.message1])
self.assertEqual(self.message1.file_path.parent, pathlib.Path(self.cache_path.name)) # type: ignore [union-attr] self.assertEqual(self.message1.file_path.parent, pathlib.Path(self.cache_path.name)) # type: ignore [union-attr]
self.assert_messages_equal(chat_db.msg_gather(loc='db'), [self.message2, self.message3, self.message4]) self.assert_messages_equal(chat_db.msg_gather(loc=msg_location.DB), [self.message2, self.message3, self.message4])
self.assert_messages_equal(chat_db.msg_gather(loc='all'), all_messages) self.assert_messages_equal(chat_db.msg_gather(loc=msg_location.ALL), all_messages)
self.assert_messages_equal(chat_db.msg_gather(loc='disk'), all_messages) self.assert_messages_equal(chat_db.msg_gather(loc=msg_location.DISK), all_messages)
self.assert_messages_equal(chat_db.msg_gather(loc='mem'), all_messages) self.assert_messages_equal(chat_db.msg_gather(loc=msg_location.MEM), all_messages)
# now move first message back to the DB # now move first message back to the DB
chat_db.db_move(self.message1) chat_db.db_move(self.message1)
self.assert_messages_equal(chat_db.msg_gather(loc='cache'), []) self.assert_messages_equal(chat_db.msg_gather(loc=msg_location.CACHE), [])
self.assertEqual(self.message1.file_path.parent, pathlib.Path(self.db_path.name)) # type: ignore [union-attr] self.assertEqual(self.message1.file_path.parent, pathlib.Path(self.db_path.name)) # type: ignore [union-attr]
self.assert_messages_equal(chat_db.msg_gather(loc='db'), all_messages) self.assert_messages_equal(chat_db.msg_gather(loc=msg_location.DB), all_messages)
+6 -1
View File
@@ -71,11 +71,13 @@ class TestConfig(unittest.TestCase):
'frequency_penalty': 0.7, 'frequency_penalty': 0.7,
'presence_penalty': 0.2 'presence_penalty': 0.2
} }
} },
'glossaries': './glossaries/'
} }
config = Config.from_dict(source_dict) config = Config.from_dict(source_dict)
self.assertEqual(config.cache, '.') self.assertEqual(config.cache, '.')
self.assertEqual(config.db, './test_db/') self.assertEqual(config.db, './test_db/')
self.assertEqual(config.glossaries, './glossaries/')
self.assertEqual(len(config.ais), 1) self.assertEqual(len(config.ais), 1)
self.assertEqual(config.ais['myopenai'].name, 'openai') self.assertEqual(config.ais['myopenai'].name, 'openai')
self.assertEqual(cast(OpenAIConfig, config.ais['myopenai']).system, 'Custom system') self.assertEqual(cast(OpenAIConfig, config.ais['myopenai']).system, 'Custom system')
@@ -105,6 +107,7 @@ class TestConfig(unittest.TestCase):
'frequency_penalty': 0.7, 'frequency_penalty': 0.7,
'presence_penalty': 0.2 'presence_penalty': 0.2
} }
# omit glossaries, since it's optional
} }
} }
with open(self.test_file.name, 'w') as f: with open(self.test_file.name, 'w') as f:
@@ -113,6 +116,8 @@ class TestConfig(unittest.TestCase):
self.assertIsInstance(config, Config) self.assertIsInstance(config, Config)
self.assertEqual(config.cache, './test_cache/') self.assertEqual(config.cache, './test_cache/')
self.assertEqual(config.db, './test_db/') self.assertEqual(config.db, './test_db/')
# missing 'glossaries' should result in 'None'
self.assertEqual(config.glossaries, None)
self.assertEqual(len(config.ais), 1) self.assertEqual(len(config.ais), 1)
self.assertIsInstance(config.ais['default'], AIConfig) self.assertIsInstance(config.ais['default'], AIConfig)
self.assertEqual(cast(OpenAIConfig, config.ais['default']).system, 'Custom system') self.assertEqual(cast(OpenAIConfig, config.ais['default']).system, 'Custom system')
+209
View File
@@ -0,0 +1,209 @@
import unittest
import tempfile
from pathlib import Path
from chatmastermind.glossary import Glossary, GlossaryError
glossary_suffix: str = Glossary.file_suffix
class TestGlossary(unittest.TestCase):
def test_from_file_yaml_unquoted(self) -> None:
"""
Test glossary creatiom from YAML with unquoted entries.
"""
with tempfile.NamedTemporaryFile('w', delete=False, suffix=glossary_suffix) as yaml_file:
yaml_file.write("Name: Sample\n"
"Description: A brief description\n"
"ID: '123'\n"
"SourceLang: en\n"
"TargetLang: es\n"
"Entries:\n"
" hello: hola\n"
" goodbye: adiós\n"
# 'yes' is a YAML keyword and would normally be quoted
" yes: sí\n"
" I'm going home: me voy a casa\n")
yaml_file_path = Path(yaml_file.name)
# create and check valid glossary
glossary = Glossary.from_file(yaml_file_path)
self.assertEqual(glossary.name, "Sample")
self.assertEqual(glossary.desc, "A brief description")
self.assertEqual(glossary.ID, "123")
self.assertEqual(glossary.source_lang, "en")
self.assertEqual(glossary.target_lang, "es")
self.assertEqual(glossary.entries, {"hello": "hola",
"goodbye": "adiós",
"yes": "",
"I'm going home": "me voy a casa"})
yaml_file_path.unlink() # Remove the temporary file
def test_from_file_yaml_quoted(self) -> None:
"""
Test glossary creatiom from YAML with quoted entries.
"""
with tempfile.NamedTemporaryFile('w', delete=False, suffix=glossary_suffix) as yaml_file:
yaml_file.write("Name: Sample\n"
"Description: A brief description\n"
"ID: '123'\n"
"SourceLang: en\n"
"TargetLang: es\n"
"Entries:\n"
" 'hello': 'hola'\n"
" 'goodbye': 'adiós'\n"
" 'yes': ''\n"
" \"I'm going home\": 'me voy a casa'\n")
yaml_file_path = Path(yaml_file.name)
# create and check valid glossary
glossary = Glossary.from_file(yaml_file_path)
self.assertEqual(glossary.name, "Sample")
self.assertEqual(glossary.desc, "A brief description")
self.assertEqual(glossary.ID, "123")
self.assertEqual(glossary.source_lang, "en")
self.assertEqual(glossary.target_lang, "es")
self.assertEqual(glossary.entries, {"hello": "hola",
"goodbye": "adiós",
"yes": "",
"I'm going home": "me voy a casa"})
yaml_file_path.unlink() # Remove the temporary file
def test_to_file_writes_yaml(self) -> None:
# Create glossary instance
glossary = Glossary(name="Test",
desc="Test description",
ID="666",
source_lang="en",
target_lang="fr",
entries={"yes": "oui"})
with tempfile.NamedTemporaryFile('w', suffix=glossary_suffix) as tmp_file:
file_path = Path(tmp_file.name)
glossary.to_file(file_path)
# read and check valid YAML
with open(file_path, 'r') as file:
content = file.read()
self.assertIn("Name: Test", content)
self.assertIn("Description: Test description", content)
self.assertIn("ID: '666'", content)
self.assertIn("SourceLang: en", content)
self.assertIn("TargetLang: fr", content)
self.assertIn("Entries", content)
# 'yes' is a YAML keyword and therefore quoted
self.assertIn("'yes': oui", content)
def test_write_read_glossary(self) -> None:
# Create glossary instance
# -> use 'yes' in order to test if the YAML quoting is correctly removed when reading the file
glossary_write = Glossary(name="Test", source_lang="en", target_lang="fr", entries={"yes": "oui"})
with tempfile.NamedTemporaryFile('w', suffix=glossary_suffix) as tmp_file:
file_path = Path(tmp_file.name)
glossary_write.to_file(file_path)
# create new instance from glossary file
glossary_read = Glossary.from_file(file_path)
self.assertEqual(glossary_write.name, glossary_read.name)
self.assertEqual(glossary_write.source_lang, glossary_read.source_lang)
self.assertEqual(glossary_write.target_lang, glossary_read.target_lang)
self.assertDictEqual(glossary_write.entries, glossary_read.entries)
def test_import_export_csv(self) -> None:
glossary = Glossary(name="Test", source_lang="en", target_lang="fr", entries={})
# First export to CSV
with tempfile.NamedTemporaryFile('w', suffix=glossary_suffix) as csvfile:
csv_file_path = Path(csvfile.name)
glossary.entries = {"hello": "salut", "goodbye": "au revoir"}
glossary.export_csv(glossary.entries, csv_file_path)
# Now import CSV
glossary.import_csv(csv_file_path)
self.assertEqual(glossary.entries, {"hello": "salut", "goodbye": "au revoir"})
def test_import_export_tsv(self) -> None:
glossary = Glossary(name="Test", source_lang="en", target_lang="fr", entries={})
# First export to TSV
with tempfile.NamedTemporaryFile('w', suffix=glossary_suffix) as tsvfile:
tsv_file_path = Path(tsvfile.name)
glossary.entries = {"hello": "salut", "goodbye": "au revoir"}
glossary.export_tsv(glossary.entries, tsv_file_path)
# Now import TSV
glossary.import_tsv(tsv_file_path)
self.assertEqual(glossary.entries, {"hello": "salut", "goodbye": "au revoir"})
def test_to_file_wrong_suffix(self) -> None:
"""
Test for exception if suffix is wrong.
"""
glossary = Glossary(name="Test", source_lang="en", target_lang="fr", entries={"yes": "oui"})
with tempfile.NamedTemporaryFile('w', suffix='.wrong') as tmp_file:
file_path = Path(tmp_file.name)
with self.assertRaises(GlossaryError) as err:
glossary.to_file(file_path)
self.assertEqual(str(err.exception), "File suffix '.wrong' is not supported")
def test_to_file_auto_suffix(self) -> None:
"""
Test if suffix is auto-generated if omitted.
"""
glossary = Glossary(name="Test", source_lang="en", target_lang="fr", entries={"yes": "oui"})
with tempfile.NamedTemporaryFile('w', suffix='') as tmp_file:
file_path = Path(tmp_file.name)
glossary.to_file(file_path)
assert glossary.file_path is not None
self.assertEqual(glossary.file_path.suffix, glossary_suffix)
# remove glossary file (differs from 'tmp_file' because of the added suffix
glossary.file_path.unlink()
def test_to_str_with_id(self) -> None:
# Create a Glossary instance with an ID
glossary_with_id = Glossary(name="TestGlossary", source_lang="en", target_lang="fr",
desc="A simple test glossary", ID="1001", entries={"one": "un"})
glossary_str = glossary_with_id.to_str()
self.assertIn("TestGlossary (ID: 1001):", glossary_str)
self.assertIn("- A simple test glossary", glossary_str)
self.assertIn("- Languages: en -> fr", glossary_str)
self.assertIn("- Entries: 1", glossary_str)
def test_to_str_with_id_and_entries(self) -> None:
# Create a Glossary instance with an ID and include entries
glossary_with_entries = Glossary(name="TestGlossaryWithEntries", source_lang="en", target_lang="fr",
desc="Another test glossary", ID="2002",
entries={"hello": "salut", "goodbye": "au revoir"})
glossary_str_with_entries = glossary_with_entries.to_str(with_entries=True)
self.assertIn("TestGlossaryWithEntries (ID: 2002):", glossary_str_with_entries)
self.assertIn("- Entries:", glossary_str_with_entries)
self.assertIn(" hello : salut", glossary_str_with_entries)
self.assertIn(" goodbye : au revoir", glossary_str_with_entries)
def test_to_str_without_id(self) -> None:
# Create a Glossary instance without an ID
glossary_without_id = Glossary(name="TestGlossaryNoID", source_lang="en", target_lang="fr",
desc="A test glossary without an ID", ID=None, entries={"yes": "oui"})
glossary_str_no_id = glossary_without_id.to_str()
self.assertIn("TestGlossaryNoID (ID: None):", glossary_str_no_id)
self.assertIn("- A test glossary without an ID", glossary_str_no_id)
self.assertIn("- Languages: en -> fr", glossary_str_no_id)
self.assertIn("- Entries: 1", glossary_str_no_id)
def test_to_str_without_id_and_no_entries(self) -> None:
# Create a Glossary instance without an ID and no entries
glossary_no_id_no_entries = Glossary(name="EmptyGlossary", source_lang="en", target_lang="fr",
desc="An empty test glossary", ID=None, entries={})
glossary_str_no_id_no_entries = glossary_no_id_no_entries.to_str()
self.assertIn("EmptyGlossary (ID: None):", glossary_str_no_id_no_entries)
self.assertIn("- An empty test glossary", glossary_str_no_id_no_entries)
self.assertIn("- Languages: en -> fr", glossary_str_no_id_no_entries)
self.assertIn("- Entries: 0", glossary_str_no_id_no_entries)
def test_to_str_no_description(self) -> None:
# Create a Glossary instance with an ID
glossary_with_id = Glossary(name="TestGlossary", source_lang="en", target_lang="fr",
ID="1001", entries={"one": "un"})
glossary_str = glossary_with_id.to_str()
expected_str = """TestGlossary (ID: 1001):
- Languages: en -> fr
- Entries: 1"""
self.assertEqual(expected_str, glossary_str)
+149
View File
@@ -0,0 +1,149 @@
import unittest
import argparse
import tempfile
import io
from contextlib import redirect_stdout
from chatmastermind.configuration import Config
from chatmastermind.commands.glossary import (
Glossary,
GlossaryCmdError,
glossary_cmd,
get_glossary_file_path,
create_glossary,
print_glossary,
list_glossaries
)
class TestGlossaryCmdNoGlossaries(unittest.TestCase):
def setUp(self) -> None:
# create DB and cache
self.db_dir = tempfile.TemporaryDirectory()
self.cache_dir = tempfile.TemporaryDirectory()
self.glossaries_dir = tempfile.TemporaryDirectory()
# create configuration
self.config = Config()
self.config.cache = self.cache_dir.name
self.config.db = self.db_dir.name
self.config.glossaries = self.glossaries_dir.name
# create a mock argparse.Namespace
self.args = argparse.Namespace(
create=True,
list=False,
print=False,
name='new_glossary',
file=None,
source_lang='en',
target_lang='de',
description=False,
)
def test_glossary_create_no_glossaries_err(self) -> None:
self.config.glossaries = None
with self.assertRaises(GlossaryCmdError) as err:
create_glossary(self.args, self.config)
self.assertIn(str(err.exception).lower(), "glossaries directory missing")
def test_glossary_create_no_name_err(self) -> None:
self.args.name = None
with self.assertRaises(GlossaryCmdError) as err:
create_glossary(self.args, self.config)
self.assertIn(str(err.exception).lower(), "missing glossary name")
def test_glossary_create_no_source_lang_err(self) -> None:
self.args.source_lang = None
with self.assertRaises(GlossaryCmdError) as err:
create_glossary(self.args, self.config)
self.assertIn(str(err.exception).lower(), "missing source language")
def test_glossary_create_no_target_lang_err(self) -> None:
self.args.target_lang = None
with self.assertRaises(GlossaryCmdError) as err:
create_glossary(self.args, self.config)
self.assertIn(str(err.exception).lower(), "missing target language")
def test_glossary_print_no_name_err(self) -> None:
self.args.name = None
with self.assertRaises(GlossaryCmdError) as err:
print_glossary(self.args, self.config)
self.assertIn(str(err.exception).lower(), "missing glossary name")
def test_glossary_list_no_glossaries_err(self) -> None:
self.config.glossaries = None
with self.assertRaises(GlossaryCmdError) as err:
list_glossaries(self.args, self.config)
self.assertIn(str(err.exception).lower(), "glossaries directory missing")
def test_glossary_create(self) -> None:
self.args.create = True
self.args.list = False
self.args.print = False
glossary_cmd(self.args, self.config)
expected_path = get_glossary_file_path(self.args.name, self.config)
glo = Glossary.from_file(expected_path)
self.assertEqual(glo.name, self.args.name)
expected_path.unlink()
def test_glossary_create_twice_err(self) -> None:
self.args.create = True
self.args.list = False
self.args.print = False
glossary_cmd(self.args, self.config)
expected_path = get_glossary_file_path(self.args.name, self.config)
glo = Glossary.from_file(expected_path)
self.assertEqual(glo.name, self.args.name)
# create glossary with the same name again
with self.assertRaises(GlossaryCmdError) as err:
create_glossary(self.args, self.config)
self.assertIn(str(err.exception).lower(), "already exists")
expected_path.unlink()
class TestGlossaryCmdWithGlossaries(unittest.TestCase):
def setUp(self) -> None:
# create DB and cache
self.db_dir = tempfile.TemporaryDirectory()
self.cache_dir = tempfile.TemporaryDirectory()
self.glossaries_dir = tempfile.TemporaryDirectory()
# create configuration
self.config = Config()
self.config.cache = self.cache_dir.name
self.config.db = self.db_dir.name
self.config.glossaries = self.glossaries_dir.name
# create a mock argparse.Namespace
self.args = argparse.Namespace(
create=True,
list=False,
print=False,
name='Glossary1',
file=None,
source_lang='en',
target_lang='de',
description=False,
)
# create Glossary1
glossary_cmd(self.args, self.config)
self.Glossary1_path = get_glossary_file_path('Glossary1', self.config)
# create Glossary2
self.args.name = 'Glossary2'
glossary_cmd(self.args, self.config)
self.Glossary2_path = get_glossary_file_path('Glossary2', self.config)
def test_glossaries_exist(self) -> None:
"""
Test if the default glossaries created in setUp exist.
"""
glo = Glossary.from_file(self.Glossary1_path)
self.assertEqual(glo.name, 'Glossary1')
glo = Glossary.from_file(self.Glossary2_path)
self.assertEqual(glo.name, 'Glossary2')
def test_glossaries_list(self) -> None:
self.args.create = False
self.args.list = True
with redirect_stdout(io.StringIO()) as list_output:
glossary_cmd(self.args, self.config)
self.assertIn('Glossary1', list_output.getvalue())
self.assertIn('Glossary2', list_output.getvalue())
+2 -2
View File
@@ -4,7 +4,7 @@ import tempfile
import yaml import yaml
from pathlib import Path from pathlib import Path
from chatmastermind.message import Message, Question from chatmastermind.message import Message, Question
from chatmastermind.chat import ChatDB, ChatError from chatmastermind.chat import ChatDB, ChatError, msg_location
from chatmastermind.configuration import Config from chatmastermind.configuration import Config
from chatmastermind.commands.hist import convert_messages from chatmastermind.commands.hist import convert_messages
@@ -41,7 +41,7 @@ class TestConvertMessages(unittest.TestCase):
def test_convert_messages(self) -> None: def test_convert_messages(self) -> None:
self.args.convert = 'yaml' self.args.convert = 'yaml'
convert_messages(self.args, self.config) convert_messages(self.args, self.config)
msgs = self.chat.msg_gather(loc='disk', glob='*.*') msgs = self.chat.msg_gather(loc=msg_location.DISK, glob='*.*')
# Check if the number of messages is the same as before # Check if the number of messages is the same as before
self.assertEqual(len(msgs), len(self.messages)) self.assertEqual(len(msgs), len(self.messages))
# Check if all messages have the requested suffix # Check if all messages have the requested suffix
+1 -1
View File
@@ -91,7 +91,7 @@ class QuestionTestCase(unittest.TestCase):
class AnswerTestCase(unittest.TestCase): class AnswerTestCase(unittest.TestCase):
def test_answer_with_header(self) -> None: def test_answer_with_header(self) -> None:
with self.assertRaises(MessageError): with self.assertRaises(MessageError):
Answer(f"{Answer.txt_header}\nno") str(Answer(f"{Answer.txt_header}\nno"))
def test_answer_with_legal_header(self) -> None: def test_answer_with_legal_header(self) -> None:
answer = Answer(f"This is a line contaning '{Answer.txt_header}'\nIt is what it is.") answer = Answer(f"This is a line contaning '{Answer.txt_header}'\nIt is what it is.")
+16 -14
View File
@@ -9,7 +9,7 @@ from chatmastermind.configuration import Config
from chatmastermind.commands.question import create_message, question_cmd from chatmastermind.commands.question import create_message, question_cmd
from chatmastermind.tags import Tag from chatmastermind.tags import Tag
from chatmastermind.message import Message, Question, Answer from chatmastermind.message import Message, Question, Answer
from chatmastermind.chat import Chat, ChatDB from chatmastermind.chat import Chat, ChatDB, msg_location
from chatmastermind.ai import AIError from chatmastermind.ai import AIError
from .test_common import TestWithFakeAI from .test_common import TestWithFakeAI
@@ -234,6 +234,8 @@ class TestQuestionCmd(TestWithFakeAI):
# create a mock argparse.Namespace # create a mock argparse.Namespace
self.args = argparse.Namespace( self.args = argparse.Namespace(
ask=['What is the meaning of life?'], ask=['What is the meaning of life?'],
glob=None,
location='db',
num_answers=1, num_answers=1,
output_tags=['science'], output_tags=['science'],
AI='FakeAI', AI='FakeAI',
@@ -279,7 +281,7 @@ class TestQuestionCmdAsk(TestQuestionCmd):
# check for the expected message files # check for the expected message files
chat = ChatDB.from_dir(Path(self.cache_dir.name), chat = ChatDB.from_dir(Path(self.cache_dir.name),
Path(self.db_dir.name)) Path(self.db_dir.name))
cached_msg = chat.msg_gather(loc='cache') cached_msg = chat.msg_gather(loc=msg_location.CACHE)
self.assertEqual(len(self.message_list(self.cache_dir)), 1) self.assertEqual(len(self.message_list(self.cache_dir)), 1)
self.assert_msgs_equal_except_file_path(cached_msg, expected_responses) self.assert_msgs_equal_except_file_path(cached_msg, expected_responses)
@@ -337,7 +339,7 @@ class TestQuestionCmdAsk(TestQuestionCmd):
# check for the expected message files # check for the expected message files
chat = ChatDB.from_dir(Path(self.cache_dir.name), chat = ChatDB.from_dir(Path(self.cache_dir.name),
Path(self.db_dir.name)) Path(self.db_dir.name))
cached_msg = chat.msg_gather(loc='cache') cached_msg = chat.msg_gather(loc=msg_location.CACHE)
self.assertEqual(len(self.message_list(self.cache_dir)), 1) self.assertEqual(len(self.message_list(self.cache_dir)), 1)
self.assert_msgs_equal_except_file_path(cached_msg, [expected_question]) self.assert_msgs_equal_except_file_path(cached_msg, [expected_question])
@@ -375,7 +377,7 @@ class TestQuestionCmdRepeat(TestQuestionCmd):
# we expect the original message + the one with the new response # we expect the original message + the one with the new response
expected_responses = [message] + [expected_response] expected_responses = [message] + [expected_response]
question_cmd(self.args, self.config) question_cmd(self.args, self.config)
cached_msg = chat.msg_gather(loc='cache') cached_msg = chat.msg_gather(loc=msg_location.CACHE)
print(self.message_list(self.cache_dir)) print(self.message_list(self.cache_dir))
self.assertEqual(len(self.message_list(self.cache_dir)), 2) self.assertEqual(len(self.message_list(self.cache_dir)), 2)
self.assert_msgs_equal_except_file_path(cached_msg, expected_responses) self.assert_msgs_equal_except_file_path(cached_msg, expected_responses)
@@ -396,7 +398,7 @@ class TestQuestionCmdRepeat(TestQuestionCmd):
model=self.args.model, model=self.args.model,
file_path=Path(self.cache_dir.name) / f'0001{msg_suffix}') file_path=Path(self.cache_dir.name) / f'0001{msg_suffix}')
chat.msg_write([message]) chat.msg_write([message])
cached_msg = chat.msg_gather(loc='cache') cached_msg = chat.msg_gather(loc=msg_location.CACHE)
assert cached_msg[0].file_path assert cached_msg[0].file_path
cached_msg_file_id = cached_msg[0].file_path.stem cached_msg_file_id = cached_msg[0].file_path.stem
@@ -412,7 +414,7 @@ class TestQuestionCmdRepeat(TestQuestionCmd):
tags=message.tags, tags=message.tags,
file_path=Path('<NOT COMPARED>')) file_path=Path('<NOT COMPARED>'))
question_cmd(self.args, self.config) question_cmd(self.args, self.config)
cached_msg = chat.msg_gather(loc='cache') cached_msg = chat.msg_gather(loc=msg_location.CACHE)
self.assertEqual(len(self.message_list(self.cache_dir)), 1) self.assertEqual(len(self.message_list(self.cache_dir)), 1)
self.assert_msgs_equal_except_file_path(cached_msg, [expected_response]) self.assert_msgs_equal_except_file_path(cached_msg, [expected_response])
# also check that the file ID has not been changed # also check that the file ID has not been changed
@@ -435,7 +437,7 @@ class TestQuestionCmdRepeat(TestQuestionCmd):
model=self.args.model, model=self.args.model,
file_path=Path(self.cache_dir.name) / f'0001{msg_suffix}') file_path=Path(self.cache_dir.name) / f'0001{msg_suffix}')
chat.msg_write([message]) chat.msg_write([message])
cached_msg = chat.msg_gather(loc='cache') cached_msg = chat.msg_gather(loc=msg_location.CACHE)
assert cached_msg[0].file_path assert cached_msg[0].file_path
cached_msg_file_id = cached_msg[0].file_path.stem cached_msg_file_id = cached_msg[0].file_path.stem
@@ -452,7 +454,7 @@ class TestQuestionCmdRepeat(TestQuestionCmd):
tags=message.tags, tags=message.tags,
file_path=Path('<NOT COMPARED>')) file_path=Path('<NOT COMPARED>'))
question_cmd(self.args, self.config) question_cmd(self.args, self.config)
cached_msg = chat.msg_gather(loc='cache') cached_msg = chat.msg_gather(loc=msg_location.CACHE)
self.assertEqual(len(self.message_list(self.cache_dir)), 1) self.assertEqual(len(self.message_list(self.cache_dir)), 1)
self.assert_msgs_equal_except_file_path(cached_msg, [expected_response]) self.assert_msgs_equal_except_file_path(cached_msg, [expected_response])
# also check that the file ID has not been changed # also check that the file ID has not been changed
@@ -475,7 +477,7 @@ class TestQuestionCmdRepeat(TestQuestionCmd):
model=self.args.model, model=self.args.model,
file_path=Path(self.cache_dir.name) / f'0001{msg_suffix}') file_path=Path(self.cache_dir.name) / f'0001{msg_suffix}')
chat.msg_write([message]) chat.msg_write([message])
cached_msg = chat.msg_gather(loc='cache') cached_msg = chat.msg_gather(loc=msg_location.CACHE)
assert cached_msg[0].file_path assert cached_msg[0].file_path
# repeat the last question with new arguments (without overwriting) # repeat the last question with new arguments (without overwriting)
@@ -493,7 +495,7 @@ class TestQuestionCmdRepeat(TestQuestionCmd):
tags={Tag('newtag')}, tags={Tag('newtag')},
file_path=Path('<NOT COMPARED>')) file_path=Path('<NOT COMPARED>'))
question_cmd(self.args, self.config) question_cmd(self.args, self.config)
cached_msg = chat.msg_gather(loc='cache') cached_msg = chat.msg_gather(loc=msg_location.CACHE)
self.assertEqual(len(self.message_list(self.cache_dir)), 2) self.assertEqual(len(self.message_list(self.cache_dir)), 2)
self.assert_msgs_equal_except_file_path(cached_msg, [message] + [new_expected_response]) self.assert_msgs_equal_except_file_path(cached_msg, [message] + [new_expected_response])
@@ -513,7 +515,7 @@ class TestQuestionCmdRepeat(TestQuestionCmd):
model=self.args.model, model=self.args.model,
file_path=Path(self.cache_dir.name) / f'0001{msg_suffix}') file_path=Path(self.cache_dir.name) / f'0001{msg_suffix}')
chat.msg_write([message]) chat.msg_write([message])
cached_msg = chat.msg_gather(loc='cache') cached_msg = chat.msg_gather(loc=msg_location.CACHE)
assert cached_msg[0].file_path assert cached_msg[0].file_path
# repeat the last question with new arguments # repeat the last question with new arguments
@@ -530,7 +532,7 @@ class TestQuestionCmdRepeat(TestQuestionCmd):
tags={Tag('newtag')}, tags={Tag('newtag')},
file_path=Path('<NOT COMPARED>')) file_path=Path('<NOT COMPARED>'))
question_cmd(self.args, self.config) question_cmd(self.args, self.config)
cached_msg = chat.msg_gather(loc='cache') cached_msg = chat.msg_gather(loc=msg_location.CACHE)
self.assertEqual(len(self.message_list(self.cache_dir)), 1) self.assertEqual(len(self.message_list(self.cache_dir)), 1)
self.assert_msgs_equal_except_file_path(cached_msg, [new_expected_response]) self.assert_msgs_equal_except_file_path(cached_msg, [new_expected_response])
@@ -586,8 +588,8 @@ class TestQuestionCmdRepeat(TestQuestionCmd):
self.assertEqual(len(self.message_list(self.cache_dir)), 4) self.assertEqual(len(self.message_list(self.cache_dir)), 4)
self.assertEqual(len(self.message_list(self.db_dir)), 1) self.assertEqual(len(self.message_list(self.db_dir)), 1)
expected_cache_messages = [expected_responses[0], message2, expected_responses[1], expected_responses[2]] expected_cache_messages = [expected_responses[0], message2, expected_responses[1], expected_responses[2]]
cached_msg = chat.msg_gather(loc='cache') cached_msg = chat.msg_gather(loc=msg_location.CACHE)
self.assert_msgs_equal_except_file_path(cached_msg, expected_cache_messages) self.assert_msgs_equal_except_file_path(cached_msg, expected_cache_messages)
# check that the DB message has not been modified at all # check that the DB message has not been modified at all
db_msg = chat.msg_gather(loc='db') db_msg = chat.msg_gather(loc=msg_location.DB)
self.assert_msgs_all_equal(db_msg, [message3]) self.assert_msgs_all_equal(db_msg, [message3])