'cmm question --repeat' supports multiple questions, added tests and fixes #15
@@ -2,6 +2,7 @@ import sys
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
from itertools import zip_longest
|
||||
from copy import deepcopy
|
||||
from ..configuration import Config
|
||||
from ..chat import ChatDB
|
||||
from ..message import Message, MessageFilter, MessageError, Question, source_code
|
||||
@@ -105,11 +106,37 @@ def make_request(ai: AI, chat: ChatDB, message: Message, args: argparse.Namespac
|
||||
print(response.tokens)
|
||||
|
||||
|
||||
def repeat_messages(messages: list[Message], ai: AI, chat: ChatDB, args: argparse.Namespace) -> None:
|
||||
def make_msg_args(msg: Message, args: argparse.Namespace) -> argparse.Namespace:
|
||||
"""
|
||||
Takes an existing message and CLI arguments, and returns modified args based
|
||||
on the members of the given message. Used e.g. when repeating messages, where
|
||||
it's necessary to determine the correct AI, module and output tags to use
|
||||
(either from the existing message or the given args).
|
||||
"""
|
||||
msg_args = args
|
||||
# if AI, model or output tags have not been specified,
|
||||
# use those from the original message
|
||||
if (args.AI is None
|
||||
or args.model is None # noqa: W503
|
||||
or args.output_tags is None # noqa: W503
|
||||
or len(args.output_tags) == 0): # noqa: W503
|
||||
msg_args = deepcopy(args)
|
||||
if args.AI is None and msg.ai is not None:
|
||||
msg_args.AI = msg.ai
|
||||
if args.model is None and msg.model is not None:
|
||||
msg_args.model = msg.model
|
||||
if (args.output_tags is None or len(args.output_tags) == 0) and msg.tags is not None:
|
||||
msg_args.output_tags = msg.tags
|
||||
return msg_args
|
||||
|
||||
|
||||
def repeat_messages(messages: list[Message], chat: ChatDB, args: argparse.Namespace, config: Config) -> None:
|
||||
"""
|
||||
Repeat the given messages using the given arguments.
|
||||
"""
|
||||
ai: AI
|
||||
for msg in messages:
|
||||
ai = create_ai(make_msg_args(msg, args), config)
|
||||
print(f"--------- Repeating message '{msg.msg_id()}': ---------")
|
||||
# overwrite the latest message if requested or empty
|
||||
# -> but not if it's in the DB!
|
||||
@@ -139,11 +166,10 @@ def question_cmd(args: argparse.Namespace, config: Config) -> None:
|
||||
message = create_message(chat, args)
|
||||
if args.create:
|
||||
return
|
||||
# create the correct AI instance
|
||||
ai: AI = create_ai(args, config)
|
||||
|
||||
# === ASK ===
|
||||
if args.ask:
|
||||
ai: AI = create_ai(args, config)
|
||||
make_request(ai, chat, message, args)
|
||||
# === REPEAT ===
|
||||
elif args.repeat is not None:
|
||||
@@ -158,7 +184,7 @@ def question_cmd(args: argparse.Namespace, config: Config) -> None:
|
||||
# repeat given message(s)
|
||||
else:
|
||||
repeat_msgs = chat.msg_find(args.repeat, loc='disk')
|
||||
repeat_messages(repeat_msgs, ai, chat, args)
|
||||
repeat_messages(repeat_msgs, chat, args, config)
|
||||
# === PROCESS ===
|
||||
elif args.process is not None:
|
||||
# TODO: process either all questions without an
|
||||
|
||||
Reference in New Issue
Block a user