2 Commits

Author SHA1 Message Date
juk0de 065038b1c4 added tests for 'chat.py' 2023-08-31 12:36:33 +02:00
juk0de 2c054aa8db added new module 'chat.py' 2023-08-31 12:36:33 +02:00
25 changed files with 807 additions and 1550 deletions
-74
View File
@@ -1,74 +0,0 @@
from dataclasses import dataclass
from typing import Protocol, Optional, Union
from .configuration import AIConfig
from .tags import Tag
from .message import Message
from .chat import Chat
class AIError(Exception):
pass
@dataclass
class Tokens:
prompt: int = 0
completion: int = 0
total: int = 0
@dataclass
class AIResponse:
"""
The response to an AI request. Consists of one or more messages
(each containing the question and a single answer) and the nr.
of used tokens.
"""
messages: list[Message]
tokens: Optional[Tokens] = None
class AI(Protocol):
"""
The base class for AI clients.
"""
ID: str
name: str
config: AIConfig
def request(self,
question: Message,
chat: Chat,
num_answers: int = 1,
otags: Optional[set[Tag]] = None) -> AIResponse:
"""
Make an AI request. Parameters:
* question: the question to ask
* chat: the chat history to be added as context
* num_answers: nr. of requested answers (corresponds
to the nr. of messages in the 'AIResponse')
* otags: the output tags, i. e. the tags that all
returned messages should contain
"""
raise NotImplementedError
def models(self) -> list[str]:
"""
Return all models supported by this AI.
"""
raise NotImplementedError
def tokens(self, data: Union[Message, Chat]) -> int:
"""
Computes the nr. of AI language tokens for the given message
or chat. Note that the computation may not be 100% accurate
and is not implemented for all AIs.
"""
raise NotImplementedError
def print(self) -> None:
"""
Print some info about the current AI, like system message.
"""
pass
-43
View File
@@ -1,43 +0,0 @@
"""
Creates different AI instances, based on the given configuration.
"""
import argparse
from typing import cast
from .configuration import Config, AIConfig, OpenAIConfig
from .ai import AI, AIError
from .ais.openai import OpenAI
def create_ai(args: argparse.Namespace, config: Config) -> AI: # noqa: 11
"""
Creates an AI subclass instance from the given arguments
and configuration file. If AI has not been set in the
arguments, it searches for the ID 'default'. If that
is not found, it uses the first AI in the list.
"""
ai_conf: AIConfig
if args.AI:
try:
ai_conf = config.ais[args.AI]
except KeyError:
raise AIError(f"AI ID '{args.AI}' does not exist in this configuration")
elif 'default' in config.ais:
ai_conf = config.ais['default']
else:
try:
ai_conf = next(iter(config.ais.values()))
except StopIteration:
raise AIError("No AI found in this configuration")
if ai_conf.name == 'openai':
ai = OpenAI(cast(OpenAIConfig, ai_conf))
if args.model:
ai.config.model = args.model
if args.max_tokens:
ai.config.max_tokens = args.max_tokens
if args.temperature:
ai.config.temperature = args.temperature
return ai
else:
raise AIError(f"AI '{args.AI}' is not supported")
View File
-106
View File
@@ -1,106 +0,0 @@
"""
Implements the OpenAI client classes and functions.
"""
import openai
from typing import Optional, Union
from ..tags import Tag
from ..message import Message, Answer
from ..chat import Chat
from ..ai import AI, AIResponse, Tokens
from ..configuration import OpenAIConfig
ChatType = list[dict[str, str]]
class OpenAI(AI):
"""
The OpenAI AI client.
"""
def __init__(self, config: OpenAIConfig) -> None:
self.ID = config.ID
self.name = config.name
self.config = config
openai.api_key = config.api_key
def request(self,
question: Message,
chat: Chat,
num_answers: int = 1,
otags: Optional[set[Tag]] = None) -> AIResponse:
"""
Make an AI request, asking the given question with the given
chat history. The nr. of requested answers corresponds to the
nr. of messages in the 'AIResponse'.
"""
oai_chat = self.openai_chat(chat, self.config.system, question)
response = openai.ChatCompletion.create(
model=self.config.model,
messages=oai_chat,
temperature=self.config.temperature,
max_tokens=self.config.max_tokens,
top_p=self.config.top_p,
n=num_answers,
frequency_penalty=self.config.frequency_penalty,
presence_penalty=self.config.presence_penalty)
question.answer = Answer(response['choices'][0]['message']['content'])
question.tags = otags
question.ai = self.ID
question.model = self.config.model
answers: list[Message] = [question]
for choice in response['choices'][1:]: # type: ignore
answers.append(Message(question=question.question,
answer=Answer(choice['message']['content']),
tags=otags,
ai=self.ID,
model=self.config.model))
return AIResponse(answers, Tokens(response['usage']['prompt_tokens'],
response['usage']['completion_tokens'],
response['usage']['total_tokens']))
def models(self) -> list[str]:
"""
Return all models supported by this AI.
"""
raise NotImplementedError
def print_models(self) -> None:
"""
Print all models supported by the current AI.
"""
not_ready = []
for engine in sorted(openai.Engine.list()['data'], key=lambda x: x['id']):
if engine['ready']:
print(engine['id'])
else:
not_ready.append(engine['id'])
if len(not_ready) > 0:
print('\nNot ready: ' + ', '.join(not_ready))
def openai_chat(self, chat: Chat, system: str,
question: Optional[Message] = None) -> ChatType:
"""
Create a chat history with system message in OpenAI format.
Optionally append a new question.
"""
oai_chat: ChatType = []
def append(role: str, content: str) -> None:
oai_chat.append({'role': role, 'content': content.replace("''", "'")})
append('system', system)
for message in chat.messages:
if message.answer:
append('user', message.question)
append('assistant', message.answer)
if question:
append('user', question.question)
return oai_chat
def tokens(self, data: Union[Message, Chat]) -> int:
raise NotImplementedError
def print(self) -> None:
print(f"MODEL: {self.config.model}")
print("=== SYSTEM ===")
print(self.config.system)
+45
View File
@@ -0,0 +1,45 @@
import openai
from .utils import ChatType
from .configuration import Config
def openai_api_key(api_key: str) -> None:
openai.api_key = api_key
def print_models() -> None:
"""
Print all models supported by the current AI.
"""
not_ready = []
for engine in sorted(openai.Engine.list()['data'], key=lambda x: x['id']):
if engine['ready']:
print(engine['id'])
else:
not_ready.append(engine['id'])
if len(not_ready) > 0:
print('\nNot ready: ' + ', '.join(not_ready))
def ai(chat: ChatType,
config: Config,
number: int
) -> tuple[list[str], dict[str, int]]:
"""
Make AI request with the given chat history and configuration.
Return AI response and tokens used.
"""
response = openai.ChatCompletion.create(
model=config.openai.model,
messages=chat,
temperature=config.openai.temperature,
max_tokens=config.openai.max_tokens,
top_p=config.openai.top_p,
n=number,
frequency_penalty=config.openai.frequency_penalty,
presence_penalty=config.openai.presence_penalty)
result = []
for choice in response['choices']: # type: ignore
result.append(choice['message']['content'].strip())
return result, dict(response['usage']) # type: ignore
+43 -181
View File
@@ -2,12 +2,12 @@
Module implementing various chat classes and functions for managing a chat history. Module implementing various chat classes and functions for managing a chat history.
""" """
import shutil import shutil
from pathlib import Path import pathlib
from pprint import PrettyPrinter from pprint import PrettyPrinter
from pydoc import pager from pydoc import pager
from dataclasses import dataclass from dataclasses import dataclass
from typing import TypeVar, Type, Optional, ClassVar, Any, Callable from typing import TypeVar, Type, Optional, ClassVar, Any, Callable
from .message import Message, MessageFilter, MessageError, message_in from .message import Question, Answer, Message, MessageFilter, MessageError, source_code, message_in
from .tags import Tag from .tags import Tag
ChatInst = TypeVar('ChatInst', bound='Chat') ChatInst = TypeVar('ChatInst', bound='Chat')
@@ -30,7 +30,7 @@ def print_paged(text: str) -> None:
pager(text) pager(text)
def read_dir(dir_path: Path, def read_dir(dir_path: pathlib.Path,
glob: Optional[str] = None, glob: Optional[str] = None,
mfilter: Optional[MessageFilter] = None) -> list[Message]: mfilter: Optional[MessageFilter] = None) -> list[Message]:
""" """
@@ -45,7 +45,7 @@ def read_dir(dir_path: Path,
messages: list[Message] = [] messages: list[Message] = []
file_iter = dir_path.glob(glob) if glob else dir_path.iterdir() file_iter = dir_path.glob(glob) if glob else dir_path.iterdir()
for file_path in sorted(file_iter): for file_path in sorted(file_iter):
if file_path.is_file() and file_path.suffix in Message.file_suffixes: if file_path.is_file():
try: try:
message = Message.from_file(file_path, mfilter) message = Message.from_file(file_path, mfilter)
if message: if message:
@@ -55,20 +55,7 @@ def read_dir(dir_path: Path,
return messages return messages
def make_file_path(dir_path: Path, def write_dir(dir_path: pathlib.Path,
file_suffix: str,
next_fid: Callable[[], int]) -> Path:
"""
Create a file_path for the given directory using the
given file_suffix and ID generator function.
"""
file_path = dir_path / f"{next_fid():04d}{file_suffix}"
while file_path.exists():
file_path = dir_path / f"{next_fid():04d}{file_suffix}"
return file_path
def write_dir(dir_path: Path,
messages: list[Message], messages: list[Message],
file_suffix: str, file_suffix: str,
next_fid: Callable[[], int]) -> None: next_fid: Callable[[], int]) -> None:
@@ -86,24 +73,15 @@ def write_dir(dir_path: Path,
file_path = message.file_path file_path = message.file_path
# message has no file_path: create one # message has no file_path: create one
if not file_path: if not file_path:
file_path = make_file_path(dir_path, file_suffix, next_fid) fid = next_fid()
fname = f"{fid:04d}{file_suffix}"
file_path = dir_path / fname
# file_path does not point to given directory: modify it # file_path does not point to given directory: modify it
elif not file_path.parent.samefile(dir_path): elif not file_path.parent.samefile(dir_path):
file_path = dir_path / file_path.name file_path = dir_path / file_path.name
message.to_file(file_path) message.to_file(file_path)
def clear_dir(dir_path: Path,
glob: Optional[str] = None) -> None:
"""
Deletes all Message files in the given directory.
"""
file_iter = dir_path.glob(glob) if glob else dir_path.iterdir()
for file_path in file_iter:
if file_path.is_file() and file_path.suffix in Message.file_suffixes:
file_path.unlink(missing_ok=True)
@dataclass @dataclass
class Chat: class Chat:
""" """
@@ -135,39 +113,11 @@ class Chat:
""" """
self.messages = [] self.messages = []
def add_messages(self, messages: list[Message]) -> None: def add_msgs(self, msgs: list[Message]) -> None:
""" """
Add new messages and sort them if possible. Add new messages and sort them if possible.
""" """
self.messages += messages self.messages += msgs
self.sort()
def latest_message(self) -> Optional[Message]:
"""
Returns the last added message (according to the file ID).
"""
if len(self.messages) > 0:
self.sort()
return self.messages[-1]
else:
return None
def find_messages(self, msg_names: list[str]) -> list[Message]:
"""
Search and return the messages with the given names. Names can either be filenames
(incl. suffixes) or full paths. Messages that can't be found are ignored (i. e. the
caller should check the result if he requires all messages).
"""
return [m for m in self.messages
if any((m.file_path and (m.file_path == Path(mn) or m.file_path.name == mn)) for mn in msg_names)]
def remove_messages(self, msg_names: list[str]) -> None:
"""
Remove the messages with the given names. Names can either be filenames
(incl. the suffix) or full paths.
"""
self.messages = [m for m in self.messages
if not any((m.file_path and (m.file_path == Path(mn) or m.file_path.name == mn)) for mn in msg_names)]
self.sort() self.sort()
def tags(self, prefix: Optional[str] = None, contain: Optional[str] = None) -> set[Tag]: def tags(self, prefix: Optional[str] = None, contain: Optional[str] = None) -> set[Tag]:
@@ -177,34 +127,29 @@ class Chat:
tags: set[Tag] = set() tags: set[Tag] = set()
for m in self.messages: for m in self.messages:
tags |= m.filter_tags(prefix, contain) tags |= m.filter_tags(prefix, contain)
return set(sorted(tags)) return tags
def tags_frequency(self, prefix: Optional[str] = None, contain: Optional[str] = None) -> dict[Tag, int]: def print(self, dump: bool = False, source_code_only: bool = False,
""" with_tags: bool = False, with_file: bool = False,
Get the frequency of all tags of all messages, optionally filtered by prefix or substring.
"""
tags: list[Tag] = []
for m in self.messages:
tags += [tag for tag in m.filter_tags(prefix, contain)]
return {tag: tags.count(tag) for tag in sorted(tags)}
def tokens(self) -> int:
"""
Returns the nr. of AI language tokens used by all messages in this chat.
If unknown, 0 is returned.
"""
return sum(m.tokens() for m in self.messages)
def print(self, source_code_only: bool = False,
with_tags: bool = False, with_files: bool = False,
paged: bool = True) -> None: paged: bool = True) -> None:
if dump:
pp(self)
return
output: list[str] = [] output: list[str] = []
for message in self.messages: for message in self.messages:
if source_code_only: if source_code_only:
output.append(message.to_str(source_code_only=True)) output.extend(source_code(message.question, include_delims=True))
continue continue
output.append(message.to_str(with_tags, with_files)) output.append('-' * terminal_width())
output.append('\n' + ('-' * terminal_width()) + '\n') output.append(Question.txt_header)
output.append(message.question)
if message.answer:
output.append(Answer.txt_header)
output.append(message.answer)
if with_tags:
output.append(message.tags_str())
if with_file:
output.append('FILE: ' + str(message.file_path))
if paged: if paged:
print_paged('\n'.join(output)) print_paged('\n'.join(output))
else: else:
@@ -223,8 +168,8 @@ class ChatDB(Chat):
default_file_suffix: ClassVar[str] = '.txt' default_file_suffix: ClassVar[str] = '.txt'
cache_path: Path cache_path: pathlib.Path
db_path: Path db_path: pathlib.Path
# a MessageFilter that all messages must match (if given) # a MessageFilter that all messages must match (if given)
mfilter: Optional[MessageFilter] = None mfilter: Optional[MessageFilter] = None
file_suffix: str = default_file_suffix file_suffix: str = default_file_suffix
@@ -234,14 +179,11 @@ class ChatDB(Chat):
def __post_init__(self) -> None: def __post_init__(self) -> None:
# contains the latest message ID # contains the latest message ID
self.next_fname = self.db_path / '.next' self.next_fname = self.db_path / '.next'
# make all paths absolute
self.cache_path = self.cache_path.absolute()
self.db_path = self.db_path.absolute()
@classmethod @classmethod
def from_dir(cls: Type[ChatDBInst], def from_dir(cls: Type[ChatDBInst],
cache_path: Path, cache_path: pathlib.Path,
db_path: Path, db_path: pathlib.Path,
glob: Optional[str] = None, glob: Optional[str] = None,
mfilter: Optional[MessageFilter] = None) -> ChatDBInst: mfilter: Optional[MessageFilter] = None) -> ChatDBInst:
""" """
@@ -261,8 +203,8 @@ class ChatDB(Chat):
@classmethod @classmethod
def from_messages(cls: Type[ChatDBInst], def from_messages(cls: Type[ChatDBInst],
cache_path: Path, cache_path: pathlib.Path,
db_path: Path, db_path: pathlib.Path,
messages: list[Message], messages: list[Message],
mfilter: Optional[MessageFilter] = None) -> ChatDBInst: mfilter: Optional[MessageFilter] = None) -> ChatDBInst:
""" """
@@ -310,98 +252,18 @@ class ChatDB(Chat):
self.messages += new_messages self.messages += new_messages
self.sort() self.sort()
def write_db(self, messages: Optional[list[Message]] = None) -> None: def write_db(self) -> None:
""" """
Write messages to the DB directory. If a message has no file_path, a new one Write all messages to the DB directory. If a message has no file_path,
will be created. If message.file_path exists, it will be modified to point a new one will be created. If message.file_path exists, it will be modified
to the DB directory. to point to the DB directory.
""" """
write_dir(self.db_path, write_dir(self.db_path, self.messages, self.file_suffix, self.get_next_fid)
messages if messages else self.messages,
self.file_suffix,
self.get_next_fid)
def write_cache(self, messages: Optional[list[Message]] = None) -> None: def write_cache(self) -> None:
""" """
Write messages to the cache directory. If a message has no file_path, a new one Write all messages to the cache directory. If a message has no file_path,
will be created. If message.file_path exists, it will be modified to point to a new one will be created. If message.file_path exists, it will be modified
the cache directory. to point to the cache directory.
""" """
write_dir(self.cache_path, write_dir(self.cache_path, self.messages, self.file_suffix, self.get_next_fid)
messages if messages else self.messages,
self.file_suffix,
self.get_next_fid)
def clear_cache(self) -> None:
"""
Deletes all Message files from the cache dir and removes those messages from
the internal list.
"""
clear_dir(self.cache_path, self.glob)
# only keep messages from DB dir (or those that have not yet been written)
self.messages = [m for m in self.messages if not m.file_path or m.file_path.parent.samefile(self.db_path)]
def add_to_db(self, messages: list[Message], write: bool = True) -> None:
"""
Add the given new messages and set the file_path to the DB directory.
Only accepts messages without a file_path.
"""
if any(m.file_path is not None for m in messages):
raise ChatError("Can't add new messages with existing file_path")
if write:
write_dir(self.db_path,
messages,
self.file_suffix,
self.get_next_fid)
else:
for m in messages:
m.file_path = make_file_path(self.db_path, self.default_file_suffix, self.get_next_fid)
self.messages += messages
self.sort()
def add_to_cache(self, messages: list[Message], write: bool = True) -> None:
"""
Add the given new messages and set the file_path to the cache directory.
Only accepts messages without a file_path.
"""
if any(m.file_path is not None for m in messages):
raise ChatError("Can't add new messages with existing file_path")
if write:
write_dir(self.cache_path,
messages,
self.file_suffix,
self.get_next_fid)
else:
for m in messages:
m.file_path = make_file_path(self.cache_path, self.default_file_suffix, self.get_next_fid)
self.messages += messages
self.sort()
def write_messages(self, messages: Optional[list[Message]] = None) -> None:
"""
Write either the given messages or the internal ones to their current file_path.
If messages are given, they all must have a valid file_path. When writing the
internal messages, the ones with a valid file_path are written, the others
are ignored.
"""
if messages and any(m.file_path is None for m in messages):
raise ChatError("Can't write files without a valid file_path")
msgs = iter(messages if messages else self.messages)
while (m := next(msgs, None)):
m.to_file()
def update_messages(self, messages: list[Message], write: bool = True) -> None:
"""
Update existing messages. A message is determined as 'existing' if a message with
the same base filename (i. e. 'file_path.name') is already in the list. Only accepts
existing messages.
"""
if any(not message_in(m, self.messages) for m in messages):
raise ChatError("Can't update messages that are not in the internal list")
# remove old versions and add new ones
self.messages = [m for m in self.messages if not message_in(m, messages)]
self.messages += messages
self.sort()
# write the UPDATED messages if requested
if write:
self.write_messages(messages)
-11
View File
@@ -1,11 +0,0 @@
import argparse
from pathlib import Path
from ..configuration import Config
def config_cmd(args: argparse.Namespace) -> None:
"""
Handler for the 'config' command.
"""
if args.create:
Config.create_default(Path(args.create))
-23
View File
@@ -1,23 +0,0 @@
import argparse
from pathlib import Path
from ..configuration import Config
from ..chat import ChatDB
from ..message import MessageFilter
def hist_cmd(args: argparse.Namespace, config: Config) -> None:
"""
Handler for the 'hist' command.
"""
mfilter = MessageFilter(tags_or=args.or_tags,
tags_and=args.and_tags,
tags_not=args.exclude_tags,
question_contains=args.question,
answer_contains=args.answer)
chat = ChatDB.from_dir(Path('.'),
Path(config.db),
mfilter=mfilter)
chat.print(args.source_code_only,
args.with_tags,
args.with_files)
-27
View File
@@ -1,27 +0,0 @@
import sys
import argparse
from pathlib import Path
from ..configuration import Config
from ..message import Message, MessageError
def print_cmd(args: argparse.Namespace, config: Config) -> None:
"""
Handler for the 'print' command.
"""
fname = Path(args.file)
try:
message = Message.from_file(fname)
if message:
if args.question:
print(message.question)
elif args.answer:
print(message.answer)
elif message.answer and args.only_source_code:
for code in message.answer.source_code():
print(code)
else:
print(message.to_str())
except MessageError:
print(f"File is not a valid message: {args.file}")
sys.exit(1)
-94
View File
@@ -1,94 +0,0 @@
import argparse
from pathlib import Path
from itertools import zip_longest
from ..configuration import Config
from ..chat import ChatDB
from ..message import Message, MessageFilter, Question, source_code
from ..ai_factory import create_ai
from ..ai import AI, AIResponse
def create_message(chat: ChatDB, args: argparse.Namespace) -> Message:
"""
Creates (and writes) a new message from the given arguments.
"""
question_parts = []
question_list = args.ask if args.ask is not None else []
text_files = args.source_text if args.source_text is not None else []
code_files = args.source_code if args.source_code is not None else []
for question, source, code in zip_longest(question_list, text_files, code_files, fillvalue=None):
if question is not None and len(question.strip()) > 0:
question_parts.append(question)
if source is not None and len(source) > 0:
with open(source) as r:
content = r.read().strip()
if len(content) > 0:
question_parts.append(content)
if code is not None and len(code) > 0:
with open(code) as r:
content = r.read().strip()
if len(content) == 0:
continue
# try to extract and add source code
code_parts = source_code(content, include_delims=True)
if len(code_parts) > 0:
question_parts += code_parts
# if there's none, add the whole file
else:
question_parts.append(f"```\n{content}\n```")
full_question = '\n\n'.join(question_parts)
message = Message(question=Question(full_question),
tags=args.output_tags, # FIXME
ai=args.AI,
model=args.model)
chat.add_to_cache([message])
return message
def question_cmd(args: argparse.Namespace, config: Config) -> None:
"""
Handler for the 'question' command.
"""
mfilter = MessageFilter(tags_or=args.or_tags if args.or_tags is not None else set(),
tags_and=args.and_tags if args.and_tags is not None else set(),
tags_not=args.exclude_tags if args.exclude_tags is not None else set())
chat = ChatDB.from_dir(cache_path=Path('.'),
db_path=Path(config.db),
mfilter=mfilter)
# if it's a new question, create and store it immediately
if args.ask or args.create:
message = create_message(chat, args)
if args.create:
return
# create the correct AI instance
ai: AI = create_ai(args, config)
if args.ask:
ai.print()
chat.print(paged=False)
response: AIResponse = ai.request(message,
chat,
args.num_answers, # FIXME
args.output_tags) # FIXME
chat.update_messages([response.messages[0]])
chat.add_to_cache(response.messages[1:])
for idx, msg in enumerate(response.messages):
print(f"=== ANSWER {idx+1} ===")
print(msg.answer)
if response.tokens:
print("===============")
print(response.tokens)
elif args.repeat is not None:
lmessage = chat.latest_message()
assert lmessage
# TODO: repeat either the last question or the
# one(s) given in 'args.repeat' (overwrite
# existing ones if 'args.overwrite' is True)
pass
elif args.process is not None:
# TODO: process either all questions without an
# answer or the one(s) given in 'args.process'
pass
-17
View File
@@ -1,17 +0,0 @@
import argparse
from pathlib import Path
from ..configuration import Config
from ..chat import ChatDB
def tags_cmd(args: argparse.Namespace, config: Config) -> None:
"""
Handler for the 'tags' command.
"""
chat = ChatDB.from_dir(cache_path=Path('.'),
db_path=Path(config.db))
if args.list:
tags_freq = chat.tags_frequency(args.prefix, args.contain)
for tag, freq in tags_freq.items():
print(f"- {tag}: {freq}")
# TODO: add renaming
+23 -120
View File
@@ -1,166 +1,69 @@
import yaml import yaml
from pathlib import Path from typing import Type, TypeVar, Any
from typing import Type, TypeVar, Any, Optional, ClassVar from dataclasses import dataclass, asdict
from dataclasses import dataclass, asdict, field
ConfigInst = TypeVar('ConfigInst', bound='Config') ConfigInst = TypeVar('ConfigInst', bound='Config')
AIConfigInst = TypeVar('AIConfigInst', bound='AIConfig')
OpenAIConfigInst = TypeVar('OpenAIConfigInst', bound='OpenAIConfig') OpenAIConfigInst = TypeVar('OpenAIConfigInst', bound='OpenAIConfig')
supported_ais: list[str] = ['openai']
default_config_path = '.config.yaml'
class ConfigError(Exception):
pass
def str_presenter(dumper: yaml.Dumper, data: str) -> yaml.ScalarNode:
"""
Changes the YAML dump style to multiline syntax for multiline strings.
"""
if len(data.splitlines()) > 1:
return dumper.represent_scalar('tag:yaml.org,2002:str', data, style='|')
return dumper.represent_scalar('tag:yaml.org,2002:str', data)
yaml.add_representer(str, str_presenter)
@dataclass @dataclass
class AIConfig: class OpenAIConfig():
"""
The base class of all AI configurations.
"""
# the name of the AI the config class represents
# -> it's a class variable and thus not part of the
# dataclass constructor
name: ClassVar[str]
# a user-defined ID for an AI configuration entry
ID: str
# the name must not be changed
def __setattr__(self, name: str, value: Any) -> None:
if name == 'name':
raise AttributeError("'{name}' is not allowed to be changed")
else:
super().__setattr__(name, value)
@dataclass
class OpenAIConfig(AIConfig):
""" """
The OpenAI section of the configuration file. The OpenAI section of the configuration file.
""" """
name: ClassVar[str] = 'openai' api_key: str
model: str
# all members have default values, so we can easily create temperature: float
# a default configuration max_tokens: int
ID: str = 'myopenai' top_p: float
api_key: str = '0123456789' frequency_penalty: float
model: str = 'gpt-3.5-turbo-16k' presence_penalty: float
temperature: float = 1.0
max_tokens: int = 4000
top_p: float = 1.0
frequency_penalty: float = 0.0
presence_penalty: float = 0.0
system: str = 'You are an assistant'
@classmethod @classmethod
def from_dict(cls: Type[OpenAIConfigInst], source: dict[str, Any]) -> OpenAIConfigInst: def from_dict(cls: Type[OpenAIConfigInst], source: dict[str, Any]) -> OpenAIConfigInst:
""" """
Create OpenAIConfig from a dict. Create OpenAIConfig from a dict.
""" """
res = cls( return cls(
api_key=str(source['api_key']), api_key=str(source['api_key']),
model=str(source['model']), model=str(source['model']),
max_tokens=int(source['max_tokens']), max_tokens=int(source['max_tokens']),
temperature=float(source['temperature']), temperature=float(source['temperature']),
top_p=float(source['top_p']), top_p=float(source['top_p']),
frequency_penalty=float(source['frequency_penalty']), frequency_penalty=float(source['frequency_penalty']),
presence_penalty=float(source['presence_penalty']), presence_penalty=float(source['presence_penalty'])
system=str(source['system'])
) )
# overwrite default ID if provided
if 'ID' in source:
res.ID = source['ID']
return res
def ai_config_instance(name: str, conf_dict: Optional[dict[str, Any]] = None) -> AIConfig:
"""
Creates an AIConfig instance of the given name.
"""
if name.lower() == 'openai':
if conf_dict is None:
return OpenAIConfig()
else:
return OpenAIConfig.from_dict(conf_dict)
else:
raise ConfigError(f"Unknown AI '{name}'")
def create_default_ai_configs() -> dict[str, AIConfig]:
"""
Create a dict containing default configurations for all supported AIs.
"""
return {ai_config_instance(name).ID: ai_config_instance(name) for name in supported_ais}
@dataclass @dataclass
class Config: class Config():
""" """
The configuration file structure. The configuration file structure.
""" """
# all members have default values, so we can easily create system: str
# a default configuration db: str
db: str = './db/' openai: OpenAIConfig
ais: dict[str, AIConfig] = field(default_factory=create_default_ai_configs)
@classmethod @classmethod
def from_dict(cls: Type[ConfigInst], source: dict[str, Any]) -> ConfigInst: def from_dict(cls: Type[ConfigInst], source: dict[str, Any]) -> ConfigInst:
""" """
Create Config from a dict (with the same format as the config file). Create OpenAIConfig from a dict.
""" """
# create the correct AI type instances
ais: dict[str, AIConfig] = {}
for ID, conf in source['ais'].items():
# add the AI ID to the config (for easy internal access)
conf['ID'] = ID
ai_conf = ai_config_instance(conf['name'], conf)
ais[ID] = ai_conf
return cls( return cls(
system=str(source['system']),
db=str(source['db']), db=str(source['db']),
ais=ais openai=OpenAIConfig.from_dict(source['openai'])
) )
@classmethod
def create_default(self, file_path: Path) -> None:
"""
Creates a default Config in the given file.
"""
conf = Config()
conf.to_file(file_path)
@classmethod @classmethod
def from_file(cls: Type[ConfigInst], path: str) -> ConfigInst: def from_file(cls: Type[ConfigInst], path: str) -> ConfigInst:
with open(path, 'r') as f: with open(path, 'r') as f:
source = yaml.load(f, Loader=yaml.FullLoader) source = yaml.load(f, Loader=yaml.FullLoader)
return cls.from_dict(source) return cls.from_dict(source)
def to_file(self, file_path: Path) -> None: def to_file(self, path: str) -> None:
# remove the AI name from the config (for a cleaner format) with open(path, 'w') as f:
data = self.as_dict() yaml.dump(asdict(self), f, sort_keys=False)
for conf in data['ais'].values():
del (conf['ID'])
with open(file_path, 'w') as f:
yaml.dump(data, f, sort_keys=False)
def as_dict(self) -> dict[str, Any]: def as_dict(self) -> dict[str, Any]:
res = asdict(self) return asdict(self)
# add the AI name manually (as first element)
# (not done by 'asdict' because it's a class variable)
for ID, conf in res['ais'].items():
res['ais'][ID] = {**{'name': self.ais[ID].name}, **conf}
return res
+164 -66
View File
@@ -2,29 +2,138 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# vim: set fileencoding=utf-8 : # vim: set fileencoding=utf-8 :
import yaml
import sys import sys
import argcomplete import argcomplete
import argparse import argparse
from pathlib import Path import pathlib
from .utils import terminal_width, print_tag_args, print_chat_hist, display_source_code, print_tags_frequency, ChatType
from .storage import save_answers, create_chat_hist, get_tags, get_tags_unique, read_file, dump_data
from .api_client import ai, openai_api_key, print_models
from .configuration import Config
from itertools import zip_longest
from typing import Any from typing import Any
from .configuration import Config, default_config_path
from .message import Message default_config = '.config.yaml'
from .commands.question import question_cmd
from .commands.tags import tags_cmd
from .commands.config import config_cmd
from .commands.hist import hist_cmd
from .commands.print import print_cmd
def tags_completer(prefix: str, parsed_args: Any, **kwargs: Any) -> list[str]: def tags_completer(prefix: str, parsed_args: Any, **kwargs: Any) -> list[str]:
config = Config.from_file(parsed_args.config) with open(parsed_args.config, 'r') as f:
return list(Message.tags_from_dir(Path(config.db), prefix=prefix)) config = yaml.load(f, Loader=yaml.FullLoader)
return get_tags_unique(config, prefix)
def create_question_with_hist(args: argparse.Namespace,
config: Config,
) -> tuple[ChatType, str, list[str]]:
"""
Creates the "AI request", including the question and chat history as determined
by the specified tags.
"""
tags = args.tags or []
extags = args.extags or []
otags = args.output_tags or []
if not args.only_source_code:
print_tag_args(tags, extags, otags)
question_parts = []
question_list = args.question if args.question is not None else []
source_list = args.source if args.source is not None else []
for question, source in zip_longest(question_list, source_list, fillvalue=None):
if question is not None and source is not None:
with open(source) as r:
question_parts.append(f"{question}\n\n```\n{r.read().strip()}\n```")
elif question is not None:
question_parts.append(question)
elif source is not None:
with open(source) as r:
question_parts.append(f"```\n{r.read().strip()}\n```")
full_question = '\n\n'.join(question_parts)
chat = create_chat_hist(full_question, tags, extags, config,
args.match_all_tags, False, False)
return chat, full_question, tags
def tag_cmd(args: argparse.Namespace, config: Config) -> None:
"""
Handler for the 'tag' command.
"""
if args.list:
print_tags_frequency(get_tags(config, None))
def config_cmd(args: argparse.Namespace, config: Config) -> None:
"""
Handler for the 'config' command.
"""
if args.list_models:
print_models()
elif args.print_model:
print(config.openai.model)
elif args.model:
config.openai.model = args.model
config.to_file(args.config)
def ask_cmd(args: argparse.Namespace, config: Config) -> None:
"""
Handler for the 'ask' command.
"""
if args.max_tokens:
config.openai.max_tokens = args.max_tokens
if args.temperature:
config.openai.temperature = args.temperature
if args.model:
config.openai.model = args.model
chat, question, tags = create_question_with_hist(args, config)
print_chat_hist(chat, False, args.only_source_code)
otags = args.output_tags or []
answers, usage = ai(chat, config, args.number)
save_answers(question, answers, tags, otags, config)
print("-" * terminal_width())
print(f"Usage: {usage}")
def hist_cmd(args: argparse.Namespace, config: Config) -> None:
"""
Handler for the 'hist' command.
"""
tags = args.tags or []
extags = args.extags or []
chat = create_chat_hist(None, tags, extags, config,
args.match_all_tags,
args.with_tags,
args.with_files)
print_chat_hist(chat, args.dump, args.only_source_code)
def print_cmd(args: argparse.Namespace, config: Config) -> None:
"""
Handler for the 'print' command.
"""
fname = pathlib.Path(args.file)
if fname.suffix == '.yaml':
with open(args.file, 'r') as f:
data = yaml.load(f, Loader=yaml.FullLoader)
elif fname.suffix == '.txt':
data = read_file(fname)
else:
print(f"Unknown file type: {args.file}")
sys.exit(1)
if args.only_source_code:
display_source_code(data['answer'])
else:
print(dump_data(data).strip())
def create_parser() -> argparse.ArgumentParser: def create_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
description="ChatMastermind is a Python application that automates conversation with AI") description="ChatMastermind is a Python application that automates conversation with AI")
parser.add_argument('-C', '--config', help='Config file name.', default=default_config_path) parser.add_argument('-c', '--config', help='Config file name.', default=default_config)
# subcommand-parser # subcommand-parser
cmdparser = parser.add_subparsers(dest='command', cmdparser = parser.add_subparsers(dest='command',
@@ -34,66 +143,58 @@ def create_parser() -> argparse.ArgumentParser:
# a parent parser for all commands that support tag selection # a parent parser for all commands that support tag selection
tag_parser = argparse.ArgumentParser(add_help=False) tag_parser = argparse.ArgumentParser(add_help=False)
tag_arg = tag_parser.add_argument('-t', '--or-tags', nargs='+', tag_arg = tag_parser.add_argument('-t', '--tags', nargs='+',
help='List of tags (one must match)', metavar='OTAGS') help='List of tag names', metavar='TAGS')
tag_arg.completer = tags_completer # type: ignore tag_arg.completer = tags_completer # type: ignore
atag_arg = tag_parser.add_argument('-k', '--and-tags', nargs='+', extag_arg = tag_parser.add_argument('-e', '--extags', nargs='+',
help='List of tags (all must match)', metavar='ATAGS') help='List of tag names to exclude', metavar='EXTAGS')
atag_arg.completer = tags_completer # type: ignore extag_arg.completer = tags_completer # type: ignore
etag_arg = tag_parser.add_argument('-x', '--exclude-tags', nargs='+',
help='List of tags to exclude', metavar='XTAGS')
etag_arg.completer = tags_completer # type: ignore
otag_arg = tag_parser.add_argument('-o', '--output-tags', nargs='+', otag_arg = tag_parser.add_argument('-o', '--output-tags', nargs='+',
help='List of output tags (default: use input tags)', metavar='OUTTAGS') help='List of output tag names, default is input', metavar='OTAGS')
otag_arg.completer = tags_completer # type: ignore otag_arg.completer = tags_completer # type: ignore
tag_parser.add_argument('-a', '--match-all-tags',
help="All given tags must match when selecting chat history entries",
action='store_true')
# enable autocompletion for tags
# a parent parser for all commands that support AI configuration # 'ask' command parser
ai_parser = argparse.ArgumentParser(add_help=False) ask_cmd_parser = cmdparser.add_parser('ask', parents=[tag_parser],
ai_parser.add_argument('-A', '--AI', help='AI ID to use') help="Ask a question.",
ai_parser.add_argument('-M', '--model', help='Model to use') aliases=['a'])
ai_parser.add_argument('-n', '--num-answers', help='Number of answers to request', type=int, default=1) ask_cmd_parser.set_defaults(func=ask_cmd)
ai_parser.add_argument('-m', '--max-tokens', help='Max. nr. of tokens', type=int) ask_cmd_parser.add_argument('-q', '--question', nargs='+', help='Question to ask',
ai_parser.add_argument('-T', '--temperature', help='Temperature value', type=float) required=True)
ask_cmd_parser.add_argument('-m', '--max-tokens', help='Max tokens to use', type=int)
# 'question' command parser ask_cmd_parser.add_argument('-T', '--temperature', help='Temperature to use', type=float)
question_cmd_parser = cmdparser.add_parser('question', parents=[tag_parser, ai_parser], ask_cmd_parser.add_argument('-M', '--model', help='Model to use')
help="ask, create and process questions.", ask_cmd_parser.add_argument('-n', '--number', help='Number of answers to produce', type=int,
aliases=['q']) default=1)
question_cmd_parser.set_defaults(func=question_cmd) ask_cmd_parser.add_argument('-s', '--source', nargs='+', help='Source add content of a file to the query')
question_group = question_cmd_parser.add_mutually_exclusive_group(required=True) ask_cmd_parser.add_argument('-S', '--only-source-code', help='Add pure source code to the chat history',
question_group.add_argument('-a', '--ask', nargs='+', help='Ask a question') action='store_true')
question_group.add_argument('-c', '--create', nargs='+', help='Create a question')
question_group.add_argument('-r', '--repeat', nargs='*', help='Repeat a question')
question_group.add_argument('-p', '--process', nargs='*', help='Process existing questions')
question_cmd_parser.add_argument('-O', '--overwrite', help='Overwrite existing messages when repeating them',
action='store_true')
question_cmd_parser.add_argument('-s', '--source-text', nargs='+', help='Add content of a file to the query')
question_cmd_parser.add_argument('-S', '--source-code', nargs='+', help='Add source code file content to the chat history')
# 'hist' command parser # 'hist' command parser
hist_cmd_parser = cmdparser.add_parser('hist', parents=[tag_parser], hist_cmd_parser = cmdparser.add_parser('hist', parents=[tag_parser],
help="Print chat history.", help="Print chat history.",
aliases=['h']) aliases=['h'])
hist_cmd_parser.set_defaults(func=hist_cmd) hist_cmd_parser.set_defaults(func=hist_cmd)
hist_cmd_parser.add_argument('-d', '--dump', help="Print chat history as Python structure",
action='store_true')
hist_cmd_parser.add_argument('-w', '--with-tags', help="Print chat history with tags.", hist_cmd_parser.add_argument('-w', '--with-tags', help="Print chat history with tags.",
action='store_true') action='store_true')
hist_cmd_parser.add_argument('-W', '--with-files', help="Print chat history with filenames.", hist_cmd_parser.add_argument('-W', '--with-files', help="Print chat history with filenames.",
action='store_true') action='store_true')
hist_cmd_parser.add_argument('-S', '--source-code-only', help='Print only source code', hist_cmd_parser.add_argument('-S', '--only-source-code', help='Print only source code',
action='store_true') action='store_true')
hist_cmd_parser.add_argument('-A', '--answer', help='Search for answer substring')
hist_cmd_parser.add_argument('-Q', '--question', help='Search for question substring')
# 'tags' command parser # 'tag' command parser
tags_cmd_parser = cmdparser.add_parser('tags', tag_cmd_parser = cmdparser.add_parser('tag',
help="Manage tags.", help="Manage tags.",
aliases=['t']) aliases=['t'])
tags_cmd_parser.set_defaults(func=tags_cmd) tag_cmd_parser.set_defaults(func=tag_cmd)
tags_group = tags_cmd_parser.add_mutually_exclusive_group(required=True) tag_group = tag_cmd_parser.add_mutually_exclusive_group(required=True)
tags_group.add_argument('-l', '--list', help="List all tags and their frequency", tag_group.add_argument('-l', '--list', help="List all tags and their frequency",
action='store_true') action='store_true')
tags_cmd_parser.add_argument('-p', '--prefix', help="Filter tags by prefix")
tags_cmd_parser.add_argument('-c', '--contain', help="Filter tags by contained substring")
# 'config' command parser # 'config' command parser
config_cmd_parser = cmdparser.add_parser('config', config_cmd_parser = cmdparser.add_parser('config',
@@ -105,18 +206,16 @@ def create_parser() -> argparse.ArgumentParser:
action='store_true') action='store_true')
config_group.add_argument('-m', '--print-model', help="Print the currently configured model", config_group.add_argument('-m', '--print-model', help="Print the currently configured model",
action='store_true') action='store_true')
config_group.add_argument('-c', '--create', help="Create config with default settings in the given file") config_group.add_argument('-M', '--model', help="Set model in the config file")
# 'print' command parser # 'print' command parser
print_cmd_parser = cmdparser.add_parser('print', print_cmd_parser = cmdparser.add_parser('print',
help="Print message files.", help="Print files.",
aliases=['p']) aliases=['p'])
print_cmd_parser.set_defaults(func=print_cmd) print_cmd_parser.set_defaults(func=print_cmd)
print_cmd_parser.add_argument('-f', '--file', help='File to print', required=True) print_cmd_parser.add_argument('-f', '--file', help='File to print', required=True)
print_cmd_modes = print_cmd_parser.add_mutually_exclusive_group() print_cmd_parser.add_argument('-S', '--only-source-code', help='Print only source code',
print_cmd_modes.add_argument('-q', '--question', help='Print only question', action='store_true') action='store_true')
print_cmd_modes.add_argument('-a', '--answer', help='Print only answer', action='store_true')
print_cmd_modes.add_argument('-S', '--only-source-code', help='Print only source code', action='store_true')
argcomplete.autocomplete(parser) argcomplete.autocomplete(parser)
return parser return parser
@@ -126,12 +225,11 @@ def main() -> int:
parser = create_parser() parser = create_parser()
args = parser.parse_args() args = parser.parse_args()
command = parser.parse_args() command = parser.parse_args()
config = Config.from_file(args.config)
if command.func == config_cmd: openai_api_key(config.openai.api_key)
command.func(command)
else: command.func(command, config)
config = Config.from_file(args.config)
command.func(command, config)
return 0 return 0
+36 -93
View File
@@ -3,11 +3,9 @@ Module implementing message related functions and classes.
""" """
import pathlib import pathlib
import yaml import yaml
import tempfile
import shutil
from typing import Type, TypeVar, ClassVar, Optional, Any, Union, Final, Literal, Iterable from typing import Type, TypeVar, ClassVar, Optional, Any, Union, Final, Literal, Iterable
from dataclasses import dataclass, asdict, field from dataclasses import dataclass, asdict, field
from .tags import Tag, TagLine, TagError, match_tags, rename_tags from .tags import Tag, TagLine, TagError, match_tags
QuestionInst = TypeVar('QuestionInst', bound='Question') QuestionInst = TypeVar('QuestionInst', bound='Question')
AnswerInst = TypeVar('AnswerInst', bound='Answer') AnswerInst = TypeVar('AnswerInst', bound='Answer')
@@ -98,7 +96,7 @@ class AILine(str):
def __new__(cls: Type[AILineInst], string: str) -> AILineInst: def __new__(cls: Type[AILineInst], string: str) -> AILineInst:
if not string.startswith(cls.prefix): if not string.startswith(cls.prefix):
raise MessageError(f"AILine '{string}' is missing prefix '{cls.prefix}'") raise TagError(f"AILine '{string}' is missing prefix '{cls.prefix}'")
instance = super().__new__(cls, string) instance = super().__new__(cls, string)
return instance return instance
@@ -118,7 +116,7 @@ class ModelLine(str):
def __new__(cls: Type[ModelLineInst], string: str) -> ModelLineInst: def __new__(cls: Type[ModelLineInst], string: str) -> ModelLineInst:
if not string.startswith(cls.prefix): if not string.startswith(cls.prefix):
raise MessageError(f"ModelLine '{string}' is missing prefix '{cls.prefix}'") raise TagError(f"ModelLine '{string}' is missing prefix '{cls.prefix}'")
instance = super().__new__(cls, string) instance = super().__new__(cls, string)
return instance return instance
@@ -130,29 +128,28 @@ class ModelLine(str):
return cls(' '.join([cls.prefix, model])) return cls(' '.join([cls.prefix, model]))
class Answer(str): class Question(str):
""" """
A single answer with a defined header. A single question with a defined header.
""" """
tokens: int = 0 # tokens used by this answer txt_header: ClassVar[str] = '=== QUESTION ==='
txt_header: ClassVar[str] = '==== ANSWER ====' yaml_key: ClassVar[str] = 'question'
yaml_key: ClassVar[str] = 'answer'
def __new__(cls: Type[AnswerInst], string: str) -> AnswerInst: def __new__(cls: Type[QuestionInst], string: str) -> QuestionInst:
""" """
Make sure the answer string does not contain the header as a whole line. Make sure the question string does not contain the header.
""" """
if cls.txt_header in string.split('\n'): if cls.txt_header in string:
raise MessageError(f"Answer '{string}' contains the header '{cls.txt_header}'") raise MessageError(f"Question '{string}' contains the header '{cls.txt_header}'")
instance = super().__new__(cls, string) instance = super().__new__(cls, string)
return instance return instance
@classmethod @classmethod
def from_list(cls: Type[AnswerInst], strings: list[str]) -> AnswerInst: def from_list(cls: Type[QuestionInst], strings: list[str]) -> QuestionInst:
""" """
Build Question from a list of strings. Make sure strings do not contain the header. Build Question from a list of strings. Make sure strings do not contain the header.
""" """
if cls.txt_header in strings: if any(cls.txt_header in string for string in strings):
raise MessageError(f"Question contains the header '{cls.txt_header}'") raise MessageError(f"Question contains the header '{cls.txt_header}'")
instance = super().__new__(cls, '\n'.join(strings).strip()) instance = super().__new__(cls, '\n'.join(strings).strip())
return instance return instance
@@ -164,33 +161,28 @@ class Answer(str):
return source_code(self, include_delims) return source_code(self, include_delims)
class Question(str): class Answer(str):
""" """
A single question with a defined header. A single answer with a defined header.
""" """
tokens: int = 0 # tokens used by this question txt_header: ClassVar[str] = '=== ANSWER ==='
txt_header: ClassVar[str] = '=== QUESTION ===' yaml_key: ClassVar[str] = 'answer'
yaml_key: ClassVar[str] = 'question'
def __new__(cls: Type[QuestionInst], string: str) -> QuestionInst: def __new__(cls: Type[AnswerInst], string: str) -> AnswerInst:
""" """
Make sure the question string does not contain the header as a whole line Make sure the answer string does not contain the header.
(also not that from 'Answer', so it's always clear where the answer starts).
""" """
string_lines = string.split('\n') if cls.txt_header in string:
if cls.txt_header in string_lines: raise MessageError(f"Answer '{string}' contains the header '{cls.txt_header}'")
raise MessageError(f"Question '{string}' contains the header '{cls.txt_header}'")
if Answer.txt_header in string_lines:
raise MessageError(f"Question '{string}' contains the header '{Answer.txt_header}'")
instance = super().__new__(cls, string) instance = super().__new__(cls, string)
return instance return instance
@classmethod @classmethod
def from_list(cls: Type[QuestionInst], strings: list[str]) -> QuestionInst: def from_list(cls: Type[AnswerInst], strings: list[str]) -> AnswerInst:
""" """
Build Question from a list of strings. Make sure strings do not contain the header. Build Question from a list of strings. Make sure strings do not contain the header.
""" """
if cls.txt_header in strings: if any(cls.txt_header in string for string in strings):
raise MessageError(f"Question contains the header '{cls.txt_header}'") raise MessageError(f"Question contains the header '{cls.txt_header}'")
instance = super().__new__(cls, '\n'.join(strings).strip()) instance = super().__new__(cls, '\n'.join(strings).strip())
return instance return instance
@@ -314,7 +306,7 @@ class Message():
mfilter.tags_not if mfilter else None) mfilter.tags_not if mfilter else None)
else: else:
message = cls.__from_file_yaml(file_path) message = cls.__from_file_yaml(file_path)
if message and (mfilter is None or message.match(mfilter)): if message and (not mfilter or (mfilter and message.match(mfilter))):
return message return message
else: else:
return None return None
@@ -357,20 +349,17 @@ class Message():
try: try:
pos = fd.tell() pos = fd.tell()
ai = AILine(fd.readline()).ai() ai = AILine(fd.readline()).ai()
except MessageError: except TagError:
fd.seek(pos) fd.seek(pos)
# ModelLine (Optional) # ModelLine (Optional)
try: try:
pos = fd.tell() pos = fd.tell()
model = ModelLine(fd.readline()).model() model = ModelLine(fd.readline()).model()
except MessageError: except TagError:
fd.seek(pos) fd.seek(pos)
# Question and Answer # Question and Answer
text = fd.read().strip().split('\n') text = fd.read().strip().split('\n')
try: question_idx = text.index(Question.txt_header) + 1
question_idx = text.index(Question.txt_header) + 1
except ValueError:
raise MessageError(f"Question header '{Question.txt_header}' not found in '{file_path}'")
try: try:
answer_idx = text.index(Answer.txt_header) answer_idx = text.index(Answer.txt_header)
question = Question.from_list(text[question_idx:answer_idx]) question = Question.from_list(text[question_idx:answer_idx])
@@ -394,30 +383,6 @@ class Message():
data[cls.file_yaml_key] = file_path data[cls.file_yaml_key] = file_path
return cls.from_dict(data) return cls.from_dict(data)
def to_str(self, with_tags: bool = False, with_file: bool = False, source_code_only: bool = False) -> str:
"""
Return the current Message as a string.
"""
output: list[str] = []
if source_code_only:
# use the source code from answer only
if self.answer:
output.extend(self.answer.source_code(include_delims=True))
return '\n'.join(output) if len(output) > 0 else ''
if with_tags:
output.append(self.tags_str())
if with_file:
output.append('FILE: ' + str(self.file_path))
output.append(Question.txt_header)
output.append(self.question)
if self.answer:
output.append(Answer.txt_header)
output.append(self.answer)
return '\n'.join(output)
def __str__(self) -> str:
return self.to_str(True, True, False)
def to_file(self, file_path: Optional[pathlib.Path]=None) -> None: # noqa: 11 def to_file(self, file_path: Optional[pathlib.Path]=None) -> None: # noqa: 11
""" """
Write a Message to the given file. Type is determined based on the suffix. Write a Message to the given file. Type is determined based on the suffix.
@@ -447,18 +412,16 @@ class Message():
* Answer.txt_header * Answer.txt_header
* Answer * Answer
""" """
with tempfile.NamedTemporaryFile(dir=file_path.parent, prefix=file_path.name, mode="w", delete=False) as temp_fd: with open(file_path, "w") as fd:
temp_file_path = pathlib.Path(temp_fd.name)
if self.tags: if self.tags:
temp_fd.write(f'{TagLine.from_set(self.tags)}\n') fd.write(f'{TagLine.from_set(self.tags)}\n')
if self.ai: if self.ai:
temp_fd.write(f'{AILine.from_ai(self.ai)}\n') fd.write(f'{AILine.from_ai(self.ai)}\n')
if self.model: if self.model:
temp_fd.write(f'{ModelLine.from_model(self.model)}\n') fd.write(f'{ModelLine.from_model(self.model)}\n')
temp_fd.write(f'{Question.txt_header}\n{self.question}\n') fd.write(f'{Question.txt_header}\n{self.question}\n')
if self.answer: if self.answer:
temp_fd.write(f'{Answer.txt_header}\n{self.answer}\n') fd.write(f'{Answer.txt_header}\n{self.answer}\n')
shutil.move(temp_file_path, file_path)
def __to_file_yaml(self, file_path: pathlib.Path) -> None: def __to_file_yaml(self, file_path: pathlib.Path) -> None:
""" """
@@ -470,8 +433,7 @@ class Message():
* Message.ai_yaml_key: str [Optional] * Message.ai_yaml_key: str [Optional]
* Message.model_yaml_key: str [Optional] * Message.model_yaml_key: str [Optional]
""" """
with tempfile.NamedTemporaryFile(dir=file_path.parent, prefix=file_path.name, mode="w", delete=False) as temp_fd: with open(file_path, "w") as fd:
temp_file_path = pathlib.Path(temp_fd.name)
data: YamlDict = {Question.yaml_key: str(self.question)} data: YamlDict = {Question.yaml_key: str(self.question)}
if self.answer: if self.answer:
data[Answer.yaml_key] = str(self.answer) data[Answer.yaml_key] = str(self.answer)
@@ -481,8 +443,7 @@ class Message():
data[self.model_yaml_key] = self.model data[self.model_yaml_key] = self.model
if self.tags: if self.tags:
data[self.tags_yaml_key] = sorted([str(tag) for tag in self.tags]) data[self.tags_yaml_key] = sorted([str(tag) for tag in self.tags])
yaml.dump(data, temp_fd, sort_keys=False) yaml.dump(data, fd, sort_keys=False)
shutil.move(temp_file_path, file_path)
def filter_tags(self, prefix: Optional[str] = None, contain: Optional[str] = None) -> set[Tag]: def filter_tags(self, prefix: Optional[str] = None, contain: Optional[str] = None) -> set[Tag]:
""" """
@@ -514,7 +475,7 @@ class Message():
Return True if all attributes match, else False. Return True if all attributes match, else False.
""" """
mytags = self.tags or set() mytags = self.tags or set()
if (((mfilter.tags_or is not None or mfilter.tags_and is not None or mfilter.tags_not is not None) if (((mfilter.tags_or or mfilter.tags_and or mfilter.tags_not)
and not match_tags(mytags, mfilter.tags_or, mfilter.tags_and, mfilter.tags_not)) # noqa: W503 and not match_tags(mytags, mfilter.tags_or, mfilter.tags_and, mfilter.tags_not)) # noqa: W503
or (mfilter.ai and (not self.ai or mfilter.ai != self.ai)) # noqa: W503 or (mfilter.ai and (not self.ai or mfilter.ai != self.ai)) # noqa: W503
or (mfilter.model and (not self.model or mfilter.model != self.model)) # noqa: W503 or (mfilter.model and (not self.model or mfilter.model != self.model)) # noqa: W503
@@ -529,14 +490,6 @@ class Message():
return False return False
return True return True
def rename_tags(self, tags_rename: set[tuple[Tag, Tag]]) -> None:
"""
Renames the given tags. The first tuple element is the old name,
the second one is the new name.
"""
if self.tags:
self.tags = rename_tags(self.tags, tags_rename)
def msg_id(self) -> str: def msg_id(self) -> str:
""" """
Returns an ID that is unique throughout all messages in the same (DB) directory. Returns an ID that is unique throughout all messages in the same (DB) directory.
@@ -549,13 +502,3 @@ class Message():
def as_dict(self) -> dict[str, Any]: def as_dict(self) -> dict[str, Any]:
return asdict(self) return asdict(self)
def tokens(self) -> int:
"""
Returns the nr. of AI language tokens used by this message.
If unknown, 0 is returned.
"""
if self.answer:
return self.question.tokens + self.answer.tokens
else:
return self.question.tokens
+121
View File
@@ -0,0 +1,121 @@
import yaml
import io
import pathlib
from .utils import terminal_width, append_message, message_to_chat, ChatType
from .configuration import Config
from typing import Any, Optional
def read_file(fname: pathlib.Path, tags_only: bool = False) -> dict[str, Any]:
with open(fname, "r") as fd:
tagline = fd.readline().strip().split(':', maxsplit=1)[1].strip()
# also support tags separated by ',' (old format)
separator = ',' if ',' in tagline else ' '
tags = [t.strip() for t in tagline.split(separator)]
if tags_only:
return {"tags": tags}
text = fd.read().strip().split('\n')
question_idx = text.index("=== QUESTION ===") + 1
answer_idx = text.index("==== ANSWER ====")
question = "\n".join(text[question_idx:answer_idx]).strip()
answer = "\n".join(text[answer_idx + 1:]).strip()
return {"question": question, "answer": answer, "tags": tags,
"file": fname.name}
def dump_data(data: dict[str, Any]) -> str:
with io.StringIO() as fd:
fd.write(f'TAGS: {" ".join(data["tags"])}\n')
fd.write(f'=== QUESTION ===\n{data["question"]}\n')
fd.write(f'==== ANSWER ====\n{data["answer"]}\n')
return fd.getvalue()
def write_file(fname: str, data: dict[str, Any]) -> None:
with open(fname, "w") as fd:
fd.write(f'TAGS: {" ".join(data["tags"])}\n')
fd.write(f'=== QUESTION ===\n{data["question"]}\n')
fd.write(f'==== ANSWER ====\n{data["answer"]}\n')
def save_answers(question: str,
answers: list[str],
tags: list[str],
otags: Optional[list[str]],
config: Config
) -> None:
wtags = otags or tags
num, inum = 0, 0
next_fname = pathlib.Path(str(config.db)) / '.next'
try:
with open(next_fname, 'r') as f:
num = int(f.read())
except Exception:
pass
for answer in answers:
num += 1
inum += 1
title = f'-- ANSWER {inum} '
title_end = '-' * (terminal_width() - len(title))
print(f'{title}{title_end}')
print(answer)
write_file(f"{num:04d}.txt", {"question": question, "answer": answer, "tags": wtags})
with open(next_fname, 'w') as f:
f.write(f'{num}')
def create_chat_hist(question: Optional[str],
tags: Optional[list[str]],
extags: Optional[list[str]],
config: Config,
match_all_tags: bool = False,
with_tags: bool = False,
with_file: bool = False
) -> ChatType:
chat: ChatType = []
append_message(chat, 'system', str(config.system).strip())
for file in sorted(pathlib.Path(str(config.db)).iterdir()):
if file.suffix == '.yaml':
with open(file, 'r') as f:
data = yaml.load(f, Loader=yaml.FullLoader)
data['file'] = file.name
elif file.suffix == '.txt':
data = read_file(file)
else:
continue
data_tags = set(data.get('tags', []))
tags_match: bool
if match_all_tags:
tags_match = not tags or set(tags).issubset(data_tags)
else:
tags_match = not tags or bool(data_tags.intersection(tags))
extags_do_not_match = \
not extags or not data_tags.intersection(extags)
if tags_match and extags_do_not_match:
message_to_chat(data, chat, with_tags, with_file)
if question:
append_message(chat, 'user', question)
return chat
def get_tags(config: Config, prefix: Optional[str]) -> list[str]:
result = []
for file in sorted(pathlib.Path(str(config.db)).iterdir()):
if file.suffix == '.yaml':
with open(file, 'r') as f:
data = yaml.load(f, Loader=yaml.FullLoader)
elif file.suffix == '.txt':
data = read_file(file, tags_only=True)
else:
continue
for tag in data.get('tags', []):
if prefix and len(prefix) > 0:
if tag.startswith(prefix):
result.append(tag)
else:
result.append(tag)
return result
def get_tags_unique(config: Config, prefix: Optional[str]) -> list[str]:
return list(set(get_tags(config, prefix)))
+1 -2
View File
@@ -77,8 +77,7 @@ def match_tags(tags: set[Tag], tags_or: Optional[set[Tag]], tags_and: Optional[s
i. e. you can select a TagLine if it either contains one of the tags in 'tags_or' i. e. you can select a TagLine if it either contains one of the tags in 'tags_or'
or all of the tags in 'tags_and' but it must never contain any of the tags in or all of the tags in 'tags_and' but it must never contain any of the tags in
'tags_not'. If 'tags_or' and 'tags_and' are 'None', they match all tags (tag 'tags_not'. If 'tags_or' and 'tags_and' are 'None', they match all tags (tag
exclusion is still done if 'tags_not' is not 'None'). If they are empty (set()), exclusion is still done if 'tags_not' is not 'None').
they match no tags.
""" """
required_tags_present = False required_tags_present = False
excluded_tags_missing = False excluded_tags_missing = False
+85
View File
@@ -0,0 +1,85 @@
import shutil
from pprint import PrettyPrinter
from typing import Any
ChatType = list[dict[str, str]]
def terminal_width() -> int:
return shutil.get_terminal_size().columns
def pp(*args: Any, **kwargs: Any) -> None:
return PrettyPrinter(width=terminal_width()).pprint(*args, **kwargs)
def print_tag_args(tags: list[str], extags: list[str], otags: list[str]) -> None:
"""
Prints the tags specified in the given args.
"""
printed_messages = []
if tags:
printed_messages.append(f"Tags: {' '.join(tags)}")
if extags:
printed_messages.append(f"Excluding tags: {' '.join(extags)}")
if otags:
printed_messages.append(f"Output tags: {' '.join(otags)}")
if printed_messages:
print("\n".join(printed_messages))
print()
def append_message(chat: ChatType,
role: str,
content: str
) -> None:
chat.append({'role': role, 'content': content.replace("''", "'")})
def message_to_chat(message: dict[str, str],
chat: ChatType,
with_tags: bool = False,
with_file: bool = False
) -> None:
append_message(chat, 'user', message['question'])
append_message(chat, 'assistant', message['answer'])
if with_tags:
tags = " ".join(message['tags'])
append_message(chat, 'tags', tags)
if with_file:
append_message(chat, 'file', message['file'])
def display_source_code(content: str) -> None:
try:
content_start = content.index('```')
content_end = content.rindex('```')
if content_start + 3 < content_end:
print(content[content_start + 3:content_end].strip())
except ValueError:
pass
def print_chat_hist(chat: ChatType, dump: bool = False, source_code: bool = False) -> None:
if dump:
pp(chat)
return
for message in chat:
text_too_long = len(message['content']) > terminal_width() - len(message['role']) - 2
if source_code:
display_source_code(message['content'])
continue
if message['role'] == 'user':
print('-' * terminal_width())
if text_too_long:
print(f"{message['role'].upper()}:")
print(message['content'])
else:
print(f"{message['role'].upper()}: {message['content']}")
def print_tags_frequency(tags: list[str]) -> None:
for tag in sorted(set(tags)):
print(f"- {tag}: {tags.count(tag)}")
+2 -2
View File
@@ -12,7 +12,7 @@ setup(
long_description=long_description, long_description=long_description,
long_description_content_type="text/markdown", long_description_content_type="text/markdown",
url="https://github.com/ok2/ChatMastermind", url="https://github.com/ok2/ChatMastermind",
packages=find_packages() + ["chatmastermind.ais", "chatmastermind.commands"], packages=find_packages(),
classifiers=[ classifiers=[
"Development Status :: 3 - Alpha", "Development Status :: 3 - Alpha",
"Environment :: Console", "Environment :: Console",
@@ -32,7 +32,7 @@ setup(
"openai", "openai",
"PyYAML", "PyYAML",
"argcomplete", "argcomplete",
"pytest", "pytest"
], ],
python_requires=">=3.9", python_requires=">=3.9",
test_suite="tests", test_suite="tests",
-48
View File
@@ -1,48 +0,0 @@
import argparse
import unittest
from unittest.mock import MagicMock
from chatmastermind.ai_factory import create_ai
from chatmastermind.configuration import Config
from chatmastermind.ai import AIError
from chatmastermind.ais.openai import OpenAI
class TestCreateAI(unittest.TestCase):
def setUp(self) -> None:
self.args = MagicMock(spec=argparse.Namespace)
self.args.AI = 'myopenai'
self.args.model = None
self.args.max_tokens = None
self.args.temperature = None
def test_create_ai_from_args(self) -> None:
# Create an AI with the default configuration
config = Config()
self.args.AI = 'myopenai'
ai = create_ai(self.args, config)
self.assertIsInstance(ai, OpenAI)
def test_create_ai_from_default(self) -> None:
self.args.AI = None
# Create an AI with the default configuration
config = Config()
ai = create_ai(self.args, config)
self.assertIsInstance(ai, OpenAI)
def test_create_empty_ai_error(self) -> None:
self.args.AI = None
# Create Config with empty AIs
config = Config()
config.ais = {}
# Call create_ai function and assert that it raises AIError
with self.assertRaises(AIError):
create_ai(self.args, config)
def test_create_unsupported_ai_error(self) -> None:
# Mock argparse.Namespace with ai='invalid_ai'
self.args.AI = 'invalid_ai'
# Create default Config
config = Config()
# Call create_ai function and assert that it raises AIError
with self.assertRaises(AIError):
create_ai(self.args, config)
+30 -221
View File
@@ -1,4 +1,3 @@
import unittest
import pathlib import pathlib
import tempfile import tempfile
import time import time
@@ -6,15 +5,16 @@ from io import StringIO
from unittest.mock import patch from unittest.mock import patch
from chatmastermind.tags import TagLine from chatmastermind.tags import TagLine
from chatmastermind.message import Message, Question, Answer, Tag, MessageFilter from chatmastermind.message import Message, Question, Answer, Tag, MessageFilter
from chatmastermind.chat import Chat, ChatDB, terminal_width, ChatError from chatmastermind.chat import Chat, ChatDB, terminal_width
from .test_main import CmmTestCase
class TestChat(unittest.TestCase): class TestChat(CmmTestCase):
def setUp(self) -> None: def setUp(self) -> None:
self.chat = Chat([]) self.chat = Chat([])
self.message1 = Message(Question('Question 1'), self.message1 = Message(Question('Question 1'),
Answer('Answer 1'), Answer('Answer 1'),
{Tag('atag1'), Tag('btag2')}, {Tag('atag1')},
file_path=pathlib.Path('0001.txt')) file_path=pathlib.Path('0001.txt'))
self.message2 = Message(Question('Question 2'), self.message2 = Message(Question('Question 2'),
Answer('Answer 2'), Answer('Answer 2'),
@@ -22,14 +22,14 @@ class TestChat(unittest.TestCase):
file_path=pathlib.Path('0002.txt')) file_path=pathlib.Path('0002.txt'))
def test_filter(self) -> None: def test_filter(self) -> None:
self.chat.add_messages([self.message1, self.message2]) self.chat.add_msgs([self.message1, self.message2])
self.chat.filter(MessageFilter(answer_contains='Answer 1')) self.chat.filter(MessageFilter(answer_contains='Answer 1'))
self.assertEqual(len(self.chat.messages), 1) self.assertEqual(len(self.chat.messages), 1)
self.assertEqual(self.chat.messages[0].question, 'Question 1') self.assertEqual(self.chat.messages[0].question, 'Question 1')
def test_sort(self) -> None: def test_sort(self) -> None:
self.chat.add_messages([self.message2, self.message1]) self.chat.add_msgs([self.message2, self.message1])
self.chat.sort() self.chat.sort()
self.assertEqual(self.chat.messages[0].question, 'Question 1') self.assertEqual(self.chat.messages[0].question, 'Question 1')
self.assertEqual(self.chat.messages[1].question, 'Question 2') self.assertEqual(self.chat.messages[1].question, 'Question 2')
@@ -38,18 +38,18 @@ class TestChat(unittest.TestCase):
self.assertEqual(self.chat.messages[1].question, 'Question 1') self.assertEqual(self.chat.messages[1].question, 'Question 1')
def test_clear(self) -> None: def test_clear(self) -> None:
self.chat.add_messages([self.message1]) self.chat.add_msgs([self.message1])
self.chat.clear() self.chat.clear()
self.assertEqual(len(self.chat.messages), 0) self.assertEqual(len(self.chat.messages), 0)
def test_add_messages(self) -> None: def test_add_msgs(self) -> None:
self.chat.add_messages([self.message1, self.message2]) self.chat.add_msgs([self.message1, self.message2])
self.assertEqual(len(self.chat.messages), 2) self.assertEqual(len(self.chat.messages), 2)
self.assertEqual(self.chat.messages[0].question, 'Question 1') self.assertEqual(self.chat.messages[0].question, 'Question 1')
self.assertEqual(self.chat.messages[1].question, 'Question 2') self.assertEqual(self.chat.messages[1].question, 'Question 2')
def test_tags(self) -> None: def test_tags(self) -> None:
self.chat.add_messages([self.message1, self.message2]) self.chat.add_msgs([self.message1, self.message2])
tags_all = self.chat.tags() tags_all = self.chat.tags()
self.assertSetEqual(tags_all, {Tag('atag1'), Tag('btag2')}) self.assertSetEqual(tags_all, {Tag('atag1'), Tag('btag2')})
tags_pref = self.chat.tags(prefix='a') tags_pref = self.chat.tags(prefix='a')
@@ -57,81 +57,46 @@ class TestChat(unittest.TestCase):
tags_cont = self.chat.tags(contain='2') tags_cont = self.chat.tags(contain='2')
self.assertSetEqual(tags_cont, {Tag('btag2')}) self.assertSetEqual(tags_cont, {Tag('btag2')})
def test_tags_frequency(self) -> None:
self.chat.add_messages([self.message1, self.message2])
tags_freq = self.chat.tags_frequency()
self.assertDictEqual(tags_freq, {'atag1': 1, 'btag2': 2})
def test_find_remove_messages(self) -> None:
self.chat.add_messages([self.message1, self.message2])
msgs = self.chat.find_messages(['0001.txt'])
self.assertListEqual(msgs, [self.message1])
msgs = self.chat.find_messages(['0001.txt', '0002.txt'])
self.assertListEqual(msgs, [self.message1, self.message2])
# add new Message with full path
message3 = Message(Question('Question 2'),
Answer('Answer 2'),
{Tag('btag2')},
file_path=pathlib.Path('/foo/bla/0003.txt'))
self.chat.add_messages([message3])
# find new Message by full path
msgs = self.chat.find_messages(['/foo/bla/0003.txt'])
self.assertListEqual(msgs, [message3])
# find Message with full path only by filename
msgs = self.chat.find_messages(['0003.txt'])
self.assertListEqual(msgs, [message3])
# remove last message
self.chat.remove_messages(['0003.txt'])
self.assertListEqual(self.chat.messages, [self.message1, self.message2])
@patch('sys.stdout', new_callable=StringIO) @patch('sys.stdout', new_callable=StringIO)
def test_print(self, mock_stdout: StringIO) -> None: def test_print(self, mock_stdout: StringIO) -> None:
self.chat.add_messages([self.message1, self.message2]) self.chat.add_msgs([self.message1, self.message2])
self.chat.print(paged=False) self.chat.print(paged=False)
expected_output = f"""{Question.txt_header} expected_output = f"""{'-'*terminal_width()}
{Question.txt_header}
Question 1 Question 1
{Answer.txt_header} {Answer.txt_header}
Answer 1 Answer 1
{'-'*terminal_width()} {'-'*terminal_width()}
{Question.txt_header} {Question.txt_header}
Question 2 Question 2
{Answer.txt_header} {Answer.txt_header}
Answer 2 Answer 2
{'-'*terminal_width()}
""" """
self.assertEqual(mock_stdout.getvalue(), expected_output) self.assertEqual(mock_stdout.getvalue(), expected_output)
@patch('sys.stdout', new_callable=StringIO) @patch('sys.stdout', new_callable=StringIO)
def test_print_with_tags_and_file(self, mock_stdout: StringIO) -> None: def test_print_with_tags_and_file(self, mock_stdout: StringIO) -> None:
self.chat.add_messages([self.message1, self.message2]) self.chat.add_msgs([self.message1, self.message2])
self.chat.print(paged=False, with_tags=True, with_files=True) self.chat.print(paged=False, with_tags=True, with_file=True)
expected_output = f"""{TagLine.prefix} atag1 btag2 expected_output = f"""{'-'*terminal_width()}
FILE: 0001.txt
{Question.txt_header} {Question.txt_header}
Question 1 Question 1
{Answer.txt_header} {Answer.txt_header}
Answer 1 Answer 1
{TagLine.prefix} atag1
FILE: 0001.txt
{'-'*terminal_width()} {'-'*terminal_width()}
{TagLine.prefix} btag2
FILE: 0002.txt
{Question.txt_header} {Question.txt_header}
Question 2 Question 2
{Answer.txt_header} {Answer.txt_header}
Answer 2 Answer 2
{TagLine.prefix} btag2
{'-'*terminal_width()} FILE: 0002.txt
""" """
self.assertEqual(mock_stdout.getvalue(), expected_output) self.assertEqual(mock_stdout.getvalue(), expected_output)
class TestChatDB(unittest.TestCase): class TestChatDB(CmmTestCase):
def setUp(self) -> None: def setUp(self) -> None:
self.db_path = tempfile.TemporaryDirectory() self.db_path = tempfile.TemporaryDirectory()
self.cache_path = tempfile.TemporaryDirectory() self.cache_path = tempfile.TemporaryDirectory()
@@ -157,17 +122,6 @@ class TestChatDB(unittest.TestCase):
self.message2.to_file(pathlib.Path(self.db_path.name, '0002.yaml')) self.message2.to_file(pathlib.Path(self.db_path.name, '0002.yaml'))
self.message3.to_file(pathlib.Path(self.db_path.name, '0003.txt')) self.message3.to_file(pathlib.Path(self.db_path.name, '0003.txt'))
self.message4.to_file(pathlib.Path(self.db_path.name, '0004.yaml')) self.message4.to_file(pathlib.Path(self.db_path.name, '0004.yaml'))
# make the next FID match the current state
next_fname = pathlib.Path(self.db_path.name) / '.next'
with open(next_fname, 'w') as f:
f.write('4')
def message_list(self, tmp_dir: tempfile.TemporaryDirectory) -> list[pathlib.Path]:
"""
List all Message files in the given TemporaryDirectory.
"""
# exclude '.next'
return list(pathlib.Path(tmp_dir.name).glob('*.[ty]*'))
def tearDown(self) -> None: def tearDown(self) -> None:
self.db_path.cleanup() self.db_path.cleanup()
@@ -202,25 +156,7 @@ class TestChatDB(unittest.TestCase):
self.assertEqual(chat_db.messages[1].file_path, self.assertEqual(chat_db.messages[1].file_path,
pathlib.Path(self.db_path.name, '0003.txt')) pathlib.Path(self.db_path.name, '0003.txt'))
def test_chat_db_from_dir_filter_tags(self) -> None: def test_chat_db_filter(self) -> None:
chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name),
pathlib.Path(self.db_path.name),
mfilter=MessageFilter(tags_or={Tag('tag1')}))
self.assertEqual(len(chat_db.messages), 1)
self.assertEqual(chat_db.cache_path, pathlib.Path(self.cache_path.name))
self.assertEqual(chat_db.db_path, pathlib.Path(self.db_path.name))
self.assertEqual(chat_db.messages[0].file_path,
pathlib.Path(self.db_path.name, '0001.txt'))
def test_chat_db_from_dir_filter_tags_empty(self) -> None:
chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name),
pathlib.Path(self.db_path.name),
mfilter=MessageFilter(tags_or=set(),
tags_and=set(),
tags_not=set()))
self.assertEqual(len(chat_db.messages), 0)
def test_chat_db_from_dir_filter_answer(self) -> None:
chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name), chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name),
pathlib.Path(self.db_path.name), pathlib.Path(self.db_path.name),
mfilter=MessageFilter(answer_contains='Answer 2')) mfilter=MessageFilter(answer_contains='Answer 2'))
@@ -231,7 +167,7 @@ class TestChatDB(unittest.TestCase):
pathlib.Path(self.db_path.name, '0002.yaml')) pathlib.Path(self.db_path.name, '0002.yaml'))
self.assertEqual(chat_db.messages[0].answer, 'Answer 2') self.assertEqual(chat_db.messages[0].answer, 'Answer 2')
def test_chat_db_from_messages(self) -> None: def test_chat_db_from_messges(self) -> None:
chat_db = ChatDB.from_messages(pathlib.Path(self.cache_path.name), chat_db = ChatDB.from_messages(pathlib.Path(self.cache_path.name),
pathlib.Path(self.db_path.name), pathlib.Path(self.db_path.name),
messages=[self.message1, self.message2, messages=[self.message1, self.message2,
@@ -243,11 +179,11 @@ class TestChatDB(unittest.TestCase):
def test_chat_db_fids(self) -> None: def test_chat_db_fids(self) -> None:
chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name), chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name),
pathlib.Path(self.db_path.name)) pathlib.Path(self.db_path.name))
self.assertEqual(chat_db.get_next_fid(), 5) self.assertEqual(chat_db.get_next_fid(), 1)
self.assertEqual(chat_db.get_next_fid(), 6) self.assertEqual(chat_db.get_next_fid(), 2)
self.assertEqual(chat_db.get_next_fid(), 7) self.assertEqual(chat_db.get_next_fid(), 3)
with open(chat_db.next_fname, 'r') as f: with open(chat_db.next_fname, 'r') as f:
self.assertEqual(f.read(), '7') self.assertEqual(f.read(), '3')
def test_chat_db_write(self) -> None: def test_chat_db_write(self) -> None:
# create a new ChatDB instance # create a new ChatDB instance
@@ -262,7 +198,7 @@ class TestChatDB(unittest.TestCase):
# write the messages to the cache directory # write the messages to the cache directory
chat_db.write_cache() chat_db.write_cache()
# check if the written files are in the cache directory # check if the written files are in the cache directory
cache_dir_files = self.message_list(self.cache_path) cache_dir_files = list(pathlib.Path(self.cache_path.name).glob('*'))
self.assertEqual(len(cache_dir_files), 4) self.assertEqual(len(cache_dir_files), 4)
self.assertIn(pathlib.Path(self.cache_path.name, '0001.txt'), cache_dir_files) self.assertIn(pathlib.Path(self.cache_path.name, '0001.txt'), cache_dir_files)
self.assertIn(pathlib.Path(self.cache_path.name, '0002.yaml'), cache_dir_files) self.assertIn(pathlib.Path(self.cache_path.name, '0002.yaml'), cache_dir_files)
@@ -275,14 +211,14 @@ class TestChatDB(unittest.TestCase):
self.assertEqual(chat_db.messages[3].file_path, pathlib.Path(self.cache_path.name, '0004.yaml')) self.assertEqual(chat_db.messages[3].file_path, pathlib.Path(self.cache_path.name, '0004.yaml'))
# check the timestamp of the files in the DB directory # check the timestamp of the files in the DB directory
db_dir_files = self.message_list(self.db_path) db_dir_files = list(pathlib.Path(self.db_path.name).glob('*'))
self.assertEqual(len(db_dir_files), 4) self.assertEqual(len(db_dir_files), 4)
old_timestamps = {file: file.stat().st_mtime for file in db_dir_files} old_timestamps = {file: file.stat().st_mtime for file in db_dir_files}
# overwrite the messages in the db directory # overwrite the messages in the db directory
time.sleep(0.05) time.sleep(0.05)
chat_db.write_db() chat_db.write_db()
# check if the written files are in the DB directory # check if the written files are in the DB directory
db_dir_files = self.message_list(self.db_path) db_dir_files = list(pathlib.Path(self.db_path.name).glob('*'))
self.assertEqual(len(db_dir_files), 4) self.assertEqual(len(db_dir_files), 4)
self.assertIn(pathlib.Path(self.db_path.name, '0001.txt'), db_dir_files) self.assertIn(pathlib.Path(self.db_path.name, '0001.txt'), db_dir_files)
self.assertIn(pathlib.Path(self.db_path.name, '0002.yaml'), db_dir_files) self.assertIn(pathlib.Path(self.db_path.name, '0002.yaml'), db_dir_files)
@@ -359,130 +295,3 @@ class TestChatDB(unittest.TestCase):
# check that they now have the DB path # check that they now have the DB path
self.assertEqual(chat_db.messages[6].file_path, pathlib.Path(self.db_path.name, '0007.txt')) self.assertEqual(chat_db.messages[6].file_path, pathlib.Path(self.db_path.name, '0007.txt'))
self.assertEqual(chat_db.messages[7].file_path, pathlib.Path(self.db_path.name, '0008.yaml')) self.assertEqual(chat_db.messages[7].file_path, pathlib.Path(self.db_path.name, '0008.yaml'))
def test_chat_db_clear(self) -> None:
# create a new ChatDB instance
chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name),
pathlib.Path(self.db_path.name))
# check that Message.file_path is correct
self.assertEqual(chat_db.messages[0].file_path, pathlib.Path(self.db_path.name, '0001.txt'))
self.assertEqual(chat_db.messages[1].file_path, pathlib.Path(self.db_path.name, '0002.yaml'))
self.assertEqual(chat_db.messages[2].file_path, pathlib.Path(self.db_path.name, '0003.txt'))
self.assertEqual(chat_db.messages[3].file_path, pathlib.Path(self.db_path.name, '0004.yaml'))
# write the messages to the cache directory
chat_db.write_cache()
# check if the written files are in the cache directory
cache_dir_files = self.message_list(self.cache_path)
self.assertEqual(len(cache_dir_files), 4)
# now rewrite them to the DB dir and check for modified paths
chat_db.write_db()
db_dir_files = self.message_list(self.db_path)
self.assertEqual(len(db_dir_files), 4)
self.assertIn(pathlib.Path(self.db_path.name, '0001.txt'), db_dir_files)
self.assertIn(pathlib.Path(self.db_path.name, '0002.yaml'), db_dir_files)
self.assertIn(pathlib.Path(self.db_path.name, '0003.txt'), db_dir_files)
self.assertIn(pathlib.Path(self.db_path.name, '0004.yaml'), db_dir_files)
# add a new message with empty file_path
message_empty = Message(question=Question("What the hell am I doing here?"),
answer=Answer("You don't belong here!"))
# and one for the cache dir
message_cache = Message(question=Question("What the hell am I doing here?"),
answer=Answer("You're a creep!"),
file_path=pathlib.Path(self.cache_path.name, '0005.txt'))
chat_db.add_messages([message_empty, message_cache])
# clear the cache and check the cache dir
chat_db.clear_cache()
cache_dir_files = self.message_list(self.cache_path)
self.assertEqual(len(cache_dir_files), 0)
# make sure that the DB messages (and the new message) are still there
self.assertEqual(len(chat_db.messages), 5)
db_dir_files = self.message_list(self.db_path)
self.assertEqual(len(db_dir_files), 4)
# but not the message with the cache dir path
self.assertFalse(any(m.file_path == message_cache.file_path for m in chat_db.messages))
def test_chat_db_add(self) -> None:
# create a new ChatDB instance
chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name),
pathlib.Path(self.db_path.name))
db_dir_files = self.message_list(self.db_path)
self.assertEqual(len(db_dir_files), 4)
# add new messages to the cache dir
message1 = Message(question=Question("Question 1"),
answer=Answer("Answer 1"))
chat_db.add_to_cache([message1])
# check if the file_path has been correctly set
self.assertIsNotNone(message1.file_path)
self.assertEqual(message1.file_path.parent, pathlib.Path(self.cache_path.name)) # type: ignore [union-attr]
cache_dir_files = self.message_list(self.cache_path)
self.assertEqual(len(cache_dir_files), 1)
# add new messages to the DB dir
message2 = Message(question=Question("Question 2"),
answer=Answer("Answer 2"))
chat_db.add_to_db([message2])
# check if the file_path has been correctly set
self.assertIsNotNone(message2.file_path)
self.assertEqual(message2.file_path.parent, pathlib.Path(self.db_path.name)) # type: ignore [union-attr]
db_dir_files = self.message_list(self.db_path)
self.assertEqual(len(db_dir_files), 5)
with self.assertRaises(ChatError):
chat_db.add_to_cache([Message(Question("?"), file_path=pathlib.Path("foo"))])
def test_chat_db_write_messages(self) -> None:
# create a new ChatDB instance
chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name),
pathlib.Path(self.db_path.name))
db_dir_files = self.message_list(self.db_path)
self.assertEqual(len(db_dir_files), 4)
cache_dir_files = self.message_list(self.cache_path)
self.assertEqual(len(cache_dir_files), 0)
# try to write a message without a valid file_path
message = Message(question=Question("Question 1"),
answer=Answer("Answer 1"))
with self.assertRaises(ChatError):
chat_db.write_messages([message])
# write a message with a valid file_path
message.file_path = pathlib.Path(self.cache_path.name) / '123456.txt'
chat_db.write_messages([message])
cache_dir_files = self.message_list(self.cache_path)
self.assertEqual(len(cache_dir_files), 1)
self.assertIn(pathlib.Path(self.cache_path.name, '123456.txt'), cache_dir_files)
def test_chat_db_update_messages(self) -> None:
# create a new ChatDB instance
chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name),
pathlib.Path(self.db_path.name))
db_dir_files = self.message_list(self.db_path)
self.assertEqual(len(db_dir_files), 4)
cache_dir_files = self.message_list(self.cache_path)
self.assertEqual(len(cache_dir_files), 0)
message = chat_db.messages[0]
message.answer = Answer("New answer")
# update message without writing
chat_db.update_messages([message], write=False)
self.assertEqual(chat_db.messages[0].answer, Answer("New answer"))
# re-read the message and check for old content
chat_db.read_db()
self.assertEqual(chat_db.messages[0].answer, Answer("Answer 1"))
# now check with writing (message should be overwritten)
chat_db.update_messages([message], write=True)
chat_db.read_db()
self.assertEqual(chat_db.messages[0].answer, Answer("New answer"))
# test without file_path -> expect error
message1 = Message(question=Question("Question 1"),
answer=Answer("Answer 1"))
with self.assertRaises(ChatError):
chat_db.update_messages([message1])
-160
View File
@@ -1,160 +0,0 @@
import os
import unittest
import yaml
from tempfile import NamedTemporaryFile
from pathlib import Path
from typing import cast
from chatmastermind.configuration import AIConfig, OpenAIConfig, ConfigError, ai_config_instance, Config
class TestAIConfigInstance(unittest.TestCase):
def test_ai_config_instance_with_valid_name_should_return_instance_with_default_values(self) -> None:
ai_config = cast(OpenAIConfig, ai_config_instance('openai'))
ai_reference = OpenAIConfig()
self.assertEqual(ai_config.ID, ai_reference.ID)
self.assertEqual(ai_config.name, ai_reference.name)
self.assertEqual(ai_config.api_key, ai_reference.api_key)
self.assertEqual(ai_config.system, ai_reference.system)
self.assertEqual(ai_config.model, ai_reference.model)
self.assertEqual(ai_config.temperature, ai_reference.temperature)
self.assertEqual(ai_config.max_tokens, ai_reference.max_tokens)
self.assertEqual(ai_config.top_p, ai_reference.top_p)
self.assertEqual(ai_config.frequency_penalty, ai_reference.frequency_penalty)
self.assertEqual(ai_config.presence_penalty, ai_reference.presence_penalty)
def test_ai_config_instance_with_valid_name_and_configuration_should_return_instance_with_custom_values(self) -> None:
conf_dict = {
'system': 'Custom system',
'api_key': '9876543210',
'model': 'custom_model',
'max_tokens': 5000,
'temperature': 0.5,
'top_p': 0.8,
'frequency_penalty': 0.7,
'presence_penalty': 0.2
}
ai_config = cast(OpenAIConfig, ai_config_instance('openai', conf_dict))
self.assertEqual(ai_config.system, 'Custom system')
self.assertEqual(ai_config.api_key, '9876543210')
self.assertEqual(ai_config.model, 'custom_model')
self.assertEqual(ai_config.max_tokens, 5000)
self.assertAlmostEqual(ai_config.temperature, 0.5)
self.assertAlmostEqual(ai_config.top_p, 0.8)
self.assertAlmostEqual(ai_config.frequency_penalty, 0.7)
self.assertAlmostEqual(ai_config.presence_penalty, 0.2)
def test_ai_config_instance_with_invalid_name_should_raise_config_error(self) -> None:
with self.assertRaises(ConfigError):
ai_config_instance('invalid_name')
class TestConfig(unittest.TestCase):
def setUp(self) -> None:
self.test_file = NamedTemporaryFile(delete=False)
def tearDown(self) -> None:
os.remove(self.test_file.name)
def test_from_dict_should_create_config_from_dict(self) -> None:
source_dict = {
'db': './test_db/',
'ais': {
'myopenai': {
'name': 'openai',
'system': 'Custom system',
'api_key': '9876543210',
'model': 'custom_model',
'max_tokens': 5000,
'temperature': 0.5,
'top_p': 0.8,
'frequency_penalty': 0.7,
'presence_penalty': 0.2
}
}
}
config = Config.from_dict(source_dict)
self.assertEqual(config.db, './test_db/')
self.assertEqual(len(config.ais), 1)
self.assertEqual(config.ais['myopenai'].name, 'openai')
self.assertEqual(cast(OpenAIConfig, config.ais['myopenai']).system, 'Custom system')
# check that 'ID' has been added
self.assertEqual(config.ais['myopenai'].ID, 'myopenai')
def test_create_default_should_create_default_config(self) -> None:
Config.create_default(Path(self.test_file.name))
with open(self.test_file.name, 'r') as f:
default_config = yaml.load(f, Loader=yaml.FullLoader)
config_reference = Config()
self.assertEqual(default_config['db'], config_reference.db)
def test_from_file_should_load_config_from_file(self) -> None:
source_dict = {
'db': './test_db/',
'ais': {
'default': {
'name': 'openai',
'system': 'Custom system',
'api_key': '9876543210',
'model': 'custom_model',
'max_tokens': 5000,
'temperature': 0.5,
'top_p': 0.8,
'frequency_penalty': 0.7,
'presence_penalty': 0.2
}
}
}
with open(self.test_file.name, 'w') as f:
yaml.dump(source_dict, f)
config = Config.from_file(self.test_file.name)
self.assertIsInstance(config, Config)
self.assertEqual(config.db, './test_db/')
self.assertEqual(len(config.ais), 1)
self.assertIsInstance(config.ais['default'], AIConfig)
self.assertEqual(cast(OpenAIConfig, config.ais['default']).system, 'Custom system')
def test_to_file_should_save_config_to_file(self) -> None:
config = Config(
db='./test_db/',
ais={
'myopenai': OpenAIConfig(
ID='myopenai',
system='Custom system',
api_key='9876543210',
model='custom_model',
max_tokens=5000,
temperature=0.5,
top_p=0.8,
frequency_penalty=0.7,
presence_penalty=0.2
)
}
)
config.to_file(Path(self.test_file.name))
with open(self.test_file.name, 'r') as f:
saved_config = yaml.load(f, Loader=yaml.FullLoader)
self.assertEqual(saved_config['db'], './test_db/')
self.assertEqual(len(saved_config['ais']), 1)
self.assertEqual(saved_config['ais']['myopenai']['system'], 'Custom system')
def test_from_file_error_unknown_ai(self) -> None:
source_dict = {
'db': './test_db/',
'ais': {
'default': {
'name': 'foobla',
'system': 'Custom system',
'api_key': '9876543210',
'model': 'custom_model',
'max_tokens': 5000,
'temperature': 0.5,
'top_p': 0.8,
'frequency_penalty': 0.7,
'presence_penalty': 0.2
}
}
}
with open(self.test_file.name, 'w') as f:
yaml.dump(source_dict, f)
with self.assertRaises(ConfigError):
Config.from_file(self.test_file.name)
+233
View File
@@ -0,0 +1,233 @@
import unittest
import io
import pathlib
import argparse
from chatmastermind.utils import terminal_width
from chatmastermind.main import create_parser, ask_cmd
from chatmastermind.api_client import ai
from chatmastermind.configuration import Config
from chatmastermind.storage import create_chat_hist, save_answers, dump_data
from unittest import mock
from unittest.mock import patch, MagicMock, Mock, ANY
class CmmTestCase(unittest.TestCase):
"""
Base class for all cmm testcases.
"""
def dummy_config(self, db: str) -> Config:
"""
Creates a dummy configuration.
"""
return Config.from_dict(
{'system': 'dummy_system',
'db': db,
'openai': {'api_key': 'dummy_key',
'model': 'dummy_model',
'max_tokens': 4000,
'temperature': 1.0,
'top_p': 1,
'frequency_penalty': 0,
'presence_penalty': 0}}
)
class TestCreateChat(CmmTestCase):
def setUp(self) -> None:
self.config = self.dummy_config(db='test_files')
self.question = "test question"
self.tags = ['test_tag']
@patch('os.listdir')
@patch('pathlib.Path.iterdir')
@patch('builtins.open')
def test_create_chat_with_tags(self, open_mock: MagicMock, iterdir_mock: MagicMock, listdir_mock: MagicMock) -> None:
listdir_mock.return_value = ['testfile.txt']
iterdir_mock.return_value = [pathlib.Path(x) for x in listdir_mock.return_value]
open_mock.return_value.__enter__.return_value = io.StringIO(dump_data(
{'question': 'test_content', 'answer': 'some answer',
'tags': ['test_tag']}))
test_chat = create_chat_hist(self.question, self.tags, None, self.config)
self.assertEqual(len(test_chat), 4)
self.assertEqual(test_chat[0],
{'role': 'system', 'content': self.config.system})
self.assertEqual(test_chat[1],
{'role': 'user', 'content': 'test_content'})
self.assertEqual(test_chat[2],
{'role': 'assistant', 'content': 'some answer'})
self.assertEqual(test_chat[3],
{'role': 'user', 'content': self.question})
@patch('os.listdir')
@patch('pathlib.Path.iterdir')
@patch('builtins.open')
def test_create_chat_with_other_tags(self, open_mock: MagicMock, iterdir_mock: MagicMock, listdir_mock: MagicMock) -> None:
listdir_mock.return_value = ['testfile.txt']
iterdir_mock.return_value = [pathlib.Path(x) for x in listdir_mock.return_value]
open_mock.return_value.__enter__.return_value = io.StringIO(dump_data(
{'question': 'test_content', 'answer': 'some answer',
'tags': ['other_tag']}))
test_chat = create_chat_hist(self.question, self.tags, None, self.config)
self.assertEqual(len(test_chat), 2)
self.assertEqual(test_chat[0],
{'role': 'system', 'content': self.config.system})
self.assertEqual(test_chat[1],
{'role': 'user', 'content': self.question})
@patch('os.listdir')
@patch('pathlib.Path.iterdir')
@patch('builtins.open')
def test_create_chat_without_tags(self, open_mock: MagicMock, iterdir_mock: MagicMock, listdir_mock: MagicMock) -> None:
listdir_mock.return_value = ['testfile.txt', 'testfile2.txt']
iterdir_mock.return_value = [pathlib.Path(x) for x in listdir_mock.return_value]
open_mock.side_effect = (
io.StringIO(dump_data({'question': 'test_content',
'answer': 'some answer',
'tags': ['test_tag']})),
io.StringIO(dump_data({'question': 'test_content2',
'answer': 'some answer2',
'tags': ['test_tag2']})),
)
test_chat = create_chat_hist(self.question, [], None, self.config)
self.assertEqual(len(test_chat), 6)
self.assertEqual(test_chat[0],
{'role': 'system', 'content': self.config.system})
self.assertEqual(test_chat[1],
{'role': 'user', 'content': 'test_content'})
self.assertEqual(test_chat[2],
{'role': 'assistant', 'content': 'some answer'})
self.assertEqual(test_chat[3],
{'role': 'user', 'content': 'test_content2'})
self.assertEqual(test_chat[4],
{'role': 'assistant', 'content': 'some answer2'})
class TestHandleQuestion(CmmTestCase):
def setUp(self) -> None:
self.question = "test question"
self.args = argparse.Namespace(
tags=['tag1'],
extags=['extag1'],
output_tags=None,
question=[self.question],
source=None,
only_source_code=False,
number=3,
max_tokens=None,
temperature=None,
model=None,
match_all_tags=False,
with_tags=False,
with_file=False,
)
self.config = self.dummy_config(db='test_files')
@patch("chatmastermind.main.create_chat_hist", return_value="test_chat")
@patch("chatmastermind.main.print_tag_args")
@patch("chatmastermind.main.print_chat_hist")
@patch("chatmastermind.main.ai", return_value=(["answer1", "answer2", "answer3"], "test_usage"))
@patch("chatmastermind.utils.pp")
@patch("builtins.print")
def test_ask_cmd(self, mock_print: MagicMock, mock_pp: MagicMock, mock_ai: MagicMock,
mock_print_chat_hist: MagicMock, mock_print_tag_args: MagicMock,
mock_create_chat_hist: MagicMock) -> None:
open_mock = MagicMock()
with patch("chatmastermind.storage.open", open_mock):
ask_cmd(self.args, self.config)
mock_print_tag_args.assert_called_once_with(self.args.tags,
self.args.extags,
[])
mock_create_chat_hist.assert_called_once_with(self.question,
self.args.tags,
self.args.extags,
self.config,
False, False, False)
mock_print_chat_hist.assert_called_once_with('test_chat',
False,
self.args.only_source_code)
mock_ai.assert_called_with("test_chat",
self.config,
self.args.number)
expected_calls = []
for num, answer in enumerate(mock_ai.return_value[0], start=1):
title = f'-- ANSWER {num} '
title_end = '-' * (terminal_width() - len(title))
expected_calls.append(((f'{title}{title_end}',),))
expected_calls.append(((answer,),))
expected_calls.append((("-" * terminal_width(),),))
expected_calls.append(((f"Usage: {mock_ai.return_value[1]}",),))
self.assertEqual(mock_print.call_args_list, expected_calls)
open_expected_calls = list([mock.call(f"{num:04d}.txt", "w") for num in range(2, 5)])
open_mock.assert_has_calls(open_expected_calls, any_order=True)
class TestSaveAnswers(CmmTestCase):
@mock.patch('builtins.open')
@mock.patch('chatmastermind.storage.print')
def test_save_answers(self, print_mock: MagicMock, open_mock: MagicMock) -> None:
question = "Test question?"
answers = ["Answer 1", "Answer 2"]
tags = ["tag1", "tag2"]
otags = ["otag1", "otag2"]
config = self.dummy_config(db='test_db')
with mock.patch('chatmastermind.storage.pathlib.Path.exists', return_value=True), \
mock.patch('chatmastermind.storage.yaml.dump'), \
mock.patch('io.StringIO') as stringio_mock:
stringio_instance = stringio_mock.return_value
stringio_instance.getvalue.side_effect = ["question", "answer1", "answer2"]
save_answers(question, answers, tags, otags, config)
open_calls = [
mock.call(pathlib.Path('test_db/.next'), 'r'),
mock.call(pathlib.Path('test_db/.next'), 'w'),
]
open_mock.assert_has_calls(open_calls, any_order=True)
class TestAI(CmmTestCase):
@patch("openai.ChatCompletion.create")
def test_ai(self, mock_create: MagicMock) -> None:
mock_create.return_value = {
'choices': [
{'message': {'content': 'response_text_1'}},
{'message': {'content': 'response_text_2'}}
],
'usage': {'tokens': 10}
}
chat = [{"role": "system", "content": "hello ai"}]
config = self.dummy_config(db='dummy')
config.openai.model = "text-davinci-002"
config.openai.max_tokens = 150
config.openai.temperature = 0.5
result = ai(chat, config, 2)
expected_result = (['response_text_1', 'response_text_2'],
{'tokens': 10})
self.assertEqual(result, expected_result)
class TestCreateParser(CmmTestCase):
def test_create_parser(self) -> None:
with patch('argparse.ArgumentParser.add_subparsers') as mock_add_subparsers:
mock_cmdparser = Mock()
mock_add_subparsers.return_value = mock_cmdparser
parser = create_parser()
self.assertIsInstance(parser, argparse.ArgumentParser)
mock_add_subparsers.assert_called_once_with(dest='command', title='commands', description='supported commands', required=True)
mock_cmdparser.add_parser.assert_any_call('ask', parents=ANY, help=ANY, aliases=ANY)
mock_cmdparser.add_parser.assert_any_call('hist', parents=ANY, help=ANY, aliases=ANY)
mock_cmdparser.add_parser.assert_any_call('tag', help=ANY, aliases=ANY)
mock_cmdparser.add_parser.assert_any_call('config', help=ANY, aliases=ANY)
mock_cmdparser.add_parser.assert_any_call('print', help=ANY, aliases=ANY)
self.assertTrue('.config.yaml' in parser.get_default('config'))
+21 -80
View File
@@ -1,12 +1,12 @@
import unittest
import pathlib import pathlib
import tempfile import tempfile
from typing import cast from typing import cast
from .test_main import CmmTestCase
from chatmastermind.message import source_code, Message, MessageError, Question, Answer, AILine, ModelLine, MessageFilter, message_in from chatmastermind.message import source_code, Message, MessageError, Question, Answer, AILine, ModelLine, MessageFilter, message_in
from chatmastermind.tags import Tag, TagLine from chatmastermind.tags import Tag, TagLine
class SourceCodeTestCase(unittest.TestCase): class SourceCodeTestCase(CmmTestCase):
def test_source_code_with_include_delims(self) -> None: def test_source_code_with_include_delims(self) -> None:
text = """ text = """
Some text before the code block Some text before the code block
@@ -60,46 +60,29 @@ class SourceCodeTestCase(unittest.TestCase):
self.assertEqual(result, expected_result) self.assertEqual(result, expected_result)
class QuestionTestCase(unittest.TestCase): class QuestionTestCase(CmmTestCase):
def test_question_with_header(self) -> None: def test_question_with_prefix(self) -> None:
with self.assertRaises(MessageError): with self.assertRaises(MessageError):
Question(f"{Question.txt_header}\nWhat is your name?") Question("=== QUESTION === What is your name?")
def test_question_with_answer_header(self) -> None: def test_question_without_prefix(self) -> None:
with self.assertRaises(MessageError):
Question(f"{Answer.txt_header}\nBob")
def test_question_with_legal_header(self) -> None:
"""
If the header is just a part of a line, it's fine.
"""
question = Question(f"This is a line contaning '{Question.txt_header}'\nWhat does that mean?")
self.assertIsInstance(question, Question)
self.assertEqual(question, f"This is a line contaning '{Question.txt_header}'\nWhat does that mean?")
def test_question_without_header(self) -> None:
question = Question("What is your favorite color?") question = Question("What is your favorite color?")
self.assertIsInstance(question, Question) self.assertIsInstance(question, Question)
self.assertEqual(question, "What is your favorite color?") self.assertEqual(question, "What is your favorite color?")
class AnswerTestCase(unittest.TestCase): class AnswerTestCase(CmmTestCase):
def test_answer_with_header(self) -> None: def test_answer_with_prefix(self) -> None:
with self.assertRaises(MessageError): with self.assertRaises(MessageError):
Answer(f"{Answer.txt_header}\nno") Answer("=== ANSWER === Yes")
def test_answer_with_legal_header(self) -> None: def test_answer_without_prefix(self) -> None:
answer = Answer(f"This is a line contaning '{Answer.txt_header}'\nIt is what it is.")
self.assertIsInstance(answer, Answer)
self.assertEqual(answer, f"This is a line contaning '{Answer.txt_header}'\nIt is what it is.")
def test_answer_without_header(self) -> None:
answer = Answer("No") answer = Answer("No")
self.assertIsInstance(answer, Answer) self.assertIsInstance(answer, Answer)
self.assertEqual(answer, "No") self.assertEqual(answer, "No")
class MessageToFileTxtTestCase(unittest.TestCase): class MessageToFileTxtTestCase(CmmTestCase):
def setUp(self) -> None: def setUp(self) -> None:
self.file = tempfile.NamedTemporaryFile(delete=False, suffix='.txt') self.file = tempfile.NamedTemporaryFile(delete=False, suffix='.txt')
self.file_path = pathlib.Path(self.file.name) self.file_path = pathlib.Path(self.file.name)
@@ -160,7 +143,7 @@ This is a question.
self.message_complete.file_path = self.file_path self.message_complete.file_path = self.file_path
class MessageToFileYamlTestCase(unittest.TestCase): class MessageToFileYamlTestCase(CmmTestCase):
def setUp(self) -> None: def setUp(self) -> None:
self.file = tempfile.NamedTemporaryFile(delete=False, suffix='.yaml') self.file = tempfile.NamedTemporaryFile(delete=False, suffix='.yaml')
self.file_path = pathlib.Path(self.file.name) self.file_path = pathlib.Path(self.file.name)
@@ -226,7 +209,7 @@ class MessageToFileYamlTestCase(unittest.TestCase):
self.assertEqual(content, expected_content) self.assertEqual(content, expected_content)
class MessageFromFileTxtTestCase(unittest.TestCase): class MessageFromFileTxtTestCase(CmmTestCase):
def setUp(self) -> None: def setUp(self) -> None:
self.file = tempfile.NamedTemporaryFile(delete=False, suffix='.txt') self.file = tempfile.NamedTemporaryFile(delete=False, suffix='.txt')
self.file_path = pathlib.Path(self.file.name) self.file_path = pathlib.Path(self.file.name)
@@ -300,12 +283,6 @@ This is a question.
MessageFilter(tags_or={Tag('tag1')})) MessageFilter(tags_or={Tag('tag1')}))
self.assertIsNone(message) self.assertIsNone(message)
def test_from_file_txt_empty_tags_dont_match(self) -> None:
message = Message.from_file(self.file_path_min,
MessageFilter(tags_or=set(),
tags_and=set()))
self.assertIsNone(message)
def test_from_file_txt_no_tags_match_tags_not(self) -> None: def test_from_file_txt_no_tags_match_tags_not(self) -> None:
message = Message.from_file(self.file_path_min, message = Message.from_file(self.file_path_min,
MessageFilter(tags_not={Tag('tag1')})) MessageFilter(tags_not={Tag('tag1')}))
@@ -394,7 +371,7 @@ This is a question.
self.assertIsNone(message) self.assertIsNone(message)
class MessageFromFileYamlTestCase(unittest.TestCase): class MessageFromFileYamlTestCase(CmmTestCase):
def setUp(self) -> None: def setUp(self) -> None:
self.file = tempfile.NamedTemporaryFile(delete=False, suffix='.yaml') self.file = tempfile.NamedTemporaryFile(delete=False, suffix='.yaml')
self.file_path = pathlib.Path(self.file.name) self.file_path = pathlib.Path(self.file.name)
@@ -561,7 +538,7 @@ class MessageFromFileYamlTestCase(unittest.TestCase):
self.assertIsNone(message) self.assertIsNone(message)
class TagsFromFileTestCase(unittest.TestCase): class TagsFromFileTestCase(CmmTestCase):
def setUp(self) -> None: def setUp(self) -> None:
self.file_txt = tempfile.NamedTemporaryFile(delete=False, suffix='.txt') self.file_txt = tempfile.NamedTemporaryFile(delete=False, suffix='.txt')
self.file_path_txt = pathlib.Path(self.file_txt.name) self.file_path_txt = pathlib.Path(self.file_txt.name)
@@ -669,7 +646,7 @@ This is an answer.
self.assertSetEqual(tags, set()) self.assertSetEqual(tags, set())
class TagsFromDirTestCase(unittest.TestCase): class TagsFromDirTestCase(CmmTestCase):
def setUp(self) -> None: def setUp(self) -> None:
self.temp_dir = tempfile.TemporaryDirectory() self.temp_dir = tempfile.TemporaryDirectory()
self.temp_dir_no_tags = tempfile.TemporaryDirectory() self.temp_dir_no_tags = tempfile.TemporaryDirectory()
@@ -717,7 +694,7 @@ class TagsFromDirTestCase(unittest.TestCase):
self.assertSetEqual(all_tags, set()) self.assertSetEqual(all_tags, set())
class MessageIDTestCase(unittest.TestCase): class MessageIDTestCase(CmmTestCase):
def setUp(self) -> None: def setUp(self) -> None:
self.file = tempfile.NamedTemporaryFile(delete=False, suffix='.txt') self.file = tempfile.NamedTemporaryFile(delete=False, suffix='.txt')
self.file_path = pathlib.Path(self.file.name) self.file_path = pathlib.Path(self.file.name)
@@ -737,7 +714,7 @@ class MessageIDTestCase(unittest.TestCase):
self.message_no_file_path.msg_id() self.message_no_file_path.msg_id()
class MessageHashTestCase(unittest.TestCase): class MessageHashTestCase(CmmTestCase):
def setUp(self) -> None: def setUp(self) -> None:
self.message1 = Message(Question('This is a question.'), self.message1 = Message(Question('This is a question.'),
tags={Tag('tag1')}, tags={Tag('tag1')},
@@ -761,7 +738,7 @@ class MessageHashTestCase(unittest.TestCase):
self.assertIn(msg, msgs) self.assertIn(msg, msgs)
class MessageTagsStrTestCase(unittest.TestCase): class MessageTagsStrTestCase(CmmTestCase):
def setUp(self) -> None: def setUp(self) -> None:
self.message = Message(Question('This is a question.'), self.message = Message(Question('This is a question.'),
tags={Tag('tag1')}, tags={Tag('tag1')},
@@ -771,7 +748,7 @@ class MessageTagsStrTestCase(unittest.TestCase):
self.assertEqual(self.message.tags_str(), f'{TagLine.prefix} tag1') self.assertEqual(self.message.tags_str(), f'{TagLine.prefix} tag1')
class MessageFilterTagsTestCase(unittest.TestCase): class MessageFilterTagsTestCase(CmmTestCase):
def setUp(self) -> None: def setUp(self) -> None:
self.message = Message(Question('This is a question.'), self.message = Message(Question('This is a question.'),
tags={Tag('atag1'), Tag('btag2')}, tags={Tag('atag1'), Tag('btag2')},
@@ -786,7 +763,7 @@ class MessageFilterTagsTestCase(unittest.TestCase):
self.assertSetEqual(tags_cont, {Tag('btag2')}) self.assertSetEqual(tags_cont, {Tag('btag2')})
class MessageInTestCase(unittest.TestCase): class MessageInTestCase(CmmTestCase):
def setUp(self) -> None: def setUp(self) -> None:
self.message1 = Message(Question('This is a question.'), self.message1 = Message(Question('This is a question.'),
tags={Tag('atag1'), Tag('btag2')}, tags={Tag('atag1'), Tag('btag2')},
@@ -798,39 +775,3 @@ class MessageInTestCase(unittest.TestCase):
def test_message_in(self) -> None: def test_message_in(self) -> None:
self.assertTrue(message_in(self.message1, [self.message1])) self.assertTrue(message_in(self.message1, [self.message1]))
self.assertFalse(message_in(self.message1, [self.message2])) self.assertFalse(message_in(self.message1, [self.message2]))
class MessageRenameTagsTestCase(unittest.TestCase):
def setUp(self) -> None:
self.message = Message(Question('This is a question.'),
tags={Tag('atag1'), Tag('btag2')},
file_path=pathlib.Path('/tmp/foo/bla'))
def test_rename_tags(self) -> None:
self.message.rename_tags({(Tag('atag1'), Tag('atag2')), (Tag('btag2'), Tag('btag3'))})
self.assertIsNotNone(self.message.tags)
self.assertSetEqual(self.message.tags, {Tag('atag2'), Tag('btag3')}) # type: ignore [arg-type]
class MessageToStrTestCase(unittest.TestCase):
def setUp(self) -> None:
self.message = Message(Question('This is a question.'),
Answer('This is an answer.'),
tags={Tag('atag1'), Tag('btag2')},
file_path=pathlib.Path('/tmp/foo/bla'))
def test_to_str(self) -> None:
expected_output = f"""{Question.txt_header}
This is a question.
{Answer.txt_header}
This is an answer."""
self.assertEqual(self.message.to_str(), expected_output)
def test_to_str_with_tags_and_file(self) -> None:
expected_output = f"""{TagLine.prefix} atag1 btag2
FILE: /tmp/foo/bla
{Question.txt_header}
This is a question.
{Answer.txt_header}
This is an answer."""
self.assertEqual(self.message.to_str(with_tags=True, with_file=True), expected_output)
-162
View File
@@ -1,162 +0,0 @@
import os
import unittest
import argparse
import tempfile
from pathlib import Path
from unittest.mock import MagicMock
from chatmastermind.commands.question import create_message
from chatmastermind.message import Message, Question
from chatmastermind.chat import ChatDB
class TestMessageCreate(unittest.TestCase):
"""
Test if messages created by the 'question' command have
the correct format.
"""
def setUp(self) -> None:
# create ChatDB structure
self.db_path = tempfile.TemporaryDirectory()
self.cache_path = tempfile.TemporaryDirectory()
self.chat = ChatDB.from_dir(cache_path=Path(self.cache_path.name),
db_path=Path(self.db_path.name))
# create arguments mock
self.args = MagicMock(spec=argparse.Namespace)
self.args.source_text = None
self.args.source_code = None
self.args.AI = None
self.args.model = None
self.args.output_tags = None
# File 1 : no source code block, only text
self.source_file1 = tempfile.NamedTemporaryFile(delete=False)
self.source_file1_content = """This is just text.
No source code.
Nope. Go look elsewhere!"""
with open(self.source_file1.name, 'w') as f:
f.write(self.source_file1_content)
# File 2 : one embedded source code block
self.source_file2 = tempfile.NamedTemporaryFile(delete=False)
self.source_file2_content = """This is just text.
```
This is embedded source code.
```
And some text again."""
with open(self.source_file2.name, 'w') as f:
f.write(self.source_file2_content)
# File 3 : all source code
self.source_file3 = tempfile.NamedTemporaryFile(delete=False)
self.source_file3_content = """This is all source code.
Yes, really.
Language is called 'brainfart'."""
with open(self.source_file3.name, 'w') as f:
f.write(self.source_file3_content)
# File 4 : two source code blocks
self.source_file4 = tempfile.NamedTemporaryFile(delete=False)
self.source_file4_content = """This is just text.
```
This is embedded source code.
```
And some text again.
```
This is embedded source code.
```
Aaaand again some text."""
with open(self.source_file4.name, 'w') as f:
f.write(self.source_file4_content)
def tearDown(self) -> None:
os.remove(self.source_file1.name)
os.remove(self.source_file2.name)
os.remove(self.source_file3.name)
def message_list(self, tmp_dir: tempfile.TemporaryDirectory) -> list[Path]:
# exclude '.next'
return list(Path(tmp_dir.name).glob('*.[ty]*'))
def test_message_file_created(self) -> None:
self.args.ask = ["What is this?"]
cache_dir_files = self.message_list(self.cache_path)
self.assertEqual(len(cache_dir_files), 0)
create_message(self.chat, self.args)
cache_dir_files = self.message_list(self.cache_path)
self.assertEqual(len(cache_dir_files), 1)
message = Message.from_file(cache_dir_files[0])
self.assertIsInstance(message, Message)
self.assertEqual(message.question, Question("What is this?")) # type: ignore [union-attr]
def test_single_question(self) -> None:
self.args.ask = ["What is this?"]
message = create_message(self.chat, self.args)
self.assertIsInstance(message, Message)
self.assertEqual(message.question, Question("What is this?"))
self.assertEqual(len(message.question.source_code()), 0)
def test_multipart_question(self) -> None:
self.args.ask = ["What is this", "'bard' thing?", "Is it good?"]
message = create_message(self.chat, self.args)
self.assertIsInstance(message, Message)
self.assertEqual(message.question, Question("""What is this
'bard' thing?
Is it good?"""))
def test_single_question_with_text_only_file(self) -> None:
self.args.ask = ["What is this?"]
self.args.source_text = [f"{self.source_file1.name}"]
message = create_message(self.chat, self.args)
self.assertIsInstance(message, Message)
# file contains no source code (only text)
# -> don't expect any in the question
self.assertEqual(len(message.question.source_code()), 0)
self.assertEqual(message.question, Question(f"""What is this?
{self.source_file1_content}"""))
def test_single_question_with_text_file_and_embedded_code(self) -> None:
self.args.ask = ["What is this?"]
self.args.source_code = [f"{self.source_file2.name}"]
message = create_message(self.chat, self.args)
self.assertIsInstance(message, Message)
# file contains 1 source code block
# -> expect it in the question
self.assertEqual(len(message.question.source_code()), 1)
self.assertEqual(message.question, Question("""What is this?
```
This is embedded source code.
```
"""))
def test_single_question_with_code_only_file(self) -> None:
self.args.ask = ["What is this?"]
self.args.source_code = [f"{self.source_file3.name}"]
message = create_message(self.chat, self.args)
self.assertIsInstance(message, Message)
# file is complete source code
self.assertEqual(len(message.question.source_code()), 1)
self.assertEqual(message.question, Question(f"""What is this?
```
{self.source_file3_content}
```"""))
def test_single_question_with_text_file_and_multi_embedded_code(self) -> None:
self.args.ask = ["What is this?"]
self.args.source_code = [f"{self.source_file4.name}"]
message = create_message(self.chat, self.args)
self.assertIsInstance(message, Message)
# file contains 2 source code blocks
# -> expect them in the question
self.assertEqual(len(message.question.source_code()), 2)
self.assertEqual(message.question, Question("""What is this?
```
This is embedded source code.
```
```
This is embedded source code.
```
"""))
+3 -20
View File
@@ -1,8 +1,8 @@
import unittest from .test_main import CmmTestCase
from chatmastermind.tags import Tag, TagLine, TagError from chatmastermind.tags import Tag, TagLine, TagError
class TestTag(unittest.TestCase): class TestTag(CmmTestCase):
def test_valid_tag(self) -> None: def test_valid_tag(self) -> None:
tag = Tag('mytag') tag = Tag('mytag')
self.assertEqual(tag, 'mytag') self.assertEqual(tag, 'mytag')
@@ -18,7 +18,7 @@ class TestTag(unittest.TestCase):
self.assertEqual(Tag.alternative_separators, [',']) self.assertEqual(Tag.alternative_separators, [','])
class TestTagLine(unittest.TestCase): class TestTagLine(CmmTestCase):
def test_valid_tagline(self) -> None: def test_valid_tagline(self) -> None:
tagline = TagLine('TAGS: tag1 tag2') tagline = TagLine('TAGS: tag1 tag2')
self.assertEqual(tagline, 'TAGS: tag1 tag2') self.assertEqual(tagline, 'TAGS: tag1 tag2')
@@ -144,20 +144,3 @@ class TestTagLine(unittest.TestCase):
# Test case 9: 'tags_or' and 'tags_and' are None, match all tags except excluded tags # Test case 9: 'tags_or' and 'tags_and' are None, match all tags except excluded tags
tags_not = {Tag('tag2')} tags_not = {Tag('tag2')}
self.assertFalse(tagline.match_tags(None, None, tags_not)) self.assertFalse(tagline.match_tags(None, None, tags_not))
# Test case 10: 'tags_or' and 'tags_and' are empty, match no tags
self.assertFalse(tagline.match_tags(set(), set(), None))
# Test case 11: 'tags_or' is empty, match no tags
self.assertFalse(tagline.match_tags(set(), None, None))
# Test case 12: 'tags_and' is empty, match no tags
self.assertFalse(tagline.match_tags(None, set(), None))
# Test case 13: 'tags_or' is empty, match 'tags_and'
tags_and = {Tag('tag1'), Tag('tag2')}
self.assertTrue(tagline.match_tags(None, tags_and, None))
# Test case 14: 'tags_and' is empty, match 'tags_or'
tags_or = {Tag('tag1'), Tag('tag2')}
self.assertTrue(tagline.match_tags(tags_or, None, None))