question_cmd: fixed source code extraction and added a testcase

This commit is contained in:
2023-09-09 18:28:10 +02:00
parent 39b518a8a6
commit 53582a7123
4 changed files with 72 additions and 28 deletions
+12 -5
View File
@@ -3,7 +3,7 @@ from pathlib import Path
from itertools import zip_longest
from ..configuration import Config
from ..chat import ChatDB
from ..message import Message, Question
from ..message import Message, Question, source_code
from ..ai_factory import create_ai
from ..ai import AI, AIResponse
@@ -14,10 +14,10 @@ def create_message(chat: ChatDB, args: argparse.Namespace) -> Message:
"""
question_parts = []
question_list = args.ask if args.ask is not None else []
source_list = args.source if args.source is not None else []
code_list = args.source_code if args.source_code is not None else []
text_files = args.source_text if args.source_text is not None else []
code_files = args.source_code if args.source_code is not None else []
for question, source, code in zip_longest(question_list, source_list, code_list, fillvalue=None):
for question, source, code in zip_longest(question_list, text_files, code_files, fillvalue=None):
if question is not None and len(question.strip()) > 0:
question_parts.append(question)
if source is not None and len(source) > 0:
@@ -28,7 +28,14 @@ def create_message(chat: ChatDB, args: argparse.Namespace) -> Message:
if code is not None and len(code) > 0:
with open(code) as r:
content = r.read().strip()
if len(content) > 0:
if len(content) == 0:
continue
# try to extract and add source code
code_parts = source_code(content, include_delims=True)
if len(code_parts) > 0:
question_parts += code_parts
# if there's none, add the whole file
else:
question_parts.append(f"```\n{content}\n```")
full_question = '\n\n'.join(question_parts)
+1 -1
View File
@@ -67,7 +67,7 @@ def create_parser() -> argparse.ArgumentParser:
question_group.add_argument('-p', '--process', nargs='*', help='Process existing questions')
question_cmd_parser.add_argument('-O', '--overwrite', help='Overwrite existing messages when repeating them',
action='store_true')
question_cmd_parser.add_argument('-s', '--source', nargs='+', help='Add content of a file to the query')
question_cmd_parser.add_argument('-s', '--source-text', nargs='+', help='Add content of a file to the query')
question_cmd_parser.add_argument('-S', '--source-code', nargs='+', help='Add source code file content to the chat history')
# 'hist' command parser
+1 -1
View File
@@ -414,7 +414,7 @@ class Message():
return '\n'.join(output)
def __str__(self) -> str:
return self.to_str(False, False, False)
return self.to_str(True, True, False)
def to_file(self, file_path: Optional[pathlib.Path]=None) -> None: # noqa: 11
"""