8 Commits

Author SHA1 Message Date
Oleksandr Kozachuk 9a957a89ac Switch to current version of OpenAI. 2024-03-30 14:17:38 +01:00
Oleksandr Kozachuk 5d1bb1f9e4 Fix some of the commands. 2023-11-10 10:42:46 +01:00
Oleksandr Kozachuk 75a123eb72 Fix usage of the dynamic answer is some cases. 2023-10-24 12:59:13 +02:00
juk0de 7c1c67f8ff Merge pull request 'Dynamic Answer class and OpenAI streaming API' (#19) from dynamic_answer into main
Introduces several changes with the main objective of enabling OpenAI's streaming API in the chatmastermind application. This allows for the retrieval of AI responses gradually as a stream, which can significantly improve the user experience in interactions that involve large result sets.

* Added tiktoken import in 'openai.py' and modifications to the OpenAI class to support streaming. This includes the addition of a new class OpenAIAnswer to handle streaming API responses.
* Modified request function in the OpenAI class: the stream=True flag is added to the openai.ChatCompletion.create method to enable streaming API.
* Modified 'question.py' to print the answer parts as they are streamed.
* Replaced the Answer class's string data type with a generator which supports str and Generator[str, None, None] data types. Modifications are made to the Answer class methods to handle both data types accordingly.
* Updated the tests in 'test_ais_openai.py' and 'test_message.py' to reflect and validate these changes.
2023-10-21 15:50:45 +02:00
Oleksandr Kozachuk dbe72ff11c Activate and use OpenAI streaming API. 2023-10-21 14:21:48 +02:00
Oleksandr Kozachuk bbc1ab5a0a Fix source_code function with the dynamic answer class. 2023-10-20 14:02:09 +02:00
Oleksandr Kozachuk 2aee018708 Refactor message.Answer class in a way, that it can be constructed dynamically step by step, in preparation of using streaming API. 2023-10-20 13:43:31 +02:00
ok 17c6fa2453 Merge pull request 'Configurable glob and location on question and hist commands' (#18) from cust_loc_glob into main
Reviewed-on: #18
2023-10-20 09:47:03 +02:00
10 changed files with 94 additions and 658 deletions
+14 -18
View File
@@ -47,12 +47,12 @@ class OpenAIAnswer:
self.finished = True
if not self.finished:
found_choice = False
for choice in chunk['choices']:
if not choice['finish_reason']:
self.streams[choice['index']].data.append(choice['delta']['content'])
self.tokens.completion += len(self.encoding.encode(choice['delta']['content']))
for choice in chunk.choices:
if not choice.finish_reason:
self.streams[choice.index].data.append(choice.delta.content)
self.tokens.completion += len(self.encoding.encode(choice.delta.content))
self.tokens.total = self.tokens.prompt + self.tokens.completion
if choice['index'] == self.idx:
if choice.index == self.idx:
found_choice = True
if not found_choice:
return False
@@ -68,7 +68,10 @@ class OpenAI(AI):
self.ID = config.ID
self.name = config.name
self.config = config
openai.api_key = self.config.api_key
self.client = openai.OpenAI(api_key=self.config.api_key)
def _completions(self, *args, **kw): # type: ignore
return self.client.chat.completions.create(*args, **kw)
def request(self,
question: Message,
@@ -83,7 +86,7 @@ class OpenAI(AI):
self.encoding = tiktoken.encoding_for_model(self.config.model)
oai_chat, prompt_tokens = self.openai_chat(chat, self.config.system, question)
tokens: Tokens = Tokens(prompt_tokens, 0, prompt_tokens)
response = openai.ChatCompletion.create(
response = self._completions(
model=self.config.model,
messages=oai_chat,
temperature=self.config.temperature,
@@ -114,9 +117,8 @@ class OpenAI(AI):
Return all models supported by this AI.
"""
ret = []
for engine in sorted(openai.Engine.list()['data'], key=lambda x: x['id']):
if engine['ready']:
ret.append(engine['id'])
for engine in sorted(self.client.models.list().data, key=lambda x: x.id):
ret.append(engine.id)
ret.sort()
return ret
@@ -124,14 +126,8 @@ class OpenAI(AI):
"""
Print all models supported by the current AI.
"""
not_ready = []
for engine in sorted(openai.Engine.list()['data'], key=lambda x: x['id']):
if engine['ready']:
print(engine['id'])
else:
not_ready.append(engine['id'])
if len(not_ready) > 0:
print('\nNot ready: ' + ', '.join(not_ready))
for model in self.models():
print(model)
def openai_chat(self, chat: Chat, system: str,
question: Optional[Message] = None) -> tuple[ChatType, int]:
-69
View File
@@ -1,69 +0,0 @@
"""
Contains shared functions for the various CMM subcommands.
"""
import argparse
from pathlib import Path
from ..message import Message, MessageError, source_code
def read_text_file(file: Path) -> str:
with open(file) as r:
content = r.read().strip()
return content
def add_file_as_text(question_parts: list[str], file: str) -> None:
"""
Add the given file as plain text to the question part list.
If the file is a Message, add the answer.
"""
file_path = Path(file)
content: str
try:
message = Message.from_file(file_path)
if message and message.answer:
content = message.answer
except MessageError:
content = read_text_file(Path(file))
if len(content) > 0:
question_parts.append(content)
def add_file_as_code(question_parts: list[str], file: str) -> None:
"""
Add all source code from the given file. If no code segments can be extracted,
the whole content is added as source code segment. If the file is a Message,
extract the source code from the answer.
"""
file_path = Path(file)
content: str
try:
message = Message.from_file(file_path)
if message and message.answer:
content = message.answer
except MessageError:
with open(file) as r:
content = r.read().strip()
# extract and add source code
code_parts = source_code(content, include_delims=True)
if len(code_parts) > 0:
question_parts += code_parts
else:
question_parts.append(f"```\n{content}\n```")
def invert_input_tag_args(args: argparse.Namespace) -> None:
"""
Changes the semantics of the INPUT tags for this command:
* not tags specified on the CLI -> no tags are selected
* empty tags specified on the CLI -> all tags are selected
"""
if args.or_tags is None:
args.or_tags = set()
elif len(args.or_tags) == 0:
args.or_tags = None
if args.and_tags is None:
args.and_tags = set()
elif len(args.and_tags) == 0:
args.and_tags = None
+58 -2
View File
@@ -3,10 +3,9 @@ import argparse
from pathlib import Path
from itertools import zip_longest
from copy import deepcopy
from .common import invert_input_tag_args, add_file_as_code, add_file_as_text
from ..configuration import Config
from ..chat import ChatDB, msg_location
from ..message import Message, MessageFilter, Question
from ..message import Message, MessageFilter, MessageError, Question, source_code
from ..ai_factory import create_ai
from ..ai import AI, AIResponse
@@ -15,6 +14,47 @@ class QuestionCmdError(Exception):
pass
def add_file_as_text(question_parts: list[str], file: str) -> None:
"""
Add the given file as plain text to the question part list.
If the file is a Message, add the answer.
"""
file_path = Path(file)
content: str
try:
message = Message.from_file(file_path)
if message and message.answer:
content = message.answer
except MessageError:
with open(file) as r:
content = r.read().strip()
if len(content) > 0:
question_parts.append(content)
def add_file_as_code(question_parts: list[str], file: str) -> None:
"""
Add all source code from the given file. If no code segments can be extracted,
the whole content is added as source code segment. If the file is a Message,
extract the source code from the answer.
"""
file_path = Path(file)
content: str
try:
message = Message.from_file(file_path)
if message and message.answer:
content = message.answer
except MessageError:
with open(file) as r:
content = r.read().strip()
# extract and add source code
code_parts = source_code(content, include_delims=True)
if len(code_parts) > 0:
question_parts += code_parts
else:
question_parts.append(f"```\n{content}\n```")
def create_msg_args(msg: Message, args: argparse.Namespace) -> argparse.Namespace:
"""
Takes an existing message and CLI arguments, and returns modified args based
@@ -123,6 +163,22 @@ def repeat_messages(messages: list[Message], chat: ChatDB, args: argparse.Namesp
make_request(ai, chat, message, msg_args)
def invert_input_tag_args(args: argparse.Namespace) -> None:
"""
Changes the semantics of the INPUT tags for this command:
* not tags specified on the CLI -> no tags are selected
* empty tags specified on the CLI -> all tags are selected
"""
if args.or_tags is None:
args.or_tags = set()
elif len(args.or_tags) == 0:
args.or_tags = None
if args.and_tags is None:
args.and_tags = set()
elif len(args.and_tags) == 0:
args.and_tags = None
def question_cmd(args: argparse.Namespace, config: Config) -> None:
"""
Handler for the 'question' command.
-105
View File
@@ -1,105 +0,0 @@
import argparse
import mimetypes
from pathlib import Path
from .common import invert_input_tag_args, read_text_file
from ..configuration import Config
from ..message import MessageFilter, Message, Question
from ..chat import ChatDB, msg_location
class TranslationCmdError(Exception):
pass
text_separator: str = 'TEXT:'
def assert_document_type_supported_openai(document_file: Path) -> None:
doctype = mimetypes.guess_type(document_file)
if doctype != 'text/plain':
raise TranslationCmdError("AI 'OpenAI' only supports document type 'text/plain''")
def translation_prompt_openai(source_lang: str, target_lang: str) -> str:
"""
Return the prompt for GPT that tells it to do the translation.
"""
return f"Translate the text below the line {text_separator} from {source_lang} to {target_lang}."
def create_message_openai(chat: ChatDB, args: argparse.Namespace) -> Message:
"""
Create a new message from the given arguments and write it to the cache directory.
Message format
1. Translation prompt (tells GPT to do a translation)
2. Glossary (if specified as an argument)
3. User provided prompt enhancements
4. Translation separator
5. User provided text to be translated
The text to be translated is determined as a follows:
- if a document is provided in the arguments, translate its content
- if no document is provided, translate the last text argument
The other text arguments will be put into the "header" and can be used
to improve the translation prompt.
"""
text_args: list[str] = []
if args.create is not None:
text_args = args.create
elif args.ask is not None:
text_args = args.ask
else:
raise TranslationCmdError("No input text found")
# extract user prompt and user text to be translated
user_text: str
user_prompt: str
if args.input_document is not None:
assert_document_type_supported_openai(Path(args.input_document))
user_text = read_text_file(Path(args.input_document))
user_prompt = '\n\n'.join([str(s) for s in text_args])
else:
user_text = text_args[-1]
user_prompt = '\n\n'.join([str(s) for s in text_args[:-1]])
# build full question string
# FIXME: add glossaries if given
question_text: str = '\n\n'.join([translation_prompt_openai(args.source_lang, args.target_lang),
user_prompt,
text_separator,
user_text])
# create and write the message
message = Message(question=Question(question_text),
tags=args.output_tags,
ai=args.AI,
model=args.model)
# only write the new message to the cache,
# don't add it to the internal list
chat.cache_write([message])
return message
def translation_cmd(args: argparse.Namespace, config: Config) -> None:
"""
Handler for the 'translation' command. Creates and executes translation
requests based on the input and selected AI. Depending on the AI, the
whole process may be significantly different (e.g. DeepL vs OpenAI).
"""
invert_input_tag_args(args)
mfilter = MessageFilter(tags_or=args.or_tags,
tags_and=args.and_tags,
tags_not=args.exclude_tags)
chat = ChatDB.from_dir(cache_path=Path(config.cache),
db_path=Path(config.db),
mfilter=mfilter,
glob=args.glob,
loc=msg_location(args.location))
# if it's a new translation, create and store it immediately
# FIXME: check AI type
if args.ask or args.create:
# message = create_message(chat, args)
create_message_openai(chat, args)
if args.create:
return
+1 -3
View File
@@ -118,7 +118,6 @@ class Config:
# a default configuration
cache: str = '.'
db: str = './db/'
glossaries: str | None = './glossaries/'
ais: dict[str, AIConfig] = field(default_factory=create_default_ai_configs)
@classmethod
@@ -136,8 +135,7 @@ class Config:
return cls(
cache=str(source['cache']) if 'cache' in source else '.',
db=str(source['db']),
ais=ais,
glossaries=str(source['glossaries']) if 'glossaries' in source else None
ais=ais
)
@classmethod
-161
View File
@@ -1,161 +0,0 @@
"""
Module implementing glossaries for translations.
"""
import yaml
import tempfile
import shutil
import csv
from pathlib import Path
from dataclasses import dataclass, field
from typing import Type, TypeVar, ClassVar
GlossaryInst = TypeVar('GlossaryInst', bound='Glossary')
class GlossaryError(Exception):
pass
def str_presenter(dumper: yaml.Dumper, data: str) -> yaml.ScalarNode:
"""
Changes the YAML dump style to multiline syntax for multiline strings.
"""
if len(data.splitlines()) > 1:
return dumper.represent_scalar('tag:yaml.org,2002:str', data, style='|')
return dumper.represent_scalar('tag:yaml.org,2002:str', data)
@dataclass
class Glossary:
"""
A glossary consists of the following parameters:
- Name (freely selectable)
- Path (full file path, suffix is automatically generated)
- Source language
- Target language
- Description (optional)
- Entries (pairs of source lang and target lang terms)
- ID (automatically generated / modified, required by DeepL)
"""
name: str
source_lang: str
target_lang: str
file_path: Path | None = None
desc: str | None = None
entries: dict[str, str] = field(default_factory=lambda: dict())
ID: str | None = None
file_suffix: ClassVar[str] = '.glo'
@classmethod
def from_file(cls: Type[GlossaryInst], file_path: Path) -> GlossaryInst:
"""
Create a glossary from the given file.
"""
if not file_path.exists():
raise GlossaryError(f"Glossary file '{file_path}' does not exist")
if file_path.suffix != cls.file_suffix:
raise GlossaryError(f"File type '{file_path.suffix}' is not supported")
with open(file_path, "r") as fd:
try:
# use BaseLoader so every entry is read as a string
# - disables automatic conversions
# - makes it possible to omit quoting for YAML keywords in entries (e. g. 'yes')
# - also correctly reads quoted entries
data = yaml.load(fd, Loader=yaml.BaseLoader)
clean_entries = data['Entries']
return cls(name=data['Name'],
source_lang=data['SourceLang'],
target_lang=data['TargetLang'],
file_path=file_path,
desc=data['Description'],
entries=clean_entries,
ID=data['ID'] if data['ID'] != 'None' else None)
except Exception:
raise GlossaryError(f"'{file_path}' does not contain a valid glossary")
def to_file(self, file_path: Path | None = None) -> None:
"""
Write glossary to given file.
"""
if file_path:
self.file_path = file_path
if not self.file_path:
raise GlossaryError("Got no valid path to write glossary")
# check / add valid suffix
if not self.file_path.suffix:
self.file_path = self.file_path.with_suffix(self.file_suffix)
elif self.file_path.suffix != self.file_suffix:
raise GlossaryError(f"File suffix '{self.file_path.suffix}' is not supported")
# write YAML
with tempfile.NamedTemporaryFile(dir=self.file_path.parent, prefix=self.file_path.name, mode="w", delete=False) as temp_fd:
temp_file_path = Path(temp_fd.name)
data = {'Name': self.name,
'Description': self.desc,
'ID': str(self.ID),
'SourceLang': self.source_lang,
'TargetLang': self.target_lang,
'Entries': self.entries}
yaml.dump(data, temp_fd, sort_keys=False)
shutil.move(temp_file_path, self.file_path)
def export_csv(self, dictionary: dict[str, str], file_path: Path) -> None:
"""
Export the 'entries' of this glossary to a file in CSV format (compatible with DeepL).
"""
with open(file_path, 'w', newline='', encoding='utf-8') as csvfile:
writer = csv.writer(csvfile, delimiter=',', quotechar='"', quoting=csv.QUOTE_ALL)
for source_entry, target_entry in self.entries.items():
writer.writerow([source_entry, target_entry])
def export_tsv(self, entries: dict[str, str], file_path: Path) -> None:
"""
Export the 'entries' of this glossary to a file in TSV format (compatible with DeepL).
"""
with open(file_path, 'w', encoding='utf-8') as file:
for source_entry, target_entry in self.entries.items():
file.write(f"{source_entry}\t{target_entry}\n")
def import_csv(self, file_path: Path) -> None:
"""
Import the entries from the given CSV file to those of the current glossary.
Existing entries are overwritten.
"""
try:
with open(file_path, mode='r', encoding='utf-8') as csvfile:
reader = csv.reader(csvfile, delimiter=',', quotechar='"')
self.entries = {rows[0]: rows[1] for rows in reader if len(rows) >= 2}
except Exception as e:
raise GlossaryError(f"Error importing CSV: {e}")
def import_tsv(self, file_path: Path) -> None:
"""
Import the entries from the given CSV file to those of the current glossary.
Existing entries are overwritten.
"""
try:
with open(file_path, mode='r', encoding='utf-8') as tsvfile:
self.entries = {}
for line in tsvfile:
parts = line.strip().split('\t')
if len(parts) == 2:
self.entries[parts[0]] = parts[1]
except Exception as e:
raise GlossaryError(f"Error importing TSV: {e}")
def to_str(self, with_entries: bool = False) -> str:
"""
Return the current glossary as a string.
"""
output: list[str] = []
output.append(f'{self.name} (ID: {self.ID}):')
if self.desc:
output.append('- ' + self.desc)
output.append(f'- Languages: {self.source_lang} -> {self.target_lang}')
if with_entries:
output.append('- Entries:')
for source, target in self.entries.items():
output.append(f' {source} : {target}')
else:
output.append(f'- Entries: {len(self.entries)}')
return '\n'.join(output)
+1 -59
View File
@@ -3,7 +3,6 @@
# vim: set fileencoding=utf-8 :
import sys
import os
import argcomplete
import argparse
from pathlib import Path
@@ -15,7 +14,6 @@ from .commands.tags import tags_cmd
from .commands.config import config_cmd
from .commands.hist import hist_cmd
from .commands.print import print_cmd
from .commands.translation import translation_cmd
from .chat import msg_location
@@ -104,7 +102,7 @@ def create_parser() -> argparse.ArgumentParser:
# 'tags' command parser
tags_cmd_parser = cmdparser.add_parser('tags',
help="Manage tags.",
aliases=['T'])
aliases=['t'])
tags_cmd_parser.set_defaults(func=tags_cmd)
tags_group = tags_cmd_parser.add_mutually_exclusive_group(required=True)
tags_group.add_argument('-l', '--list', help="List all tags and their frequency",
@@ -138,65 +136,10 @@ def create_parser() -> argparse.ArgumentParser:
print_cmd_modes.add_argument('-a', '--answer', help='Only print the answer', action='store_true')
print_cmd_modes.add_argument('-S', '--only-source-code', help='Only print embedded source code', action='store_true')
# 'translation' command parser
translation_cmd_parser = cmdparser.add_parser('translation', parents=[ai_parser, tag_parser],
help="ask, create and repeat translations.",
aliases=['t'])
translation_cmd_parser.set_defaults(func=translation_cmd)
translation_group = translation_cmd_parser.add_mutually_exclusive_group(required=True)
translation_group.add_argument('-a', '--ask', nargs='+', help='Ask to translate the given text', metavar='TEXT')
translation_group.add_argument('-c', '--create', nargs='+', help='Create a translation', metavar='TEXT')
translation_group.add_argument('-r', '--repeat', nargs='*', help='Repeat a translation', metavar='MESSAGE')
translation_cmd_parser.add_argument('-S', '--source-lang', help="Source language", metavar="LANGUAGE", required=True)
translation_cmd_parser.add_argument('-T', '--target-lang', help="Target language", metavar="LANGUAGE", required=True)
translation_cmd_parser.add_argument('-G', '--glossaries', nargs='+', help="List of glossaries", metavar="GLOSSARY")
translation_cmd_parser.add_argument('-d', '--input-document', help="Document to translate", metavar="FILE")
translation_cmd_parser.add_argument('-D', '--output-document', help="Path for the translated document", metavar="FILE")
argcomplete.autocomplete(parser)
return parser
def create_directories(config: Config) -> None: # noqa: 11
"""
Create the directories in the given configuration if they don't exist.
"""
def make_dir(path: Path) -> None:
try:
os.makedirs(path.absolute())
except Exception as e:
print(f"Creating directory '{path.absolute()}' failed with: {e}")
sys.exit(1)
# Cache
cache_path = Path(config.cache)
if not cache_path.exists():
answer = input(f"Cache directory '{cache_path}' does not exist. Create it? [y/n]")
if answer.lower() in ['y', 'yes']:
make_dir(cache_path.absolute())
else:
print("Can't continue without a valid cache directory!")
sys.exit(1)
# DB
db_path = Path(config.db)
if not db_path.exists():
answer = input(f"DB directory '{db_path}' does not exist. Create it? [y/n]")
if answer.lower() in ['y', 'yes']:
make_dir(db_path.absolute())
else:
print("Can't continue without a valid DB directory!")
sys.exit(1)
# Glossaries
if config.glossaries:
glossaries_path = Path(config.glossaries)
if not glossaries_path.exists():
answer = input(f"Glossaries directory '{glossaries_path}' does not exist. Create it? [y/n]")
if answer.lower() in ['y', 'yes']:
make_dir(glossaries_path.absolute())
else:
print("Can't continue without a valid glossaries directory. Create it or remove it from the configuration.")
sys.exit(1)
def main() -> int:
parser = create_parser()
args = parser.parse_args()
@@ -206,7 +149,6 @@ def main() -> int:
command.func(command)
else:
config = Config.from_file(args.config)
create_directories(config)
command.func(command, config)
return 0
+19 -31
View File
@@ -9,43 +9,31 @@ from chatmastermind.configuration import OpenAIConfig
class OpenAITest(unittest.TestCase):
@mock.patch('openai.ChatCompletion.create')
@mock.patch('chatmastermind.ais.openai.OpenAI._completions')
def test_request(self, mock_create: mock.MagicMock) -> None:
# Create a test instance of OpenAI
config = OpenAIConfig()
openai = OpenAI(config)
# Set up the mock response from openai.ChatCompletion.create
mock_chunk1 = {
'choices': [
{
'index': 0,
'delta': {
'content': 'Answer 1'
},
'finish_reason': None
},
{
'index': 1,
'delta': {
'content': 'Answer 2'
},
'finish_reason': None
}
],
}
mock_chunk2 = {
'choices': [
{
'index': 0,
'finish_reason': 'stop'
},
{
'index': 1,
'finish_reason': 'stop'
}
],
}
class mock_obj:
pass
mock_chunk1 = mock_obj()
mock_chunk1.choices = [mock_obj(), mock_obj()] # type: ignore
mock_chunk1.choices[0].index = 0 # type: ignore
mock_chunk1.choices[0].delta = mock_obj() # type: ignore
mock_chunk1.choices[0].delta.content = 'Answer 1' # type: ignore
mock_chunk1.choices[0].finish_reason = None # type: ignore
mock_chunk1.choices[1].index = 1 # type: ignore
mock_chunk1.choices[1].delta = mock_obj() # type: ignore
mock_chunk1.choices[1].delta.content = 'Answer 2' # type: ignore
mock_chunk1.choices[1].finish_reason = None # type: ignore
mock_chunk2 = mock_obj()
mock_chunk2.choices = [mock_obj(), mock_obj()] # type: ignore
mock_chunk2.choices[0].index = 0 # type: ignore
mock_chunk2.choices[0].finish_reason = 'stop' # type: ignore
mock_chunk2.choices[1].index = 1 # type: ignore
mock_chunk2.choices[1].finish_reason = 'stop' # type: ignore
mock_create.return_value = iter([mock_chunk1, mock_chunk2])
# Create test data
+1 -6
View File
@@ -71,13 +71,11 @@ class TestConfig(unittest.TestCase):
'frequency_penalty': 0.7,
'presence_penalty': 0.2
}
},
'glossaries': './glossaries/'
}
}
config = Config.from_dict(source_dict)
self.assertEqual(config.cache, '.')
self.assertEqual(config.db, './test_db/')
self.assertEqual(config.glossaries, './glossaries/')
self.assertEqual(len(config.ais), 1)
self.assertEqual(config.ais['myopenai'].name, 'openai')
self.assertEqual(cast(OpenAIConfig, config.ais['myopenai']).system, 'Custom system')
@@ -107,7 +105,6 @@ class TestConfig(unittest.TestCase):
'frequency_penalty': 0.7,
'presence_penalty': 0.2
}
# omit glossaries, since it's optional
}
}
with open(self.test_file.name, 'w') as f:
@@ -116,8 +113,6 @@ class TestConfig(unittest.TestCase):
self.assertIsInstance(config, Config)
self.assertEqual(config.cache, './test_cache/')
self.assertEqual(config.db, './test_db/')
# missing 'glossaries' should result in 'None'
self.assertEqual(config.glossaries, None)
self.assertEqual(len(config.ais), 1)
self.assertIsInstance(config.ais['default'], AIConfig)
self.assertEqual(cast(OpenAIConfig, config.ais['default']).system, 'Custom system')
-204
View File
@@ -1,204 +0,0 @@
import unittest
import tempfile
from pathlib import Path
from chatmastermind.glossary import Glossary, GlossaryError
glossary_suffix: str = Glossary.file_suffix
class TestGlossary(unittest.TestCase):
def test_from_file_yaml_unquoted(self) -> None:
"""
Test glossary creatiom from YAML with unquoted entries.
"""
with tempfile.NamedTemporaryFile('w', delete=False, suffix=glossary_suffix) as yaml_file:
yaml_file.write("Name: Sample\n"
"Description: A brief description\n"
"ID: '123'\n"
"SourceLang: en\n"
"TargetLang: es\n"
"Entries:\n"
" hello: hola\n"
" goodbye: adiós\n"
# 'yes' is a YAML keyword and would normally be quoted
" yes: sí\n"
" I'm going home: me voy a casa\n")
yaml_file_path = Path(yaml_file.name)
glossary = Glossary.from_file(yaml_file_path)
self.assertEqual(glossary.name, "Sample")
self.assertEqual(glossary.desc, "A brief description")
self.assertEqual(glossary.ID, "123")
self.assertEqual(glossary.source_lang, "en")
self.assertEqual(glossary.target_lang, "es")
self.assertEqual(glossary.entries, {"hello": "hola",
"goodbye": "adiós",
"yes": "",
"I'm going home": "me voy a casa"})
yaml_file_path.unlink() # Remove the temporary file
def test_from_file_yaml_quoted(self) -> None:
"""
Test glossary creatiom from YAML with quoted entries.
"""
with tempfile.NamedTemporaryFile('w', delete=False, suffix=glossary_suffix) as yaml_file:
yaml_file.write("Name: Sample\n"
"Description: A brief description\n"
"ID: '123'\n"
"SourceLang: en\n"
"TargetLang: es\n"
"Entries:\n"
" 'hello': 'hola'\n"
" 'goodbye': 'adiós'\n"
" 'yes': ''\n"
" \"I'm going home\": 'me voy a casa'\n")
yaml_file_path = Path(yaml_file.name)
glossary = Glossary.from_file(yaml_file_path)
self.assertEqual(glossary.name, "Sample")
self.assertEqual(glossary.desc, "A brief description")
self.assertEqual(glossary.ID, "123")
self.assertEqual(glossary.source_lang, "en")
self.assertEqual(glossary.target_lang, "es")
self.assertEqual(glossary.entries, {"hello": "hola",
"goodbye": "adiós",
"yes": "",
"I'm going home": "me voy a casa"})
yaml_file_path.unlink() # Remove the temporary file
def test_to_file_writes_yaml(self) -> None:
# Create glossary instance
glossary = Glossary(name="Test",
desc="Test description",
ID="666",
source_lang="en",
target_lang="fr",
entries={"yes": "oui"})
with tempfile.NamedTemporaryFile('w', delete=False, suffix=glossary_suffix) as tmp_file:
file_path = Path(tmp_file.name)
glossary.to_file(file_path)
with open(file_path, 'r') as file:
content = file.read()
self.assertIn("Name: Test", content)
self.assertIn("Description: Test description", content)
self.assertIn("ID: '666'", content)
self.assertIn("SourceLang: en", content)
self.assertIn("TargetLang: fr", content)
self.assertIn("Entries", content)
# 'yes' is a YAML keyword and therefore quoted
self.assertIn("'yes': oui", content)
file_path.unlink() # Remove the temporary file
def test_write_read_glossary(self) -> None:
# Create glossary instance
# -> use 'yes' in order to test if the YAML quoting is correctly removed when reading the file
glossary_write = Glossary(name="Test", source_lang="en", target_lang="fr", entries={"yes": "oui"})
with tempfile.NamedTemporaryFile('w', delete=False, suffix=glossary_suffix) as tmp_file:
file_path = Path(tmp_file.name)
glossary_write.to_file(file_path)
# create new instance from glossary file
glossary_read = Glossary.from_file(file_path)
self.assertEqual(glossary_write.name, glossary_read.name)
self.assertEqual(glossary_write.source_lang, glossary_read.source_lang)
self.assertEqual(glossary_write.target_lang, glossary_read.target_lang)
self.assertDictEqual(glossary_write.entries, glossary_read.entries)
file_path.unlink() # Remove the temporary file
def test_import_export_csv(self) -> None:
glossary = Glossary(name="Test", source_lang="en", target_lang="fr", entries={})
# First export to CSV
with tempfile.NamedTemporaryFile('w', delete=False, suffix=glossary_suffix) as csvfile:
csv_file_path = Path(csvfile.name)
glossary.entries = {"hello": "salut", "goodbye": "au revoir"}
glossary.export_csv(glossary.entries, csv_file_path)
# Now import CSV
glossary.import_csv(csv_file_path)
self.assertEqual(glossary.entries, {"hello": "salut", "goodbye": "au revoir"})
csv_file_path.unlink() # Remove the temporary file
def test_import_export_tsv(self) -> None:
glossary = Glossary(name="Test", source_lang="en", target_lang="fr", entries={})
# First export to TSV
with tempfile.NamedTemporaryFile('w', delete=False, suffix=glossary_suffix) as tsvfile:
tsv_file_path = Path(tsvfile.name)
glossary.entries = {"hello": "salut", "goodbye": "au revoir"}
glossary.export_tsv(glossary.entries, tsv_file_path)
# Now import TSV
glossary.import_tsv(tsv_file_path)
self.assertEqual(glossary.entries, {"hello": "salut", "goodbye": "au revoir"})
tsv_file_path.unlink() # Remove the temporary file
def test_to_file_wrong_suffix(self) -> None:
"""
Test for exception if suffix is wrong.
"""
glossary = Glossary(name="Test", source_lang="en", target_lang="fr", entries={"yes": "oui"})
with tempfile.NamedTemporaryFile('w', delete=False, suffix='.wrong') as tmp_file:
file_path = Path(tmp_file.name)
with self.assertRaises(GlossaryError) as err:
glossary.to_file(file_path)
self.assertEqual(str(err.exception), "File suffix '.wrong' is not supported")
def test_to_file_auto_suffix(self) -> None:
"""
Test if suffix is auto-generated if omitted.
"""
glossary = Glossary(name="Test", source_lang="en", target_lang="fr", entries={"yes": "oui"})
with tempfile.NamedTemporaryFile('w', delete=False, suffix='') as tmp_file:
file_path = Path(tmp_file.name)
glossary.to_file(file_path)
assert glossary.file_path is not None
self.assertEqual(glossary.file_path.suffix, glossary_suffix)
def test_to_str_with_id(self) -> None:
# Create a Glossary instance with an ID
glossary_with_id = Glossary(name="TestGlossary", source_lang="en", target_lang="fr",
desc="A simple test glossary", ID="1001", entries={"one": "un"})
glossary_str = glossary_with_id.to_str()
self.assertIn("TestGlossary (ID: 1001):", glossary_str)
self.assertIn("- A simple test glossary", glossary_str)
self.assertIn("- Languages: en -> fr", glossary_str)
self.assertIn("- Entries: 1", glossary_str)
def test_to_str_with_id_and_entries(self) -> None:
# Create a Glossary instance with an ID and include entries
glossary_with_entries = Glossary(name="TestGlossaryWithEntries", source_lang="en", target_lang="fr",
desc="Another test glossary", ID="2002",
entries={"hello": "salut", "goodbye": "au revoir"})
glossary_str_with_entries = glossary_with_entries.to_str(with_entries=True)
self.assertIn("TestGlossaryWithEntries (ID: 2002):", glossary_str_with_entries)
self.assertIn("- Entries:", glossary_str_with_entries)
self.assertIn(" hello : salut", glossary_str_with_entries)
self.assertIn(" goodbye : au revoir", glossary_str_with_entries)
def test_to_str_without_id(self) -> None:
# Create a Glossary instance without an ID
glossary_without_id = Glossary(name="TestGlossaryNoID", source_lang="en", target_lang="fr",
desc="A test glossary without an ID", ID=None, entries={"yes": "oui"})
glossary_str_no_id = glossary_without_id.to_str()
self.assertIn("TestGlossaryNoID (ID: None):", glossary_str_no_id)
self.assertIn("- A test glossary without an ID", glossary_str_no_id)
self.assertIn("- Languages: en -> fr", glossary_str_no_id)
self.assertIn("- Entries: 1", glossary_str_no_id)
def test_to_str_without_id_and_no_entries(self) -> None:
# Create a Glossary instance without an ID and no entries
glossary_no_id_no_entries = Glossary(name="EmptyGlossary", source_lang="en", target_lang="fr",
desc="An empty test glossary", ID=None, entries={})
glossary_str_no_id_no_entries = glossary_no_id_no_entries.to_str()
self.assertIn("EmptyGlossary (ID: None):", glossary_str_no_id_no_entries)
self.assertIn("- An empty test glossary", glossary_str_no_id_no_entries)
self.assertIn("- Languages: en -> fr", glossary_str_no_id_no_entries)
self.assertIn("- Entries: 0", glossary_str_no_id_no_entries)