Working question -a #7

Closed
ok wants to merge 6 commits from restructurings.ok into restructurings
8 changed files with 111 additions and 61 deletions
+6
View File
@@ -66,3 +66,9 @@ class AI(Protocol):
and is not implemented for all AIs. and is not implemented for all AIs.
""" """
raise NotImplementedError raise NotImplementedError
def print(self) -> None:
"""
Print some info about the current AI, like system message.
"""
pass
+15 -6
View File
@@ -43,16 +43,20 @@ class OpenAI(AI):
n=num_answers, n=num_answers,
frequency_penalty=self.config.frequency_penalty, frequency_penalty=self.config.frequency_penalty,
presence_penalty=self.config.presence_penalty) presence_penalty=self.config.presence_penalty)
answers: list[Message] = [] question.answer = Answer(response['choices'][0]['message']['content'])
for choice in response['choices']: # type: ignore question.tags = otags
question.ai = self.ID
question.model = self.config.model
answers: list[Message] = [question]
for choice in response['choices'][1:]: # type: ignore
answers.append(Message(question=question.question, answers.append(Message(question=question.question,
answer=Answer(choice['message']['content']), answer=Answer(choice['message']['content']),
tags=otags, tags=otags,
ai=self.name, ai=self.ID,
model=self.config.model)) model=self.config.model))
return AIResponse(answers, Tokens(response['usage']['prompt'], return AIResponse(answers, Tokens(response['usage']['prompt_tokens'],
response['usage']['completion'], response['usage']['completion_tokens'],
response['usage']['total'])) response['usage']['total_tokens']))
def models(self) -> list[str]: def models(self) -> list[str]:
""" """
@@ -95,3 +99,8 @@ class OpenAI(AI):
def tokens(self, data: Union[Message, Chat]) -> int: def tokens(self, data: Union[Message, Chat]) -> int:
raise NotImplementedError raise NotImplementedError
def print(self) -> None:
print(f"MODEL: {self.config.model}")
print("=== SYSTEM ===")
print(self.config.system)
+1 -3
View File
@@ -201,7 +201,7 @@ class Chat:
output.append(message.to_str(source_code_only=True)) output.append(message.to_str(source_code_only=True))
continue continue
output.append(message.to_str(with_tags, with_files)) output.append(message.to_str(with_tags, with_files))
output.append('\n' + ('-' * terminal_width()) + '\n') # output.append('\n' + ('-' * terminal_width()) + '\n')
if paged: if paged:
print_paged('\n'.join(output)) print_paged('\n'.join(output))
else: else:
@@ -361,8 +361,6 @@ class ChatDB(Chat):
Add the given new messages and set the file_path to the cache directory. Add the given new messages and set the file_path to the cache directory.
Only accepts messages without a file_path. Only accepts messages without a file_path.
""" """
if any(m.file_path is not None for m in messages):
raise ChatError("Can't add new messages with existing file_path")
if write: if write:
write_dir(self.cache_path, write_dir(self.cache_path,
messages, messages,
+28 -13
View File
@@ -3,7 +3,7 @@ from pathlib import Path
from itertools import zip_longest from itertools import zip_longest
from ..configuration import Config from ..configuration import Config
from ..chat import ChatDB from ..chat import ChatDB
from ..message import Message, Question from ..message import Message, MessageFilter, Question, source_code
from ..ai_factory import create_ai from ..ai_factory import create_ai
from ..ai import AI, AIResponse from ..ai import AI, AIResponse
@@ -14,10 +14,10 @@ def create_message(chat: ChatDB, args: argparse.Namespace) -> Message:
""" """
question_parts = [] question_parts = []
question_list = args.ask if args.ask is not None else [] question_list = args.ask if args.ask is not None else []
source_list = args.source if args.source is not None else [] text_files = args.source_text if args.source_text is not None else []
code_list = args.source_code if args.source_code is not None else [] code_files = args.source_code if args.source_code is not None else []
for question, source, code in zip_longest(question_list, source_list, code_list, fillvalue=None): for question, source, code in zip_longest(question_list, text_files, code_files, fillvalue=None):
if question is not None and len(question.strip()) > 0: if question is not None and len(question.strip()) > 0:
question_parts.append(question) question_parts.append(question)
if source is not None and len(source) > 0: if source is not None and len(source) > 0:
@@ -28,7 +28,14 @@ def create_message(chat: ChatDB, args: argparse.Namespace) -> Message:
if code is not None and len(code) > 0: if code is not None and len(code) > 0:
with open(code) as r: with open(code) as r:
content = r.read().strip() content = r.read().strip()
if len(content) > 0: if len(content) == 0:
continue
# try to extract and add source code
code_parts = source_code(content, include_delims=True)
if len(code_parts) > 0:
question_parts += code_parts
# if there's none, add the whole file
else:
question_parts.append(f"```\n{content}\n```") question_parts.append(f"```\n{content}\n```")
full_question = '\n\n'.join(question_parts) full_question = '\n\n'.join(question_parts)
@@ -45,8 +52,12 @@ def question_cmd(args: argparse.Namespace, config: Config) -> None:
""" """
Handler for the 'question' command. Handler for the 'question' command.
""" """
mfilter = MessageFilter(tags_or=args.or_tags,
tags_and=args.and_tags,
tags_not=args.exclude_tags)
chat = ChatDB.from_dir(cache_path=Path('.'), chat = ChatDB.from_dir(cache_path=Path('.'),
db_path=Path(config.db)) db_path=Path(config.db),
mfilter=mfilter)
# if it's a new question, create and store it immediately # if it's a new question, create and store it immediately
if args.ask or args.create: if args.ask or args.create:
message = create_message(chat, args) message = create_message(chat, args)
@@ -56,23 +67,27 @@ def question_cmd(args: argparse.Namespace, config: Config) -> None:
# create the correct AI instance # create the correct AI instance
ai: AI = create_ai(args, config) ai: AI = create_ai(args, config)
if args.ask: if args.ask:
ai.print()
chat.print(paged=False)
response: AIResponse = ai.request(message, response: AIResponse = ai.request(message,
chat, chat,
args.num_answers, # FIXME args.num_answers, # FIXME
args.output_tags) # FIXME args.output_tags) # FIXME
assert response chat.add_to_cache(response.messages)
# TODO: for idx, msg in enumerate(response.messages):
# * add answer to the message above (and create print(f"=== ANSWER {idx+1} ===")
# more messages for any additional answers) print(msg.answer)
pass if response.tokens:
elif args.repeat: print("===============")
print(response.tokens)
elif args.repeat is not None:
lmessage = chat.latest_message() lmessage = chat.latest_message()
assert lmessage assert lmessage
# TODO: repeat either the last question or the # TODO: repeat either the last question or the
# one(s) given in 'args.repeat' (overwrite # one(s) given in 'args.repeat' (overwrite
# existing ones if 'args.overwrite' is True) # existing ones if 'args.overwrite' is True)
pass pass
elif args.process: elif args.process is not None:
# TODO: process either all questions without an # TODO: process either all questions without an
# answer or the one(s) given in 'args.process' # answer or the one(s) given in 'args.process'
pass pass
+1 -1
View File
@@ -67,7 +67,7 @@ def create_parser() -> argparse.ArgumentParser:
question_group.add_argument('-p', '--process', nargs='*', help='Process existing questions') question_group.add_argument('-p', '--process', nargs='*', help='Process existing questions')
question_cmd_parser.add_argument('-O', '--overwrite', help='Overwrite existing messages when repeating them', question_cmd_parser.add_argument('-O', '--overwrite', help='Overwrite existing messages when repeating them',
action='store_true') action='store_true')
question_cmd_parser.add_argument('-s', '--source', nargs='+', help='Add content of a file to the query') question_cmd_parser.add_argument('-s', '--source-text', nargs='+', help='Add content of a file to the query')
question_cmd_parser.add_argument('-S', '--source-code', nargs='+', help='Add source code file content to the chat history') question_cmd_parser.add_argument('-S', '--source-code', nargs='+', help='Add source code file content to the chat history')
# 'hist' command parser # 'hist' command parser
+1 -1
View File
@@ -414,7 +414,7 @@ class Message():
return '\n'.join(output) return '\n'.join(output)
def __str__(self) -> str: def __str__(self) -> str:
return self.to_str(False, False, False) return self.to_str(True, True, False)
def to_file(self, file_path: Optional[pathlib.Path]=None) -> None: # noqa: 11 def to_file(self, file_path: Optional[pathlib.Path]=None) -> None: # noqa: 11
""" """
+1 -16
View File
@@ -6,7 +6,7 @@ from io import StringIO
from unittest.mock import patch from unittest.mock import patch
from chatmastermind.tags import TagLine from chatmastermind.tags import TagLine
from chatmastermind.message import Message, Question, Answer, Tag, MessageFilter from chatmastermind.message import Message, Question, Answer, Tag, MessageFilter
from chatmastermind.chat import Chat, ChatDB, terminal_width, ChatError from chatmastermind.chat import Chat, ChatDB, ChatError
class TestChat(unittest.TestCase): class TestChat(unittest.TestCase):
@@ -92,16 +92,10 @@ class TestChat(unittest.TestCase):
Question 1 Question 1
{Answer.txt_header} {Answer.txt_header}
Answer 1 Answer 1
{'-'*terminal_width()}
{Question.txt_header} {Question.txt_header}
Question 2 Question 2
{Answer.txt_header} {Answer.txt_header}
Answer 2 Answer 2
{'-'*terminal_width()}
""" """
self.assertEqual(mock_stdout.getvalue(), expected_output) self.assertEqual(mock_stdout.getvalue(), expected_output)
@@ -115,18 +109,12 @@ FILE: 0001.txt
Question 1 Question 1
{Answer.txt_header} {Answer.txt_header}
Answer 1 Answer 1
{'-'*terminal_width()}
{TagLine.prefix} btag2 {TagLine.prefix} btag2
FILE: 0002.txt FILE: 0002.txt
{Question.txt_header} {Question.txt_header}
Question 2 Question 2
{Answer.txt_header} {Answer.txt_header}
Answer 2 Answer 2
{'-'*terminal_width()}
""" """
self.assertEqual(mock_stdout.getvalue(), expected_output) self.assertEqual(mock_stdout.getvalue(), expected_output)
@@ -415,9 +403,6 @@ class TestChatDB(unittest.TestCase):
db_dir_files = self.message_list(self.db_path) db_dir_files = self.message_list(self.db_path)
self.assertEqual(len(db_dir_files), 5) self.assertEqual(len(db_dir_files), 5)
with self.assertRaises(ChatError):
chat_db.add_to_cache([Message(Question("?"), file_path=pathlib.Path("foo"))])
def test_chat_db_write_messages(self) -> None: def test_chat_db_write_messages(self) -> None:
# create a new ChatDB instance # create a new ChatDB instance
chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name), chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name),
+58 -21
View File
@@ -22,18 +22,19 @@ class TestMessageCreate(unittest.TestCase):
db_path=Path(self.db_path.name)) db_path=Path(self.db_path.name))
# create arguments mock # create arguments mock
self.args = MagicMock(spec=argparse.Namespace) self.args = MagicMock(spec=argparse.Namespace)
self.args.source = None self.args.source_text = None
self.args.source_code = None self.args.source_code = None
self.args.AI = None self.args.AI = None
self.args.model = None self.args.model = None
self.args.output_tags = None self.args.output_tags = None
# create some files for sourcing # File 1 : no source code block, only text
self.source_file1 = tempfile.NamedTemporaryFile(delete=False) self.source_file1 = tempfile.NamedTemporaryFile(delete=False)
self.source_file1_content = """This is just text. self.source_file1_content = """This is just text.
No source code. No source code.
Nope. Go look elsewhere!""" Nope. Go look elsewhere!"""
with open(self.source_file1.name, 'w') as f: with open(self.source_file1.name, 'w') as f:
f.write(self.source_file1_content) f.write(self.source_file1_content)
# File 2 : one embedded source code block
self.source_file2 = tempfile.NamedTemporaryFile(delete=False) self.source_file2 = tempfile.NamedTemporaryFile(delete=False)
self.source_file2_content = """This is just text. self.source_file2_content = """This is just text.
``` ```
@@ -42,12 +43,26 @@ This is embedded source code.
And some text again.""" And some text again."""
with open(self.source_file2.name, 'w') as f: with open(self.source_file2.name, 'w') as f:
f.write(self.source_file2_content) f.write(self.source_file2_content)
# File 3 : all source code
self.source_file3 = tempfile.NamedTemporaryFile(delete=False) self.source_file3 = tempfile.NamedTemporaryFile(delete=False)
self.source_file3_content = """This is all source code. self.source_file3_content = """This is all source code.
Yes, really. Yes, really.
Language is called 'brainfart'.""" Language is called 'brainfart'."""
with open(self.source_file3.name, 'w') as f: with open(self.source_file3.name, 'w') as f:
f.write(self.source_file3_content) f.write(self.source_file3_content)
# File 4 : two source code blocks
self.source_file4 = tempfile.NamedTemporaryFile(delete=False)
self.source_file4_content = """This is just text.
```
This is embedded source code.
```
And some text again.
```
This is embedded source code.
```
Aaaand again some text."""
with open(self.source_file4.name, 'w') as f:
f.write(self.source_file4_content)
def tearDown(self) -> None: def tearDown(self) -> None:
os.remove(self.source_file1.name) os.remove(self.source_file1.name)
@@ -86,40 +101,62 @@ Language is called 'brainfart'."""
Is it good?""")) Is it good?"""))
def test_single_question_with_text_only_source(self) -> None: def test_single_question_with_text_only_file(self) -> None:
self.args.ask = ["What is this?"] self.args.ask = ["What is this?"]
self.args.source = [f"{self.source_file1.name}"] self.args.source_text = [f"{self.source_file1.name}"]
message = create_message(self.chat, self.args) message = create_message(self.chat, self.args)
self.assertIsInstance(message, Message) self.assertIsInstance(message, Message)
# source file contains no source code # file contains no source code (only text)
# -> don't expect any in the question # -> don't expect any in the question
self.assertEqual(len(message.question.source_code()), 0) self.assertEqual(len(message.question.source_code()), 0)
self.assertEqual(message.question, Question(f"""What is this? self.assertEqual(message.question, Question(f"""What is this?
{self.source_file1_content}""")) {self.source_file1_content}"""))
def test_single_question_with_embedded_source_source(self) -> None: def test_single_question_with_text_file_and_embedded_code(self) -> None:
self.args.ask = ["What is this?"]
self.args.source = [f"{self.source_file2.name}"]
message = create_message(self.chat, self.args)
self.assertIsInstance(message, Message)
# source file contains 1 source code block
# -> expect it in the question
self.assertEqual(len(message.question.source_code()), 1)
self.assertEqual(message.question, Question(f"""What is this?
{self.source_file2_content}"""))
def test_single_question_with_embedded_source_code_source(self) -> None:
self.args.ask = ["What is this?"] self.args.ask = ["What is this?"]
self.args.source_code = [f"{self.source_file2.name}"] self.args.source_code = [f"{self.source_file2.name}"]
message = create_message(self.chat, self.args) message = create_message(self.chat, self.args)
self.assertIsInstance(message, Message) self.assertIsInstance(message, Message)
# source file contains 1 source code block # file contains 1 source code block
# -> expect it in the question # -> expect it in the question
self.assertEqual(len(message.question.source_code()), 2) self.assertEqual(len(message.question.source_code()), 1)
self.assertEqual(message.question, Question("""What is this?
```
This is embedded source code.
```
"""))
def test_single_question_with_code_only_file(self) -> None:
self.args.ask = ["What is this?"]
self.args.source_code = [f"{self.source_file3.name}"]
message = create_message(self.chat, self.args)
self.assertIsInstance(message, Message)
# file is complete source code
self.assertEqual(len(message.question.source_code()), 1)
self.assertEqual(message.question, Question(f"""What is this? self.assertEqual(message.question, Question(f"""What is this?
``` ```
{self.source_file2_content} {self.source_file3_content}
```""")) ```"""))
def test_single_question_with_text_file_and_multi_embedded_code(self) -> None:
self.args.ask = ["What is this?"]
self.args.source_code = [f"{self.source_file4.name}"]
message = create_message(self.chat, self.args)
self.assertIsInstance(message, Message)
# file contains 2 source code blocks
# -> expect them in the question
self.assertEqual(len(message.question.source_code()), 2)
self.assertEqual(message.question, Question("""What is this?
```
This is embedded source code.
```
```
This is embedded source code.
```
"""))