7 Commits

9 changed files with 195 additions and 42 deletions
+64
View File
@@ -0,0 +1,64 @@
from dataclasses import dataclass
from abc import abstractmethod
from typing import Protocol, Optional, Union
from .configuration import AIConfig
from .message import Message
from .chat import Chat
class AIError(Exception):
pass
@dataclass
class Tokens:
prompt: int = 0
completion: int = 0
total: int = 0
@dataclass
class AIResponse:
"""
The response to an AI request. Consists of one or more messages
(each containing the question and a single answer) and the nr.
of used tokens.
"""
messages: list[Message]
tokens: Optional[Tokens] = None
class AI(Protocol):
"""
The base class for AI clients.
"""
name: str
config: AIConfig
@abstractmethod
def request(self,
question: Message,
context: Chat,
num_answers: int = 1) -> AIResponse:
"""
Make an AI request, asking the given question with the given
context (i. e. chat history). The nr. of requested answers
corresponds to the nr. of messages in the 'AIResponse'.
"""
raise NotImplementedError
@abstractmethod
def models(self) -> list[str]:
"""
Return all models supported by this AI.
"""
raise NotImplementedError
def tokens(self, data: Union[Message, Chat]) -> int:
"""
Computes the nr. of AI language tokens for the given message
or chat. Note that the computation may not be 100% accurate
and is not implemented for all AIs.
"""
raise NotImplementedError
+43
View File
@@ -0,0 +1,43 @@
"""
Implements the OpenAI client classes and functions.
"""
import openai
from ..message import Message
from ..chat import Chat
from ..ai import AI, AIResponse
class OpenAI(AI):
"""
The OpenAI AI client.
"""
def request(self,
question: Message,
context: Chat,
num_answers: int = 1) -> AIResponse:
"""
Make an AI request, asking the given question with the given
context (i. e. chat history). The nr. of requested answers
corresponds to the nr. of messages in the 'AIResponse'.
"""
raise NotImplementedError
def models(self) -> list[str]:
"""
Return all models supported by this AI.
"""
raise NotImplementedError
def print_models(self) -> None:
"""
Print all models supported by the current AI.
"""
not_ready = []
for engine in sorted(openai.Engine.list()['data'], key=lambda x: x['id']):
if engine['ready']:
print(engine['id'])
else:
not_ready.append(engine['id'])
if len(not_ready) > 0:
print('\nNot ready: ' + ', '.join(not_ready))
+11 -2
View File
@@ -45,7 +45,7 @@ def read_dir(dir_path: pathlib.Path,
messages: list[Message] = []
file_iter = dir_path.glob(glob) if glob else dir_path.iterdir()
for file_path in sorted(file_iter):
if file_path.is_file():
if file_path.is_file() and file_path.suffix in Message.file_suffixes:
try:
message = Message.from_file(file_path, mfilter)
if message:
@@ -127,7 +127,16 @@ class Chat:
tags: set[Tag] = set()
for m in self.messages:
tags |= m.filter_tags(prefix, contain)
return tags
return set(sorted(tags))
def tags_frequency(self, prefix: Optional[str] = None, contain: Optional[str] = None) -> dict[Tag, int]:
"""
Get the frequency of all tags of all messages, optionally filtered by prefix or substring.
"""
tags: list[Tag] = []
for m in self.messages:
tags += [tag for tag in m.filter_tags(prefix, contain)]
return {tag: tags.count(tag) for tag in sorted(tags)}
def tokens(self) -> int:
"""
+11 -2
View File
@@ -7,7 +7,15 @@ OpenAIConfigInst = TypeVar('OpenAIConfigInst', bound='OpenAIConfig')
@dataclass
class OpenAIConfig():
class AIConfig:
"""
The base class of all AI configurations.
"""
name: str
@dataclass
class OpenAIConfig(AIConfig):
"""
The OpenAI section of the configuration file.
"""
@@ -25,6 +33,7 @@ class OpenAIConfig():
Create OpenAIConfig from a dict.
"""
return cls(
name='OpenAI',
api_key=str(source['api_key']),
model=str(source['model']),
max_tokens=int(source['max_tokens']),
@@ -36,7 +45,7 @@ class OpenAIConfig():
@dataclass
class Config():
class Config:
"""
The configuration file structure.
"""
+10 -3
View File
@@ -7,10 +7,11 @@ import sys
import argcomplete
import argparse
import pathlib
from .utils import terminal_width, print_tag_args, print_chat_hist, display_source_code, print_tags_frequency, ChatType
from .storage import save_answers, create_chat_hist, get_tags, get_tags_unique, read_file, dump_data
from .utils import terminal_width, print_tag_args, print_chat_hist, display_source_code, ChatType
from .storage import save_answers, create_chat_hist, get_tags_unique, read_file, dump_data
from .api_client import ai, openai_api_key, print_models
from .configuration import Config
from .chat import ChatDB
from itertools import zip_longest
from typing import Any
@@ -61,8 +62,12 @@ def tag_cmd(args: argparse.Namespace, config: Config) -> None:
"""
Handler for the 'tag' command.
"""
chat = ChatDB.from_dir(cache_path=pathlib.Path('.'),
db_path=pathlib.Path(config.db))
if args.list:
print_tags_frequency(get_tags(config, None))
tags_freq = chat.tags_frequency(args.prefix, args.contain)
for tag, freq in tags_freq.items():
print(f"- {tag}: {freq}")
def config_cmd(args: argparse.Namespace, config: Config) -> None:
@@ -195,6 +200,8 @@ def create_parser() -> argparse.ArgumentParser:
tag_group = tag_cmd_parser.add_mutually_exclusive_group(required=True)
tag_group.add_argument('-l', '--list', help="List all tags and their frequency",
action='store_true')
tag_cmd_parser.add_argument('-p', '--prefix', help="Filter tags by prefix")
tag_cmd_parser.add_argument('-c', '--contain', help="Filter tags by contained substring")
# 'config' command parser
config_cmd_parser = cmdparser.add_parser('config',
+26 -22
View File
@@ -128,29 +128,29 @@ class ModelLine(str):
return cls(' '.join([cls.prefix, model]))
class Question(str):
class Answer(str):
"""
A single question with a defined header.
A single answer with a defined header.
"""
tokens: int = 0 # tokens used by this question
txt_header: ClassVar[str] = '=== QUESTION ==='
yaml_key: ClassVar[str] = 'question'
tokens: int = 0 # tokens used by this answer
txt_header: ClassVar[str] = '=== ANSWER ==='
yaml_key: ClassVar[str] = 'answer'
def __new__(cls: Type[QuestionInst], string: str) -> QuestionInst:
def __new__(cls: Type[AnswerInst], string: str) -> AnswerInst:
"""
Make sure the question string does not contain the header.
Make sure the answer string does not contain the header as a whole line.
"""
if cls.txt_header in string:
raise MessageError(f"Question '{string}' contains the header '{cls.txt_header}'")
if cls.txt_header in string.split('\n'):
raise MessageError(f"Answer '{string}' contains the header '{cls.txt_header}'")
instance = super().__new__(cls, string)
return instance
@classmethod
def from_list(cls: Type[QuestionInst], strings: list[str]) -> QuestionInst:
def from_list(cls: Type[AnswerInst], strings: list[str]) -> AnswerInst:
"""
Build Question from a list of strings. Make sure strings do not contain the header.
"""
if any(cls.txt_header in string for string in strings):
if cls.txt_header in strings:
raise MessageError(f"Question contains the header '{cls.txt_header}'")
instance = super().__new__(cls, '\n'.join(strings).strip())
return instance
@@ -162,29 +162,33 @@ class Question(str):
return source_code(self, include_delims)
class Answer(str):
class Question(str):
"""
A single answer with a defined header.
A single question with a defined header.
"""
tokens: int = 0 # tokens used by this answer
txt_header: ClassVar[str] = '=== ANSWER ==='
yaml_key: ClassVar[str] = 'answer'
tokens: int = 0 # tokens used by this question
txt_header: ClassVar[str] = '=== QUESTION ==='
yaml_key: ClassVar[str] = 'question'
def __new__(cls: Type[AnswerInst], string: str) -> AnswerInst:
def __new__(cls: Type[QuestionInst], string: str) -> QuestionInst:
"""
Make sure the answer string does not contain the header.
Make sure the question string does not contain the header as a whole line
(also not that from 'Answer', so it's always clear where the answer starts).
"""
if cls.txt_header in string:
raise MessageError(f"Answer '{string}' contains the header '{cls.txt_header}'")
string_lines = string.split('\n')
if cls.txt_header in string_lines:
raise MessageError(f"Question '{string}' contains the header '{cls.txt_header}'")
if Answer.txt_header in string_lines:
raise MessageError(f"Question '{string}' contains the header '{Answer.txt_header}'")
instance = super().__new__(cls, string)
return instance
@classmethod
def from_list(cls: Type[AnswerInst], strings: list[str]) -> AnswerInst:
def from_list(cls: Type[QuestionInst], strings: list[str]) -> QuestionInst:
"""
Build Question from a list of strings. Make sure strings do not contain the header.
"""
if any(cls.txt_header in string for string in strings):
if cls.txt_header in strings:
raise MessageError(f"Question contains the header '{cls.txt_header}'")
instance = super().__new__(cls, '\n'.join(strings).strip())
return instance
-5
View File
@@ -78,8 +78,3 @@ def print_chat_hist(chat: ChatType, dump: bool = False, source_code: bool = Fals
print(message['content'])
else:
print(f"{message['role'].upper()}: {message['content']}")
def print_tags_frequency(tags: list[str]) -> None:
for tag in sorted(set(tags)):
print(f"- {tag}: {tags.count(tag)}")
+7 -2
View File
@@ -14,7 +14,7 @@ class TestChat(CmmTestCase):
self.chat = Chat([])
self.message1 = Message(Question('Question 1'),
Answer('Answer 1'),
{Tag('atag1')},
{Tag('atag1'), Tag('btag2')},
file_path=pathlib.Path('0001.txt'))
self.message2 = Message(Question('Question 2'),
Answer('Answer 2'),
@@ -57,6 +57,11 @@ class TestChat(CmmTestCase):
tags_cont = self.chat.tags(contain='2')
self.assertSetEqual(tags_cont, {Tag('btag2')})
def test_tags_frequency(self) -> None:
self.chat.add_msgs([self.message1, self.message2])
tags_freq = self.chat.tags_frequency()
self.assertDictEqual(tags_freq, {'atag1': 1, 'btag2': 2})
@patch('sys.stdout', new_callable=StringIO)
def test_print(self, mock_stdout: StringIO) -> None:
self.chat.add_msgs([self.message1, self.message2])
@@ -83,7 +88,7 @@ Answer 2
Question 1
{Answer.txt_header}
Answer 1
{TagLine.prefix} atag1
{TagLine.prefix} atag1 btag2
FILE: 0001.txt
{'-'*terminal_width()}
{Question.txt_header}
+23 -6
View File
@@ -61,22 +61,39 @@ class SourceCodeTestCase(CmmTestCase):
class QuestionTestCase(CmmTestCase):
def test_question_with_prefix(self) -> None:
def test_question_with_header(self) -> None:
with self.assertRaises(MessageError):
Question("=== QUESTION === What is your name?")
Question(f"{Question.txt_header}\nWhat is your name?")
def test_question_without_prefix(self) -> None:
def test_question_with_answer_header(self) -> None:
with self.assertRaises(MessageError):
Question(f"{Answer.txt_header}\nBob")
def test_question_with_legal_header(self) -> None:
"""
If the header is just a part of a line, it's fine.
"""
question = Question(f"This is a line contaning '{Question.txt_header}'\nWhat does that mean?")
self.assertIsInstance(question, Question)
self.assertEqual(question, f"This is a line contaning '{Question.txt_header}'\nWhat does that mean?")
def test_question_without_header(self) -> None:
question = Question("What is your favorite color?")
self.assertIsInstance(question, Question)
self.assertEqual(question, "What is your favorite color?")
class AnswerTestCase(CmmTestCase):
def test_answer_with_prefix(self) -> None:
def test_answer_with_header(self) -> None:
with self.assertRaises(MessageError):
Answer("=== ANSWER === Yes")
Answer(f"{Answer.txt_header}\nno")
def test_answer_without_prefix(self) -> None:
def test_answer_with_legal_header(self) -> None:
answer = Answer(f"This is a line contaning '{Answer.txt_header}'\nIt is what it is.")
self.assertIsInstance(answer, Answer)
self.assertEqual(answer, f"This is a line contaning '{Answer.txt_header}'\nIt is what it is.")
def test_answer_without_header(self) -> None:
answer = Answer("No")
self.assertIsInstance(answer, Answer)
self.assertEqual(answer, "No")