From 2aee0187083cfbada1820a9a991766f92cc70783 Mon Sep 17 00:00:00 2001 From: Oleksandr Kozachuk Date: Fri, 20 Oct 2023 13:43:31 +0200 Subject: [PATCH 1/3] Refactor message.Answer class in a way, that it can be constructed dynamically step by step, in preparation of using streaming API. --- chatmastermind/commands/question.py | 2 +- chatmastermind/message.py | 102 ++++++++++++++++++++++++---- tests/test_message.py | 2 +- 3 files changed, 89 insertions(+), 17 deletions(-) diff --git a/chatmastermind/commands/question.py b/chatmastermind/commands/question.py index cd31d54..ae96bac 100644 --- a/chatmastermind/commands/question.py +++ b/chatmastermind/commands/question.py @@ -101,7 +101,7 @@ def create_message(chat: ChatDB, args: argparse.Namespace) -> Message: if code_file is not None and len(code_file) > 0: add_file_as_code(question_parts, code_file) - full_question = '\n\n'.join(question_parts) + full_question = '\n\n'.join([str(s) for s in question_parts]) message = Message(question=Question(full_question), tags=args.output_tags, diff --git a/chatmastermind/message.py b/chatmastermind/message.py index 8e7a55d..97e3e3a 100644 --- a/chatmastermind/message.py +++ b/chatmastermind/message.py @@ -5,7 +5,9 @@ import pathlib import yaml import tempfile import shutil +import io from typing import Type, TypeVar, ClassVar, Optional, Any, Union, Final, Literal, Iterable, Tuple +from typing import Generator, Iterator from typing import get_args as typing_get_args from dataclasses import dataclass, asdict, field from .tags import Tag, TagLine, TagError, match_tags, rename_tags @@ -142,30 +144,100 @@ class Answer(str): txt_header: ClassVar[str] = '==== ANSWER ====' yaml_key: ClassVar[str] = 'answer' - def __new__(cls: Type[AnswerInst], string: str) -> AnswerInst: + def __init__(self, data: Union[str, Generator[str, None, None]]) -> None: + # Indicator of whether all of data has been processed + self.is_exhausted: bool = False + + # Initialize data + self.iterator: Iterator[str] = self._init_data(data) + + # Set up the buffer to hold the 'Answer' content + self.buffer: io.StringIO = io.StringIO() + + def _init_data(self, data: Union[str, Generator[str, None, None]]) -> Iterator[str]: """ - Make sure the answer string does not contain the header as a whole line. + Process input data (either a string or a string generator) """ - if cls.txt_header in string.split('\n'): - raise MessageError(f"Answer '{string}' contains the header '{cls.txt_header}'") - instance = super().__new__(cls, string) - return instance + if isinstance(data, str): + yield data + else: + yield from data + + def __str__(self) -> str: + """ + Output all content when converted into a string + """ + # Ensure all data has been processed + for _ in self: + pass + # Return the 'Answer' content + return self.buffer.getvalue() + + def __repr__(self) -> str: + return repr(str(self)) + + def __iter__(self) -> Generator[str, None, None]: + """ + Allows the object to be iterable + """ + # Generate content if not all data has been processed + if not self.is_exhausted: + yield from self.generator_iter() + else: + yield self.buffer.getvalue() + + def generator_iter(self) -> Generator[str, None, None]: + """ + Main generator method to process data + """ + for piece in self.iterator: + # Write to buffer and yield piece for the iterator + self.buffer.write(piece) + yield piece + self.is_exhausted = True # Set the flag that all data has been processed + # If the header occurs in the 'Answer' content, raise an error + if f'\n{self.txt_header}' in self.buffer.getvalue() or self.buffer.getvalue().startswith(self.txt_header): + raise MessageError(f"Answer {repr(self.buffer.getvalue())} contains the header {repr(Answer.txt_header)}") + + def __eq__(self, other: object) -> bool: + """ + Comparing the object to a string or another object + """ + if isinstance(other, str): + return str(self) == other # Compare the string value of this object to the other string + # Default behavior for comparing non-string objects + return super().__eq__(other) + + def __hash__(self) -> int: + """ + Generate a hash for the object based on its string representation. + """ + return hash(str(self)) + + def __format__(self, format_spec: str) -> str: + """ + Return a formatted version of the string as per the format specification. + """ + return str(self).__format__(format_spec) @classmethod def from_list(cls: Type[AnswerInst], strings: list[str]) -> AnswerInst: """ - Build Question from a list of strings. Make sure strings do not contain the header. + Build Answer from a list of strings. Make sure strings do not contain the header. """ - if cls.txt_header in strings: - raise MessageError(f"Question contains the header '{cls.txt_header}'") - instance = super().__new__(cls, '\n'.join(strings).strip()) - return instance + def _gen() -> Generator[str, None, None]: + if len(strings) > 0: + yield strings[0] + for s in strings[1:]: + yield '\n' + yield s + return cls(_gen()) def source_code(self, include_delims: bool = False) -> list[str]: """ Extract and return all source code sections. """ - return source_code(self, include_delims) + return source_code(str(self), include_delims) class Question(str): @@ -441,7 +513,7 @@ class Message(): output.append(self.question) if self.answer: output.append(Answer.txt_header) - output.append(self.answer) + output.append(str(self.answer)) return '\n'.join(output) def to_file(self, file_path: Optional[pathlib.Path]=None, mformat: MessageFormat = message_default_format) -> None: # noqa: 11 @@ -491,7 +563,7 @@ class Message(): temp_fd.write(f'{ModelLine.from_model(self.model)}\n') temp_fd.write(f'{Question.txt_header}\n{self.question}\n') if self.answer: - temp_fd.write(f'{Answer.txt_header}\n{self.answer}\n') + temp_fd.write(f'{Answer.txt_header}\n{str(self.answer)}\n') shutil.move(temp_file_path, file_path) def __to_file_yaml(self, file_path: pathlib.Path) -> None: @@ -560,7 +632,7 @@ class Message(): or (mfilter.ai and (not self.ai or mfilter.ai != self.ai)) # noqa: W503 or (mfilter.model and (not self.model or mfilter.model != self.model)) # noqa: W503 or (mfilter.question_contains and mfilter.question_contains not in self.question) # noqa: W503 - or (mfilter.answer_contains and (not self.answer or mfilter.answer_contains not in self.answer)) # noqa: W503 + or (mfilter.answer_contains and (not self.answer or mfilter.answer_contains not in str(self.answer))) # noqa: W503 or (mfilter.answer_state == 'available' and not self.answer) # noqa: W503 or (mfilter.ai_state == 'available' and not self.ai) # noqa: W503 or (mfilter.model_state == 'available' and not self.model) # noqa: W503 diff --git a/tests/test_message.py b/tests/test_message.py index b79bcae..0a6c2de 100644 --- a/tests/test_message.py +++ b/tests/test_message.py @@ -91,7 +91,7 @@ class QuestionTestCase(unittest.TestCase): class AnswerTestCase(unittest.TestCase): def test_answer_with_header(self) -> None: with self.assertRaises(MessageError): - Answer(f"{Answer.txt_header}\nno") + str(Answer(f"{Answer.txt_header}\nno")) def test_answer_with_legal_header(self) -> None: answer = Answer(f"This is a line contaning '{Answer.txt_header}'\nIt is what it is.") From bbc1ab5a0a3ce76d9ae90e461152321236d0530f Mon Sep 17 00:00:00 2001 From: Oleksandr Kozachuk Date: Fri, 20 Oct 2023 14:02:09 +0200 Subject: [PATCH 2/3] Fix source_code function with the dynamic answer class. --- chatmastermind/message.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/chatmastermind/message.py b/chatmastermind/message.py index 97e3e3a..e8b19ba 100644 --- a/chatmastermind/message.py +++ b/chatmastermind/message.py @@ -51,7 +51,7 @@ def source_code(text: str, include_delims: bool = False) -> list[str]: code_lines: list[str] = [] in_code_block = False - for line in text.split('\n'): + for line in str(text).split('\n'): if line.strip().startswith('```'): if include_delims: code_lines.append(line) From dbe72ff11c6ba5f04fb05748e0a05e18acad41dd Mon Sep 17 00:00:00 2001 From: Oleksandr Kozachuk Date: Sat, 21 Oct 2023 14:21:48 +0200 Subject: [PATCH 3/3] Activate and use OpenAI streaming API. --- chatmastermind/ais/openai.py | 85 +++++++++++++++++++++++------ chatmastermind/commands/question.py | 9 ++- requirements.txt | 1 + tests/test_ais_openai.py | 40 +++++++++----- 4 files changed, 102 insertions(+), 33 deletions(-) diff --git a/chatmastermind/ais/openai.py b/chatmastermind/ais/openai.py index d7bb12f..a8ceb34 100644 --- a/chatmastermind/ais/openai.py +++ b/chatmastermind/ais/openai.py @@ -2,7 +2,8 @@ Implements the OpenAI client classes and functions. """ import openai -from typing import Optional, Union +import tiktoken +from typing import Optional, Union, Generator from ..tags import Tag from ..message import Message, Answer from ..chat import Chat @@ -12,6 +13,52 @@ from ..configuration import OpenAIConfig ChatType = list[dict[str, str]] +class OpenAIAnswer: + def __init__(self, + idx: int, + streams: dict[int, 'OpenAIAnswer'], + response: openai.ChatCompletion, + tokens: Tokens, + encoding: tiktoken.core.Encoding) -> None: + self.idx = idx + self.streams = streams + self.response = response + self.position: int = 0 + self.encoding = encoding + self.data: list[str] = [] + self.finished: bool = False + self.tokens = tokens + + def stream(self) -> Generator[str, None, None]: + while True: + if not self.next(): + continue + if len(self.data) <= self.position: + break + yield self.data[self.position] + self.position += 1 + + def next(self) -> bool: + if self.finished: + return True + try: + chunk = next(self.response) + except StopIteration: + self.finished = True + if not self.finished: + found_choice = False + for choice in chunk['choices']: + if not choice['finish_reason']: + self.streams[choice['index']].data.append(choice['delta']['content']) + self.tokens.completion += len(self.encoding.encode(choice['delta']['content'])) + self.tokens.total = self.tokens.prompt + self.tokens.completion + if choice['index'] == self.idx: + found_choice = True + if not found_choice: + return False + return True + + class OpenAI(AI): """ The OpenAI AI client. @@ -21,7 +68,6 @@ class OpenAI(AI): self.ID = config.ID self.name = config.name self.config = config - openai.api_key = config.api_key def request(self, question: Message, @@ -33,7 +79,10 @@ class OpenAI(AI): chat history. The nr. of requested answers corresponds to the nr. of messages in the 'AIResponse'. """ - oai_chat = self.openai_chat(chat, self.config.system, question) + self.encoding = tiktoken.encoding_for_model(self.config.model) + openai.api_key = self.config.api_key + oai_chat, prompt_tokens = self.openai_chat(chat, self.config.system, question) + tokens: Tokens = Tokens(prompt_tokens, 0, prompt_tokens) response = openai.ChatCompletion.create( model=self.config.model, messages=oai_chat, @@ -41,22 +90,24 @@ class OpenAI(AI): max_tokens=self.config.max_tokens, top_p=self.config.top_p, n=num_answers, + stream=True, frequency_penalty=self.config.frequency_penalty, presence_penalty=self.config.presence_penalty) - question.answer = Answer(response['choices'][0]['message']['content']) + streams: dict[int, OpenAIAnswer] = {} + for n in range(num_answers): + streams[n] = OpenAIAnswer(n, streams, response, tokens, self.encoding) + question.answer = Answer(streams[0].stream()) question.tags = set(otags) if otags is not None else None question.ai = self.ID question.model = self.config.model answers: list[Message] = [question] - for choice in response['choices'][1:]: # type: ignore + for idx in range(1, num_answers): answers.append(Message(question=question.question, - answer=Answer(choice['message']['content']), + answer=Answer(streams[idx].stream()), tags=otags, ai=self.ID, model=self.config.model)) - return AIResponse(answers, Tokens(response['usage']['prompt_tokens'], - response['usage']['completion_tokens'], - response['usage']['total_tokens'])) + return AIResponse(answers, tokens) def models(self) -> list[str]: """ @@ -83,24 +134,26 @@ class OpenAI(AI): print('\nNot ready: ' + ', '.join(not_ready)) def openai_chat(self, chat: Chat, system: str, - question: Optional[Message] = None) -> ChatType: + question: Optional[Message] = None) -> tuple[ChatType, int]: """ Create a chat history with system message in OpenAI format. Optionally append a new question. """ oai_chat: ChatType = [] + prompt_tokens: int = 0 - def append(role: str, content: str) -> None: + def append(role: str, content: str) -> int: oai_chat.append({'role': role, 'content': content.replace("''", "'")}) + return len(self.encoding.encode(', '.join(['role:', oai_chat[-1]['role'], 'content:', oai_chat[-1]['content']]))) - append('system', system) + prompt_tokens += append('system', system) for message in chat.messages: if message.answer: - append('user', message.question) - append('assistant', message.answer) + prompt_tokens += append('user', message.question) + prompt_tokens += append('assistant', message.answer) if question: - append('user', question.question) - return oai_chat + prompt_tokens += append('user', question.question) + return oai_chat, prompt_tokens def tokens(self, data: Union[Message, Chat]) -> int: raise NotImplementedError diff --git a/chatmastermind/commands/question.py b/chatmastermind/commands/question.py index ae96bac..79c37da 100644 --- a/chatmastermind/commands/question.py +++ b/chatmastermind/commands/question.py @@ -129,13 +129,16 @@ def make_request(ai: AI, chat: ChatDB, message: Message, args: argparse.Namespac args.output_tags) # only write the response messages to the cache, # don't add them to the internal list - chat.cache_write(response.messages) for idx, msg in enumerate(response.messages): - print(f"=== ANSWER {idx+1} ===") - print(msg.answer) + print(f"=== ANSWER {idx+1} ===", flush=True) + if msg.answer: + for piece in msg.answer: + print(piece, end='', flush=True) + print() if response.tokens: print("===============") print(response.tokens) + chat.cache_write(response.messages) def repeat_messages(messages: list[Message], chat: ChatDB, args: argparse.Namespace, config: Config) -> None: diff --git a/requirements.txt b/requirements.txt index 0762ecf..00e89b3 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,3 +2,4 @@ openai PyYAML argcomplete pytest +tiktoken diff --git a/tests/test_ais_openai.py b/tests/test_ais_openai.py index b53a14d..eab84e6 100644 --- a/tests/test_ais_openai.py +++ b/tests/test_ais_openai.py @@ -16,26 +16,37 @@ class OpenAITest(unittest.TestCase): openai = OpenAI(config) # Set up the mock response from openai.ChatCompletion.create - mock_response = { + mock_chunk1 = { 'choices': [ { - 'message': { + 'index': 0, + 'delta': { 'content': 'Answer 1' - } + }, + 'finish_reason': None }, { - 'message': { + 'index': 1, + 'delta': { 'content': 'Answer 2' - } + }, + 'finish_reason': None } ], - 'usage': { - 'prompt_tokens': 10, - 'completion_tokens': 20, - 'total_tokens': 30 - } } - mock_create.return_value = mock_response + mock_chunk2 = { + 'choices': [ + { + 'index': 0, + 'finish_reason': 'stop' + }, + { + 'index': 1, + 'finish_reason': 'stop' + } + ], + } + mock_create.return_value = iter([mock_chunk1, mock_chunk2]) # Create test data question = Message(Question('Question')) @@ -57,9 +68,9 @@ class OpenAITest(unittest.TestCase): self.assertIsNotNone(response.tokens) self.assertIsInstance(response.tokens, Tokens) assert response.tokens - self.assertEqual(response.tokens.prompt, 10) - self.assertEqual(response.tokens.completion, 20) - self.assertEqual(response.tokens.total, 30) + self.assertEqual(response.tokens.prompt, 53) + self.assertEqual(response.tokens.completion, 6) + self.assertEqual(response.tokens.total, 59) # Assert the mock call to openai.ChatCompletion.create mock_create.assert_called_once_with( @@ -76,6 +87,7 @@ class OpenAITest(unittest.TestCase): max_tokens=config.max_tokens, top_p=config.top_p, n=2, + stream=True, frequency_penalty=config.frequency_penalty, presence_penalty=config.presence_penalty )