From 0657a1bab8c7953860a632c8f27be42abbf8688a Mon Sep 17 00:00:00 2001 From: juk0de Date: Thu, 21 Sep 2023 18:21:43 +0200 Subject: [PATCH] question_cmd: fixed AI and model arguments when repeating messages --- chatmastermind/commands/question.py | 34 +++++++++++++++++++++++++---- 1 file changed, 30 insertions(+), 4 deletions(-) diff --git a/chatmastermind/commands/question.py b/chatmastermind/commands/question.py index e35bfe5..bc4a8c4 100644 --- a/chatmastermind/commands/question.py +++ b/chatmastermind/commands/question.py @@ -2,6 +2,7 @@ import sys import argparse from pathlib import Path from itertools import zip_longest +from copy import deepcopy from ..configuration import Config from ..chat import ChatDB from ..message import Message, MessageFilter, MessageError, Question, source_code @@ -105,11 +106,37 @@ def make_request(ai: AI, chat: ChatDB, message: Message, args: argparse.Namespac print(response.tokens) -def repeat_messages(messages: list[Message], ai: AI, chat: ChatDB, args: argparse.Namespace) -> None: +def make_msg_args(msg: Message, args: argparse.Namespace) -> argparse.Namespace: + """ + Takes an existing message and CLI arguments, and returns modified args based + on the members of the given message. Used e.g. when repeating messages, where + it's necessary to determine the correct AI, module and output tags to use + (either from the existing message or the given args). + """ + msg_args = args + # if AI, model or output tags have not been specified, + # use those from the original message + if (args.AI is None + or args.model is None # noqa: W503 + or args.output_tags is None # noqa: W503 + or len(args.output_tags) == 0): # noqa: W503 + msg_args = deepcopy(args) + if args.AI is None and msg.ai is not None: + msg_args.AI = msg.ai + if args.model is None and msg.model is not None: + msg_args.model = msg.model + if (args.output_tags is None or len(args.output_tags) == 0) and msg.tags is not None: + msg_args.output_tags = msg.tags + return msg_args + + +def repeat_messages(messages: list[Message], chat: ChatDB, args: argparse.Namespace, config: Config) -> None: """ Repeat the given messages using the given arguments. """ + ai: AI for msg in messages: + ai = create_ai(make_msg_args(msg, args), config) print(f"--------- Repeating message '{msg.msg_id()}': ---------") # overwrite the latest message if requested or empty # -> but not if it's in the DB! @@ -139,11 +166,10 @@ def question_cmd(args: argparse.Namespace, config: Config) -> None: message = create_message(chat, args) if args.create: return - # create the correct AI instance - ai: AI = create_ai(args, config) # === ASK === if args.ask: + ai: AI = create_ai(args, config) make_request(ai, chat, message, args) # === REPEAT === elif args.repeat is not None: @@ -158,7 +184,7 @@ def question_cmd(args: argparse.Namespace, config: Config) -> None: # repeat given message(s) else: repeat_msgs = chat.msg_find(args.repeat, loc='disk') - repeat_messages(repeat_msgs, ai, chat, args) + repeat_messages(repeat_msgs, chat, args, config) # === PROCESS === elif args.process is not None: # TODO: process either all questions without an