Merge pull request 'cmm question --repeat supports multiple questions, added tests and fixes' (#15) from repeat_multi into main
This PR primarily modifies the `cmm question --repeat` command to allow repeating multiple questions, instead of only the last one. Additionally, this PR includes the following changes: - In `ai_factory.py`, added optional parameters 'def_ai' and 'def_model' to the `create_ai` function which allows specifying a default AI and model. - In `openai.py`, a potential bug was fixed where the 'tags' attribute was updated to ensure it is always a set, even when 'otags' is None. - In `question.py`, a significant amount of new code was added to facilitate the 'repeat' functionality. This includes functions to create modified args based on an existing message (`create_msg_args`), to repeat a given list of messages (`repeat_messages`), and to invert the semantics of the INPUT tags for this command (`invert_input_tag_args`). - In `main.py`, the 'nargs' parameter was changed from `+` to `*` in the 'or-tags', 'and-tags', and 'exclude-tags' arguments to accommodate the updated handling of tags in `question.py`. - A new `test_common.py` file was added which includes a `FakeAI` class for testing purposes, and a `TestWithFakeAI` class which includes a number of methods for asserting various conditions about messages. This PR also includes additional tests to verify the correct operation of the new 'repeat' functionality.
This commit was merged in pull request #15.
This commit is contained in:
@@ -3,18 +3,20 @@ Creates different AI instances, based on the given configuration.
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
from typing import cast
|
from typing import cast, Optional
|
||||||
from .configuration import Config, AIConfig, OpenAIConfig
|
from .configuration import Config, AIConfig, OpenAIConfig
|
||||||
from .ai import AI, AIError
|
from .ai import AI, AIError
|
||||||
from .ais.openai import OpenAI
|
from .ais.openai import OpenAI
|
||||||
|
|
||||||
|
|
||||||
def create_ai(args: argparse.Namespace, config: Config) -> AI: # noqa: 11
|
def create_ai(args: argparse.Namespace, config: Config, # noqa: 11
|
||||||
|
def_ai: Optional[str] = None,
|
||||||
|
def_model: Optional[str] = None) -> AI:
|
||||||
"""
|
"""
|
||||||
Creates an AI subclass instance from the given arguments
|
Creates an AI subclass instance from the given arguments and configuration file.
|
||||||
and configuration file. If AI has not been set in the
|
If AI has not been set in the arguments, it searches for the ID 'default'. If
|
||||||
arguments, it searches for the ID 'default'. If that
|
that is not found, it uses the first AI in the list. It's also possible to
|
||||||
is not found, it uses the first AI in the list.
|
specify a default AI and model using 'def_ai' and 'def_model'.
|
||||||
"""
|
"""
|
||||||
ai_conf: AIConfig
|
ai_conf: AIConfig
|
||||||
if hasattr(args, 'AI') and args.AI:
|
if hasattr(args, 'AI') and args.AI:
|
||||||
@@ -22,6 +24,8 @@ def create_ai(args: argparse.Namespace, config: Config) -> AI: # noqa: 11
|
|||||||
ai_conf = config.ais[args.AI]
|
ai_conf = config.ais[args.AI]
|
||||||
except KeyError:
|
except KeyError:
|
||||||
raise AIError(f"AI ID '{args.AI}' does not exist in this configuration")
|
raise AIError(f"AI ID '{args.AI}' does not exist in this configuration")
|
||||||
|
elif def_ai:
|
||||||
|
ai_conf = config.ais[def_ai]
|
||||||
elif 'default' in config.ais:
|
elif 'default' in config.ais:
|
||||||
ai_conf = config.ais['default']
|
ai_conf = config.ais['default']
|
||||||
else:
|
else:
|
||||||
@@ -34,6 +38,8 @@ def create_ai(args: argparse.Namespace, config: Config) -> AI: # noqa: 11
|
|||||||
ai = OpenAI(cast(OpenAIConfig, ai_conf))
|
ai = OpenAI(cast(OpenAIConfig, ai_conf))
|
||||||
if hasattr(args, 'model') and args.model:
|
if hasattr(args, 'model') and args.model:
|
||||||
ai.config.model = args.model
|
ai.config.model = args.model
|
||||||
|
elif def_model:
|
||||||
|
ai.config.model = def_model
|
||||||
if hasattr(args, 'max_tokens') and args.max_tokens:
|
if hasattr(args, 'max_tokens') and args.max_tokens:
|
||||||
ai.config.max_tokens = args.max_tokens
|
ai.config.max_tokens = args.max_tokens
|
||||||
if hasattr(args, 'temperature') and args.temperature:
|
if hasattr(args, 'temperature') and args.temperature:
|
||||||
|
|||||||
@@ -44,7 +44,7 @@ class OpenAI(AI):
|
|||||||
frequency_penalty=self.config.frequency_penalty,
|
frequency_penalty=self.config.frequency_penalty,
|
||||||
presence_penalty=self.config.presence_penalty)
|
presence_penalty=self.config.presence_penalty)
|
||||||
question.answer = Answer(response['choices'][0]['message']['content'])
|
question.answer = Answer(response['choices'][0]['message']['content'])
|
||||||
question.tags = otags
|
question.tags = set(otags) if otags is not None else None
|
||||||
question.ai = self.ID
|
question.ai = self.ID
|
||||||
question.model = self.config.model
|
question.model = self.config.model
|
||||||
answers: list[Message] = [question]
|
answers: list[Message] = [question]
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ import sys
|
|||||||
import argparse
|
import argparse
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from itertools import zip_longest
|
from itertools import zip_longest
|
||||||
|
from copy import deepcopy
|
||||||
from ..configuration import Config
|
from ..configuration import Config
|
||||||
from ..chat import ChatDB
|
from ..chat import ChatDB
|
||||||
from ..message import Message, MessageFilter, MessageError, Question, source_code
|
from ..message import Message, MessageFilter, MessageError, Question, source_code
|
||||||
@@ -105,13 +106,75 @@ def make_request(ai: AI, chat: ChatDB, message: Message, args: argparse.Namespac
|
|||||||
print(response.tokens)
|
print(response.tokens)
|
||||||
|
|
||||||
|
|
||||||
|
def create_msg_args(msg: Message, args: argparse.Namespace) -> argparse.Namespace:
|
||||||
|
"""
|
||||||
|
Takes an existing message and CLI arguments, and returns modified args based
|
||||||
|
on the members of the given message. Used e.g. when repeating messages, where
|
||||||
|
it's necessary to determine the correct AI, module and output tags to use
|
||||||
|
(either from the existing message or the given args).
|
||||||
|
"""
|
||||||
|
msg_args = args
|
||||||
|
# if AI, model or output tags have not been specified,
|
||||||
|
# use those from the original message
|
||||||
|
if (args.AI is None
|
||||||
|
or args.model is None # noqa: W503
|
||||||
|
or args.output_tags is None): # noqa: W503
|
||||||
|
msg_args = deepcopy(args)
|
||||||
|
if args.AI is None and msg.ai is not None:
|
||||||
|
msg_args.AI = msg.ai
|
||||||
|
if args.model is None and msg.model is not None:
|
||||||
|
msg_args.model = msg.model
|
||||||
|
if args.output_tags is None and msg.tags is not None:
|
||||||
|
msg_args.output_tags = msg.tags
|
||||||
|
return msg_args
|
||||||
|
|
||||||
|
|
||||||
|
def repeat_messages(messages: list[Message], chat: ChatDB, args: argparse.Namespace, config: Config) -> None:
|
||||||
|
"""
|
||||||
|
Repeat the given messages using the given arguments.
|
||||||
|
"""
|
||||||
|
ai: AI
|
||||||
|
for msg in messages:
|
||||||
|
msg_args = create_msg_args(msg, args)
|
||||||
|
ai = create_ai(msg_args, config)
|
||||||
|
print(f"--------- Repeating message '{msg.msg_id()}': ---------")
|
||||||
|
# overwrite the latest message if requested or empty
|
||||||
|
# -> but not if it's in the DB!
|
||||||
|
if ((msg.answer is None or msg_args.overwrite is True)
|
||||||
|
and (not chat.msg_in_db(msg))): # noqa: W503
|
||||||
|
msg.clear_answer()
|
||||||
|
make_request(ai, chat, msg, msg_args)
|
||||||
|
# otherwise create a new one
|
||||||
|
else:
|
||||||
|
msg_args.ask = [msg.question]
|
||||||
|
message = create_message(chat, msg_args)
|
||||||
|
make_request(ai, chat, message, msg_args)
|
||||||
|
|
||||||
|
|
||||||
|
def invert_input_tag_args(args: argparse.Namespace) -> None:
|
||||||
|
"""
|
||||||
|
Changes the semantics of the INPUT tags for this command:
|
||||||
|
* not tags specified on the CLI -> no tags are selected
|
||||||
|
* empty tags specified on the CLI -> all tags are selected
|
||||||
|
"""
|
||||||
|
if args.or_tags is None:
|
||||||
|
args.or_tags = set()
|
||||||
|
elif len(args.or_tags) == 0:
|
||||||
|
args.or_tags = None
|
||||||
|
if args.and_tags is None:
|
||||||
|
args.and_tags = set()
|
||||||
|
elif len(args.and_tags) == 0:
|
||||||
|
args.and_tags = None
|
||||||
|
|
||||||
|
|
||||||
def question_cmd(args: argparse.Namespace, config: Config) -> None:
|
def question_cmd(args: argparse.Namespace, config: Config) -> None:
|
||||||
"""
|
"""
|
||||||
Handler for the 'question' command.
|
Handler for the 'question' command.
|
||||||
"""
|
"""
|
||||||
mfilter = MessageFilter(tags_or=args.or_tags if args.or_tags is not None else set(),
|
invert_input_tag_args(args)
|
||||||
tags_and=args.and_tags if args.and_tags is not None else set(),
|
mfilter = MessageFilter(tags_or=args.or_tags,
|
||||||
tags_not=args.exclude_tags if args.exclude_tags is not None else set())
|
tags_and=args.and_tags,
|
||||||
|
tags_not=args.exclude_tags)
|
||||||
chat = ChatDB.from_dir(cache_path=Path(config.cache),
|
chat = ChatDB.from_dir(cache_path=Path(config.cache),
|
||||||
db_path=Path(config.db),
|
db_path=Path(config.db),
|
||||||
mfilter=mfilter)
|
mfilter=mfilter)
|
||||||
@@ -121,30 +184,24 @@ def question_cmd(args: argparse.Namespace, config: Config) -> None:
|
|||||||
if args.create:
|
if args.create:
|
||||||
return
|
return
|
||||||
|
|
||||||
# create the correct AI instance
|
|
||||||
ai: AI = create_ai(args, config)
|
|
||||||
|
|
||||||
# === ASK ===
|
# === ASK ===
|
||||||
if args.ask:
|
if args.ask:
|
||||||
|
ai: AI = create_ai(args, config)
|
||||||
make_request(ai, chat, message, args)
|
make_request(ai, chat, message, args)
|
||||||
# === REPEAT ===
|
# === REPEAT ===
|
||||||
elif args.repeat is not None:
|
elif args.repeat is not None:
|
||||||
|
repeat_msgs: list[Message] = []
|
||||||
|
# repeat latest message
|
||||||
|
if len(args.repeat) == 0:
|
||||||
lmessage = chat.msg_latest(loc='cache')
|
lmessage = chat.msg_latest(loc='cache')
|
||||||
if lmessage is None:
|
if lmessage is None:
|
||||||
print("No message found to repeat!")
|
print("No message found to repeat!")
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
repeat_msgs.append(lmessage)
|
||||||
|
# repeat given message(s)
|
||||||
else:
|
else:
|
||||||
print(f"Repeating message '{lmessage.msg_id()}':")
|
repeat_msgs = chat.msg_find(args.repeat, loc='disk')
|
||||||
# overwrite the latest message if requested or empty
|
repeat_messages(repeat_msgs, chat, args, config)
|
||||||
if lmessage.answer is None or args.overwrite is True:
|
|
||||||
lmessage.clear_answer()
|
|
||||||
make_request(ai, chat, lmessage, args)
|
|
||||||
# otherwise create a new one
|
|
||||||
else:
|
|
||||||
args.ask = [lmessage.question]
|
|
||||||
message = create_message(chat, args)
|
|
||||||
make_request(ai, chat, message, args)
|
|
||||||
|
|
||||||
# === PROCESS ===
|
# === PROCESS ===
|
||||||
elif args.process is not None:
|
elif args.process is not None:
|
||||||
# TODO: process either all questions without an
|
# TODO: process either all questions without an
|
||||||
|
|||||||
@@ -34,13 +34,13 @@ def create_parser() -> argparse.ArgumentParser:
|
|||||||
|
|
||||||
# a parent parser for all commands that support tag selection
|
# a parent parser for all commands that support tag selection
|
||||||
tag_parser = argparse.ArgumentParser(add_help=False)
|
tag_parser = argparse.ArgumentParser(add_help=False)
|
||||||
tag_arg = tag_parser.add_argument('-t', '--or-tags', nargs='+',
|
tag_arg = tag_parser.add_argument('-t', '--or-tags', nargs='*',
|
||||||
help='List of tags (one must match)', metavar='OTAGS')
|
help='List of tags (one must match)', metavar='OTAGS')
|
||||||
tag_arg.completer = tags_completer # type: ignore
|
tag_arg.completer = tags_completer # type: ignore
|
||||||
atag_arg = tag_parser.add_argument('-k', '--and-tags', nargs='+',
|
atag_arg = tag_parser.add_argument('-k', '--and-tags', nargs='*',
|
||||||
help='List of tags (all must match)', metavar='ATAGS')
|
help='List of tags (all must match)', metavar='ATAGS')
|
||||||
atag_arg.completer = tags_completer # type: ignore
|
atag_arg.completer = tags_completer # type: ignore
|
||||||
etag_arg = tag_parser.add_argument('-x', '--exclude-tags', nargs='+',
|
etag_arg = tag_parser.add_argument('-x', '--exclude-tags', nargs='*',
|
||||||
help='List of tags to exclude', metavar='XTAGS')
|
help='List of tags to exclude', metavar='XTAGS')
|
||||||
etag_arg.completer = tags_completer # type: ignore
|
etag_arg.completer = tags_completer # type: ignore
|
||||||
otag_arg = tag_parser.add_argument('-o', '--output-tags', nargs='+',
|
otag_arg = tag_parser.add_argument('-o', '--output-tags', nargs='+',
|
||||||
|
|||||||
@@ -0,0 +1,100 @@
|
|||||||
|
import unittest
|
||||||
|
import argparse
|
||||||
|
from typing import Union, Optional
|
||||||
|
from chatmastermind.configuration import Config, AIConfig
|
||||||
|
from chatmastermind.tags import Tag
|
||||||
|
from chatmastermind.message import Message, Answer
|
||||||
|
from chatmastermind.chat import Chat
|
||||||
|
from chatmastermind.ai import AI, AIResponse, Tokens, AIError
|
||||||
|
|
||||||
|
|
||||||
|
class FakeAI(AI):
|
||||||
|
"""
|
||||||
|
A mocked version of the 'AI' class.
|
||||||
|
"""
|
||||||
|
ID: str
|
||||||
|
name: str
|
||||||
|
config: AIConfig
|
||||||
|
|
||||||
|
def models(self) -> list[str]:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def tokens(self, data: Union[Message, Chat]) -> int:
|
||||||
|
return 123
|
||||||
|
|
||||||
|
def print(self) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def print_models(self) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def __init__(self, ID: str, model: str, error: bool = False):
|
||||||
|
self.ID = ID
|
||||||
|
self.model = model
|
||||||
|
self.error = error
|
||||||
|
|
||||||
|
def request(self,
|
||||||
|
question: Message,
|
||||||
|
chat: Chat,
|
||||||
|
num_answers: int = 1,
|
||||||
|
otags: Optional[set[Tag]] = None) -> AIResponse:
|
||||||
|
"""
|
||||||
|
Mock the 'ai.request()' function by either returning fake
|
||||||
|
answers or raising an exception.
|
||||||
|
"""
|
||||||
|
if self.error:
|
||||||
|
raise AIError
|
||||||
|
question.answer = Answer("Answer 0")
|
||||||
|
question.tags = set(otags) if otags is not None else None
|
||||||
|
question.ai = self.ID
|
||||||
|
question.model = self.model
|
||||||
|
answers: list[Message] = [question]
|
||||||
|
for n in range(1, num_answers):
|
||||||
|
answers.append(Message(question=question.question,
|
||||||
|
answer=Answer(f"Answer {n}"),
|
||||||
|
tags=otags,
|
||||||
|
ai=self.ID,
|
||||||
|
model=self.model))
|
||||||
|
return AIResponse(answers, Tokens(10, 10, 20))
|
||||||
|
|
||||||
|
|
||||||
|
class TestWithFakeAI(unittest.TestCase):
|
||||||
|
"""
|
||||||
|
Base class for all tests that need to use the FakeAI.
|
||||||
|
"""
|
||||||
|
def assert_msgs_equal_except_file_path(self, msg1: list[Message], msg2: list[Message]) -> None:
|
||||||
|
"""
|
||||||
|
Compare messages using Question, Answer and all metadata excecot for the file_path.
|
||||||
|
"""
|
||||||
|
self.assertEqual(len(msg1), len(msg2))
|
||||||
|
for m1, m2 in zip(msg1, msg2):
|
||||||
|
# exclude the file_path, compare only Q, A and metadata
|
||||||
|
self.assertTrue(m1.equals(m2, file_path=False, verbose=True))
|
||||||
|
|
||||||
|
def assert_msgs_all_equal(self, msg1: list[Message], msg2: list[Message]) -> None:
|
||||||
|
"""
|
||||||
|
Compare messages using Question, Answer and ALL metadata.
|
||||||
|
"""
|
||||||
|
self.assertEqual(len(msg1), len(msg2))
|
||||||
|
for m1, m2 in zip(msg1, msg2):
|
||||||
|
self.assertTrue(m1.equals(m2, verbose=True))
|
||||||
|
|
||||||
|
def assert_msgs_content_equal(self, msg1: list[Message], msg2: list[Message]) -> None:
|
||||||
|
"""
|
||||||
|
Compare messages using only Question and Answer.
|
||||||
|
"""
|
||||||
|
self.assertEqual(len(msg1), len(msg2))
|
||||||
|
for m1, m2 in zip(msg1, msg2):
|
||||||
|
self.assertEqual(m1, m2)
|
||||||
|
|
||||||
|
def mock_create_ai(self, args: argparse.Namespace, config: Config) -> AI:
|
||||||
|
"""
|
||||||
|
Mocked 'create_ai' that returns a 'FakeAI' instance.
|
||||||
|
"""
|
||||||
|
return FakeAI(args.AI, args.model)
|
||||||
|
|
||||||
|
def mock_create_ai_with_error(self, args: argparse.Namespace, config: Config) -> AI:
|
||||||
|
"""
|
||||||
|
Mocked 'create_ai' that returns a 'FakeAI' instance.
|
||||||
|
"""
|
||||||
|
return FakeAI(args.AI, args.model, error=True)
|
||||||
+211
-149
@@ -1,31 +1,20 @@
|
|||||||
import os
|
import os
|
||||||
import unittest
|
|
||||||
import argparse
|
import argparse
|
||||||
import tempfile
|
import tempfile
|
||||||
|
from copy import copy
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from unittest import mock
|
from unittest import mock
|
||||||
from unittest.mock import MagicMock, call, ANY
|
from unittest.mock import MagicMock, call
|
||||||
from typing import Optional
|
|
||||||
from chatmastermind.configuration import Config
|
from chatmastermind.configuration import Config
|
||||||
from chatmastermind.commands.question import create_message, question_cmd
|
from chatmastermind.commands.question import create_message, question_cmd
|
||||||
from chatmastermind.tags import Tag
|
from chatmastermind.tags import Tag
|
||||||
from chatmastermind.message import Message, Question, Answer
|
from chatmastermind.message import Message, Question, Answer
|
||||||
from chatmastermind.chat import Chat, ChatDB
|
from chatmastermind.chat import Chat, ChatDB
|
||||||
from chatmastermind.ai import AI, AIResponse, Tokens, AIError
|
from chatmastermind.ai import AIError
|
||||||
|
from .test_common import TestWithFakeAI
|
||||||
|
|
||||||
|
|
||||||
class TestQuestionCmdBase(unittest.TestCase):
|
class TestMessageCreate(TestWithFakeAI):
|
||||||
def assert_messages_equal(self, msg1: list[Message], msg2: list[Message]) -> None:
|
|
||||||
"""
|
|
||||||
Compare messages using more than just Question and Answer.
|
|
||||||
"""
|
|
||||||
self.assertEqual(len(msg1), len(msg2))
|
|
||||||
for m1, m2 in zip(msg1, msg2):
|
|
||||||
# exclude the file_path, compare only Q, A and metadata
|
|
||||||
self.assertTrue(m1.equals(m2, file_path=False, verbose=True))
|
|
||||||
|
|
||||||
|
|
||||||
class TestMessageCreate(TestQuestionCmdBase):
|
|
||||||
"""
|
"""
|
||||||
Test if messages created by the 'question' command have
|
Test if messages created by the 'question' command have
|
||||||
the correct format.
|
the correct format.
|
||||||
@@ -212,7 +201,7 @@ It is embedded code
|
|||||||
"""))
|
"""))
|
||||||
|
|
||||||
|
|
||||||
class TestQuestionCmd(TestQuestionCmdBase):
|
class TestQuestionCmd(TestWithFakeAI):
|
||||||
|
|
||||||
def setUp(self) -> None:
|
def setUp(self) -> None:
|
||||||
# create DB and cache
|
# create DB and cache
|
||||||
@@ -227,8 +216,8 @@ class TestQuestionCmd(TestQuestionCmdBase):
|
|||||||
ask=['What is the meaning of life?'],
|
ask=['What is the meaning of life?'],
|
||||||
num_answers=1,
|
num_answers=1,
|
||||||
output_tags=['science'],
|
output_tags=['science'],
|
||||||
AI='openai',
|
AI='FakeAI',
|
||||||
model='gpt-3.5-turbo',
|
model='FakeModel',
|
||||||
or_tags=None,
|
or_tags=None,
|
||||||
and_tags=None,
|
and_tags=None,
|
||||||
exclude_tags=None,
|
exclude_tags=None,
|
||||||
@@ -239,57 +228,27 @@ class TestQuestionCmd(TestQuestionCmdBase):
|
|||||||
process=None,
|
process=None,
|
||||||
overwrite=None
|
overwrite=None
|
||||||
)
|
)
|
||||||
# create a mock AI instance
|
|
||||||
self.ai = MagicMock(spec=AI)
|
|
||||||
self.ai.request.side_effect = self.mock_request
|
|
||||||
|
|
||||||
def input_message(self, args: argparse.Namespace) -> Message:
|
|
||||||
"""
|
|
||||||
Create the expected input message for a question using the
|
|
||||||
given arguments.
|
|
||||||
"""
|
|
||||||
# NOTE: we only use the first question from the "ask" list
|
|
||||||
# -> message creation using "question.create_message()" is
|
|
||||||
# tested above
|
|
||||||
# the answer is always empty for the input message
|
|
||||||
return Message(Question(args.ask[0]),
|
|
||||||
tags=args.output_tags,
|
|
||||||
ai=args.AI,
|
|
||||||
model=args.model)
|
|
||||||
|
|
||||||
def mock_request(self,
|
|
||||||
question: Message,
|
|
||||||
chat: Chat,
|
|
||||||
num_answers: int = 1,
|
|
||||||
otags: Optional[set[Tag]] = None) -> AIResponse:
|
|
||||||
"""
|
|
||||||
Mock the 'ai.request()' function
|
|
||||||
"""
|
|
||||||
question.answer = Answer("Answer 0")
|
|
||||||
question.tags = set(otags) if otags else None
|
|
||||||
question.ai = 'FakeAI'
|
|
||||||
question.model = 'FakeModel'
|
|
||||||
answers: list[Message] = [question]
|
|
||||||
for n in range(1, num_answers):
|
|
||||||
answers.append(Message(question=question.question,
|
|
||||||
answer=Answer(f"Answer {n}"),
|
|
||||||
tags=otags,
|
|
||||||
ai='FakeAI',
|
|
||||||
model='FakeModel'))
|
|
||||||
return AIResponse(answers, Tokens(10, 10, 20))
|
|
||||||
|
|
||||||
def message_list(self, tmp_dir: tempfile.TemporaryDirectory) -> list[Path]:
|
def message_list(self, tmp_dir: tempfile.TemporaryDirectory) -> list[Path]:
|
||||||
# exclude '.next'
|
# exclude '.next'
|
||||||
return sorted([f for f in Path(tmp_dir.name).glob('*.[ty]*')])
|
return sorted([f for f in Path(tmp_dir.name).glob('*.[ty]*')])
|
||||||
|
|
||||||
|
|
||||||
|
class TestQuestionCmdAsk(TestQuestionCmd):
|
||||||
|
|
||||||
@mock.patch('chatmastermind.commands.question.create_ai')
|
@mock.patch('chatmastermind.commands.question.create_ai')
|
||||||
def test_ask_single_answer(self, mock_create_ai: MagicMock) -> None:
|
def test_ask_single_answer(self, mock_create_ai: MagicMock) -> None:
|
||||||
"""
|
"""
|
||||||
Test single answer with no errors.
|
Test single answer with no errors.
|
||||||
"""
|
"""
|
||||||
mock_create_ai.return_value = self.ai
|
mock_create_ai.side_effect = self.mock_create_ai
|
||||||
expected_question = self.input_message(self.args)
|
expected_question = Message(Question(self.args.ask[0]),
|
||||||
expected_responses = self.mock_request(expected_question,
|
tags=set(self.args.output_tags),
|
||||||
|
ai=self.args.AI,
|
||||||
|
model=self.args.model,
|
||||||
|
file_path=Path('<NOT COMPARED>'))
|
||||||
|
fake_ai = self.mock_create_ai(self.args, self.config)
|
||||||
|
expected_responses = fake_ai.request(expected_question,
|
||||||
Chat([]),
|
Chat([]),
|
||||||
self.args.num_answers,
|
self.args.num_answers,
|
||||||
self.args.output_tags).messages
|
self.args.output_tags).messages
|
||||||
@@ -297,17 +256,12 @@ class TestQuestionCmd(TestQuestionCmdBase):
|
|||||||
# execute the command
|
# execute the command
|
||||||
question_cmd(self.args, self.config)
|
question_cmd(self.args, self.config)
|
||||||
|
|
||||||
# check for correct request call
|
|
||||||
self.ai.request.assert_called_once_with(expected_question,
|
|
||||||
ANY,
|
|
||||||
self.args.num_answers,
|
|
||||||
self.args.output_tags)
|
|
||||||
# check for the expected message files
|
# check for the expected message files
|
||||||
chat = ChatDB.from_dir(Path(self.cache_dir.name),
|
chat = ChatDB.from_dir(Path(self.cache_dir.name),
|
||||||
Path(self.db_dir.name))
|
Path(self.db_dir.name))
|
||||||
cached_msg = chat.msg_gather(loc='cache')
|
cached_msg = chat.msg_gather(loc='cache')
|
||||||
self.assertEqual(len(self.message_list(self.cache_dir)), 1)
|
self.assertEqual(len(self.message_list(self.cache_dir)), 1)
|
||||||
self.assert_messages_equal(cached_msg, expected_responses)
|
self.assert_msgs_equal_except_file_path(cached_msg, expected_responses)
|
||||||
|
|
||||||
@mock.patch('chatmastermind.commands.question.ChatDB.from_dir')
|
@mock.patch('chatmastermind.commands.question.ChatDB.from_dir')
|
||||||
@mock.patch('chatmastermind.commands.question.create_ai')
|
@mock.patch('chatmastermind.commands.question.create_ai')
|
||||||
@@ -318,9 +272,14 @@ class TestQuestionCmd(TestQuestionCmdBase):
|
|||||||
chat = MagicMock(spec=ChatDB)
|
chat = MagicMock(spec=ChatDB)
|
||||||
mock_from_dir.return_value = chat
|
mock_from_dir.return_value = chat
|
||||||
|
|
||||||
mock_create_ai.return_value = self.ai
|
mock_create_ai.side_effect = self.mock_create_ai
|
||||||
expected_question = self.input_message(self.args)
|
expected_question = Message(Question(self.args.ask[0]),
|
||||||
expected_responses = self.mock_request(expected_question,
|
tags=set(self.args.output_tags),
|
||||||
|
ai=self.args.AI,
|
||||||
|
model=self.args.model,
|
||||||
|
file_path=Path('<NOT COMPARED>'))
|
||||||
|
fake_ai = self.mock_create_ai(self.args, self.config)
|
||||||
|
expected_responses = fake_ai.request(expected_question,
|
||||||
Chat([]),
|
Chat([]),
|
||||||
self.args.num_answers,
|
self.args.num_answers,
|
||||||
self.args.output_tags).messages
|
self.args.output_tags).messages
|
||||||
@@ -328,12 +287,6 @@ class TestQuestionCmd(TestQuestionCmdBase):
|
|||||||
# execute the command
|
# execute the command
|
||||||
question_cmd(self.args, self.config)
|
question_cmd(self.args, self.config)
|
||||||
|
|
||||||
# check for correct request call
|
|
||||||
self.ai.request.assert_called_once_with(expected_question,
|
|
||||||
chat,
|
|
||||||
self.args.num_answers,
|
|
||||||
self.args.output_tags)
|
|
||||||
|
|
||||||
# check for the correct ChatDB calls:
|
# check for the correct ChatDB calls:
|
||||||
# - initial question has been written (prior to the actual request)
|
# - initial question has been written (prior to the actual request)
|
||||||
# - responses have been written (after the request)
|
# - responses have been written (after the request)
|
||||||
@@ -350,86 +303,98 @@ class TestQuestionCmd(TestQuestionCmdBase):
|
|||||||
Provoke an error during the AI request and verify that the question
|
Provoke an error during the AI request and verify that the question
|
||||||
has been correctly stored in the cache.
|
has been correctly stored in the cache.
|
||||||
"""
|
"""
|
||||||
mock_create_ai.return_value = self.ai
|
mock_create_ai.side_effect = self.mock_create_ai_with_error
|
||||||
expected_question = self.input_message(self.args)
|
expected_question = Message(Question(self.args.ask[0]),
|
||||||
self.ai.request.side_effect = AIError
|
tags=set(self.args.output_tags),
|
||||||
|
ai=self.args.AI,
|
||||||
|
model=self.args.model,
|
||||||
|
file_path=Path('<NOT COMPARED>'))
|
||||||
|
|
||||||
# execute the command
|
# execute the command
|
||||||
with self.assertRaises(AIError):
|
with self.assertRaises(AIError):
|
||||||
question_cmd(self.args, self.config)
|
question_cmd(self.args, self.config)
|
||||||
|
|
||||||
# check for correct request call
|
|
||||||
self.ai.request.assert_called_once_with(expected_question,
|
|
||||||
ANY,
|
|
||||||
self.args.num_answers,
|
|
||||||
self.args.output_tags)
|
|
||||||
# check for the expected message files
|
# check for the expected message files
|
||||||
chat = ChatDB.from_dir(Path(self.cache_dir.name),
|
chat = ChatDB.from_dir(Path(self.cache_dir.name),
|
||||||
Path(self.db_dir.name))
|
Path(self.db_dir.name))
|
||||||
cached_msg = chat.msg_gather(loc='cache')
|
cached_msg = chat.msg_gather(loc='cache')
|
||||||
self.assertEqual(len(self.message_list(self.cache_dir)), 1)
|
self.assertEqual(len(self.message_list(self.cache_dir)), 1)
|
||||||
self.assert_messages_equal(cached_msg, [expected_question])
|
self.assert_msgs_equal_except_file_path(cached_msg, [expected_question])
|
||||||
|
|
||||||
|
|
||||||
|
class TestQuestionCmdRepeat(TestQuestionCmd):
|
||||||
|
|
||||||
@mock.patch('chatmastermind.commands.question.create_ai')
|
@mock.patch('chatmastermind.commands.question.create_ai')
|
||||||
def test_repeat_single_question(self, mock_create_ai: MagicMock) -> None:
|
def test_repeat_single_question(self, mock_create_ai: MagicMock) -> None:
|
||||||
"""
|
"""
|
||||||
Repeat a single question.
|
Repeat a single question.
|
||||||
"""
|
"""
|
||||||
# 1. ask a question
|
mock_create_ai.side_effect = self.mock_create_ai
|
||||||
mock_create_ai.return_value = self.ai
|
# create a message
|
||||||
expected_question = self.input_message(self.args)
|
message = Message(Question(self.args.ask[0]),
|
||||||
expected_responses = self.mock_request(expected_question,
|
Answer('Old Answer'),
|
||||||
Chat([]),
|
tags=set(self.args.output_tags),
|
||||||
self.args.num_answers,
|
ai=self.args.AI,
|
||||||
self.args.output_tags).messages
|
model=self.args.model,
|
||||||
question_cmd(self.args, self.config)
|
file_path=Path(self.cache_dir.name) / '0001.txt')
|
||||||
chat = ChatDB.from_dir(Path(self.cache_dir.name),
|
message.to_file()
|
||||||
Path(self.db_dir.name))
|
|
||||||
cached_msg = chat.msg_gather(loc='cache')
|
|
||||||
self.assertEqual(len(self.message_list(self.cache_dir)), 1)
|
|
||||||
self.assert_messages_equal(cached_msg, expected_responses)
|
|
||||||
|
|
||||||
# 2. repeat the last question (without overwriting)
|
# repeat the last question (without overwriting)
|
||||||
# -> expect two identical messages (except for the file_path)
|
# -> expect two identical messages (except for the file_path)
|
||||||
self.args.ask = None
|
self.args.ask = None
|
||||||
self.args.repeat = []
|
self.args.repeat = []
|
||||||
self.args.overwrite = False
|
self.args.overwrite = False
|
||||||
expected_responses += expected_responses
|
expected_response = Message(Question(message.question),
|
||||||
|
Answer('Answer 0'),
|
||||||
|
ai=message.ai,
|
||||||
|
model=message.model,
|
||||||
|
tags=message.tags,
|
||||||
|
file_path=Path('<NOT COMPARED>'))
|
||||||
|
# we expect the original message + the one with the new response
|
||||||
|
expected_responses = [message] + [expected_response]
|
||||||
question_cmd(self.args, self.config)
|
question_cmd(self.args, self.config)
|
||||||
|
chat = ChatDB.from_dir(Path(self.cache_dir.name),
|
||||||
|
Path(self.db_dir.name))
|
||||||
cached_msg = chat.msg_gather(loc='cache')
|
cached_msg = chat.msg_gather(loc='cache')
|
||||||
|
print(self.message_list(self.cache_dir))
|
||||||
self.assertEqual(len(self.message_list(self.cache_dir)), 2)
|
self.assertEqual(len(self.message_list(self.cache_dir)), 2)
|
||||||
self.assert_messages_equal(cached_msg, expected_responses)
|
self.assert_msgs_equal_except_file_path(cached_msg, expected_responses)
|
||||||
|
|
||||||
@mock.patch('chatmastermind.commands.question.create_ai')
|
@mock.patch('chatmastermind.commands.question.create_ai')
|
||||||
def test_repeat_single_question_overwrite(self, mock_create_ai: MagicMock) -> None:
|
def test_repeat_single_question_overwrite(self, mock_create_ai: MagicMock) -> None:
|
||||||
"""
|
"""
|
||||||
Repeat a single question and overwrite the old one.
|
Repeat a single question and overwrite the old one.
|
||||||
"""
|
"""
|
||||||
# 1. ask a question
|
mock_create_ai.side_effect = self.mock_create_ai
|
||||||
mock_create_ai.return_value = self.ai
|
# create a message
|
||||||
expected_question = self.input_message(self.args)
|
message = Message(Question(self.args.ask[0]),
|
||||||
expected_responses = self.mock_request(expected_question,
|
Answer('Old Answer'),
|
||||||
Chat([]),
|
tags=set(self.args.output_tags),
|
||||||
self.args.num_answers,
|
ai=self.args.AI,
|
||||||
self.args.output_tags).messages
|
model=self.args.model,
|
||||||
question_cmd(self.args, self.config)
|
file_path=Path(self.cache_dir.name) / '0001.txt')
|
||||||
|
message.to_file()
|
||||||
chat = ChatDB.from_dir(Path(self.cache_dir.name),
|
chat = ChatDB.from_dir(Path(self.cache_dir.name),
|
||||||
Path(self.db_dir.name))
|
Path(self.db_dir.name))
|
||||||
cached_msg = chat.msg_gather(loc='cache')
|
cached_msg = chat.msg_gather(loc='cache')
|
||||||
assert cached_msg[0].file_path
|
assert cached_msg[0].file_path
|
||||||
cached_msg_file_id = cached_msg[0].file_path.stem
|
cached_msg_file_id = cached_msg[0].file_path.stem
|
||||||
self.assertEqual(len(self.message_list(self.cache_dir)), 1)
|
|
||||||
self.assert_messages_equal(cached_msg, expected_responses)
|
|
||||||
|
|
||||||
# 2. repeat the last question (WITH overwriting)
|
# repeat the last question (WITH overwriting)
|
||||||
# -> expect a single message afterwards
|
# -> expect a single message afterwards (with a new answer)
|
||||||
self.args.ask = None
|
self.args.ask = None
|
||||||
self.args.repeat = []
|
self.args.repeat = []
|
||||||
self.args.overwrite = True
|
self.args.overwrite = True
|
||||||
|
expected_response = Message(Question(message.question),
|
||||||
|
Answer('Answer 0'),
|
||||||
|
ai=message.ai,
|
||||||
|
model=message.model,
|
||||||
|
tags=message.tags,
|
||||||
|
file_path=Path('<NOT COMPARED>'))
|
||||||
question_cmd(self.args, self.config)
|
question_cmd(self.args, self.config)
|
||||||
cached_msg = chat.msg_gather(loc='cache')
|
cached_msg = chat.msg_gather(loc='cache')
|
||||||
self.assertEqual(len(self.message_list(self.cache_dir)), 1)
|
self.assertEqual(len(self.message_list(self.cache_dir)), 1)
|
||||||
self.assert_messages_equal(cached_msg, expected_responses)
|
self.assert_msgs_equal_except_file_path(cached_msg, [expected_response])
|
||||||
# also check that the file ID has not been changed
|
# also check that the file ID has not been changed
|
||||||
assert cached_msg[0].file_path
|
assert cached_msg[0].file_path
|
||||||
self.assertEqual(cached_msg_file_id, cached_msg[0].file_path.stem)
|
self.assertEqual(cached_msg_file_id, cached_msg[0].file_path.stem)
|
||||||
@@ -439,35 +404,37 @@ class TestQuestionCmd(TestQuestionCmdBase):
|
|||||||
"""
|
"""
|
||||||
Repeat a single question after an error.
|
Repeat a single question after an error.
|
||||||
"""
|
"""
|
||||||
# 1. ask a question and provoke an error
|
mock_create_ai.side_effect = self.mock_create_ai
|
||||||
mock_create_ai.return_value = self.ai
|
# create a question WITHOUT an answer
|
||||||
expected_question = self.input_message(self.args)
|
# -> just like after an error, which is tested above
|
||||||
self.ai.request.side_effect = AIError
|
message = Message(Question(self.args.ask[0]),
|
||||||
with self.assertRaises(AIError):
|
tags=set(self.args.output_tags),
|
||||||
question_cmd(self.args, self.config)
|
ai=self.args.AI,
|
||||||
|
model=self.args.model,
|
||||||
|
file_path=Path(self.cache_dir.name) / '0001.txt')
|
||||||
|
message.to_file()
|
||||||
chat = ChatDB.from_dir(Path(self.cache_dir.name),
|
chat = ChatDB.from_dir(Path(self.cache_dir.name),
|
||||||
Path(self.db_dir.name))
|
Path(self.db_dir.name))
|
||||||
cached_msg = chat.msg_gather(loc='cache')
|
cached_msg = chat.msg_gather(loc='cache')
|
||||||
assert cached_msg[0].file_path
|
assert cached_msg[0].file_path
|
||||||
cached_msg_file_id = cached_msg[0].file_path.stem
|
cached_msg_file_id = cached_msg[0].file_path.stem
|
||||||
self.assertEqual(len(self.message_list(self.cache_dir)), 1)
|
|
||||||
self.assert_messages_equal(cached_msg, [expected_question])
|
|
||||||
|
|
||||||
# 2. repeat the last question (without overwriting)
|
# repeat the last question (without overwriting)
|
||||||
# -> expect a single message because if the original has
|
# -> expect a single message because if the original has
|
||||||
# no answer, it should be overwritten by default
|
# no answer, it should be overwritten by default
|
||||||
self.args.ask = None
|
self.args.ask = None
|
||||||
self.args.repeat = []
|
self.args.repeat = []
|
||||||
self.args.overwrite = False
|
self.args.overwrite = False
|
||||||
self.ai.request.side_effect = self.mock_request
|
expected_response = Message(Question(message.question),
|
||||||
expected_responses = self.mock_request(expected_question,
|
Answer('Answer 0'),
|
||||||
Chat([]),
|
ai=message.ai,
|
||||||
self.args.num_answers,
|
model=message.model,
|
||||||
self.args.output_tags).messages
|
tags=message.tags,
|
||||||
|
file_path=Path('<NOT COMPARED>'))
|
||||||
question_cmd(self.args, self.config)
|
question_cmd(self.args, self.config)
|
||||||
cached_msg = chat.msg_gather(loc='cache')
|
cached_msg = chat.msg_gather(loc='cache')
|
||||||
self.assertEqual(len(self.message_list(self.cache_dir)), 1)
|
self.assertEqual(len(self.message_list(self.cache_dir)), 1)
|
||||||
self.assert_messages_equal(cached_msg, expected_responses)
|
self.assert_msgs_equal_except_file_path(cached_msg, [expected_response])
|
||||||
# also check that the file ID has not been changed
|
# also check that the file ID has not been changed
|
||||||
assert cached_msg[0].file_path
|
assert cached_msg[0].file_path
|
||||||
self.assertEqual(cached_msg_file_id, cached_msg[0].file_path.stem)
|
self.assertEqual(cached_msg_file_id, cached_msg[0].file_path.stem)
|
||||||
@@ -477,37 +444,132 @@ class TestQuestionCmd(TestQuestionCmdBase):
|
|||||||
"""
|
"""
|
||||||
Repeat a single question with new arguments.
|
Repeat a single question with new arguments.
|
||||||
"""
|
"""
|
||||||
# 1. ask a question
|
mock_create_ai.side_effect = self.mock_create_ai
|
||||||
mock_create_ai.return_value = self.ai
|
# create a message
|
||||||
expected_question = self.input_message(self.args)
|
message = Message(Question(self.args.ask[0]),
|
||||||
expected_responses = self.mock_request(expected_question,
|
Answer('Old Answer'),
|
||||||
Chat([]),
|
tags=set(self.args.output_tags),
|
||||||
self.args.num_answers,
|
ai=self.args.AI,
|
||||||
self.args.output_tags).messages
|
model=self.args.model,
|
||||||
question_cmd(self.args, self.config)
|
file_path=Path(self.cache_dir.name) / '0001.txt')
|
||||||
|
message.to_file()
|
||||||
chat = ChatDB.from_dir(Path(self.cache_dir.name),
|
chat = ChatDB.from_dir(Path(self.cache_dir.name),
|
||||||
Path(self.db_dir.name))
|
Path(self.db_dir.name))
|
||||||
cached_msg = chat.msg_gather(loc='cache')
|
cached_msg = chat.msg_gather(loc='cache')
|
||||||
self.assertEqual(len(self.message_list(self.cache_dir)), 1)
|
assert cached_msg[0].file_path
|
||||||
self.assert_messages_equal(cached_msg, expected_responses)
|
|
||||||
|
|
||||||
# 2. repeat the last question with new arguments (without overwriting)
|
# repeat the last question with new arguments (without overwriting)
|
||||||
# -> expect two messages with identical question and answer, but different metadata
|
# -> expect two messages with identical question but different metadata and new answer
|
||||||
self.args.ask = None
|
self.args.ask = None
|
||||||
self.args.repeat = []
|
self.args.repeat = []
|
||||||
self.args.overwrite = False
|
self.args.overwrite = False
|
||||||
self.args.output_tags = ['newtag']
|
self.args.output_tags = ['newtag']
|
||||||
self.args.AI = 'newai'
|
self.args.AI = 'newai'
|
||||||
self.args.model = 'newmodel'
|
self.args.model = 'newmodel'
|
||||||
new_expected_question = Message(question=Question(expected_question.question),
|
new_expected_response = Message(Question(message.question),
|
||||||
tags=set(self.args.output_tags),
|
Answer('Answer 0'),
|
||||||
ai=self.args.AI,
|
ai='newai',
|
||||||
model=self.args.model)
|
model='newmodel',
|
||||||
expected_responses += self.mock_request(new_expected_question,
|
tags={Tag('newtag')},
|
||||||
Chat([]),
|
file_path=Path('<NOT COMPARED>'))
|
||||||
self.args.num_answers,
|
|
||||||
set(self.args.output_tags)).messages
|
|
||||||
question_cmd(self.args, self.config)
|
question_cmd(self.args, self.config)
|
||||||
cached_msg = chat.msg_gather(loc='cache')
|
cached_msg = chat.msg_gather(loc='cache')
|
||||||
self.assertEqual(len(self.message_list(self.cache_dir)), 2)
|
self.assertEqual(len(self.message_list(self.cache_dir)), 2)
|
||||||
self.assert_messages_equal(cached_msg, expected_responses)
|
self.assert_msgs_equal_except_file_path(cached_msg, [message] + [new_expected_response])
|
||||||
|
|
||||||
|
@mock.patch('chatmastermind.commands.question.create_ai')
|
||||||
|
def test_repeat_single_question_new_args_overwrite(self, mock_create_ai: MagicMock) -> None:
|
||||||
|
"""
|
||||||
|
Repeat a single question with new arguments, overwriting the old one.
|
||||||
|
"""
|
||||||
|
mock_create_ai.side_effect = self.mock_create_ai
|
||||||
|
# create a message
|
||||||
|
message = Message(Question(self.args.ask[0]),
|
||||||
|
Answer('Old Answer'),
|
||||||
|
tags=set(self.args.output_tags),
|
||||||
|
ai=self.args.AI,
|
||||||
|
model=self.args.model,
|
||||||
|
file_path=Path(self.cache_dir.name) / '0001.txt')
|
||||||
|
message.to_file()
|
||||||
|
chat = ChatDB.from_dir(Path(self.cache_dir.name),
|
||||||
|
Path(self.db_dir.name))
|
||||||
|
cached_msg = chat.msg_gather(loc='cache')
|
||||||
|
assert cached_msg[0].file_path
|
||||||
|
|
||||||
|
# repeat the last question with new arguments
|
||||||
|
self.args.ask = None
|
||||||
|
self.args.repeat = []
|
||||||
|
self.args.overwrite = True
|
||||||
|
self.args.output_tags = ['newtag']
|
||||||
|
self.args.AI = 'newai'
|
||||||
|
self.args.model = 'newmodel'
|
||||||
|
new_expected_response = Message(Question(message.question),
|
||||||
|
Answer('Answer 0'),
|
||||||
|
ai='newai',
|
||||||
|
model='newmodel',
|
||||||
|
tags={Tag('newtag')},
|
||||||
|
file_path=Path('<NOT COMPARED>'))
|
||||||
|
question_cmd(self.args, self.config)
|
||||||
|
cached_msg = chat.msg_gather(loc='cache')
|
||||||
|
self.assertEqual(len(self.message_list(self.cache_dir)), 1)
|
||||||
|
self.assert_msgs_equal_except_file_path(cached_msg, [new_expected_response])
|
||||||
|
|
||||||
|
@mock.patch('chatmastermind.commands.question.create_ai')
|
||||||
|
def test_repeat_multiple_questions(self, mock_create_ai: MagicMock) -> None:
|
||||||
|
"""
|
||||||
|
Repeat multiple questions.
|
||||||
|
"""
|
||||||
|
mock_create_ai.side_effect = self.mock_create_ai
|
||||||
|
# 1. === create three questions ===
|
||||||
|
# cached message without an answer
|
||||||
|
message1 = Message(Question(self.args.ask[0]),
|
||||||
|
tags=self.args.output_tags,
|
||||||
|
ai=self.args.AI,
|
||||||
|
model=self.args.model,
|
||||||
|
file_path=Path(self.cache_dir.name) / '0001.txt')
|
||||||
|
# cached message with an answer
|
||||||
|
message2 = Message(Question(self.args.ask[0]),
|
||||||
|
Answer('Old Answer'),
|
||||||
|
tags=self.args.output_tags,
|
||||||
|
ai=self.args.AI,
|
||||||
|
model=self.args.model,
|
||||||
|
file_path=Path(self.cache_dir.name) / '0002.txt')
|
||||||
|
# DB message without an answer
|
||||||
|
message3 = Message(Question(self.args.ask[0]),
|
||||||
|
tags=self.args.output_tags,
|
||||||
|
ai=self.args.AI,
|
||||||
|
model=self.args.model,
|
||||||
|
file_path=Path(self.db_dir.name) / '0003.txt')
|
||||||
|
message1.to_file()
|
||||||
|
message2.to_file()
|
||||||
|
message3.to_file()
|
||||||
|
questions = [message1, message2, message3]
|
||||||
|
expected_responses: list[Message] = []
|
||||||
|
fake_ai = self.mock_create_ai(self.args, self.config)
|
||||||
|
for question in questions:
|
||||||
|
# since the message's answer is modified, we use a copy
|
||||||
|
# -> the original is used for comparison below
|
||||||
|
expected_responses += fake_ai.request(copy(question),
|
||||||
|
Chat([]),
|
||||||
|
self.args.num_answers,
|
||||||
|
set(self.args.output_tags)).messages
|
||||||
|
|
||||||
|
# 2. === repeat all three questions (without overwriting) ===
|
||||||
|
self.args.ask = None
|
||||||
|
self.args.repeat = ['0001', '0002', '0003']
|
||||||
|
self.args.overwrite = False
|
||||||
|
question_cmd(self.args, self.config)
|
||||||
|
# two new files should be in the cache directory
|
||||||
|
# * the repeated cached message with answer
|
||||||
|
# * the repeated DB message
|
||||||
|
# -> the cached message without answer should be overwritten
|
||||||
|
self.assertEqual(len(self.message_list(self.cache_dir)), 4)
|
||||||
|
self.assertEqual(len(self.message_list(self.db_dir)), 1)
|
||||||
|
expected_cache_messages = [expected_responses[0], message2, expected_responses[1], expected_responses[2]]
|
||||||
|
chat = ChatDB.from_dir(Path(self.cache_dir.name),
|
||||||
|
Path(self.db_dir.name))
|
||||||
|
cached_msg = chat.msg_gather(loc='cache')
|
||||||
|
self.assert_msgs_equal_except_file_path(cached_msg, expected_cache_messages)
|
||||||
|
# check that the DB message has not been modified at all
|
||||||
|
db_msg = chat.msg_gather(loc='db')
|
||||||
|
self.assert_msgs_all_equal(db_msg, [message3])
|
||||||
|
|||||||
Reference in New Issue
Block a user