33 Commits

Author SHA1 Message Date
Oleksandr Kozachuk 9a957a89ac Switch to current version of OpenAI. 2024-03-30 14:17:38 +01:00
Oleksandr Kozachuk 5d1bb1f9e4 Fix some of the commands. 2023-11-10 10:42:46 +01:00
Oleksandr Kozachuk 75a123eb72 Fix usage of the dynamic answer is some cases. 2023-10-24 12:59:13 +02:00
juk0de 7c1c67f8ff Merge pull request 'Dynamic Answer class and OpenAI streaming API' (#19) from dynamic_answer into main
Introduces several changes with the main objective of enabling OpenAI's streaming API in the chatmastermind application. This allows for the retrieval of AI responses gradually as a stream, which can significantly improve the user experience in interactions that involve large result sets.

* Added tiktoken import in 'openai.py' and modifications to the OpenAI class to support streaming. This includes the addition of a new class OpenAIAnswer to handle streaming API responses.
* Modified request function in the OpenAI class: the stream=True flag is added to the openai.ChatCompletion.create method to enable streaming API.
* Modified 'question.py' to print the answer parts as they are streamed.
* Replaced the Answer class's string data type with a generator which supports str and Generator[str, None, None] data types. Modifications are made to the Answer class methods to handle both data types accordingly.
* Updated the tests in 'test_ais_openai.py' and 'test_message.py' to reflect and validate these changes.
2023-10-21 15:50:45 +02:00
Oleksandr Kozachuk dbe72ff11c Activate and use OpenAI streaming API. 2023-10-21 14:21:48 +02:00
Oleksandr Kozachuk bbc1ab5a0a Fix source_code function with the dynamic answer class. 2023-10-20 14:02:09 +02:00
Oleksandr Kozachuk 2aee018708 Refactor message.Answer class in a way, that it can be constructed dynamically step by step, in preparation of using streaming API. 2023-10-20 13:43:31 +02:00
ok 17c6fa2453 Merge pull request 'Configurable glob and location on question and hist commands' (#18) from cust_loc_glob into main
Reviewed-on: #18
2023-10-20 09:47:03 +02:00
juk0de 5774278fb7 README: added new 'question' command parameters 2023-10-20 09:16:03 +02:00
juk0de 40d0de50de cmm: limited the message locations for the new cmm parameters to those that make sense 2023-10-20 09:16:03 +02:00
juk0de 72d31c26e9 main: improved parameter descriptions 2023-10-20 09:16:03 +02:00
juk0de 980e5ac51f chat: changed default glob to '*.msg' in all ChatDB functions 2023-10-20 09:00:58 +02:00
Oleksandr Kozachuk 114282dfd8 Add --glob and --location flags to hist and question commands, to be able to specify the location and files they should use. 2023-10-19 16:03:51 +02:00
Oleksandr Kozachuk 9a493b57da Per default use only files with .msg suffix ignoring other files. 2023-10-19 16:02:40 +02:00
Oleksandr Kozachuk 9b0951cb3f Change type msg_location to an Enum instead of Literal to be able to get all values easy and improve type checks. 2023-10-19 16:00:44 +02:00
Oleksandr Kozachuk 5f29f60168 Add .old/ to git ignore, I use that dir ofter for old files, I do not want to delete. 2023-10-17 11:53:49 +02:00
juk0de 3ea1f49027 cmm: added options '--tight' and '--no-paging' to the 'hist --print' cmd 2023-10-02 08:35:19 +02:00
juk0de 8f56399844 cmm: replaced options '--with-tags' and '--with-file' with '--with-metadata' 2023-10-01 10:11:16 +02:00
juk0de e4cb6eb22b README: updated 'hist' command description 2023-10-01 09:27:40 +02:00
juk0de e19c6bb1ea hist_cmd: added module 'test_hist_cmd.py' 2023-09-30 08:25:33 +02:00
juk0de 811b2e6830 hist_cmd: implemented '--convert' option 2023-09-29 18:53:12 +02:00
juk0de 2a8f01aee4 chat: 'msg_gather()' now supports globbing 2023-09-29 07:16:20 +02:00
juk0de efdb3cae2f question: moved around some code 2023-09-29 07:16:20 +02:00
juk0de aecfd1088d chat: added message file format as ChatDB class member 2023-09-29 07:16:20 +02:00
juk0de 140dbed809 message: added function 'rm_file()' and test 2023-09-29 07:16:20 +02:00
juk0de 01860ace2c test_question_cmd: modified tests to use '.msg' file suffix 2023-09-29 07:16:20 +02:00
juk0de df42bcee09 test_chat: added test for file_path collision detection 2023-09-29 07:16:20 +02:00
juk0de e34eab6519 test_chat: changed all tests to use the new '.msg' suffix 2023-09-29 07:16:20 +02:00
juk0de d07fd13e8e test_message: changed all tests to use the new '.msg' suffix 2023-09-29 07:16:20 +02:00
juk0de b8681e8274 message: fixed tag matching for YAML file format 2023-09-29 07:16:20 +02:00
juk0de d2be53aeab chat: switched to new message suffix and formats
- no longer using file suffix to choose the format
- added 'mformat' argument to 'write_xxx()' functions
- file suffix is now set by 'Message.to_file()' per default
2023-09-29 07:16:20 +02:00
juk0de 9ca9a23569 message: introduced file suffix '.msg'
- '.msg' suffix is always used for writing
- 'Message.to_file()' will set the file suffix if the given file_path has none
- added 'mformat' argument to 'Message.to_file()' for choosing the file format
- '.txt' and '.yaml' suffixes are only supported for reading
2023-09-29 07:16:20 +02:00
juk0de 6f3758e12e question_cmd: fixed '--create' option 2023-09-29 07:15:46 +02:00
15 changed files with 518 additions and 196 deletions
+1
View File
@@ -106,6 +106,7 @@ celerybeat.pid
.venv
env/
venv/
.old/
ENV/
env.bak/
venv.bak/
+11 -6
View File
@@ -65,23 +65,28 @@ cmm question [-t OTAGS]... [-k ATAGS]... [-x XTAGS]... [-o OUTTAGS]... [-A AI_ID
* `-O, --overwrite`: Overwrite existing messages when repeating them
* `-s, --source-text FILE`: Add content of a file to the query
* `-S, --source-code FILE`: Add source code file content to the chat history
* `-l, --location {cache,db,all}`: Use given location when building the chat history (default: 'db')
* `-g, --glob GLOB`: Filter message files using the given glob pattern
#### Hist
The `hist` command is used to print the chat history.
The `hist` command is used to print and manage the chat history.
```bash
cmm hist [-t OTAGS]... [-k ATAGS]... [-x XTAGS]... [-w] [-W] [-S] [-A SUBSTRING] [-Q SUBSTRING]
cmm hist [--print | --convert FORMAT] [-t OTAGS]... [-k ATAGS]... [-x XTAGS]... [-w] [-W] [-S] [-A SUBSTRING] [-Q SUBSTRING]
```
* `-p, --print`: Print the DB chat history
* `-c, --convert FORMAT`: Convert all messages to the given format
* `-t, --or-tags OTAGS`: List of tags (one must match)
* `-k, --and-tags ATAGS`: List of tags (all must match)
* `-x, --exclude-tags XTAGS`: List of tags to exclude
* `-w, --with-tags`: Print chat history with tags
* `-W, --with-files`: Print chat history with filenames
* `-w, --with-metadata`: Print chat history with metadata (tags, filenames, AI, etc.)
* `-S, --source-code-only`: Only print embedded source code
* `-A, --answer SUBSTRING`: Search for answer substring
* `-Q, --question SUBSTRING`: Search for question substring
* `-A, --answer SUBSTRING`: Filter for answer substring
* `-Q, --question SUBSTRING`: Filter for question substring
* `-l, --location {cache,db,all}`: Use given location when building the chat history (default: 'db')
* `-g, --glob GLOB`: Filter message files using the given glob pattern
#### Tags
+77 -28
View File
@@ -2,7 +2,8 @@
Implements the OpenAI client classes and functions.
"""
import openai
from typing import Optional, Union
import tiktoken
from typing import Optional, Union, Generator
from ..tags import Tag
from ..message import Message, Answer
from ..chat import Chat
@@ -12,6 +13,52 @@ from ..configuration import OpenAIConfig
ChatType = list[dict[str, str]]
class OpenAIAnswer:
def __init__(self,
idx: int,
streams: dict[int, 'OpenAIAnswer'],
response: openai.ChatCompletion,
tokens: Tokens,
encoding: tiktoken.core.Encoding) -> None:
self.idx = idx
self.streams = streams
self.response = response
self.position: int = 0
self.encoding = encoding
self.data: list[str] = []
self.finished: bool = False
self.tokens = tokens
def stream(self) -> Generator[str, None, None]:
while True:
if not self.next():
continue
if len(self.data) <= self.position:
break
yield self.data[self.position]
self.position += 1
def next(self) -> bool:
if self.finished:
return True
try:
chunk = next(self.response)
except StopIteration:
self.finished = True
if not self.finished:
found_choice = False
for choice in chunk.choices:
if not choice.finish_reason:
self.streams[choice.index].data.append(choice.delta.content)
self.tokens.completion += len(self.encoding.encode(choice.delta.content))
self.tokens.total = self.tokens.prompt + self.tokens.completion
if choice.index == self.idx:
found_choice = True
if not found_choice:
return False
return True
class OpenAI(AI):
"""
The OpenAI AI client.
@@ -21,7 +68,10 @@ class OpenAI(AI):
self.ID = config.ID
self.name = config.name
self.config = config
openai.api_key = config.api_key
self.client = openai.OpenAI(api_key=self.config.api_key)
def _completions(self, *args, **kw): # type: ignore
return self.client.chat.completions.create(*args, **kw)
def request(self,
question: Message,
@@ -33,39 +83,42 @@ class OpenAI(AI):
chat history. The nr. of requested answers corresponds to the
nr. of messages in the 'AIResponse'.
"""
oai_chat = self.openai_chat(chat, self.config.system, question)
response = openai.ChatCompletion.create(
self.encoding = tiktoken.encoding_for_model(self.config.model)
oai_chat, prompt_tokens = self.openai_chat(chat, self.config.system, question)
tokens: Tokens = Tokens(prompt_tokens, 0, prompt_tokens)
response = self._completions(
model=self.config.model,
messages=oai_chat,
temperature=self.config.temperature,
max_tokens=self.config.max_tokens,
top_p=self.config.top_p,
n=num_answers,
stream=True,
frequency_penalty=self.config.frequency_penalty,
presence_penalty=self.config.presence_penalty)
question.answer = Answer(response['choices'][0]['message']['content'])
streams: dict[int, OpenAIAnswer] = {}
for n in range(num_answers):
streams[n] = OpenAIAnswer(n, streams, response, tokens, self.encoding)
question.answer = Answer(streams[0].stream())
question.tags = set(otags) if otags is not None else None
question.ai = self.ID
question.model = self.config.model
answers: list[Message] = [question]
for choice in response['choices'][1:]: # type: ignore
for idx in range(1, num_answers):
answers.append(Message(question=question.question,
answer=Answer(choice['message']['content']),
answer=Answer(streams[idx].stream()),
tags=otags,
ai=self.ID,
model=self.config.model))
return AIResponse(answers, Tokens(response['usage']['prompt_tokens'],
response['usage']['completion_tokens'],
response['usage']['total_tokens']))
return AIResponse(answers, tokens)
def models(self) -> list[str]:
"""
Return all models supported by this AI.
"""
ret = []
for engine in sorted(openai.Engine.list()['data'], key=lambda x: x['id']):
if engine['ready']:
ret.append(engine['id'])
for engine in sorted(self.client.models.list().data, key=lambda x: x.id):
ret.append(engine.id)
ret.sort()
return ret
@@ -73,34 +126,30 @@ class OpenAI(AI):
"""
Print all models supported by the current AI.
"""
not_ready = []
for engine in sorted(openai.Engine.list()['data'], key=lambda x: x['id']):
if engine['ready']:
print(engine['id'])
else:
not_ready.append(engine['id'])
if len(not_ready) > 0:
print('\nNot ready: ' + ', '.join(not_ready))
for model in self.models():
print(model)
def openai_chat(self, chat: Chat, system: str,
question: Optional[Message] = None) -> ChatType:
question: Optional[Message] = None) -> tuple[ChatType, int]:
"""
Create a chat history with system message in OpenAI format.
Optionally append a new question.
"""
oai_chat: ChatType = []
prompt_tokens: int = 0
def append(role: str, content: str) -> None:
def append(role: str, content: str) -> int:
oai_chat.append({'role': role, 'content': content.replace("''", "'")})
return len(self.encoding.encode(', '.join(['role:', oai_chat[-1]['role'], 'content:', oai_chat[-1]['content']])))
append('system', system)
prompt_tokens += append('system', system)
for message in chat.messages:
if message.answer:
append('user', message.question)
append('assistant', message.answer)
prompt_tokens += append('user', message.question)
prompt_tokens += append('assistant', str(message.answer))
if question:
append('user', question.question)
return oai_chat
prompt_tokens += append('user', question.question)
return oai_chat, prompt_tokens
def tokens(self, data: Union[Message, Chat]) -> int:
raise NotImplementedError
+47 -28
View File
@@ -6,7 +6,8 @@ from pathlib import Path
from pprint import PrettyPrinter
from pydoc import pager
from dataclasses import dataclass
from typing import TypeVar, Type, Optional, Any, Callable, Literal, Union
from enum import Enum
from typing import TypeVar, Type, Optional, Any, Callable, Union
from .configuration import default_config_file
from .message import Message, MessageFilter, MessageError, MessageFormat, message_in, message_valid_formats
from .tags import Tag
@@ -16,10 +17,17 @@ ChatDBInst = TypeVar('ChatDBInst', bound='ChatDB')
db_next_file = '.next'
ignored_files = [db_next_file, default_config_file]
msg_location = Literal['mem', 'disk', 'cache', 'db', 'all']
msg_suffix = Message.file_suffix_write
class msg_location(Enum):
MEM = 'mem'
DISK = 'disk'
CACHE = 'cache'
DB = 'db'
ALL = 'all'
class ChatError(Exception):
pass
@@ -44,12 +52,12 @@ def read_dir(dir_path: Path,
Parameters:
* 'dir_path': source directory
* 'glob': if specified, files will be filtered using 'path.glob()',
otherwise it uses 'path.iterdir()'.
otherwise it reads all files with the default message suffix
* 'mfilter': use with 'Message.from_file()' to filter messages
when reading them.
"""
messages: list[Message] = []
file_iter = dir_path.glob(glob) if glob else dir_path.iterdir()
file_iter = dir_path.glob(glob) if glob else dir_path.glob(f'*{msg_suffix}')
for file_path in sorted(file_iter):
if (file_path.is_file()
and file_path.name not in ignored_files # noqa: W503
@@ -255,14 +263,17 @@ class Chat:
return sum(m.tokens() for m in self.messages)
def print(self, source_code_only: bool = False,
with_tags: bool = False, with_files: bool = False,
paged: bool = True) -> None:
with_metadata: bool = False,
paged: bool = True,
tight: bool = False) -> None:
output: list[str] = []
for message in self.messages:
if source_code_only:
output.append(message.to_str(source_code_only=True))
continue
output.append(message.to_str(with_tags, with_files))
output.append(message.to_str(with_metadata))
if not tight:
output.append('\n' + ('-' * terminal_width()) + '\n')
if paged:
print_paged('\n'.join(output))
else:
@@ -284,7 +295,7 @@ class ChatDB(Chat):
# a MessageFilter that all messages must match (if given)
mfilter: Optional[MessageFilter] = None
# the glob pattern for all messages
glob: Optional[str] = None
glob: str = f'*{msg_suffix}'
# message format (for writing)
mformat: MessageFormat = Message.default_format
@@ -300,20 +311,28 @@ class ChatDB(Chat):
def from_dir(cls: Type[ChatDBInst],
cache_path: Path,
db_path: Path,
glob: Optional[str] = None,
mfilter: Optional[MessageFilter] = None) -> ChatDBInst:
glob: str = f'*{msg_suffix}',
mfilter: Optional[MessageFilter] = None,
loc: msg_location = msg_location.DB) -> ChatDBInst:
"""
Create a 'ChatDB' instance from the given directory structure.
Reads all messages from 'db_path' into the local message list.
Parameters:
* 'cache_path': path to the directory for temporary messages
* 'db_path': path to the directory for persistent messages
* 'glob': if specified, files will be filtered using 'path.glob()',
otherwise it uses 'path.iterdir()'.
* 'glob': if specified, files will be filtered using 'path.glob()'
* 'mfilter': use with 'Message.from_file()' to filter messages
when reading them.
* 'loc': read messages from given location instead of 'db_path'
"""
messages = read_dir(db_path, glob, mfilter)
if loc == msg_location.MEM:
raise ChatError(f"Can't build ChatDB from message location '{loc}'")
messages: list[Message] = []
if loc in [msg_location.DB, msg_location.DISK, msg_location.ALL]:
messages.extend(read_dir(db_path, glob, mfilter))
if loc in [msg_location.CACHE, msg_location.DISK, msg_location.ALL]:
messages.extend(read_dir(cache_path, glob, mfilter))
messages.sort(key=lambda x: x.msg_id())
return cls(messages, cache_path, db_path, mfilter, glob)
@classmethod
@@ -383,7 +402,7 @@ class ChatDB(Chat):
def msg_gather(self,
loc: msg_location,
require_file_path: bool = False,
glob: Optional[str] = None,
glob: str = f'*{msg_suffix}',
mfilter: Optional[MessageFilter] = None) -> list[Message]:
"""
Gather and return messages from the given locations:
@@ -396,14 +415,14 @@ class ChatDB(Chat):
If 'require_file_path' is True, return only files with a valid file_path.
"""
loc_messages: list[Message] = []
if loc in ['mem', 'all']:
if loc in [msg_location.MEM, msg_location.ALL]:
if require_file_path:
loc_messages += [m for m in self.messages if (m.file_path is not None and (mfilter is None or m.match(mfilter)))]
else:
loc_messages += [m for m in self.messages if (mfilter is None or m.match(mfilter))]
if loc in ['cache', 'disk', 'all']:
if loc in [msg_location.CACHE, msg_location.DISK, msg_location.ALL]:
loc_messages += read_dir(self.cache_path, glob=glob, mfilter=mfilter)
if loc in ['db', 'disk', 'all']:
if loc in [msg_location.DB, msg_location.DISK, msg_location.ALL]:
loc_messages += read_dir(self.db_path, glob=glob, mfilter=mfilter)
# remove_duplicates and sort the list
unique_messages: list[Message] = []
@@ -419,7 +438,7 @@ class ChatDB(Chat):
def msg_find(self,
msg_names: list[str],
loc: msg_location = 'mem',
loc: msg_location = msg_location.MEM,
) -> list[Message]:
"""
Search and return the messages with the given names. Names can either be filenames
@@ -437,7 +456,7 @@ class ChatDB(Chat):
return [m for m in loc_messages
if any((m.file_path and self.msg_name_matches(m.file_path, mn)) for mn in msg_names)]
def msg_remove(self, msg_names: list[str], loc: msg_location = 'mem') -> None:
def msg_remove(self, msg_names: list[str], loc: msg_location = msg_location.MEM) -> None:
"""
Remove the messages with the given names. Names can either be filenames
(with or without suffix), full paths or Message.msg_id(). Also deletes the
@@ -449,7 +468,7 @@ class ChatDB(Chat):
* 'db' : messages in the DB directory
* 'all' : all messages ('mem' + 'disk')
"""
if loc != 'mem':
if loc != msg_location.MEM:
# delete the message files first
rm_messages = self.msg_find(msg_names, loc=loc)
for m in rm_messages:
@@ -460,7 +479,7 @@ class ChatDB(Chat):
def msg_latest(self,
mfilter: Optional[MessageFilter] = None,
loc: msg_location = 'mem') -> Optional[Message]:
loc: msg_location = msg_location.MEM) -> Optional[Message]:
"""
Return the last added message (according to the file ID) that matches the given filter.
Only consider messages with a valid file_path (except if loc is 'mem').
@@ -489,7 +508,7 @@ class ChatDB(Chat):
and message.file_path.parent.samefile(self.cache_path) # noqa: W503
and message.file_path.exists()) # noqa: W503
else:
return len(self.msg_find([message], loc='cache')) > 0
return len(self.msg_find([message], loc=msg_location.CACHE)) > 0
def msg_in_db(self, message: Union[Message, str]) -> bool:
"""
@@ -501,9 +520,9 @@ class ChatDB(Chat):
and message.file_path.parent.samefile(self.db_path) # noqa: W503
and message.file_path.exists()) # noqa: W503
else:
return len(self.msg_find([message], loc='db')) > 0
return len(self.msg_find([message], loc=msg_location.DB)) > 0
def cache_read(self, glob: Optional[str] = None, mfilter: Optional[MessageFilter] = None) -> None:
def cache_read(self, glob: str = f'*{msg_suffix}', mfilter: Optional[MessageFilter] = None) -> None:
"""
Read messages from the cache directory. New ones are added to the internal list,
existing ones are replaced. A message is determined as 'existing' if a message
@@ -546,7 +565,7 @@ class ChatDB(Chat):
self.messages += messages
self.msg_sort()
def cache_clear(self, glob: Optional[str] = None) -> None:
def cache_clear(self, glob: str = f'*{msg_suffix}') -> None:
"""
Delete all message files from the cache dir and remove them from the internal list.
"""
@@ -566,11 +585,11 @@ class ChatDB(Chat):
self.cache_write([message])
# remove the old one (if any)
if old_path:
self.msg_remove([str(old_path)], loc='db')
self.msg_remove([str(old_path)], loc=msg_location.DB)
# (re)add it to the internal list
self.msg_add([message])
def db_read(self, glob: Optional[str] = None, mfilter: Optional[MessageFilter] = None) -> None:
def db_read(self, glob: str = f'*{msg_suffix}', mfilter: Optional[MessageFilter] = None) -> None:
"""
Read messages from the DB directory. New ones are added to the internal list,
existing ones are replaced. A message is determined as 'existing' if a message
@@ -625,6 +644,6 @@ class ChatDB(Chat):
self.db_write([message])
# remove the old one (if any)
if old_path:
self.msg_remove([str(old_path)], loc='cache')
self.msg_remove([str(old_path)], loc=msg_location.CACHE)
# (re)add it to the internal list
self.msg_add([message])
+59 -7
View File
@@ -1,13 +1,52 @@
import sys
import argparse
from pathlib import Path
from ..configuration import Config
from ..chat import ChatDB
from ..message import MessageFilter
from ..chat import ChatDB, msg_location
from ..message import MessageFilter, Message
def hist_cmd(args: argparse.Namespace, config: Config) -> None:
msg_suffix = Message.file_suffix_write # currently '.msg'
def convert_messages(args: argparse.Namespace, config: Config) -> None:
"""
Handler for the 'hist' command.
Convert messages to a new format. Also used to change old suffixes
('.txt', '.yaml') to the latest default message file suffix ('.msg').
"""
chat = ChatDB.from_dir(Path(config.cache),
Path(config.db),
glob='*')
# read all known message files
msgs = chat.msg_gather(loc=msg_location.DISK, glob='*.*')
# make a set of all message IDs
msg_ids = set([m.msg_id() for m in msgs])
# set requested format and write all messages
chat.set_msg_format(args.convert)
# delete the current suffix
# -> a new one will automatically be created
for m in msgs:
if m.file_path:
m.file_path = m.file_path.with_suffix('')
chat.msg_write(msgs)
# read all messages with the current default suffix
msgs = chat.msg_gather(loc=msg_location.DISK, glob=f'*{msg_suffix}')
# make sure we converted all of the original messages
for mid in msg_ids:
if not any(mid == m.msg_id() for m in msgs):
print(f"Message '{mid}' has not been found after conversion. Aborting.")
sys.exit(1)
# delete messages with old suffixes
msgs = chat.msg_gather(loc=msg_location.DISK, glob='*.*')
for m in msgs:
if m.file_path and m.file_path.suffix != msg_suffix:
m.rm_file()
print(f"Successfully converted {len(msg_ids)} messages.")
def print_chat(args: argparse.Namespace, config: Config) -> None:
"""
Print the DB chat history.
"""
mfilter = MessageFilter(tags_or=args.or_tags,
@@ -17,7 +56,20 @@ def hist_cmd(args: argparse.Namespace, config: Config) -> None:
answer_contains=args.answer)
chat = ChatDB.from_dir(Path(config.cache),
Path(config.db),
mfilter=mfilter)
mfilter=mfilter,
loc=msg_location(args.location),
glob=args.glob)
chat.print(args.source_code_only,
args.with_tags,
args.with_files)
args.with_metadata,
paged=not args.no_paging,
tight=args.tight)
def hist_cmd(args: argparse.Namespace, config: Config) -> None:
"""
Handler for the 'hist' command.
"""
if args.print:
print_chat(args, config)
elif args.convert:
convert_messages(args, config)
+2 -2
View File
@@ -3,7 +3,7 @@ import argparse
from pathlib import Path
from ..configuration import Config
from ..message import Message, MessageError
from ..chat import ChatDB
from ..chat import ChatDB, msg_location
def print_message(message: Message, args: argparse.Namespace) -> None:
@@ -38,7 +38,7 @@ def print_cmd(args: argparse.Namespace, config: Config) -> None:
# print latest message
elif args.latest:
chat = ChatDB.from_dir(Path(config.cache), Path(config.db))
latest = chat.msg_latest(loc='disk')
latest = chat.msg_latest(loc=msg_location.DISK)
if not latest:
print("No message found!")
sys.exit(1)
+23 -9
View File
@@ -4,12 +4,16 @@ from pathlib import Path
from itertools import zip_longest
from copy import deepcopy
from ..configuration import Config
from ..chat import ChatDB
from ..chat import ChatDB, msg_location
from ..message import Message, MessageFilter, MessageError, Question, source_code
from ..ai_factory import create_ai
from ..ai import AI, AIResponse
class QuestionCmdError(Exception):
pass
def add_file_as_text(question_parts: list[str], file: str) -> None:
"""
Add the given file as plain text to the question part list.
@@ -80,7 +84,12 @@ def create_message(chat: ChatDB, args: argparse.Namespace) -> Message:
to the cache directory.
"""
question_parts = []
question_list = args.ask if args.ask is not None else []
if args.create is not None:
question_list = args.create
elif args.ask is not None:
question_list = args.ask
else:
raise QuestionCmdError("No question found")
text_files = args.source_text if args.source_text is not None else []
code_files = args.source_code if args.source_code is not None else []
@@ -92,7 +101,7 @@ def create_message(chat: ChatDB, args: argparse.Namespace) -> Message:
if code_file is not None and len(code_file) > 0:
add_file_as_code(question_parts, code_file)
full_question = '\n\n'.join(question_parts)
full_question = '\n\n'.join([str(s) for s in question_parts])
message = Message(question=Question(full_question),
tags=args.output_tags,
@@ -120,13 +129,16 @@ def make_request(ai: AI, chat: ChatDB, message: Message, args: argparse.Namespac
args.output_tags)
# only write the response messages to the cache,
# don't add them to the internal list
chat.cache_write(response.messages)
for idx, msg in enumerate(response.messages):
print(f"=== ANSWER {idx+1} ===")
print(msg.answer)
print(f"=== ANSWER {idx+1} ===", flush=True)
if msg.answer:
for piece in msg.answer:
print(piece, end='', flush=True)
print()
if response.tokens:
print("===============")
print(response.tokens)
chat.cache_write(response.messages)
def repeat_messages(messages: list[Message], chat: ChatDB, args: argparse.Namespace, config: Config) -> None:
@@ -177,7 +189,9 @@ def question_cmd(args: argparse.Namespace, config: Config) -> None:
tags_not=args.exclude_tags)
chat = ChatDB.from_dir(cache_path=Path(config.cache),
db_path=Path(config.db),
mfilter=mfilter)
mfilter=mfilter,
glob=args.glob,
loc=msg_location(args.location))
# if it's a new question, create and store it immediately
if args.ask or args.create:
message = create_message(chat, args)
@@ -193,14 +207,14 @@ def question_cmd(args: argparse.Namespace, config: Config) -> None:
repeat_msgs: list[Message] = []
# repeat latest message
if len(args.repeat) == 0:
lmessage = chat.msg_latest(loc='cache')
lmessage = chat.msg_latest(loc=msg_location.CACHE)
if lmessage is None:
print("No message found to repeat!")
sys.exit(1)
repeat_msgs.append(lmessage)
# repeat given message(s)
else:
repeat_msgs = chat.msg_find(args.repeat, loc='disk')
repeat_msgs = chat.msg_find(args.repeat, loc=msg_location.DISK)
repeat_messages(repeat_msgs, chat, args, config)
# === PROCESS ===
elif args.process is not None:
+20 -6
View File
@@ -14,6 +14,7 @@ from .commands.tags import tags_cmd
from .commands.config import config_cmd
from .commands.hist import hist_cmd
from .commands.print import print_cmd
from .chat import msg_location
def tags_completer(prefix: str, parsed_args: Any, **kwargs: Any) -> list[str]:
@@ -65,6 +66,11 @@ def create_parser() -> argparse.ArgumentParser:
question_group.add_argument('-c', '--create', nargs='+', help='Create a question', metavar='QUESTION')
question_group.add_argument('-r', '--repeat', nargs='*', help='Repeat a question', metavar='MESSAGE')
question_group.add_argument('-p', '--process', nargs='*', help='Process existing questions', metavar='MESSAGE')
question_cmd_parser.add_argument('-l', '--location',
choices=[x.value for x in msg_location if x not in [msg_location.MEM, msg_location.DISK]],
default='db',
help='Use given location when building the chat history (default: \'db\')')
question_cmd_parser.add_argument('-g', '--glob', help='Filter message files using the given glob pattern')
question_cmd_parser.add_argument('-O', '--overwrite', help='Overwrite existing messages when repeating them',
action='store_true')
question_cmd_parser.add_argument('-s', '--source-text', nargs='+', help='Add content of a file to the query', metavar='FILE')
@@ -73,17 +79,25 @@ def create_parser() -> argparse.ArgumentParser:
# 'hist' command parser
hist_cmd_parser = cmdparser.add_parser('hist', parents=[tag_parser],
help="Print chat history.",
help="Print and manage chat history.",
aliases=['h'])
hist_cmd_parser.set_defaults(func=hist_cmd)
hist_cmd_parser.add_argument('-w', '--with-tags', help="Print chat history with tags.",
action='store_true')
hist_cmd_parser.add_argument('-W', '--with-files', help="Print chat history with filenames.",
hist_group = hist_cmd_parser.add_mutually_exclusive_group(required=True)
hist_group.add_argument('-p', '--print', help='Print the DB chat history', action='store_true')
hist_group.add_argument('-c', '--convert', help='Convert all message files to the given format [txt|yaml]', metavar='FORMAT')
hist_cmd_parser.add_argument('-w', '--with-metadata', help="Print chat history with metadata (tags, filename, AI, etc.).",
action='store_true')
hist_cmd_parser.add_argument('-S', '--source-code-only', help='Only print embedded source code',
action='store_true')
hist_cmd_parser.add_argument('-A', '--answer', help='Search for answer substring', metavar='SUBSTRING')
hist_cmd_parser.add_argument('-Q', '--question', help='Search for question substring', metavar='SUBSTRING')
hist_cmd_parser.add_argument('-A', '--answer', help='Print only answers with given substring', metavar='SUBSTRING')
hist_cmd_parser.add_argument('-Q', '--question', help='Print only questions with given substring', metavar='SUBSTRING')
hist_cmd_parser.add_argument('-d', '--tight', help='Print without message separators', action='store_true')
hist_cmd_parser.add_argument('-P', '--no-paging', help='Print without paging', action='store_true')
hist_cmd_parser.add_argument('-l', '--location',
choices=[x.value for x in msg_location if x not in [msg_location.MEM, msg_location.DISK]],
default='db',
help='Use given location when building the chat history (default: \'db\')')
hist_cmd_parser.add_argument('-g', '--glob', help='Filter message files using the given glob pattern')
# 'tags' command parser
tags_cmd_parser = cmdparser.add_parser('tags',
+92 -19
View File
@@ -5,7 +5,9 @@ import pathlib
import yaml
import tempfile
import shutil
import io
from typing import Type, TypeVar, ClassVar, Optional, Any, Union, Final, Literal, Iterable, Tuple
from typing import Generator, Iterator
from typing import get_args as typing_get_args
from dataclasses import dataclass, asdict, field
from .tags import Tag, TagLine, TagError, match_tags, rename_tags
@@ -49,7 +51,7 @@ def source_code(text: str, include_delims: bool = False) -> list[str]:
code_lines: list[str] = []
in_code_block = False
for line in text.split('\n'):
for line in str(text).split('\n'):
if line.strip().startswith('```'):
if include_delims:
code_lines.append(line)
@@ -142,30 +144,100 @@ class Answer(str):
txt_header: ClassVar[str] = '==== ANSWER ===='
yaml_key: ClassVar[str] = 'answer'
def __new__(cls: Type[AnswerInst], string: str) -> AnswerInst:
def __init__(self, data: Union[str, Generator[str, None, None]]) -> None:
# Indicator of whether all of data has been processed
self.is_exhausted: bool = False
# Initialize data
self.iterator: Iterator[str] = self._init_data(data)
# Set up the buffer to hold the 'Answer' content
self.buffer: io.StringIO = io.StringIO()
def _init_data(self, data: Union[str, Generator[str, None, None]]) -> Iterator[str]:
"""
Make sure the answer string does not contain the header as a whole line.
Process input data (either a string or a string generator)
"""
if cls.txt_header in string.split('\n'):
raise MessageError(f"Answer '{string}' contains the header '{cls.txt_header}'")
instance = super().__new__(cls, string)
return instance
if isinstance(data, str):
yield data
else:
yield from data
def __str__(self) -> str:
"""
Output all content when converted into a string
"""
# Ensure all data has been processed
for _ in self:
pass
# Return the 'Answer' content
return self.buffer.getvalue()
def __repr__(self) -> str:
return repr(str(self))
def __iter__(self) -> Generator[str, None, None]:
"""
Allows the object to be iterable
"""
# Generate content if not all data has been processed
if not self.is_exhausted:
yield from self.generator_iter()
else:
yield self.buffer.getvalue()
def generator_iter(self) -> Generator[str, None, None]:
"""
Main generator method to process data
"""
for piece in self.iterator:
# Write to buffer and yield piece for the iterator
self.buffer.write(piece)
yield piece
self.is_exhausted = True # Set the flag that all data has been processed
# If the header occurs in the 'Answer' content, raise an error
if f'\n{self.txt_header}' in self.buffer.getvalue() or self.buffer.getvalue().startswith(self.txt_header):
raise MessageError(f"Answer {repr(self.buffer.getvalue())} contains the header {repr(Answer.txt_header)}")
def __eq__(self, other: object) -> bool:
"""
Comparing the object to a string or another object
"""
if isinstance(other, str):
return str(self) == other # Compare the string value of this object to the other string
# Default behavior for comparing non-string objects
return super().__eq__(other)
def __hash__(self) -> int:
"""
Generate a hash for the object based on its string representation.
"""
return hash(str(self))
def __format__(self, format_spec: str) -> str:
"""
Return a formatted version of the string as per the format specification.
"""
return str(self).__format__(format_spec)
@classmethod
def from_list(cls: Type[AnswerInst], strings: list[str]) -> AnswerInst:
"""
Build Question from a list of strings. Make sure strings do not contain the header.
Build Answer from a list of strings. Make sure strings do not contain the header.
"""
if cls.txt_header in strings:
raise MessageError(f"Question contains the header '{cls.txt_header}'")
instance = super().__new__(cls, '\n'.join(strings).strip())
return instance
def _gen() -> Generator[str, None, None]:
if len(strings) > 0:
yield strings[0]
for s in strings[1:]:
yield '\n'
yield s
return cls(_gen())
def source_code(self, include_delims: bool = False) -> list[str]:
"""
Extract and return all source code sections.
"""
return source_code(self, include_delims)
return source_code(str(self), include_delims)
class Question(str):
@@ -422,7 +494,7 @@ class Message():
except Exception:
raise MessageError(f"'{file_path}' does not contain a valid message")
def to_str(self, with_tags: bool = False, with_file: bool = False, source_code_only: bool = False) -> str:
def to_str(self, with_metadata: bool = False, source_code_only: bool = False) -> str:
"""
Return the current Message as a string.
"""
@@ -432,15 +504,16 @@ class Message():
if self.answer:
output.extend(self.answer.source_code(include_delims=True))
return '\n'.join(output) if len(output) > 0 else ''
if with_tags:
if with_metadata:
output.append(self.tags_str())
if with_file:
output.append('FILE: ' + str(self.file_path))
output.append('AI: ' + str(self.ai))
output.append('MODEL: ' + str(self.model))
output.append(Question.txt_header)
output.append(self.question)
if self.answer:
output.append(Answer.txt_header)
output.append(self.answer)
output.append(str(self.answer))
return '\n'.join(output)
def to_file(self, file_path: Optional[pathlib.Path]=None, mformat: MessageFormat = message_default_format) -> None: # noqa: 11
@@ -490,7 +563,7 @@ class Message():
temp_fd.write(f'{ModelLine.from_model(self.model)}\n')
temp_fd.write(f'{Question.txt_header}\n{self.question}\n')
if self.answer:
temp_fd.write(f'{Answer.txt_header}\n{self.answer}\n')
temp_fd.write(f'{Answer.txt_header}\n{str(self.answer)}\n')
shutil.move(temp_file_path, file_path)
def __to_file_yaml(self, file_path: pathlib.Path) -> None:
@@ -559,7 +632,7 @@ class Message():
or (mfilter.ai and (not self.ai or mfilter.ai != self.ai)) # noqa: W503
or (mfilter.model and (not self.model or mfilter.model != self.model)) # noqa: W503
or (mfilter.question_contains and mfilter.question_contains not in self.question) # noqa: W503
or (mfilter.answer_contains and (not self.answer or mfilter.answer_contains not in self.answer)) # noqa: W503
or (mfilter.answer_contains and (not self.answer or mfilter.answer_contains not in str(self.answer))) # noqa: W503
or (mfilter.answer_state == 'available' and not self.answer) # noqa: W503
or (mfilter.ai_state == 'available' and not self.ai) # noqa: W503
or (mfilter.model_state == 'available' and not self.model) # noqa: W503
+1
View File
@@ -2,3 +2,4 @@ openai
PyYAML
argcomplete
pytest
tiktoken
+24 -24
View File
@@ -9,33 +9,32 @@ from chatmastermind.configuration import OpenAIConfig
class OpenAITest(unittest.TestCase):
@mock.patch('openai.ChatCompletion.create')
@mock.patch('chatmastermind.ais.openai.OpenAI._completions')
def test_request(self, mock_create: mock.MagicMock) -> None:
# Create a test instance of OpenAI
config = OpenAIConfig()
openai = OpenAI(config)
# Set up the mock response from openai.ChatCompletion.create
mock_response = {
'choices': [
{
'message': {
'content': 'Answer 1'
}
},
{
'message': {
'content': 'Answer 2'
}
}
],
'usage': {
'prompt_tokens': 10,
'completion_tokens': 20,
'total_tokens': 30
}
}
mock_create.return_value = mock_response
class mock_obj:
pass
mock_chunk1 = mock_obj()
mock_chunk1.choices = [mock_obj(), mock_obj()] # type: ignore
mock_chunk1.choices[0].index = 0 # type: ignore
mock_chunk1.choices[0].delta = mock_obj() # type: ignore
mock_chunk1.choices[0].delta.content = 'Answer 1' # type: ignore
mock_chunk1.choices[0].finish_reason = None # type: ignore
mock_chunk1.choices[1].index = 1 # type: ignore
mock_chunk1.choices[1].delta = mock_obj() # type: ignore
mock_chunk1.choices[1].delta.content = 'Answer 2' # type: ignore
mock_chunk1.choices[1].finish_reason = None # type: ignore
mock_chunk2 = mock_obj()
mock_chunk2.choices = [mock_obj(), mock_obj()] # type: ignore
mock_chunk2.choices[0].index = 0 # type: ignore
mock_chunk2.choices[0].finish_reason = 'stop' # type: ignore
mock_chunk2.choices[1].index = 1 # type: ignore
mock_chunk2.choices[1].finish_reason = 'stop' # type: ignore
mock_create.return_value = iter([mock_chunk1, mock_chunk2])
# Create test data
question = Message(Question('Question'))
@@ -57,9 +56,9 @@ class OpenAITest(unittest.TestCase):
self.assertIsNotNone(response.tokens)
self.assertIsInstance(response.tokens, Tokens)
assert response.tokens
self.assertEqual(response.tokens.prompt, 10)
self.assertEqual(response.tokens.completion, 20)
self.assertEqual(response.tokens.total, 30)
self.assertEqual(response.tokens.prompt, 53)
self.assertEqual(response.tokens.completion, 6)
self.assertEqual(response.tokens.total, 59)
# Assert the mock call to openai.ChatCompletion.create
mock_create.assert_called_once_with(
@@ -76,6 +75,7 @@ class OpenAITest(unittest.TestCase):
max_tokens=config.max_tokens,
top_p=config.top_p,
n=2,
stream=True,
frequency_penalty=config.frequency_penalty,
presence_penalty=config.presence_penalty
)
+60 -51
View File
@@ -7,7 +7,7 @@ from io import StringIO
from unittest.mock import patch
from chatmastermind.tags import TagLine
from chatmastermind.message import Message, Question, Answer, Tag, MessageFilter
from chatmastermind.chat import Chat, ChatDB, ChatError
from chatmastermind.chat import Chat, ChatDB, ChatError, msg_location
msg_suffix: str = Message.file_suffix_write
@@ -41,10 +41,14 @@ class TestChat(TestChatBase):
self.message1 = Message(Question('Question 1'),
Answer('Answer 1'),
{Tag('atag1'), Tag('btag2')},
ai='FakeAI',
model='FakeModel',
file_path=pathlib.Path(f'0001{msg_suffix}'))
self.message2 = Message(Question('Question 2'),
Answer('Answer 2'),
{Tag('btag2')},
ai='FakeAI',
model='FakeModel',
file_path=pathlib.Path(f'0002{msg_suffix}'))
self.maxDiff = None
@@ -143,7 +147,7 @@ class TestChat(TestChatBase):
@patch('sys.stdout', new_callable=StringIO)
def test_print(self, mock_stdout: StringIO) -> None:
self.chat.msg_add([self.message1, self.message2])
self.chat.print(paged=False)
self.chat.print(paged=False, tight=True)
expected_output = f"""{Question.txt_header}
Question 1
{Answer.txt_header}
@@ -156,17 +160,21 @@ Answer 2
self.assertEqual(mock_stdout.getvalue(), expected_output)
@patch('sys.stdout', new_callable=StringIO)
def test_print_with_tags_and_file(self, mock_stdout: StringIO) -> None:
def test_print_with_metadata(self, mock_stdout: StringIO) -> None:
self.chat.msg_add([self.message1, self.message2])
self.chat.print(paged=False, with_tags=True, with_files=True)
self.chat.print(paged=False, with_metadata=True, tight=True)
expected_output = f"""{TagLine.prefix} atag1 btag2
FILE: 0001{msg_suffix}
AI: FakeAI
MODEL: FakeModel
{Question.txt_header}
Question 1
{Answer.txt_header}
Answer 1
{TagLine.prefix} btag2
FILE: 0002{msg_suffix}
AI: FakeAI
MODEL: FakeModel
{Question.txt_header}
Question 2
{Answer.txt_header}
@@ -232,7 +240,8 @@ class TestChatDB(TestChatBase):
msg_to_file_force_suffix(duplicate_message)
with self.assertRaises(ChatError) as cm:
ChatDB.from_dir(pathlib.Path(self.cache_path.name),
pathlib.Path(self.db_path.name))
pathlib.Path(self.db_path.name),
glob='*')
self.assertEqual(str(cm.exception), "Validation failed")
def test_file_path_ID_exists(self) -> None:
@@ -587,92 +596,92 @@ class TestChatDB(TestChatBase):
chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name),
pathlib.Path(self.db_path.name))
# search for a DB file in memory
self.assertEqual(chat_db.msg_find([str(self.message1.file_path)], loc='mem'), [self.message1])
self.assertEqual(chat_db.msg_find([self.message1.msg_id()], loc='mem'), [self.message1])
self.assertEqual(chat_db.msg_find(['0001.msg'], loc='mem'), [self.message1])
self.assertEqual(chat_db.msg_find(['0001'], loc='mem'), [self.message1])
self.assertEqual(chat_db.msg_find([str(self.message1.file_path)], loc=msg_location.MEM), [self.message1])
self.assertEqual(chat_db.msg_find([self.message1.msg_id()], loc=msg_location.MEM), [self.message1])
self.assertEqual(chat_db.msg_find(['0001.msg'], loc=msg_location.MEM), [self.message1])
self.assertEqual(chat_db.msg_find(['0001'], loc=msg_location.MEM), [self.message1])
# and on disk
self.assertEqual(chat_db.msg_find([str(self.message2.file_path)], loc='db'), [self.message2])
self.assertEqual(chat_db.msg_find([self.message2.msg_id()], loc='db'), [self.message2])
self.assertEqual(chat_db.msg_find(['0002.msg'], loc='db'), [self.message2])
self.assertEqual(chat_db.msg_find(['0002'], loc='db'), [self.message2])
self.assertEqual(chat_db.msg_find([str(self.message2.file_path)], loc=msg_location.DB), [self.message2])
self.assertEqual(chat_db.msg_find([self.message2.msg_id()], loc=msg_location.DB), [self.message2])
self.assertEqual(chat_db.msg_find(['0002.msg'], loc=msg_location.DB), [self.message2])
self.assertEqual(chat_db.msg_find(['0002'], loc=msg_location.DB), [self.message2])
# now search the cache -> expect empty result
self.assertEqual(chat_db.msg_find([str(self.message3.file_path)], loc='cache'), [])
self.assertEqual(chat_db.msg_find([self.message3.msg_id()], loc='cache'), [])
self.assertEqual(chat_db.msg_find(['0003.msg'], loc='cache'), [])
self.assertEqual(chat_db.msg_find(['0003'], loc='cache'), [])
self.assertEqual(chat_db.msg_find([str(self.message3.file_path)], loc=msg_location.CACHE), [])
self.assertEqual(chat_db.msg_find([self.message3.msg_id()], loc=msg_location.CACHE), [])
self.assertEqual(chat_db.msg_find(['0003.msg'], loc=msg_location.CACHE), [])
self.assertEqual(chat_db.msg_find(['0003'], loc=msg_location.CACHE), [])
# search for multiple messages
# -> search one twice, expect result to be unique
search_names = ['0001', '0002.msg', self.message3.msg_id(), str(self.message3.file_path)]
expected_result = [self.message1, self.message2, self.message3]
result = chat_db.msg_find(search_names, loc='all')
result = chat_db.msg_find(search_names, loc=msg_location.ALL)
self.assert_messages_equal(result, expected_result)
def test_msg_latest(self) -> None:
chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name),
pathlib.Path(self.db_path.name))
self.assertEqual(chat_db.msg_latest(loc='mem'), self.message4)
self.assertEqual(chat_db.msg_latest(loc='db'), self.message4)
self.assertEqual(chat_db.msg_latest(loc='disk'), self.message4)
self.assertEqual(chat_db.msg_latest(loc='all'), self.message4)
self.assertEqual(chat_db.msg_latest(loc=msg_location.MEM), self.message4)
self.assertEqual(chat_db.msg_latest(loc=msg_location.DB), self.message4)
self.assertEqual(chat_db.msg_latest(loc=msg_location.DISK), self.message4)
self.assertEqual(chat_db.msg_latest(loc=msg_location.ALL), self.message4)
# the cache is currently empty:
self.assertIsNone(chat_db.msg_latest(loc='cache'))
self.assertIsNone(chat_db.msg_latest(loc=msg_location.CACHE))
# add new messages to the cache dir
new_message = Message(question=Question("New Question"),
answer=Answer("New Answer"))
chat_db.cache_add([new_message])
self.assertEqual(chat_db.msg_latest(loc='cache'), new_message)
self.assertEqual(chat_db.msg_latest(loc='mem'), new_message)
self.assertEqual(chat_db.msg_latest(loc='disk'), new_message)
self.assertEqual(chat_db.msg_latest(loc='all'), new_message)
self.assertEqual(chat_db.msg_latest(loc=msg_location.CACHE), new_message)
self.assertEqual(chat_db.msg_latest(loc=msg_location.MEM), new_message)
self.assertEqual(chat_db.msg_latest(loc=msg_location.DISK), new_message)
self.assertEqual(chat_db.msg_latest(loc=msg_location.ALL), new_message)
# the DB does not contain the new message
self.assertEqual(chat_db.msg_latest(loc='db'), self.message4)
self.assertEqual(chat_db.msg_latest(loc=msg_location.DB), self.message4)
def test_msg_gather(self) -> None:
chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name),
pathlib.Path(self.db_path.name))
all_messages = [self.message1, self.message2, self.message3, self.message4]
self.assert_messages_equal(chat_db.msg_gather(loc='all'), all_messages)
self.assert_messages_equal(chat_db.msg_gather(loc='db'), all_messages)
self.assert_messages_equal(chat_db.msg_gather(loc='mem'), all_messages)
self.assert_messages_equal(chat_db.msg_gather(loc='disk'), all_messages)
self.assert_messages_equal(chat_db.msg_gather(loc='cache'), [])
self.assert_messages_equal(chat_db.msg_gather(loc=msg_location.ALL), all_messages)
self.assert_messages_equal(chat_db.msg_gather(loc=msg_location.DB), all_messages)
self.assert_messages_equal(chat_db.msg_gather(loc=msg_location.MEM), all_messages)
self.assert_messages_equal(chat_db.msg_gather(loc=msg_location.DISK), all_messages)
self.assert_messages_equal(chat_db.msg_gather(loc=msg_location.CACHE), [])
# add a new message, but only to the internal list
new_message = Message(Question("What?"))
all_messages_mem = all_messages + [new_message]
chat_db.msg_add([new_message])
self.assert_messages_equal(chat_db.msg_gather(loc='mem'), all_messages_mem)
self.assert_messages_equal(chat_db.msg_gather(loc='all'), all_messages_mem)
self.assert_messages_equal(chat_db.msg_gather(loc=msg_location.MEM), all_messages_mem)
self.assert_messages_equal(chat_db.msg_gather(loc=msg_location.ALL), all_messages_mem)
# the nr. of messages on disk did not change -> expect old result
self.assert_messages_equal(chat_db.msg_gather(loc='db'), all_messages)
self.assert_messages_equal(chat_db.msg_gather(loc='disk'), all_messages)
self.assert_messages_equal(chat_db.msg_gather(loc='cache'), [])
self.assert_messages_equal(chat_db.msg_gather(loc=msg_location.DB), all_messages)
self.assert_messages_equal(chat_db.msg_gather(loc=msg_location.DISK), all_messages)
self.assert_messages_equal(chat_db.msg_gather(loc=msg_location.CACHE), [])
# test with MessageFilter
self.assert_messages_equal(chat_db.msg_gather(loc='all', mfilter=MessageFilter(tags_or={Tag('tag1')})),
self.assert_messages_equal(chat_db.msg_gather(loc=msg_location.ALL, mfilter=MessageFilter(tags_or={Tag('tag1')})),
[self.message1])
self.assert_messages_equal(chat_db.msg_gather(loc='disk', mfilter=MessageFilter(tags_or={Tag('tag2')})),
self.assert_messages_equal(chat_db.msg_gather(loc=msg_location.DISK, mfilter=MessageFilter(tags_or={Tag('tag2')})),
[self.message2])
self.assert_messages_equal(chat_db.msg_gather(loc='cache', mfilter=MessageFilter(tags_or={Tag('tag3')})),
self.assert_messages_equal(chat_db.msg_gather(loc=msg_location.CACHE, mfilter=MessageFilter(tags_or={Tag('tag3')})),
[])
self.assert_messages_equal(chat_db.msg_gather(loc='mem', mfilter=MessageFilter(question_contains="What")),
self.assert_messages_equal(chat_db.msg_gather(loc=msg_location.MEM, mfilter=MessageFilter(question_contains="What")),
[new_message])
def test_msg_move_and_gather(self) -> None:
chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name),
pathlib.Path(self.db_path.name))
all_messages = [self.message1, self.message2, self.message3, self.message4]
self.assert_messages_equal(chat_db.msg_gather(loc='db'), all_messages)
self.assert_messages_equal(chat_db.msg_gather(loc='cache'), [])
self.assert_messages_equal(chat_db.msg_gather(loc=msg_location.DB), all_messages)
self.assert_messages_equal(chat_db.msg_gather(loc=msg_location.CACHE), [])
# move first message to the cache
chat_db.cache_move(self.message1)
self.assert_messages_equal(chat_db.msg_gather(loc='cache'), [self.message1])
self.assert_messages_equal(chat_db.msg_gather(loc=msg_location.CACHE), [self.message1])
self.assertEqual(self.message1.file_path.parent, pathlib.Path(self.cache_path.name)) # type: ignore [union-attr]
self.assert_messages_equal(chat_db.msg_gather(loc='db'), [self.message2, self.message3, self.message4])
self.assert_messages_equal(chat_db.msg_gather(loc='all'), all_messages)
self.assert_messages_equal(chat_db.msg_gather(loc='disk'), all_messages)
self.assert_messages_equal(chat_db.msg_gather(loc='mem'), all_messages)
self.assert_messages_equal(chat_db.msg_gather(loc=msg_location.DB), [self.message2, self.message3, self.message4])
self.assert_messages_equal(chat_db.msg_gather(loc=msg_location.ALL), all_messages)
self.assert_messages_equal(chat_db.msg_gather(loc=msg_location.DISK), all_messages)
self.assert_messages_equal(chat_db.msg_gather(loc=msg_location.MEM), all_messages)
# now move first message back to the DB
chat_db.db_move(self.message1)
self.assert_messages_equal(chat_db.msg_gather(loc='cache'), [])
self.assert_messages_equal(chat_db.msg_gather(loc=msg_location.CACHE), [])
self.assertEqual(self.message1.file_path.parent, pathlib.Path(self.db_path.name)) # type: ignore [union-attr]
self.assert_messages_equal(chat_db.msg_gather(loc='db'), all_messages)
self.assert_messages_equal(chat_db.msg_gather(loc=msg_location.DB), all_messages)
+62
View File
@@ -0,0 +1,62 @@
import unittest
import argparse
import tempfile
import yaml
from pathlib import Path
from chatmastermind.message import Message, Question
from chatmastermind.chat import ChatDB, ChatError, msg_location
from chatmastermind.configuration import Config
from chatmastermind.commands.hist import convert_messages
msg_suffix = Message.file_suffix_write
class TestConvertMessages(unittest.TestCase):
def setUp(self) -> None:
self.db_dir = tempfile.TemporaryDirectory()
self.cache_dir = tempfile.TemporaryDirectory()
self.db_path = Path(self.db_dir.name)
self.cache_path = Path(self.cache_dir.name)
self.args = argparse.Namespace()
self.config = Config()
self.config.cache = self.cache_dir.name
self.config.db = self.db_dir.name
# Prepare some messages
self.chat = ChatDB.from_dir(Path(self.cache_path),
Path(self.db_path))
self.messages = [Message(Question(f'Question {i}')) for i in range(0, 6)]
self.chat.db_write(self.messages[0:2])
self.chat.cache_write(self.messages[2:])
# Change some of the suffixes
assert self.messages[0].file_path
assert self.messages[1].file_path
self.messages[0].file_path.rename(self.messages[0].file_path.with_suffix('.txt'))
self.messages[1].file_path.rename(self.messages[1].file_path.with_suffix('.yaml'))
def tearDown(self) -> None:
self.db_dir.cleanup()
self.cache_dir.cleanup()
def test_convert_messages(self) -> None:
self.args.convert = 'yaml'
convert_messages(self.args, self.config)
msgs = self.chat.msg_gather(loc=msg_location.DISK, glob='*.*')
# Check if the number of messages is the same as before
self.assertEqual(len(msgs), len(self.messages))
# Check if all messages have the requested suffix
for msg in msgs:
assert msg.file_path
self.assertEqual(msg.file_path.suffix, msg_suffix)
# Check if the message IDs are correctly maintained
for m_new, m_old in zip(msgs, self.messages):
self.assertEqual(m_new.msg_id(), m_old.msg_id())
# check if all messages have the new format
for m in msgs:
with open(str(m.file_path), "r") as fd:
yaml.load(fd, Loader=yaml.FullLoader)
def test_convert_messages_wrong_format(self) -> None:
self.args.convert = 'foo'
with self.assertRaises(ChatError):
convert_messages(self.args, self.config)
+6 -2
View File
@@ -91,7 +91,7 @@ class QuestionTestCase(unittest.TestCase):
class AnswerTestCase(unittest.TestCase):
def test_answer_with_header(self) -> None:
with self.assertRaises(MessageError):
Answer(f"{Answer.txt_header}\nno")
str(Answer(f"{Answer.txt_header}\nno"))
def test_answer_with_legal_header(self) -> None:
answer = Answer(f"This is a line contaning '{Answer.txt_header}'\nIt is what it is.")
@@ -856,6 +856,8 @@ class MessageToStrTestCase(unittest.TestCase):
def setUp(self) -> None:
self.message = Message(Question('This is a question.'),
Answer('This is an answer.'),
ai=('FakeAI'),
model=('FakeModel'),
tags={Tag('atag1'), Tag('btag2')},
file_path=pathlib.Path('/tmp/foo/bla'))
@@ -869,11 +871,13 @@ This is an answer."""
def test_to_str_with_tags_and_file(self) -> None:
expected_output = f"""{TagLine.prefix} atag1 btag2
FILE: /tmp/foo/bla
AI: FakeAI
MODEL: FakeModel
{Question.txt_header}
This is a question.
{Answer.txt_header}
This is an answer."""
self.assertEqual(self.message.to_str(with_tags=True, with_file=True), expected_output)
self.assertEqual(self.message.to_str(with_metadata=True), expected_output)
class MessageRmFileTestCase(unittest.TestCase):
+33 -14
View File
@@ -9,7 +9,7 @@ from chatmastermind.configuration import Config
from chatmastermind.commands.question import create_message, question_cmd
from chatmastermind.tags import Tag
from chatmastermind.message import Message, Question, Answer
from chatmastermind.chat import Chat, ChatDB
from chatmastermind.chat import Chat, ChatDB, msg_location
from chatmastermind.ai import AIError
from .test_common import TestWithFakeAI
@@ -41,6 +41,8 @@ class TestMessageCreate(TestWithFakeAI):
self.args.AI = None
self.args.model = None
self.args.output_tags = None
self.args.ask = None
self.args.create = None
# File 1 : no source code block, only text
self.source_file1 = tempfile.NamedTemporaryFile(delete=False)
self.source_file1_content = """This is just text.
@@ -204,6 +206,21 @@ It is embedded code
"""))
class TestCreateOption(TestMessageCreate):
def test_message_file_created(self) -> None:
self.args.create = ["How does question --create work?"]
self.args.ask = None
cache_dir_files = self.message_list(self.cache_dir)
self.assertEqual(len(cache_dir_files), 0)
create_message(self.chat, self.args)
cache_dir_files = self.message_list(self.cache_dir)
self.assertEqual(len(cache_dir_files), 1)
message = Message.from_file(cache_dir_files[0])
self.assertIsInstance(message, Message)
self.assertEqual(message.question, Question("How does question --create work?")) # type: ignore [union-attr]
class TestQuestionCmd(TestWithFakeAI):
def setUp(self) -> None:
@@ -217,6 +234,8 @@ class TestQuestionCmd(TestWithFakeAI):
# create a mock argparse.Namespace
self.args = argparse.Namespace(
ask=['What is the meaning of life?'],
glob=None,
location='db',
num_answers=1,
output_tags=['science'],
AI='FakeAI',
@@ -262,7 +281,7 @@ class TestQuestionCmdAsk(TestQuestionCmd):
# check for the expected message files
chat = ChatDB.from_dir(Path(self.cache_dir.name),
Path(self.db_dir.name))
cached_msg = chat.msg_gather(loc='cache')
cached_msg = chat.msg_gather(loc=msg_location.CACHE)
self.assertEqual(len(self.message_list(self.cache_dir)), 1)
self.assert_msgs_equal_except_file_path(cached_msg, expected_responses)
@@ -320,7 +339,7 @@ class TestQuestionCmdAsk(TestQuestionCmd):
# check for the expected message files
chat = ChatDB.from_dir(Path(self.cache_dir.name),
Path(self.db_dir.name))
cached_msg = chat.msg_gather(loc='cache')
cached_msg = chat.msg_gather(loc=msg_location.CACHE)
self.assertEqual(len(self.message_list(self.cache_dir)), 1)
self.assert_msgs_equal_except_file_path(cached_msg, [expected_question])
@@ -358,7 +377,7 @@ class TestQuestionCmdRepeat(TestQuestionCmd):
# we expect the original message + the one with the new response
expected_responses = [message] + [expected_response]
question_cmd(self.args, self.config)
cached_msg = chat.msg_gather(loc='cache')
cached_msg = chat.msg_gather(loc=msg_location.CACHE)
print(self.message_list(self.cache_dir))
self.assertEqual(len(self.message_list(self.cache_dir)), 2)
self.assert_msgs_equal_except_file_path(cached_msg, expected_responses)
@@ -379,7 +398,7 @@ class TestQuestionCmdRepeat(TestQuestionCmd):
model=self.args.model,
file_path=Path(self.cache_dir.name) / f'0001{msg_suffix}')
chat.msg_write([message])
cached_msg = chat.msg_gather(loc='cache')
cached_msg = chat.msg_gather(loc=msg_location.CACHE)
assert cached_msg[0].file_path
cached_msg_file_id = cached_msg[0].file_path.stem
@@ -395,7 +414,7 @@ class TestQuestionCmdRepeat(TestQuestionCmd):
tags=message.tags,
file_path=Path('<NOT COMPARED>'))
question_cmd(self.args, self.config)
cached_msg = chat.msg_gather(loc='cache')
cached_msg = chat.msg_gather(loc=msg_location.CACHE)
self.assertEqual(len(self.message_list(self.cache_dir)), 1)
self.assert_msgs_equal_except_file_path(cached_msg, [expected_response])
# also check that the file ID has not been changed
@@ -418,7 +437,7 @@ class TestQuestionCmdRepeat(TestQuestionCmd):
model=self.args.model,
file_path=Path(self.cache_dir.name) / f'0001{msg_suffix}')
chat.msg_write([message])
cached_msg = chat.msg_gather(loc='cache')
cached_msg = chat.msg_gather(loc=msg_location.CACHE)
assert cached_msg[0].file_path
cached_msg_file_id = cached_msg[0].file_path.stem
@@ -435,7 +454,7 @@ class TestQuestionCmdRepeat(TestQuestionCmd):
tags=message.tags,
file_path=Path('<NOT COMPARED>'))
question_cmd(self.args, self.config)
cached_msg = chat.msg_gather(loc='cache')
cached_msg = chat.msg_gather(loc=msg_location.CACHE)
self.assertEqual(len(self.message_list(self.cache_dir)), 1)
self.assert_msgs_equal_except_file_path(cached_msg, [expected_response])
# also check that the file ID has not been changed
@@ -458,7 +477,7 @@ class TestQuestionCmdRepeat(TestQuestionCmd):
model=self.args.model,
file_path=Path(self.cache_dir.name) / f'0001{msg_suffix}')
chat.msg_write([message])
cached_msg = chat.msg_gather(loc='cache')
cached_msg = chat.msg_gather(loc=msg_location.CACHE)
assert cached_msg[0].file_path
# repeat the last question with new arguments (without overwriting)
@@ -476,7 +495,7 @@ class TestQuestionCmdRepeat(TestQuestionCmd):
tags={Tag('newtag')},
file_path=Path('<NOT COMPARED>'))
question_cmd(self.args, self.config)
cached_msg = chat.msg_gather(loc='cache')
cached_msg = chat.msg_gather(loc=msg_location.CACHE)
self.assertEqual(len(self.message_list(self.cache_dir)), 2)
self.assert_msgs_equal_except_file_path(cached_msg, [message] + [new_expected_response])
@@ -496,7 +515,7 @@ class TestQuestionCmdRepeat(TestQuestionCmd):
model=self.args.model,
file_path=Path(self.cache_dir.name) / f'0001{msg_suffix}')
chat.msg_write([message])
cached_msg = chat.msg_gather(loc='cache')
cached_msg = chat.msg_gather(loc=msg_location.CACHE)
assert cached_msg[0].file_path
# repeat the last question with new arguments
@@ -513,7 +532,7 @@ class TestQuestionCmdRepeat(TestQuestionCmd):
tags={Tag('newtag')},
file_path=Path('<NOT COMPARED>'))
question_cmd(self.args, self.config)
cached_msg = chat.msg_gather(loc='cache')
cached_msg = chat.msg_gather(loc=msg_location.CACHE)
self.assertEqual(len(self.message_list(self.cache_dir)), 1)
self.assert_msgs_equal_except_file_path(cached_msg, [new_expected_response])
@@ -569,8 +588,8 @@ class TestQuestionCmdRepeat(TestQuestionCmd):
self.assertEqual(len(self.message_list(self.cache_dir)), 4)
self.assertEqual(len(self.message_list(self.db_dir)), 1)
expected_cache_messages = [expected_responses[0], message2, expected_responses[1], expected_responses[2]]
cached_msg = chat.msg_gather(loc='cache')
cached_msg = chat.msg_gather(loc=msg_location.CACHE)
self.assert_msgs_equal_except_file_path(cached_msg, expected_cache_messages)
# check that the DB message has not been modified at all
db_msg = chat.msg_gather(loc='db')
db_msg = chat.msg_gather(loc=msg_location.DB)
self.assert_msgs_all_equal(db_msg, [message3])