Compare commits
5 Commits
repeat
..
0ee5645853
| Author | SHA1 | Date | |
|---|---|---|---|
| 0ee5645853 | |||
| ba16e325d1 | |||
| 42723d40ed | |||
| 4538624247 | |||
| f964c5471e |
@@ -105,6 +105,35 @@ def make_request(ai: AI, chat: ChatDB, message: Message, args: argparse.Namespac
|
|||||||
print(response.tokens)
|
print(response.tokens)
|
||||||
|
|
||||||
|
|
||||||
|
def repeat_messages(messages: list[Message], chat: ChatDB, args: argparse.Namespace, config: Config) -> None:
|
||||||
|
"""
|
||||||
|
Repeat the given messages using the given arguments.
|
||||||
|
"""
|
||||||
|
ai: AI
|
||||||
|
for msg in messages:
|
||||||
|
ai_args = args
|
||||||
|
# if AI or model have not been specified, use those from the original message
|
||||||
|
if args.AI is None or args.model is None:
|
||||||
|
ai_args = args.copy()
|
||||||
|
if args.AI is None and msg.ai is not None:
|
||||||
|
ai_args.AI = msg.ai
|
||||||
|
if args.model is None and msg.model is not None:
|
||||||
|
ai_args.model = msg.model
|
||||||
|
ai = create_ai(ai_args, config)
|
||||||
|
print(f"--------- Repeating message '{msg.msg_id()}': ---------")
|
||||||
|
# overwrite the latest message if requested or empty
|
||||||
|
# -> but not if it's in the DB!
|
||||||
|
if ((msg.answer is None or args.overwrite is True)
|
||||||
|
and (not chat.msg_in_db(msg))): # noqa: W503
|
||||||
|
msg.clear_answer()
|
||||||
|
make_request(ai, chat, msg, args)
|
||||||
|
# otherwise create a new one
|
||||||
|
else:
|
||||||
|
args.ask = [msg.question]
|
||||||
|
message = create_message(chat, args)
|
||||||
|
make_request(ai, chat, message, args)
|
||||||
|
|
||||||
|
|
||||||
def question_cmd(args: argparse.Namespace, config: Config) -> None:
|
def question_cmd(args: argparse.Namespace, config: Config) -> None:
|
||||||
"""
|
"""
|
||||||
Handler for the 'question' command.
|
Handler for the 'question' command.
|
||||||
@@ -121,30 +150,24 @@ def question_cmd(args: argparse.Namespace, config: Config) -> None:
|
|||||||
if args.create:
|
if args.create:
|
||||||
return
|
return
|
||||||
|
|
||||||
# create the correct AI instance
|
|
||||||
ai: AI = create_ai(args, config)
|
|
||||||
|
|
||||||
# === ASK ===
|
# === ASK ===
|
||||||
if args.ask:
|
if args.ask:
|
||||||
|
ai: AI = create_ai(args, config)
|
||||||
make_request(ai, chat, message, args)
|
make_request(ai, chat, message, args)
|
||||||
# === REPEAT ===
|
# === REPEAT ===
|
||||||
elif args.repeat is not None:
|
elif args.repeat is not None:
|
||||||
|
repeat_msgs: list[Message] = []
|
||||||
|
# repeat latest message
|
||||||
|
if len(args.repeat) == 0:
|
||||||
lmessage = chat.msg_latest(loc='cache')
|
lmessage = chat.msg_latest(loc='cache')
|
||||||
if lmessage is None:
|
if lmessage is None:
|
||||||
print("No message found to repeat!")
|
print("No message found to repeat!")
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
repeat_msgs.append(lmessage)
|
||||||
|
# repeat given message(s)
|
||||||
else:
|
else:
|
||||||
print(f"Repeating message '{lmessage.msg_id()}':")
|
repeat_msgs = chat.msg_find(args.repeat, loc='disk')
|
||||||
# overwrite the latest message if requested or empty
|
repeat_messages(repeat_msgs, chat, args, config)
|
||||||
if lmessage.answer is None or args.overwrite is True:
|
|
||||||
lmessage.clear_answer()
|
|
||||||
make_request(ai, chat, lmessage, args)
|
|
||||||
# otherwise create a new one
|
|
||||||
else:
|
|
||||||
args.ask = [lmessage.question]
|
|
||||||
message = create_message(chat, args)
|
|
||||||
make_request(ai, chat, message, args)
|
|
||||||
|
|
||||||
# === PROCESS ===
|
# === PROCESS ===
|
||||||
elif args.process is not None:
|
elif args.process is not None:
|
||||||
# TODO: process either all questions without an
|
# TODO: process either all questions without an
|
||||||
|
|||||||
@@ -282,6 +282,9 @@ class TestQuestionCmd(TestQuestionCmdBase):
|
|||||||
# exclude '.next'
|
# exclude '.next'
|
||||||
return sorted([f for f in Path(tmp_dir.name).glob('*.[ty]*')])
|
return sorted([f for f in Path(tmp_dir.name).glob('*.[ty]*')])
|
||||||
|
|
||||||
|
|
||||||
|
class TestQuestionCmdAsk(TestQuestionCmd):
|
||||||
|
|
||||||
@mock.patch('chatmastermind.commands.question.create_ai')
|
@mock.patch('chatmastermind.commands.question.create_ai')
|
||||||
def test_ask_single_answer(self, mock_create_ai: MagicMock) -> None:
|
def test_ask_single_answer(self, mock_create_ai: MagicMock) -> None:
|
||||||
"""
|
"""
|
||||||
@@ -370,6 +373,9 @@ class TestQuestionCmd(TestQuestionCmdBase):
|
|||||||
self.assertEqual(len(self.message_list(self.cache_dir)), 1)
|
self.assertEqual(len(self.message_list(self.cache_dir)), 1)
|
||||||
self.assert_messages_equal(cached_msg, [expected_question])
|
self.assert_messages_equal(cached_msg, [expected_question])
|
||||||
|
|
||||||
|
|
||||||
|
class TestQuestionCmdRepeat(TestQuestionCmd):
|
||||||
|
|
||||||
@mock.patch('chatmastermind.commands.question.create_ai')
|
@mock.patch('chatmastermind.commands.question.create_ai')
|
||||||
def test_repeat_single_question(self, mock_create_ai: MagicMock) -> None:
|
def test_repeat_single_question(self, mock_create_ai: MagicMock) -> None:
|
||||||
"""
|
"""
|
||||||
@@ -511,3 +517,44 @@ class TestQuestionCmd(TestQuestionCmdBase):
|
|||||||
cached_msg = chat.msg_gather(loc='cache')
|
cached_msg = chat.msg_gather(loc='cache')
|
||||||
self.assertEqual(len(self.message_list(self.cache_dir)), 2)
|
self.assertEqual(len(self.message_list(self.cache_dir)), 2)
|
||||||
self.assert_messages_equal(cached_msg, expected_responses)
|
self.assert_messages_equal(cached_msg, expected_responses)
|
||||||
|
|
||||||
|
@mock.patch('chatmastermind.commands.question.create_ai')
|
||||||
|
def test_repeat_multiple_questions(self, mock_create_ai: MagicMock) -> None:
|
||||||
|
"""
|
||||||
|
Repeat multiple questions.
|
||||||
|
"""
|
||||||
|
# 1. create some questions / messages
|
||||||
|
# cached message without an answer
|
||||||
|
message1 = Message(Question('Question 1'),
|
||||||
|
ai='foo',
|
||||||
|
model='bla',
|
||||||
|
file_path=Path(self.cache_dir.name) / '0001.txt')
|
||||||
|
# cached message with an answer
|
||||||
|
message2 = Message(Question('Question 2'),
|
||||||
|
Answer('Answer 2'),
|
||||||
|
ai='openai',
|
||||||
|
model='gpt-3.5-turbo',
|
||||||
|
file_path=Path(self.cache_dir.name) / '0002.txt')
|
||||||
|
# DB message without an answer
|
||||||
|
message3 = Message(Question('Question 3'),
|
||||||
|
ai='openai',
|
||||||
|
model='gpt-3.5-turbo',
|
||||||
|
file_path=Path(self.db_dir.name) / '0003.txt')
|
||||||
|
message1.to_file()
|
||||||
|
message2.to_file()
|
||||||
|
message3.to_file()
|
||||||
|
# chat = ChatDB.from_dir(Path(self.cache_dir.name),
|
||||||
|
# Path(self.db_dir.name))
|
||||||
|
|
||||||
|
# 2. repeat all three questions (without overwriting)
|
||||||
|
self.args.ask = None
|
||||||
|
self.args.repeat = ['0001', '0002', '0003']
|
||||||
|
self.args.overwrite = False
|
||||||
|
question_cmd(self.args, self.config)
|
||||||
|
# two new files should be in the cache directory
|
||||||
|
# * the repeated cached message with answer
|
||||||
|
# * the repeated DB message
|
||||||
|
# -> the cached message wihtout answer should be overwritten
|
||||||
|
self.assertEqual(len(self.message_list(self.cache_dir)), 4)
|
||||||
|
self.assertEqual(len(self.message_list(self.db_dir)), 1)
|
||||||
|
# FIXME: also compare actual content!
|
||||||
|
|||||||
Reference in New Issue
Block a user