20 Commits

Author SHA1 Message Date
juk0de d4021eeb11 configuration: made 'default' AI ID optional 2023-09-11 07:38:49 +02:00
juk0de c143c001f9 configuration: improved config file format 2023-09-10 19:57:06 +02:00
juk0de 59b851650a question_cmd: when no tags are specified, no tags are selected 2023-09-10 19:57:06 +02:00
juk0de 6f71a2ff69 message: to_file() now uses intermediate temporary file 2023-09-10 19:57:06 +02:00
juk0de eca44b14cb message: fixed matching with empty tag sets 2023-09-10 19:55:21 +02:00
juk0de b48667bfa0 openai: stores AI.ID instead of AI.name in message 2023-09-10 19:18:44 +02:00
juk0de 533ee1c1a9 question_cmd: added message filtering by tags 2023-09-10 19:18:44 +02:00
juk0de cf50818f28 question_cmd: fixed '--ask' command 2023-09-10 19:18:44 +02:00
juk0de dd3d3ffc82 chat: added check for existing files when creating new filenames 2023-09-10 19:18:44 +02:00
juk0de 1e3bfdd67f chat: added 'update_messages()' function and test 2023-09-10 19:14:11 +02:00
juk0de 53582a7123 question_cmd: fixed source code extraction and added a testcase 2023-09-10 19:14:11 +02:00
Oleksandr Kozachuk 39b518a8a6 Small fixes. 2023-09-09 16:05:27 +02:00
Oleksandr Kozachuk d22877a0f1 Port print arguments -q/-a/-S from main to restructuring. 2023-09-09 15:38:40 +02:00
Oleksandr Kozachuk 7cf62c54ef Allow in question -s for just sourcing file and -S to source file with ``` encapsulation. 2023-09-09 15:16:17 +02:00
juk0de 5fb5dde550 question cmd: added tests 2023-09-09 09:12:21 +02:00
juk0de c0b7d17587 question_cmd: fixes 2023-09-09 08:51:44 +02:00
juk0de 76f2373397 configuration: added tests 2023-09-09 08:31:45 +02:00
juk0de eaa399bcb9 configuration et al: implemented new Config format 2023-09-09 08:31:45 +02:00
juk0de b1a23394fc cmm: splitted commands into separate modules (and more cleanup) 2023-09-09 08:31:45 +02:00
juk0de 2df9dd6427 cmm: removed all the old code and modules 2023-09-08 13:04:11 +02:00
22 changed files with 663 additions and 648 deletions
+6
View File
@@ -66,3 +66,9 @@ class AI(Protocol):
and is not implemented for all AIs. and is not implemented for all AIs.
""" """
raise NotImplementedError raise NotImplementedError
def print(self) -> None:
"""
Print some info about the current AI, like system message.
"""
pass
+27 -10
View File
@@ -4,23 +4,40 @@ Creates different AI instances, based on the given configuration.
import argparse import argparse
from typing import cast from typing import cast
from .configuration import Config, OpenAIConfig, default_ai_ID from .configuration import Config, AIConfig, OpenAIConfig
from .ai import AI, AIError from .ai import AI, AIError
from .ais.openai import OpenAI from .ais.openai import OpenAI
def create_ai(args: argparse.Namespace, config: Config) -> AI: def create_ai(args: argparse.Namespace, config: Config) -> AI: # noqa: 11
""" """
Creates an AI subclass instance from the given args and configuration. Creates an AI subclass instance from the given arguments
and configuration file. If AI has not been set in the
arguments, it searches for the ID 'default'. If that
is not found, it uses the first AI in the list.
""" """
if args.ai: ai_conf: AIConfig
ai_conf = config.ais[args.ai] if args.AI:
elif default_ai_ID in config.ais: try:
ai_conf = config.ais[default_ai_ID] ai_conf = config.ais[args.AI]
except KeyError:
raise AIError(f"AI ID '{args.AI}' does not exist in this configuration")
elif 'default' in config.ais:
ai_conf = config.ais['default']
else: else:
raise AIError("No AI name given and no default exists") try:
ai_conf = next(iter(config.ais.values()))
except StopIteration:
raise AIError("No AI found in this configuration")
if ai_conf.name == 'openai': if ai_conf.name == 'openai':
return OpenAI(cast(OpenAIConfig, ai_conf)) ai = OpenAI(cast(OpenAIConfig, ai_conf))
if args.model:
ai.config.model = args.model
if args.max_tokens:
ai.config.max_tokens = args.max_tokens
if args.temperature:
ai.config.temperature = args.temperature
return ai
else: else:
raise AIError(f"AI '{args.ai}' is not supported") raise AIError(f"AI '{args.AI}' is not supported")
+15 -6
View File
@@ -43,16 +43,20 @@ class OpenAI(AI):
n=num_answers, n=num_answers,
frequency_penalty=self.config.frequency_penalty, frequency_penalty=self.config.frequency_penalty,
presence_penalty=self.config.presence_penalty) presence_penalty=self.config.presence_penalty)
answers: list[Message] = [] question.answer = Answer(response['choices'][0]['message']['content'])
for choice in response['choices']: # type: ignore question.tags = otags
question.ai = self.ID
question.model = self.config.model
answers: list[Message] = [question]
for choice in response['choices'][1:]: # type: ignore
answers.append(Message(question=question.question, answers.append(Message(question=question.question,
answer=Answer(choice['message']['content']), answer=Answer(choice['message']['content']),
tags=otags, tags=otags,
ai=self.name, ai=self.ID,
model=self.config.model)) model=self.config.model))
return AIResponse(answers, Tokens(response['usage']['prompt'], return AIResponse(answers, Tokens(response['usage']['prompt_tokens'],
response['usage']['completion'], response['usage']['completion_tokens'],
response['usage']['total'])) response['usage']['total_tokens']))
def models(self) -> list[str]: def models(self) -> list[str]:
""" """
@@ -95,3 +99,8 @@ class OpenAI(AI):
def tokens(self, data: Union[Message, Chat]) -> int: def tokens(self, data: Union[Message, Chat]) -> int:
raise NotImplementedError raise NotImplementedError
def print(self) -> None:
print(f"MODEL: {self.config.model}")
print("=== SYSTEM ===")
print(self.config.system)
-45
View File
@@ -1,45 +0,0 @@
import openai
from .utils import ChatType
from .configuration import Config
def openai_api_key(api_key: str) -> None:
openai.api_key = api_key
def print_models() -> None:
"""
Print all models supported by the current AI.
"""
not_ready = []
for engine in sorted(openai.Engine.list()['data'], key=lambda x: x['id']):
if engine['ready']:
print(engine['id'])
else:
not_ready.append(engine['id'])
if len(not_ready) > 0:
print('\nNot ready: ' + ', '.join(not_ready))
def ai(chat: ChatType,
config: Config,
number: int
) -> tuple[list[str], dict[str, int]]:
"""
Make AI request with the given chat history and configuration.
Return AI response and tokens used.
"""
response = openai.ChatCompletion.create(
model=config.openai.model,
messages=chat,
temperature=config.openai.temperature,
max_tokens=config.openai.max_tokens,
top_p=config.openai.top_p,
n=number,
frequency_penalty=config.openai.frequency_penalty,
presence_penalty=config.openai.presence_penalty)
result = []
for choice in response['choices']: # type: ignore
result.append(choice['message']['content'].strip())
return result, dict(response['usage']) # type: ignore
+20 -1
View File
@@ -62,7 +62,10 @@ def make_file_path(dir_path: Path,
Create a file_path for the given directory using the Create a file_path for the given directory using the
given file_suffix and ID generator function. given file_suffix and ID generator function.
""" """
return dir_path / f"{next_fid():04d}{file_suffix}" file_path = dir_path / f"{next_fid():04d}{file_suffix}"
while file_path.exists():
file_path = dir_path / f"{next_fid():04d}{file_suffix}"
return file_path
def write_dir(dir_path: Path, def write_dir(dir_path: Path,
@@ -386,3 +389,19 @@ class ChatDB(Chat):
msgs = iter(messages if messages else self.messages) msgs = iter(messages if messages else self.messages)
while (m := next(msgs, None)): while (m := next(msgs, None)):
m.to_file() m.to_file()
def update_messages(self, messages: list[Message], write: bool = True) -> None:
"""
Update existing messages. A message is determined as 'existing' if a message with
the same base filename (i. e. 'file_path.name') is already in the list. Only accepts
existing messages.
"""
if any(not message_in(m, self.messages) for m in messages):
raise ChatError("Can't update messages that are not in the internal list")
# remove old versions and add new ones
self.messages = [m for m in self.messages if not message_in(m, messages)]
self.messages += messages
self.sort()
# write the UPDATED messages if requested
if write:
self.write_messages(messages)
+11
View File
@@ -0,0 +1,11 @@
import argparse
from pathlib import Path
from ..configuration import Config
def config_cmd(args: argparse.Namespace) -> None:
"""
Handler for the 'config' command.
"""
if args.create:
Config.create_default(Path(args.create))
+23
View File
@@ -0,0 +1,23 @@
import argparse
from pathlib import Path
from ..configuration import Config
from ..chat import ChatDB
from ..message import MessageFilter
def hist_cmd(args: argparse.Namespace, config: Config) -> None:
"""
Handler for the 'hist' command.
"""
mfilter = MessageFilter(tags_or=args.or_tags,
tags_and=args.and_tags,
tags_not=args.exclude_tags,
question_contains=args.question,
answer_contains=args.answer)
chat = ChatDB.from_dir(Path('.'),
Path(config.db),
mfilter=mfilter)
chat.print(args.source_code_only,
args.with_tags,
args.with_files)
+27
View File
@@ -0,0 +1,27 @@
import sys
import argparse
from pathlib import Path
from ..configuration import Config
from ..message import Message, MessageError
def print_cmd(args: argparse.Namespace, config: Config) -> None:
"""
Handler for the 'print' command.
"""
fname = Path(args.file)
try:
message = Message.from_file(fname)
if message:
if args.question:
print(message.question)
elif args.answer:
print(message.answer)
elif message.answer and args.only_source_code:
for code in message.answer.source_code():
print(code)
else:
print(message.to_str())
except MessageError:
print(f"File is not a valid message: {args.file}")
sys.exit(1)
+94
View File
@@ -0,0 +1,94 @@
import argparse
from pathlib import Path
from itertools import zip_longest
from ..configuration import Config
from ..chat import ChatDB
from ..message import Message, MessageFilter, Question, source_code
from ..ai_factory import create_ai
from ..ai import AI, AIResponse
def create_message(chat: ChatDB, args: argparse.Namespace) -> Message:
"""
Creates (and writes) a new message from the given arguments.
"""
question_parts = []
question_list = args.ask if args.ask is not None else []
text_files = args.source_text if args.source_text is not None else []
code_files = args.source_code if args.source_code is not None else []
for question, source, code in zip_longest(question_list, text_files, code_files, fillvalue=None):
if question is not None and len(question.strip()) > 0:
question_parts.append(question)
if source is not None and len(source) > 0:
with open(source) as r:
content = r.read().strip()
if len(content) > 0:
question_parts.append(content)
if code is not None and len(code) > 0:
with open(code) as r:
content = r.read().strip()
if len(content) == 0:
continue
# try to extract and add source code
code_parts = source_code(content, include_delims=True)
if len(code_parts) > 0:
question_parts += code_parts
# if there's none, add the whole file
else:
question_parts.append(f"```\n{content}\n```")
full_question = '\n\n'.join(question_parts)
message = Message(question=Question(full_question),
tags=args.output_tags, # FIXME
ai=args.AI,
model=args.model)
chat.add_to_cache([message])
return message
def question_cmd(args: argparse.Namespace, config: Config) -> None:
"""
Handler for the 'question' command.
"""
mfilter = MessageFilter(tags_or=args.or_tags if args.or_tags is not None else set(),
tags_and=args.and_tags if args.and_tags is not None else set(),
tags_not=args.exclude_tags if args.exclude_tags is not None else set())
chat = ChatDB.from_dir(cache_path=Path('.'),
db_path=Path(config.db),
mfilter=mfilter)
# if it's a new question, create and store it immediately
if args.ask or args.create:
message = create_message(chat, args)
if args.create:
return
# create the correct AI instance
ai: AI = create_ai(args, config)
if args.ask:
ai.print()
chat.print(paged=False)
response: AIResponse = ai.request(message,
chat,
args.num_answers, # FIXME
args.output_tags) # FIXME
chat.update_messages([response.messages[0]])
chat.add_to_cache(response.messages[1:])
for idx, msg in enumerate(response.messages):
print(f"=== ANSWER {idx+1} ===")
print(msg.answer)
if response.tokens:
print("===============")
print(response.tokens)
elif args.repeat is not None:
lmessage = chat.latest_message()
assert lmessage
# TODO: repeat either the last question or the
# one(s) given in 'args.repeat' (overwrite
# existing ones if 'args.overwrite' is True)
pass
elif args.process is not None:
# TODO: process either all questions without an
# answer or the one(s) given in 'args.process'
pass
+17
View File
@@ -0,0 +1,17 @@
import argparse
from pathlib import Path
from ..configuration import Config
from ..chat import ChatDB
def tags_cmd(args: argparse.Namespace, config: Config) -> None:
"""
Handler for the 'tags' command.
"""
chat = ChatDB.from_dir(cache_path=Path('.'),
db_path=Path(config.db))
if args.list:
tags_freq = chat.tags_frequency(args.prefix, args.contain)
for tag, freq in tags_freq.items():
print(f"- {tag}: {freq}")
# TODO: add renaming
+38 -14
View File
@@ -1,6 +1,6 @@
import yaml import yaml
from pathlib import Path from pathlib import Path
from typing import Type, TypeVar, Any, Optional, Final from typing import Type, TypeVar, Any, Optional, ClassVar
from dataclasses import dataclass, asdict, field from dataclasses import dataclass, asdict, field
ConfigInst = TypeVar('ConfigInst', bound='Config') ConfigInst = TypeVar('ConfigInst', bound='Config')
@@ -9,7 +9,6 @@ OpenAIConfigInst = TypeVar('OpenAIConfigInst', bound='OpenAIConfig')
supported_ais: list[str] = ['openai'] supported_ais: list[str] = ['openai']
default_ai_ID: str = 'default'
default_config_path = '.config.yaml' default_config_path = '.config.yaml'
@@ -17,13 +16,36 @@ class ConfigError(Exception):
pass pass
def str_presenter(dumper: yaml.Dumper, data: str) -> yaml.ScalarNode:
"""
Changes the YAML dump style to multiline syntax for multiline strings.
"""
if len(data.splitlines()) > 1:
return dumper.represent_scalar('tag:yaml.org,2002:str', data, style='|')
return dumper.represent_scalar('tag:yaml.org,2002:str', data)
yaml.add_representer(str, str_presenter)
@dataclass @dataclass
class AIConfig: class AIConfig:
""" """
The base class of all AI configurations. The base class of all AI configurations.
""" """
# the name of the AI the config class represents
# -> it's a class variable and thus not part of the
# dataclass constructor
name: ClassVar[str]
# a user-defined ID for an AI configuration entry
ID: str ID: str
name: str
# the name must not be changed
def __setattr__(self, name: str, value: Any) -> None:
if name == 'name':
raise AttributeError("'{name}' is not allowed to be changed")
else:
super().__setattr__(name, value)
@dataclass @dataclass
@@ -31,19 +53,19 @@ class OpenAIConfig(AIConfig):
""" """
The OpenAI section of the configuration file. The OpenAI section of the configuration file.
""" """
# the name must not be changed name: ClassVar[str] = 'openai'
name: Final[str] = 'openai'
# all members have default values, so we can easily create # all members have default values, so we can easily create
# a default configuration # a default configuration
ID: str = 'default' ID: str = 'myopenai'
api_key: str = '0123456789' api_key: str = '0123456789'
system: str = 'You are an assistant'
model: str = 'gpt-3.5-turbo-16k' model: str = 'gpt-3.5-turbo-16k'
temperature: float = 1.0 temperature: float = 1.0
max_tokens: int = 4000 max_tokens: int = 4000
top_p: float = 1.0 top_p: float = 1.0
frequency_penalty: float = 0.0 frequency_penalty: float = 0.0
presence_penalty: float = 0.0 presence_penalty: float = 0.0
system: str = 'You are an assistant'
@classmethod @classmethod
def from_dict(cls: Type[OpenAIConfigInst], source: dict[str, Any]) -> OpenAIConfigInst: def from_dict(cls: Type[OpenAIConfigInst], source: dict[str, Any]) -> OpenAIConfigInst:
@@ -51,23 +73,20 @@ class OpenAIConfig(AIConfig):
Create OpenAIConfig from a dict. Create OpenAIConfig from a dict.
""" """
res = cls( res = cls(
system=str(source['system']),
api_key=str(source['api_key']), api_key=str(source['api_key']),
model=str(source['model']), model=str(source['model']),
max_tokens=int(source['max_tokens']), max_tokens=int(source['max_tokens']),
temperature=float(source['temperature']), temperature=float(source['temperature']),
top_p=float(source['top_p']), top_p=float(source['top_p']),
frequency_penalty=float(source['frequency_penalty']), frequency_penalty=float(source['frequency_penalty']),
presence_penalty=float(source['presence_penalty']) presence_penalty=float(source['presence_penalty']),
system=str(source['system'])
) )
# overwrite default ID if provided # overwrite default ID if provided
if 'ID' in source: if 'ID' in source:
res.ID = source['ID'] res.ID = source['ID']
return res return res
def as_dict(self) -> dict[str, Any]:
return asdict(self)
def ai_config_instance(name: str, conf_dict: Optional[dict[str, Any]] = None) -> AIConfig: def ai_config_instance(name: str, conf_dict: Optional[dict[str, Any]] = None) -> AIConfig:
""" """
@@ -79,7 +98,7 @@ def ai_config_instance(name: str, conf_dict: Optional[dict[str, Any]] = None) ->
else: else:
return OpenAIConfig.from_dict(conf_dict) return OpenAIConfig.from_dict(conf_dict)
else: else:
raise ConfigError(f"AI '{name}' is not supported") raise ConfigError(f"Unknown AI '{name}'")
def create_default_ai_configs() -> dict[str, AIConfig]: def create_default_ai_configs() -> dict[str, AIConfig]:
@@ -139,4 +158,9 @@ class Config:
yaml.dump(data, f, sort_keys=False) yaml.dump(data, f, sort_keys=False)
def as_dict(self) -> dict[str, Any]: def as_dict(self) -> dict[str, Any]:
return asdict(self) res = asdict(self)
# add the AI name manually (as first element)
# (not done by 'asdict' because it's a class variable)
for ID, conf in res['ais'].items():
res['ais'][ID] = {**{'name': self.ais[ID].name}, **conf}
return res
+26 -116
View File
@@ -6,12 +6,14 @@ import sys
import argcomplete import argcomplete
import argparse import argparse
from pathlib import Path from pathlib import Path
from .configuration import Config, default_config_path
from .chat import ChatDB
from .message import Message, MessageFilter, MessageError, Question
from .ai_factory import create_ai
from .ai import AI, AIResponse
from typing import Any from typing import Any
from .configuration import Config, default_config_path
from .message import Message
from .commands.question import question_cmd
from .commands.tags import tags_cmd
from .commands.config import config_cmd
from .commands.hist import hist_cmd
from .commands.print import print_cmd
def tags_completer(prefix: str, parsed_args: Any, **kwargs: Any) -> list[str]: def tags_completer(prefix: str, parsed_args: Any, **kwargs: Any) -> list[str]:
@@ -19,101 +21,6 @@ def tags_completer(prefix: str, parsed_args: Any, **kwargs: Any) -> list[str]:
return list(Message.tags_from_dir(Path(config.db), prefix=prefix)) return list(Message.tags_from_dir(Path(config.db), prefix=prefix))
def tags_cmd(args: argparse.Namespace, config: Config) -> None:
"""
Handler for the 'tags' command.
"""
chat = ChatDB.from_dir(cache_path=Path('.'),
db_path=Path(config.db))
if args.list:
tags_freq = chat.tags_frequency(args.prefix, args.contain)
for tag, freq in tags_freq.items():
print(f"- {tag}: {freq}")
# TODO: add renaming
def config_cmd(args: argparse.Namespace) -> None:
"""
Handler for the 'config' command.
"""
if args.create:
Config.create_default(Path(args.create))
def question_cmd(args: argparse.Namespace, config: Config) -> None:
"""
Handler for the 'question' command.
"""
chat = ChatDB.from_dir(cache_path=Path('.'),
db_path=Path(config.db))
# if it's a new question, create and store it immediately
if args.ask or args.create:
# FIXME: add sources to the question
message = Message(question=Question(args.question),
tags=args.ouput_tags, # FIXME
ai=args.ai,
model=args.model)
chat.add_to_cache([message])
if args.create:
return
# create the correct AI instance
ai: AI = create_ai(args, config)
if args.ask:
response: AIResponse = ai.request(message,
chat,
args.num_answers, # FIXME
args.otags) # FIXME
assert response
# TODO:
# * add answer to the message above (and create
# more messages for any additional answers)
pass
elif args.repeat:
lmessage = chat.latest_message()
assert lmessage
# TODO: repeat either the last question or the
# one(s) given in 'args.repeat' (overwrite
# existing ones if 'args.overwrite' is True)
pass
elif args.process:
# TODO: process either all questions without an
# answer or the one(s) given in 'args.process'
pass
def hist_cmd(args: argparse.Namespace, config: Config) -> None:
"""
Handler for the 'hist' command.
"""
mfilter = MessageFilter(tags_or=args.or_tags,
tags_and=args.and_tags,
tags_not=args.exclude_tags,
question_contains=args.question,
answer_contains=args.answer)
chat = ChatDB.from_dir(Path('.'),
Path(config.db),
mfilter=mfilter)
chat.print(args.source_code_only,
args.with_tags,
args.with_files)
def print_cmd(args: argparse.Namespace, config: Config) -> None:
"""
Handler for the 'print' command.
"""
fname = Path(args.file)
try:
message = Message.from_file(fname)
if message:
print(message.to_str(source_code_only=args.source_code_only))
except MessageError:
print(f"File is not a valid message: {args.file}")
sys.exit(1)
def create_parser() -> argparse.ArgumentParser: def create_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
description="ChatMastermind is a Python application that automates conversation with AI") description="ChatMastermind is a Python application that automates conversation with AI")
@@ -128,20 +35,28 @@ def create_parser() -> argparse.ArgumentParser:
# a parent parser for all commands that support tag selection # a parent parser for all commands that support tag selection
tag_parser = argparse.ArgumentParser(add_help=False) tag_parser = argparse.ArgumentParser(add_help=False)
tag_arg = tag_parser.add_argument('-t', '--or-tags', nargs='+', tag_arg = tag_parser.add_argument('-t', '--or-tags', nargs='+',
help='List of tag names (one must match)', metavar='OTAGS') help='List of tags (one must match)', metavar='OTAGS')
tag_arg.completer = tags_completer # type: ignore tag_arg.completer = tags_completer # type: ignore
atag_arg = tag_parser.add_argument('-k', '--and-tags', nargs='+', atag_arg = tag_parser.add_argument('-k', '--and-tags', nargs='+',
help='List of tag names (all must match)', metavar='ATAGS') help='List of tags (all must match)', metavar='ATAGS')
atag_arg.completer = tags_completer # type: ignore atag_arg.completer = tags_completer # type: ignore
etag_arg = tag_parser.add_argument('-x', '--exclude-tags', nargs='+', etag_arg = tag_parser.add_argument('-x', '--exclude-tags', nargs='+',
help='List of tag names to exclude', metavar='XTAGS') help='List of tags to exclude', metavar='XTAGS')
etag_arg.completer = tags_completer # type: ignore etag_arg.completer = tags_completer # type: ignore
otag_arg = tag_parser.add_argument('-o', '--output-tags', nargs='+', otag_arg = tag_parser.add_argument('-o', '--output-tags', nargs='+',
help='List of output tag names, default is input', metavar='OUTTAGS') help='List of output tags (default: use input tags)', metavar='OUTTAGS')
otag_arg.completer = tags_completer # type: ignore otag_arg.completer = tags_completer # type: ignore
# a parent parser for all commands that support AI configuration
ai_parser = argparse.ArgumentParser(add_help=False)
ai_parser.add_argument('-A', '--AI', help='AI ID to use')
ai_parser.add_argument('-M', '--model', help='Model to use')
ai_parser.add_argument('-n', '--num-answers', help='Number of answers to request', type=int, default=1)
ai_parser.add_argument('-m', '--max-tokens', help='Max. nr. of tokens', type=int)
ai_parser.add_argument('-T', '--temperature', help='Temperature value', type=float)
# 'question' command parser # 'question' command parser
question_cmd_parser = cmdparser.add_parser('question', parents=[tag_parser], question_cmd_parser = cmdparser.add_parser('question', parents=[tag_parser, ai_parser],
help="ask, create and process questions.", help="ask, create and process questions.",
aliases=['q']) aliases=['q'])
question_cmd_parser.set_defaults(func=question_cmd) question_cmd_parser.set_defaults(func=question_cmd)
@@ -152,15 +67,8 @@ def create_parser() -> argparse.ArgumentParser:
question_group.add_argument('-p', '--process', nargs='*', help='Process existing questions') question_group.add_argument('-p', '--process', nargs='*', help='Process existing questions')
question_cmd_parser.add_argument('-O', '--overwrite', help='Overwrite existing messages when repeating them', question_cmd_parser.add_argument('-O', '--overwrite', help='Overwrite existing messages when repeating them',
action='store_true') action='store_true')
question_cmd_parser.add_argument('-m', '--max-tokens', help='Max tokens to use', type=int) question_cmd_parser.add_argument('-s', '--source-text', nargs='+', help='Add content of a file to the query')
question_cmd_parser.add_argument('-T', '--temperature', help='Temperature to use', type=float) question_cmd_parser.add_argument('-S', '--source-code', nargs='+', help='Add source code file content to the chat history')
question_cmd_parser.add_argument('-A', '--AI', help='AI to use')
question_cmd_parser.add_argument('-M', '--model', help='Model to use')
question_cmd_parser.add_argument('-n', '--num-answers', help='Number of answers to produce', type=int,
default=1)
question_cmd_parser.add_argument('-s', '--source', nargs='+', help='Source add content of a file to the query')
question_cmd_parser.add_argument('-S', '--source-code-only', help='Add pure source code to the chat history',
action='store_true')
# 'hist' command parser # 'hist' command parser
hist_cmd_parser = cmdparser.add_parser('hist', parents=[tag_parser], hist_cmd_parser = cmdparser.add_parser('hist', parents=[tag_parser],
@@ -205,8 +113,10 @@ def create_parser() -> argparse.ArgumentParser:
aliases=['p']) aliases=['p'])
print_cmd_parser.set_defaults(func=print_cmd) print_cmd_parser.set_defaults(func=print_cmd)
print_cmd_parser.add_argument('-f', '--file', help='File to print', required=True) print_cmd_parser.add_argument('-f', '--file', help='File to print', required=True)
print_cmd_parser.add_argument('-S', '--source-code-only', help='Print source code only (from the answer, if available)', print_cmd_modes = print_cmd_parser.add_mutually_exclusive_group()
action='store_true') print_cmd_modes.add_argument('-q', '--question', help='Print only question', action='store_true')
print_cmd_modes.add_argument('-a', '--answer', help='Print only answer', action='store_true')
print_cmd_modes.add_argument('-S', '--only-source-code', help='Print only source code', action='store_true')
argcomplete.autocomplete(parser) argcomplete.autocomplete(parser)
return parser return parser
+17 -11
View File
@@ -3,6 +3,8 @@ Module implementing message related functions and classes.
""" """
import pathlib import pathlib
import yaml import yaml
import tempfile
import shutil
from typing import Type, TypeVar, ClassVar, Optional, Any, Union, Final, Literal, Iterable from typing import Type, TypeVar, ClassVar, Optional, Any, Union, Final, Literal, Iterable
from dataclasses import dataclass, asdict, field from dataclasses import dataclass, asdict, field
from .tags import Tag, TagLine, TagError, match_tags, rename_tags from .tags import Tag, TagLine, TagError, match_tags, rename_tags
@@ -312,7 +314,7 @@ class Message():
mfilter.tags_not if mfilter else None) mfilter.tags_not if mfilter else None)
else: else:
message = cls.__from_file_yaml(file_path) message = cls.__from_file_yaml(file_path)
if message and (not mfilter or (mfilter and message.match(mfilter))): if message and (mfilter is None or message.match(mfilter)):
return message return message
else: else:
return None return None
@@ -414,7 +416,7 @@ class Message():
return '\n'.join(output) return '\n'.join(output)
def __str__(self) -> str: def __str__(self) -> str:
return self.to_str(False, False, False) return self.to_str(True, True, False)
def to_file(self, file_path: Optional[pathlib.Path]=None) -> None: # noqa: 11 def to_file(self, file_path: Optional[pathlib.Path]=None) -> None: # noqa: 11
""" """
@@ -445,16 +447,18 @@ class Message():
* Answer.txt_header * Answer.txt_header
* Answer * Answer
""" """
with open(file_path, "w") as fd: with tempfile.NamedTemporaryFile(dir=file_path.parent, prefix=file_path.name, mode="w", delete=False) as temp_fd:
temp_file_path = pathlib.Path(temp_fd.name)
if self.tags: if self.tags:
fd.write(f'{TagLine.from_set(self.tags)}\n') temp_fd.write(f'{TagLine.from_set(self.tags)}\n')
if self.ai: if self.ai:
fd.write(f'{AILine.from_ai(self.ai)}\n') temp_fd.write(f'{AILine.from_ai(self.ai)}\n')
if self.model: if self.model:
fd.write(f'{ModelLine.from_model(self.model)}\n') temp_fd.write(f'{ModelLine.from_model(self.model)}\n')
fd.write(f'{Question.txt_header}\n{self.question}\n') temp_fd.write(f'{Question.txt_header}\n{self.question}\n')
if self.answer: if self.answer:
fd.write(f'{Answer.txt_header}\n{self.answer}\n') temp_fd.write(f'{Answer.txt_header}\n{self.answer}\n')
shutil.move(temp_file_path, file_path)
def __to_file_yaml(self, file_path: pathlib.Path) -> None: def __to_file_yaml(self, file_path: pathlib.Path) -> None:
""" """
@@ -466,7 +470,8 @@ class Message():
* Message.ai_yaml_key: str [Optional] * Message.ai_yaml_key: str [Optional]
* Message.model_yaml_key: str [Optional] * Message.model_yaml_key: str [Optional]
""" """
with open(file_path, "w") as fd: with tempfile.NamedTemporaryFile(dir=file_path.parent, prefix=file_path.name, mode="w", delete=False) as temp_fd:
temp_file_path = pathlib.Path(temp_fd.name)
data: YamlDict = {Question.yaml_key: str(self.question)} data: YamlDict = {Question.yaml_key: str(self.question)}
if self.answer: if self.answer:
data[Answer.yaml_key] = str(self.answer) data[Answer.yaml_key] = str(self.answer)
@@ -476,7 +481,8 @@ class Message():
data[self.model_yaml_key] = self.model data[self.model_yaml_key] = self.model
if self.tags: if self.tags:
data[self.tags_yaml_key] = sorted([str(tag) for tag in self.tags]) data[self.tags_yaml_key] = sorted([str(tag) for tag in self.tags])
yaml.dump(data, fd, sort_keys=False) yaml.dump(data, temp_fd, sort_keys=False)
shutil.move(temp_file_path, file_path)
def filter_tags(self, prefix: Optional[str] = None, contain: Optional[str] = None) -> set[Tag]: def filter_tags(self, prefix: Optional[str] = None, contain: Optional[str] = None) -> set[Tag]:
""" """
@@ -508,7 +514,7 @@ class Message():
Return True if all attributes match, else False. Return True if all attributes match, else False.
""" """
mytags = self.tags or set() mytags = self.tags or set()
if (((mfilter.tags_or or mfilter.tags_and or mfilter.tags_not) if (((mfilter.tags_or is not None or mfilter.tags_and is not None or mfilter.tags_not is not None)
and not match_tags(mytags, mfilter.tags_or, mfilter.tags_and, mfilter.tags_not)) # noqa: W503 and not match_tags(mytags, mfilter.tags_or, mfilter.tags_and, mfilter.tags_not)) # noqa: W503
or (mfilter.ai and (not self.ai or mfilter.ai != self.ai)) # noqa: W503 or (mfilter.ai and (not self.ai or mfilter.ai != self.ai)) # noqa: W503
or (mfilter.model and (not self.model or mfilter.model != self.model)) # noqa: W503 or (mfilter.model and (not self.model or mfilter.model != self.model)) # noqa: W503
-121
View File
@@ -1,121 +0,0 @@
import yaml
import io
import pathlib
from .utils import terminal_width, append_message, message_to_chat, ChatType
from .configuration import Config
from typing import Any, Optional
def read_file(fname: pathlib.Path, tags_only: bool = False) -> dict[str, Any]:
with open(fname, "r") as fd:
tagline = fd.readline().strip().split(':', maxsplit=1)[1].strip()
# also support tags separated by ',' (old format)
separator = ',' if ',' in tagline else ' '
tags = [t.strip() for t in tagline.split(separator)]
if tags_only:
return {"tags": tags}
text = fd.read().strip().split('\n')
question_idx = text.index("=== QUESTION ===") + 1
answer_idx = text.index("==== ANSWER ====")
question = "\n".join(text[question_idx:answer_idx]).strip()
answer = "\n".join(text[answer_idx + 1:]).strip()
return {"question": question, "answer": answer, "tags": tags,
"file": fname.name}
def dump_data(data: dict[str, Any]) -> str:
with io.StringIO() as fd:
fd.write(f'TAGS: {" ".join(data["tags"])}\n')
fd.write(f'=== QUESTION ===\n{data["question"]}\n')
fd.write(f'==== ANSWER ====\n{data["answer"]}\n')
return fd.getvalue()
def write_file(fname: str, data: dict[str, Any]) -> None:
with open(fname, "w") as fd:
fd.write(f'TAGS: {" ".join(data["tags"])}\n')
fd.write(f'=== QUESTION ===\n{data["question"]}\n')
fd.write(f'==== ANSWER ====\n{data["answer"]}\n')
def save_answers(question: str,
answers: list[str],
tags: list[str],
otags: Optional[list[str]],
config: Config
) -> None:
wtags = otags or tags
num, inum = 0, 0
next_fname = pathlib.Path(str(config.db)) / '.next'
try:
with open(next_fname, 'r') as f:
num = int(f.read())
except Exception:
pass
for answer in answers:
num += 1
inum += 1
title = f'-- ANSWER {inum} '
title_end = '-' * (terminal_width() - len(title))
print(f'{title}{title_end}')
print(answer)
write_file(f"{num:04d}.txt", {"question": question, "answer": answer, "tags": wtags})
with open(next_fname, 'w') as f:
f.write(f'{num}')
def create_chat_hist(question: Optional[str],
tags: Optional[list[str]],
extags: Optional[list[str]],
config: Config,
match_all_tags: bool = False,
with_tags: bool = False,
with_file: bool = False
) -> ChatType:
chat: ChatType = []
append_message(chat, 'system', str(config.system).strip())
for file in sorted(pathlib.Path(str(config.db)).iterdir()):
if file.suffix == '.yaml':
with open(file, 'r') as f:
data = yaml.load(f, Loader=yaml.FullLoader)
data['file'] = file.name
elif file.suffix == '.txt':
data = read_file(file)
else:
continue
data_tags = set(data.get('tags', []))
tags_match: bool
if match_all_tags:
tags_match = not tags or set(tags).issubset(data_tags)
else:
tags_match = not tags or bool(data_tags.intersection(tags))
extags_do_not_match = \
not extags or not data_tags.intersection(extags)
if tags_match and extags_do_not_match:
message_to_chat(data, chat, with_tags, with_file)
if question:
append_message(chat, 'user', question)
return chat
def get_tags(config: Config, prefix: Optional[str]) -> list[str]:
result = []
for file in sorted(pathlib.Path(str(config.db)).iterdir()):
if file.suffix == '.yaml':
with open(file, 'r') as f:
data = yaml.load(f, Loader=yaml.FullLoader)
elif file.suffix == '.txt':
data = read_file(file, tags_only=True)
else:
continue
for tag in data.get('tags', []):
if prefix and len(prefix) > 0:
if tag.startswith(prefix):
result.append(tag)
else:
result.append(tag)
return result
def get_tags_unique(config: Config, prefix: Optional[str]) -> list[str]:
return list(set(get_tags(config, prefix)))
-80
View File
@@ -1,80 +0,0 @@
import shutil
from pprint import PrettyPrinter
from typing import Any
ChatType = list[dict[str, str]]
def terminal_width() -> int:
return shutil.get_terminal_size().columns
def pp(*args: Any, **kwargs: Any) -> None:
return PrettyPrinter(width=terminal_width()).pprint(*args, **kwargs)
def print_tag_args(tags: list[str], extags: list[str], otags: list[str]) -> None:
"""
Prints the tags specified in the given args.
"""
printed_messages = []
if tags:
printed_messages.append(f"Tags: {' '.join(tags)}")
if extags:
printed_messages.append(f"Excluding tags: {' '.join(extags)}")
if otags:
printed_messages.append(f"Output tags: {' '.join(otags)}")
if printed_messages:
print("\n".join(printed_messages))
print()
def append_message(chat: ChatType,
role: str,
content: str
) -> None:
chat.append({'role': role, 'content': content.replace("''", "'")})
def message_to_chat(message: dict[str, str],
chat: ChatType,
with_tags: bool = False,
with_file: bool = False
) -> None:
append_message(chat, 'user', message['question'])
append_message(chat, 'assistant', message['answer'])
if with_tags:
tags = " ".join(message['tags'])
append_message(chat, 'tags', tags)
if with_file:
append_message(chat, 'file', message['file'])
def display_source_code(content: str) -> None:
try:
content_start = content.index('```')
content_end = content.rindex('```')
if content_start + 3 < content_end:
print(content[content_start + 3:content_end].strip())
except ValueError:
pass
def print_chat_hist(chat: ChatType, dump: bool = False, source_code: bool = False) -> None:
if dump:
pp(chat)
return
for message in chat:
text_too_long = len(message['content']) > terminal_width() - len(message['role']) - 2
if source_code:
display_source_code(message['content'])
continue
if message['role'] == 'user':
print('-' * terminal_width())
if text_too_long:
print(f"{message['role'].upper()}:")
print(message['content'])
else:
print(f"{message['role'].upper()}: {message['content']}")
+1 -1
View File
@@ -12,7 +12,7 @@ setup(
long_description=long_description, long_description=long_description,
long_description_content_type="text/markdown", long_description_content_type="text/markdown",
url="https://github.com/ok2/ChatMastermind", url="https://github.com/ok2/ChatMastermind",
packages=find_packages() + ["chatmastermind.ais"], packages=find_packages() + ["chatmastermind.ais", "chatmastermind.commands"],
classifiers=[ classifiers=[
"Development Status :: 3 - Alpha", "Development Status :: 3 - Alpha",
"Environment :: Console", "Environment :: Console",
+48
View File
@@ -0,0 +1,48 @@
import argparse
import unittest
from unittest.mock import MagicMock
from chatmastermind.ai_factory import create_ai
from chatmastermind.configuration import Config
from chatmastermind.ai import AIError
from chatmastermind.ais.openai import OpenAI
class TestCreateAI(unittest.TestCase):
def setUp(self) -> None:
self.args = MagicMock(spec=argparse.Namespace)
self.args.AI = 'myopenai'
self.args.model = None
self.args.max_tokens = None
self.args.temperature = None
def test_create_ai_from_args(self) -> None:
# Create an AI with the default configuration
config = Config()
self.args.AI = 'myopenai'
ai = create_ai(self.args, config)
self.assertIsInstance(ai, OpenAI)
def test_create_ai_from_default(self) -> None:
self.args.AI = None
# Create an AI with the default configuration
config = Config()
ai = create_ai(self.args, config)
self.assertIsInstance(ai, OpenAI)
def test_create_empty_ai_error(self) -> None:
self.args.AI = None
# Create Config with empty AIs
config = Config()
config.ais = {}
# Call create_ai function and assert that it raises AIError
with self.assertRaises(AIError):
create_ai(self.args, config)
def test_create_unsupported_ai_error(self) -> None:
# Mock argparse.Namespace with ai='invalid_ai'
self.args.AI = 'invalid_ai'
# Create default Config
config = Config()
# Call create_ai function and assert that it raises AIError
with self.assertRaises(AIError):
create_ai(self.args, config)
+48 -2
View File
@@ -202,7 +202,25 @@ class TestChatDB(unittest.TestCase):
self.assertEqual(chat_db.messages[1].file_path, self.assertEqual(chat_db.messages[1].file_path,
pathlib.Path(self.db_path.name, '0003.txt')) pathlib.Path(self.db_path.name, '0003.txt'))
def test_chat_db_filter(self) -> None: def test_chat_db_from_dir_filter_tags(self) -> None:
chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name),
pathlib.Path(self.db_path.name),
mfilter=MessageFilter(tags_or={Tag('tag1')}))
self.assertEqual(len(chat_db.messages), 1)
self.assertEqual(chat_db.cache_path, pathlib.Path(self.cache_path.name))
self.assertEqual(chat_db.db_path, pathlib.Path(self.db_path.name))
self.assertEqual(chat_db.messages[0].file_path,
pathlib.Path(self.db_path.name, '0001.txt'))
def test_chat_db_from_dir_filter_tags_empty(self) -> None:
chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name),
pathlib.Path(self.db_path.name),
mfilter=MessageFilter(tags_or=set(),
tags_and=set(),
tags_not=set()))
self.assertEqual(len(chat_db.messages), 0)
def test_chat_db_from_dir_filter_answer(self) -> None:
chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name), chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name),
pathlib.Path(self.db_path.name), pathlib.Path(self.db_path.name),
mfilter=MessageFilter(answer_contains='Answer 2')) mfilter=MessageFilter(answer_contains='Answer 2'))
@@ -213,7 +231,7 @@ class TestChatDB(unittest.TestCase):
pathlib.Path(self.db_path.name, '0002.yaml')) pathlib.Path(self.db_path.name, '0002.yaml'))
self.assertEqual(chat_db.messages[0].answer, 'Answer 2') self.assertEqual(chat_db.messages[0].answer, 'Answer 2')
def test_chat_db_from_messges(self) -> None: def test_chat_db_from_messages(self) -> None:
chat_db = ChatDB.from_messages(pathlib.Path(self.cache_path.name), chat_db = ChatDB.from_messages(pathlib.Path(self.cache_path.name),
pathlib.Path(self.db_path.name), pathlib.Path(self.db_path.name),
messages=[self.message1, self.message2, messages=[self.message1, self.message2,
@@ -440,3 +458,31 @@ class TestChatDB(unittest.TestCase):
cache_dir_files = self.message_list(self.cache_path) cache_dir_files = self.message_list(self.cache_path)
self.assertEqual(len(cache_dir_files), 1) self.assertEqual(len(cache_dir_files), 1)
self.assertIn(pathlib.Path(self.cache_path.name, '123456.txt'), cache_dir_files) self.assertIn(pathlib.Path(self.cache_path.name, '123456.txt'), cache_dir_files)
def test_chat_db_update_messages(self) -> None:
# create a new ChatDB instance
chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name),
pathlib.Path(self.db_path.name))
db_dir_files = self.message_list(self.db_path)
self.assertEqual(len(db_dir_files), 4)
cache_dir_files = self.message_list(self.cache_path)
self.assertEqual(len(cache_dir_files), 0)
message = chat_db.messages[0]
message.answer = Answer("New answer")
# update message without writing
chat_db.update_messages([message], write=False)
self.assertEqual(chat_db.messages[0].answer, Answer("New answer"))
# re-read the message and check for old content
chat_db.read_db()
self.assertEqual(chat_db.messages[0].answer, Answer("Answer 1"))
# now check with writing (message should be overwritten)
chat_db.update_messages([message], write=True)
chat_db.read_db()
self.assertEqual(chat_db.messages[0].answer, Answer("New answer"))
# test without file_path -> expect error
message1 = Message(question=Question("Question 1"),
answer=Answer("Answer 1"))
with self.assertRaises(ChatError):
chat_db.update_messages([message1])
+77 -5
View File
@@ -4,7 +4,7 @@ import yaml
from tempfile import NamedTemporaryFile from tempfile import NamedTemporaryFile
from pathlib import Path from pathlib import Path
from typing import cast from typing import cast
from chatmastermind.configuration import OpenAIConfig, ConfigError, ai_config_instance, Config from chatmastermind.configuration import AIConfig, OpenAIConfig, ConfigError, ai_config_instance, Config
class TestAIConfigInstance(unittest.TestCase): class TestAIConfigInstance(unittest.TestCase):
@@ -59,7 +59,7 @@ class TestConfig(unittest.TestCase):
source_dict = { source_dict = {
'db': './test_db/', 'db': './test_db/',
'ais': { 'ais': {
'default': { 'myopenai': {
'name': 'openai', 'name': 'openai',
'system': 'Custom system', 'system': 'Custom system',
'api_key': '9876543210', 'api_key': '9876543210',
@@ -75,10 +75,10 @@ class TestConfig(unittest.TestCase):
config = Config.from_dict(source_dict) config = Config.from_dict(source_dict)
self.assertEqual(config.db, './test_db/') self.assertEqual(config.db, './test_db/')
self.assertEqual(len(config.ais), 1) self.assertEqual(len(config.ais), 1)
self.assertEqual(config.ais['default'].name, 'openai') self.assertEqual(config.ais['myopenai'].name, 'openai')
self.assertEqual(cast(OpenAIConfig, config.ais['default']).system, 'Custom system') self.assertEqual(cast(OpenAIConfig, config.ais['myopenai']).system, 'Custom system')
# check that 'ID' has been added # check that 'ID' has been added
self.assertEqual(config.ais['default'].ID, 'default') self.assertEqual(config.ais['myopenai'].ID, 'myopenai')
def test_create_default_should_create_default_config(self) -> None: def test_create_default_should_create_default_config(self) -> None:
Config.create_default(Path(self.test_file.name)) Config.create_default(Path(self.test_file.name))
@@ -86,3 +86,75 @@ class TestConfig(unittest.TestCase):
default_config = yaml.load(f, Loader=yaml.FullLoader) default_config = yaml.load(f, Loader=yaml.FullLoader)
config_reference = Config() config_reference = Config()
self.assertEqual(default_config['db'], config_reference.db) self.assertEqual(default_config['db'], config_reference.db)
def test_from_file_should_load_config_from_file(self) -> None:
source_dict = {
'db': './test_db/',
'ais': {
'default': {
'name': 'openai',
'system': 'Custom system',
'api_key': '9876543210',
'model': 'custom_model',
'max_tokens': 5000,
'temperature': 0.5,
'top_p': 0.8,
'frequency_penalty': 0.7,
'presence_penalty': 0.2
}
}
}
with open(self.test_file.name, 'w') as f:
yaml.dump(source_dict, f)
config = Config.from_file(self.test_file.name)
self.assertIsInstance(config, Config)
self.assertEqual(config.db, './test_db/')
self.assertEqual(len(config.ais), 1)
self.assertIsInstance(config.ais['default'], AIConfig)
self.assertEqual(cast(OpenAIConfig, config.ais['default']).system, 'Custom system')
def test_to_file_should_save_config_to_file(self) -> None:
config = Config(
db='./test_db/',
ais={
'myopenai': OpenAIConfig(
ID='myopenai',
system='Custom system',
api_key='9876543210',
model='custom_model',
max_tokens=5000,
temperature=0.5,
top_p=0.8,
frequency_penalty=0.7,
presence_penalty=0.2
)
}
)
config.to_file(Path(self.test_file.name))
with open(self.test_file.name, 'r') as f:
saved_config = yaml.load(f, Loader=yaml.FullLoader)
self.assertEqual(saved_config['db'], './test_db/')
self.assertEqual(len(saved_config['ais']), 1)
self.assertEqual(saved_config['ais']['myopenai']['system'], 'Custom system')
def test_from_file_error_unknown_ai(self) -> None:
source_dict = {
'db': './test_db/',
'ais': {
'default': {
'name': 'foobla',
'system': 'Custom system',
'api_key': '9876543210',
'model': 'custom_model',
'max_tokens': 5000,
'temperature': 0.5,
'top_p': 0.8,
'frequency_penalty': 0.7,
'presence_penalty': 0.2
}
}
}
with open(self.test_file.name, 'w') as f:
yaml.dump(source_dict, f)
with self.assertRaises(ConfigError):
Config.from_file(self.test_file.name)
-236
View File
@@ -1,236 +0,0 @@
# import unittest
# import io
# import pathlib
# import argparse
# from chatmastermind.utils import terminal_width
# from chatmastermind.main import create_parser, ask_cmd
# from chatmastermind.api_client import ai
# from chatmastermind.configuration import Config
# from chatmastermind.storage import create_chat_hist, save_answers, dump_data
# from unittest import mock
# from unittest.mock import patch, MagicMock, Mock, ANY
# class CmmTestCase(unittest.TestCase):
# """
# Base class for all cmm testcases.
# """
# def dummy_config(self, db: str) -> Config:
# """
# Creates a dummy configuration.
# """
# return Config.from_dict(
# {'system': 'dummy_system',
# 'db': db,
# 'openai': {'api_key': 'dummy_key',
# 'model': 'dummy_model',
# 'max_tokens': 4000,
# 'temperature': 1.0,
# 'top_p': 1,
# 'frequency_penalty': 0,
# 'presence_penalty': 0}}
# )
#
#
# class TestCreateChat(CmmTestCase):
#
# def setUp(self) -> None:
# self.config = self.dummy_config(db='test_files')
# self.question = "test question"
# self.tags = ['test_tag']
#
# @patch('os.listdir')
# @patch('pathlib.Path.iterdir')
# @patch('builtins.open')
# def test_create_chat_with_tags(self, open_mock: MagicMock, iterdir_mock: MagicMock, listdir_mock: MagicMock) -> None:
# listdir_mock.return_value = ['testfile.txt']
# iterdir_mock.return_value = [pathlib.Path(x) for x in listdir_mock.return_value]
# open_mock.return_value.__enter__.return_value = io.StringIO(dump_data(
# {'question': 'test_content', 'answer': 'some answer',
# 'tags': ['test_tag']}))
#
# test_chat = create_chat_hist(self.question, self.tags, None, self.config)
#
# self.assertEqual(len(test_chat), 4)
# self.assertEqual(test_chat[0],
# {'role': 'system', 'content': self.config.system})
# self.assertEqual(test_chat[1],
# {'role': 'user', 'content': 'test_content'})
# self.assertEqual(test_chat[2],
# {'role': 'assistant', 'content': 'some answer'})
# self.assertEqual(test_chat[3],
# {'role': 'user', 'content': self.question})
#
# @patch('os.listdir')
# @patch('pathlib.Path.iterdir')
# @patch('builtins.open')
# def test_create_chat_with_other_tags(self, open_mock: MagicMock, iterdir_mock: MagicMock, listdir_mock: MagicMock) -> None:
# listdir_mock.return_value = ['testfile.txt']
# iterdir_mock.return_value = [pathlib.Path(x) for x in listdir_mock.return_value]
# open_mock.return_value.__enter__.return_value = io.StringIO(dump_data(
# {'question': 'test_content', 'answer': 'some answer',
# 'tags': ['other_tag']}))
#
# test_chat = create_chat_hist(self.question, self.tags, None, self.config)
#
# self.assertEqual(len(test_chat), 2)
# self.assertEqual(test_chat[0],
# {'role': 'system', 'content': self.config.system})
# self.assertEqual(test_chat[1],
# {'role': 'user', 'content': self.question})
#
# @patch('os.listdir')
# @patch('pathlib.Path.iterdir')
# @patch('builtins.open')
# def test_create_chat_without_tags(self, open_mock: MagicMock, iterdir_mock: MagicMock, listdir_mock: MagicMock) -> None:
# listdir_mock.return_value = ['testfile.txt', 'testfile2.txt']
# iterdir_mock.return_value = [pathlib.Path(x) for x in listdir_mock.return_value]
# open_mock.side_effect = (
# io.StringIO(dump_data({'question': 'test_content',
# 'answer': 'some answer',
# 'tags': ['test_tag']})),
# io.StringIO(dump_data({'question': 'test_content2',
# 'answer': 'some answer2',
# 'tags': ['test_tag2']})),
# )
#
# test_chat = create_chat_hist(self.question, [], None, self.config)
#
# self.assertEqual(len(test_chat), 6)
# self.assertEqual(test_chat[0],
# {'role': 'system', 'content': self.config.system})
# self.assertEqual(test_chat[1],
# {'role': 'user', 'content': 'test_content'})
# self.assertEqual(test_chat[2],
# {'role': 'assistant', 'content': 'some answer'})
# self.assertEqual(test_chat[3],
# {'role': 'user', 'content': 'test_content2'})
# self.assertEqual(test_chat[4],
# {'role': 'assistant', 'content': 'some answer2'})
#
#
# class TestHandleQuestion(CmmTestCase):
#
# def setUp(self) -> None:
# self.question = "test question"
# self.args = argparse.Namespace(
# or_tags=['tag1'],
# and_tags=None,
# exclude_tags=['xtag1'],
# output_tags=None,
# question=[self.question],
# source=None,
# source_code_only=False,
# num_answers=3,
# max_tokens=None,
# temperature=None,
# model=None,
# match_all_tags=False,
# with_tags=False,
# with_file=False,
# )
# self.config = self.dummy_config(db='test_files')
#
# @patch("chatmastermind.main.create_chat_hist", return_value="test_chat")
# @patch("chatmastermind.main.print_tag_args")
# @patch("chatmastermind.main.print_chat_hist")
# @patch("chatmastermind.main.ai", return_value=(["answer1", "answer2", "answer3"], "test_usage"))
# @patch("chatmastermind.utils.pp")
# @patch("builtins.print")
# def test_ask_cmd(self, mock_print: MagicMock, mock_pp: MagicMock, mock_ai: MagicMock,
# mock_print_chat_hist: MagicMock, mock_print_tag_args: MagicMock,
# mock_create_chat_hist: MagicMock) -> None:
# open_mock = MagicMock()
# with patch("chatmastermind.storage.open", open_mock):
# ask_cmd(self.args, self.config)
# mock_print_tag_args.assert_called_once_with(self.args.or_tags,
# self.args.exclude_tags,
# [])
# mock_create_chat_hist.assert_called_once_with(self.question,
# self.args.or_tags,
# self.args.exclude_tags,
# self.config,
# match_all_tags=False,
# with_tags=False,
# with_file=False)
# mock_print_chat_hist.assert_called_once_with('test_chat',
# False,
# self.args.source_code_only)
# mock_ai.assert_called_with("test_chat",
# self.config,
# self.args.num_answers)
# expected_calls = []
# for num, answer in enumerate(mock_ai.return_value[0], start=1):
# title = f'-- ANSWER {num} '
# title_end = '-' * (terminal_width() - len(title))
# expected_calls.append(((f'{title}{title_end}',),))
# expected_calls.append(((answer,),))
# expected_calls.append((("-" * terminal_width(),),))
# expected_calls.append(((f"Usage: {mock_ai.return_value[1]}",),))
# self.assertEqual(mock_print.call_args_list, expected_calls)
# open_expected_calls = list([mock.call(f"{num:04d}.txt", "w") for num in range(2, 5)])
# open_mock.assert_has_calls(open_expected_calls, any_order=True)
#
#
# class TestSaveAnswers(CmmTestCase):
# @mock.patch('builtins.open')
# @mock.patch('chatmastermind.storage.print')
# def test_save_answers(self, print_mock: MagicMock, open_mock: MagicMock) -> None:
# question = "Test question?"
# answers = ["Answer 1", "Answer 2"]
# tags = ["tag1", "tag2"]
# otags = ["otag1", "otag2"]
# config = self.dummy_config(db='test_db')
#
# with mock.patch('chatmastermind.storage.pathlib.Path.exists', return_value=True), \
# mock.patch('chatmastermind.storage.yaml.dump'), \
# mock.patch('io.StringIO') as stringio_mock:
# stringio_instance = stringio_mock.return_value
# stringio_instance.getvalue.side_effect = ["question", "answer1", "answer2"]
# save_answers(question, answers, tags, otags, config)
#
# open_calls = [
# mock.call(pathlib.Path('test_db/.next'), 'r'),
# mock.call(pathlib.Path('test_db/.next'), 'w'),
# ]
# open_mock.assert_has_calls(open_calls, any_order=True)
#
#
# class TestAI(CmmTestCase):
#
# @patch("openai.ChatCompletion.create")
# def test_ai(self, mock_create: MagicMock) -> None:
# mock_create.return_value = {
# 'choices': [
# {'message': {'content': 'response_text_1'}},
# {'message': {'content': 'response_text_2'}}
# ],
# 'usage': {'tokens': 10}
# }
#
# chat = [{"role": "system", "content": "hello ai"}]
# config = self.dummy_config(db='dummy')
# config.openai.model = "text-davinci-002"
# config.openai.max_tokens = 150
# config.openai.temperature = 0.5
#
# result = ai(chat, config, 2)
# expected_result = (['response_text_1', 'response_text_2'],
# {'tokens': 10})
# self.assertEqual(result, expected_result)
#
#
# class TestCreateParser(CmmTestCase):
# def test_create_parser(self) -> None:
# with patch('argparse.ArgumentParser.add_subparsers') as mock_add_subparsers:
# mock_cmdparser = Mock()
# mock_add_subparsers.return_value = mock_cmdparser
# parser = create_parser()
# self.assertIsInstance(parser, argparse.ArgumentParser)
# mock_add_subparsers.assert_called_once_with(dest='command', title='commands', description='supported commands', required=True)
# mock_cmdparser.add_parser.assert_any_call('ask', parents=ANY, help=ANY, aliases=ANY)
# mock_cmdparser.add_parser.assert_any_call('hist', parents=ANY, help=ANY, aliases=ANY)
# mock_cmdparser.add_parser.assert_any_call('tags', help=ANY, aliases=ANY)
# mock_cmdparser.add_parser.assert_any_call('config', help=ANY, aliases=ANY)
# mock_cmdparser.add_parser.assert_any_call('print', help=ANY, aliases=ANY)
# self.assertTrue('.config.yaml' in parser.get_default('config'))
+6
View File
@@ -300,6 +300,12 @@ This is a question.
MessageFilter(tags_or={Tag('tag1')})) MessageFilter(tags_or={Tag('tag1')}))
self.assertIsNone(message) self.assertIsNone(message)
def test_from_file_txt_empty_tags_dont_match(self) -> None:
message = Message.from_file(self.file_path_min,
MessageFilter(tags_or=set(),
tags_and=set()))
self.assertIsNone(message)
def test_from_file_txt_no_tags_match_tags_not(self) -> None: def test_from_file_txt_no_tags_match_tags_not(self) -> None:
message = Message.from_file(self.file_path_min, message = Message.from_file(self.file_path_min,
MessageFilter(tags_not={Tag('tag1')})) MessageFilter(tags_not={Tag('tag1')}))
+162
View File
@@ -0,0 +1,162 @@
import os
import unittest
import argparse
import tempfile
from pathlib import Path
from unittest.mock import MagicMock
from chatmastermind.commands.question import create_message
from chatmastermind.message import Message, Question
from chatmastermind.chat import ChatDB
class TestMessageCreate(unittest.TestCase):
"""
Test if messages created by the 'question' command have
the correct format.
"""
def setUp(self) -> None:
# create ChatDB structure
self.db_path = tempfile.TemporaryDirectory()
self.cache_path = tempfile.TemporaryDirectory()
self.chat = ChatDB.from_dir(cache_path=Path(self.cache_path.name),
db_path=Path(self.db_path.name))
# create arguments mock
self.args = MagicMock(spec=argparse.Namespace)
self.args.source_text = None
self.args.source_code = None
self.args.AI = None
self.args.model = None
self.args.output_tags = None
# File 1 : no source code block, only text
self.source_file1 = tempfile.NamedTemporaryFile(delete=False)
self.source_file1_content = """This is just text.
No source code.
Nope. Go look elsewhere!"""
with open(self.source_file1.name, 'w') as f:
f.write(self.source_file1_content)
# File 2 : one embedded source code block
self.source_file2 = tempfile.NamedTemporaryFile(delete=False)
self.source_file2_content = """This is just text.
```
This is embedded source code.
```
And some text again."""
with open(self.source_file2.name, 'w') as f:
f.write(self.source_file2_content)
# File 3 : all source code
self.source_file3 = tempfile.NamedTemporaryFile(delete=False)
self.source_file3_content = """This is all source code.
Yes, really.
Language is called 'brainfart'."""
with open(self.source_file3.name, 'w') as f:
f.write(self.source_file3_content)
# File 4 : two source code blocks
self.source_file4 = tempfile.NamedTemporaryFile(delete=False)
self.source_file4_content = """This is just text.
```
This is embedded source code.
```
And some text again.
```
This is embedded source code.
```
Aaaand again some text."""
with open(self.source_file4.name, 'w') as f:
f.write(self.source_file4_content)
def tearDown(self) -> None:
os.remove(self.source_file1.name)
os.remove(self.source_file2.name)
os.remove(self.source_file3.name)
def message_list(self, tmp_dir: tempfile.TemporaryDirectory) -> list[Path]:
# exclude '.next'
return list(Path(tmp_dir.name).glob('*.[ty]*'))
def test_message_file_created(self) -> None:
self.args.ask = ["What is this?"]
cache_dir_files = self.message_list(self.cache_path)
self.assertEqual(len(cache_dir_files), 0)
create_message(self.chat, self.args)
cache_dir_files = self.message_list(self.cache_path)
self.assertEqual(len(cache_dir_files), 1)
message = Message.from_file(cache_dir_files[0])
self.assertIsInstance(message, Message)
self.assertEqual(message.question, Question("What is this?")) # type: ignore [union-attr]
def test_single_question(self) -> None:
self.args.ask = ["What is this?"]
message = create_message(self.chat, self.args)
self.assertIsInstance(message, Message)
self.assertEqual(message.question, Question("What is this?"))
self.assertEqual(len(message.question.source_code()), 0)
def test_multipart_question(self) -> None:
self.args.ask = ["What is this", "'bard' thing?", "Is it good?"]
message = create_message(self.chat, self.args)
self.assertIsInstance(message, Message)
self.assertEqual(message.question, Question("""What is this
'bard' thing?
Is it good?"""))
def test_single_question_with_text_only_file(self) -> None:
self.args.ask = ["What is this?"]
self.args.source_text = [f"{self.source_file1.name}"]
message = create_message(self.chat, self.args)
self.assertIsInstance(message, Message)
# file contains no source code (only text)
# -> don't expect any in the question
self.assertEqual(len(message.question.source_code()), 0)
self.assertEqual(message.question, Question(f"""What is this?
{self.source_file1_content}"""))
def test_single_question_with_text_file_and_embedded_code(self) -> None:
self.args.ask = ["What is this?"]
self.args.source_code = [f"{self.source_file2.name}"]
message = create_message(self.chat, self.args)
self.assertIsInstance(message, Message)
# file contains 1 source code block
# -> expect it in the question
self.assertEqual(len(message.question.source_code()), 1)
self.assertEqual(message.question, Question("""What is this?
```
This is embedded source code.
```
"""))
def test_single_question_with_code_only_file(self) -> None:
self.args.ask = ["What is this?"]
self.args.source_code = [f"{self.source_file3.name}"]
message = create_message(self.chat, self.args)
self.assertIsInstance(message, Message)
# file is complete source code
self.assertEqual(len(message.question.source_code()), 1)
self.assertEqual(message.question, Question(f"""What is this?
```
{self.source_file3_content}
```"""))
def test_single_question_with_text_file_and_multi_embedded_code(self) -> None:
self.args.ask = ["What is this?"]
self.args.source_code = [f"{self.source_file4.name}"]
message = create_message(self.chat, self.args)
self.assertIsInstance(message, Message)
# file contains 2 source code blocks
# -> expect them in the question
self.assertEqual(len(message.question.source_code()), 2)
self.assertEqual(message.question, Question("""What is this?
```
This is embedded source code.
```
```
This is embedded source code.
```
"""))