question_cmd: fixed AI and model arguments when repeating messages

This commit is contained in:
2023-09-21 18:21:43 +02:00
parent e9175aface
commit 0657a1bab8
+30 -4
View File
@@ -2,6 +2,7 @@ import sys
import argparse
from pathlib import Path
from itertools import zip_longest
from copy import deepcopy
from ..configuration import Config
from ..chat import ChatDB
from ..message import Message, MessageFilter, MessageError, Question, source_code
@@ -105,11 +106,37 @@ def make_request(ai: AI, chat: ChatDB, message: Message, args: argparse.Namespac
print(response.tokens)
def repeat_messages(messages: list[Message], ai: AI, chat: ChatDB, args: argparse.Namespace) -> None:
def make_msg_args(msg: Message, args: argparse.Namespace) -> argparse.Namespace:
"""
Takes an existing message and CLI arguments, and returns modified args based
on the members of the given message. Used e.g. when repeating messages, where
it's necessary to determine the correct AI, module and output tags to use
(either from the existing message or the given args).
"""
msg_args = args
# if AI, model or output tags have not been specified,
# use those from the original message
if (args.AI is None
or args.model is None # noqa: W503
or args.output_tags is None # noqa: W503
or len(args.output_tags) == 0): # noqa: W503
msg_args = deepcopy(args)
if args.AI is None and msg.ai is not None:
msg_args.AI = msg.ai
if args.model is None and msg.model is not None:
msg_args.model = msg.model
if (args.output_tags is None or len(args.output_tags) == 0) and msg.tags is not None:
msg_args.output_tags = msg.tags
return msg_args
def repeat_messages(messages: list[Message], chat: ChatDB, args: argparse.Namespace, config: Config) -> None:
"""
Repeat the given messages using the given arguments.
"""
ai: AI
for msg in messages:
ai = create_ai(make_msg_args(msg, args), config)
print(f"--------- Repeating message '{msg.msg_id()}': ---------")
# overwrite the latest message if requested or empty
# -> but not if it's in the DB!
@@ -139,11 +166,10 @@ def question_cmd(args: argparse.Namespace, config: Config) -> None:
message = create_message(chat, args)
if args.create:
return
# create the correct AI instance
ai: AI = create_ai(args, config)
# === ASK ===
if args.ask:
ai: AI = create_ai(args, config)
make_request(ai, chat, message, args)
# === REPEAT ===
elif args.repeat is not None:
@@ -158,7 +184,7 @@ def question_cmd(args: argparse.Namespace, config: Config) -> None:
# repeat given message(s)
else:
repeat_msgs = chat.msg_find(args.repeat, loc='disk')
repeat_messages(repeat_msgs, ai, chat, args)
repeat_messages(repeat_msgs, chat, args, config)
# === PROCESS ===
elif args.process is not None:
# TODO: process either all questions without an