Change type msg_location to an Enum instead of Literal to be able to get all values easy and improve type checks.

This commit is contained in:
Oleksandr Kozachuk
2023-10-19 16:00:44 +02:00
parent 5f29f60168
commit 9b0951cb3f
7 changed files with 95 additions and 86 deletions
+23 -14
View File
@@ -6,7 +6,8 @@ from pathlib import Path
from pprint import PrettyPrinter
from pydoc import pager
from dataclasses import dataclass
from typing import TypeVar, Type, Optional, Any, Callable, Literal, Union
from enum import Enum
from typing import TypeVar, Type, Optional, Any, Callable, Union
from .configuration import default_config_file
from .message import Message, MessageFilter, MessageError, MessageFormat, message_in, message_valid_formats
from .tags import Tag
@@ -16,10 +17,17 @@ ChatDBInst = TypeVar('ChatDBInst', bound='ChatDB')
db_next_file = '.next'
ignored_files = [db_next_file, default_config_file]
msg_location = Literal['mem', 'disk', 'cache', 'db', 'all']
msg_suffix = Message.file_suffix_write
class msg_location(Enum):
MEM = 'mem'
DISK = 'disk'
CACHE = 'cache'
DB = 'db'
ALL = 'all'
class ChatError(Exception):
pass
@@ -304,7 +312,8 @@ class ChatDB(Chat):
cache_path: Path,
db_path: Path,
glob: Optional[str] = None,
mfilter: Optional[MessageFilter] = None) -> ChatDBInst:
mfilter: Optional[MessageFilter] = None,
loc: msg_location = msg_location.DB) -> ChatDBInst:
"""
Create a 'ChatDB' instance from the given directory structure.
Reads all messages from 'db_path' into the local message list.
@@ -399,14 +408,14 @@ class ChatDB(Chat):
If 'require_file_path' is True, return only files with a valid file_path.
"""
loc_messages: list[Message] = []
if loc in ['mem', 'all']:
if loc in [msg_location.MEM, msg_location.ALL]:
if require_file_path:
loc_messages += [m for m in self.messages if (m.file_path is not None and (mfilter is None or m.match(mfilter)))]
else:
loc_messages += [m for m in self.messages if (mfilter is None or m.match(mfilter))]
if loc in ['cache', 'disk', 'all']:
if loc in [msg_location.CACHE, msg_location.DISK, msg_location.ALL]:
loc_messages += read_dir(self.cache_path, glob=glob, mfilter=mfilter)
if loc in ['db', 'disk', 'all']:
if loc in [msg_location.DB, msg_location.DISK, msg_location.ALL]:
loc_messages += read_dir(self.db_path, glob=glob, mfilter=mfilter)
# remove_duplicates and sort the list
unique_messages: list[Message] = []
@@ -422,7 +431,7 @@ class ChatDB(Chat):
def msg_find(self,
msg_names: list[str],
loc: msg_location = 'mem',
loc: msg_location = msg_location.MEM,
) -> list[Message]:
"""
Search and return the messages with the given names. Names can either be filenames
@@ -440,7 +449,7 @@ class ChatDB(Chat):
return [m for m in loc_messages
if any((m.file_path and self.msg_name_matches(m.file_path, mn)) for mn in msg_names)]
def msg_remove(self, msg_names: list[str], loc: msg_location = 'mem') -> None:
def msg_remove(self, msg_names: list[str], loc: msg_location = msg_location.MEM) -> None:
"""
Remove the messages with the given names. Names can either be filenames
(with or without suffix), full paths or Message.msg_id(). Also deletes the
@@ -452,7 +461,7 @@ class ChatDB(Chat):
* 'db' : messages in the DB directory
* 'all' : all messages ('mem' + 'disk')
"""
if loc != 'mem':
if loc != msg_location.MEM:
# delete the message files first
rm_messages = self.msg_find(msg_names, loc=loc)
for m in rm_messages:
@@ -463,7 +472,7 @@ class ChatDB(Chat):
def msg_latest(self,
mfilter: Optional[MessageFilter] = None,
loc: msg_location = 'mem') -> Optional[Message]:
loc: msg_location = msg_location.MEM) -> Optional[Message]:
"""
Return the last added message (according to the file ID) that matches the given filter.
Only consider messages with a valid file_path (except if loc is 'mem').
@@ -492,7 +501,7 @@ class ChatDB(Chat):
and message.file_path.parent.samefile(self.cache_path) # noqa: W503
and message.file_path.exists()) # noqa: W503
else:
return len(self.msg_find([message], loc='cache')) > 0
return len(self.msg_find([message], loc=msg_location.CACHE)) > 0
def msg_in_db(self, message: Union[Message, str]) -> bool:
"""
@@ -504,7 +513,7 @@ class ChatDB(Chat):
and message.file_path.parent.samefile(self.db_path) # noqa: W503
and message.file_path.exists()) # noqa: W503
else:
return len(self.msg_find([message], loc='db')) > 0
return len(self.msg_find([message], loc=msg_location.DB)) > 0
def cache_read(self, glob: Optional[str] = None, mfilter: Optional[MessageFilter] = None) -> None:
"""
@@ -569,7 +578,7 @@ class ChatDB(Chat):
self.cache_write([message])
# remove the old one (if any)
if old_path:
self.msg_remove([str(old_path)], loc='db')
self.msg_remove([str(old_path)], loc=msg_location.DB)
# (re)add it to the internal list
self.msg_add([message])
@@ -628,6 +637,6 @@ class ChatDB(Chat):
self.db_write([message])
# remove the old one (if any)
if old_path:
self.msg_remove([str(old_path)], loc='cache')
self.msg_remove([str(old_path)], loc=msg_location.CACHE)
# (re)add it to the internal list
self.msg_add([message])
+4 -4
View File
@@ -2,7 +2,7 @@ import sys
import argparse
from pathlib import Path
from ..configuration import Config
from ..chat import ChatDB
from ..chat import ChatDB, msg_location
from ..message import MessageFilter, Message
@@ -17,7 +17,7 @@ def convert_messages(args: argparse.Namespace, config: Config) -> None:
chat = ChatDB.from_dir(Path(config.cache),
Path(config.db))
# read all known message files
msgs = chat.msg_gather(loc='disk', glob='*.*')
msgs = chat.msg_gather(loc=msg_location.DISK, glob='*.*')
# make a set of all message IDs
msg_ids = set([m.msg_id() for m in msgs])
# set requested format and write all messages
@@ -29,14 +29,14 @@ def convert_messages(args: argparse.Namespace, config: Config) -> None:
m.file_path = m.file_path.with_suffix('')
chat.msg_write(msgs)
# read all messages with the current default suffix
msgs = chat.msg_gather(loc='disk', glob=f'*{msg_suffix}')
msgs = chat.msg_gather(loc=msg_location.DISK, glob=f'*{msg_suffix}')
# make sure we converted all of the original messages
for mid in msg_ids:
if not any(mid == m.msg_id() for m in msgs):
print(f"Message '{mid}' has not been found after conversion. Aborting.")
sys.exit(1)
# delete messages with old suffixes
msgs = chat.msg_gather(loc='disk', glob='*.*')
msgs = chat.msg_gather(loc=msg_location.DISK, glob='*.*')
for m in msgs:
if m.file_path and m.file_path.suffix != msg_suffix:
m.rm_file()
+2 -2
View File
@@ -3,7 +3,7 @@ import argparse
from pathlib import Path
from ..configuration import Config
from ..message import Message, MessageError
from ..chat import ChatDB
from ..chat import ChatDB, msg_location
def print_message(message: Message, args: argparse.Namespace) -> None:
@@ -38,7 +38,7 @@ def print_cmd(args: argparse.Namespace, config: Config) -> None:
# print latest message
elif args.latest:
chat = ChatDB.from_dir(Path(config.cache), Path(config.db))
latest = chat.msg_latest(loc='disk')
latest = chat.msg_latest(loc=msg_location.DISK)
if not latest:
print("No message found!")
sys.exit(1)
+3 -3
View File
@@ -4,7 +4,7 @@ from pathlib import Path
from itertools import zip_longest
from copy import deepcopy
from ..configuration import Config
from ..chat import ChatDB
from ..chat import ChatDB, msg_location
from ..message import Message, MessageFilter, MessageError, Question, source_code
from ..ai_factory import create_ai
from ..ai import AI, AIResponse
@@ -202,14 +202,14 @@ def question_cmd(args: argparse.Namespace, config: Config) -> None:
repeat_msgs: list[Message] = []
# repeat latest message
if len(args.repeat) == 0:
lmessage = chat.msg_latest(loc='cache')
lmessage = chat.msg_latest(loc=msg_location.CACHE)
if lmessage is None:
print("No message found to repeat!")
sys.exit(1)
repeat_msgs.append(lmessage)
# repeat given message(s)
else:
repeat_msgs = chat.msg_find(args.repeat, loc='disk')
repeat_msgs = chat.msg_find(args.repeat, loc=msg_location.DISK)
repeat_messages(repeat_msgs, chat, args, config)
# === PROCESS ===
elif args.process is not None: