question_cmd: fixed source code extraction and added a testcase
This commit is contained in:
@@ -3,7 +3,7 @@ from pathlib import Path
|
||||
from itertools import zip_longest
|
||||
from ..configuration import Config
|
||||
from ..chat import ChatDB
|
||||
from ..message import Message, Question
|
||||
from ..message import Message, Question, source_code
|
||||
from ..ai_factory import create_ai
|
||||
from ..ai import AI, AIResponse
|
||||
|
||||
@@ -14,10 +14,10 @@ def create_message(chat: ChatDB, args: argparse.Namespace) -> Message:
|
||||
"""
|
||||
question_parts = []
|
||||
question_list = args.ask if args.ask is not None else []
|
||||
source_list = args.source if args.source is not None else []
|
||||
code_list = args.source_code if args.source_code is not None else []
|
||||
text_files = args.source_text if args.source_text is not None else []
|
||||
code_files = args.source_code if args.source_code is not None else []
|
||||
|
||||
for question, source, code in zip_longest(question_list, source_list, code_list, fillvalue=None):
|
||||
for question, source, code in zip_longest(question_list, text_files, code_files, fillvalue=None):
|
||||
if question is not None and len(question.strip()) > 0:
|
||||
question_parts.append(question)
|
||||
if source is not None and len(source) > 0:
|
||||
@@ -28,7 +28,14 @@ def create_message(chat: ChatDB, args: argparse.Namespace) -> Message:
|
||||
if code is not None and len(code) > 0:
|
||||
with open(code) as r:
|
||||
content = r.read().strip()
|
||||
if len(content) > 0:
|
||||
if len(content) == 0:
|
||||
continue
|
||||
# try to extract and add source code
|
||||
code_parts = source_code(content, include_delims=True)
|
||||
if len(code_parts) > 0:
|
||||
question_parts += code_parts
|
||||
# if there's none, add the whole file
|
||||
else:
|
||||
question_parts.append(f"```\n{content}\n```")
|
||||
|
||||
full_question = '\n\n'.join(question_parts)
|
||||
|
||||
@@ -67,7 +67,7 @@ def create_parser() -> argparse.ArgumentParser:
|
||||
question_group.add_argument('-p', '--process', nargs='*', help='Process existing questions')
|
||||
question_cmd_parser.add_argument('-O', '--overwrite', help='Overwrite existing messages when repeating them',
|
||||
action='store_true')
|
||||
question_cmd_parser.add_argument('-s', '--source', nargs='+', help='Add content of a file to the query')
|
||||
question_cmd_parser.add_argument('-s', '--source-text', nargs='+', help='Add content of a file to the query')
|
||||
question_cmd_parser.add_argument('-S', '--source-code', nargs='+', help='Add source code file content to the chat history')
|
||||
|
||||
# 'hist' command parser
|
||||
|
||||
@@ -414,7 +414,7 @@ class Message():
|
||||
return '\n'.join(output)
|
||||
|
||||
def __str__(self) -> str:
|
||||
return self.to_str(False, False, False)
|
||||
return self.to_str(True, True, False)
|
||||
|
||||
def to_file(self, file_path: Optional[pathlib.Path]=None) -> None: # noqa: 11
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user