Refactoring, fixes and new features for the 'chat.py' module #12
+10
-4
@@ -7,12 +7,16 @@ from pprint import PrettyPrinter
|
|||||||
from pydoc import pager
|
from pydoc import pager
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import TypeVar, Type, Optional, ClassVar, Any, Callable, Literal
|
from typing import TypeVar, Type, Optional, ClassVar, Any, Callable, Literal
|
||||||
|
from .configuration import default_config_file
|
||||||
from .message import Message, MessageFilter, MessageError, message_in
|
from .message import Message, MessageFilter, MessageError, message_in
|
||||||
from .tags import Tag
|
from .tags import Tag
|
||||||
|
|
||||||
ChatInst = TypeVar('ChatInst', bound='Chat')
|
ChatInst = TypeVar('ChatInst', bound='Chat')
|
||||||
ChatDBInst = TypeVar('ChatDBInst', bound='ChatDB')
|
ChatDBInst = TypeVar('ChatDBInst', bound='ChatDB')
|
||||||
|
|
||||||
|
db_next_file = '.next'
|
||||||
|
ignored_files = [db_next_file, default_config_file]
|
||||||
|
|
||||||
|
|
||||||
class ChatError(Exception):
|
class ChatError(Exception):
|
||||||
pass
|
pass
|
||||||
@@ -45,7 +49,9 @@ def read_dir(dir_path: Path,
|
|||||||
messages: list[Message] = []
|
messages: list[Message] = []
|
||||||
file_iter = dir_path.glob(glob) if glob else dir_path.iterdir()
|
file_iter = dir_path.glob(glob) if glob else dir_path.iterdir()
|
||||||
for file_path in sorted(file_iter):
|
for file_path in sorted(file_iter):
|
||||||
if file_path.is_file() and file_path.suffix in Message.file_suffixes:
|
if (file_path.is_file()
|
||||||
|
and file_path.name not in ignored_files # noqa: W503
|
||||||
|
and file_path.suffix in Message.file_suffixes): # noqa: W503
|
||||||
try:
|
try:
|
||||||
message = Message.from_file(file_path, mfilter)
|
message = Message.from_file(file_path, mfilter)
|
||||||
if message:
|
if message:
|
||||||
@@ -235,7 +241,7 @@ class ChatDB(Chat):
|
|||||||
|
|
||||||
def __post_init__(self) -> None:
|
def __post_init__(self) -> None:
|
||||||
# contains the latest message ID
|
# contains the latest message ID
|
||||||
self.next_fname = self.db_path / '.next'
|
self.next_path = self.db_path / db_next_file
|
||||||
# make all paths absolute
|
# make all paths absolute
|
||||||
self.cache_path = self.cache_path.absolute()
|
self.cache_path = self.cache_path.absolute()
|
||||||
self.db_path = self.db_path.absolute()
|
self.db_path = self.db_path.absolute()
|
||||||
@@ -274,7 +280,7 @@ class ChatDB(Chat):
|
|||||||
|
|
||||||
def get_next_fid(self) -> int:
|
def get_next_fid(self) -> int:
|
||||||
try:
|
try:
|
||||||
with open(self.next_fname, 'r') as f:
|
with open(self.next_path, 'r') as f:
|
||||||
next_fid = int(f.read()) + 1
|
next_fid = int(f.read()) + 1
|
||||||
self.set_next_fid(next_fid)
|
self.set_next_fid(next_fid)
|
||||||
return next_fid
|
return next_fid
|
||||||
@@ -283,7 +289,7 @@ class ChatDB(Chat):
|
|||||||
return 1
|
return 1
|
||||||
|
|
||||||
def set_next_fid(self, fid: int) -> None:
|
def set_next_fid(self, fid: int) -> None:
|
||||||
with open(self.next_fname, 'w') as f:
|
with open(self.next_path, 'w') as f:
|
||||||
f.write(f'{fid}')
|
f.write(f'{fid}')
|
||||||
|
|
||||||
def read_db(self) -> None:
|
def read_db(self) -> None:
|
||||||
|
|||||||
@@ -9,7 +9,7 @@ OpenAIConfigInst = TypeVar('OpenAIConfigInst', bound='OpenAIConfig')
|
|||||||
|
|
||||||
|
|
||||||
supported_ais: list[str] = ['openai']
|
supported_ais: list[str] = ['openai']
|
||||||
default_config_path = '.config.yaml'
|
default_config_file = '.config.yaml'
|
||||||
|
|
||||||
|
|
||||||
class ConfigError(Exception):
|
class ConfigError(Exception):
|
||||||
|
|||||||
@@ -7,7 +7,7 @@ import argcomplete
|
|||||||
import argparse
|
import argparse
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any
|
from typing import Any
|
||||||
from .configuration import Config, default_config_path
|
from .configuration import Config, default_config_file
|
||||||
from .message import Message
|
from .message import Message
|
||||||
from .commands.question import question_cmd
|
from .commands.question import question_cmd
|
||||||
from .commands.tags import tags_cmd
|
from .commands.tags import tags_cmd
|
||||||
@@ -24,7 +24,7 @@ def tags_completer(prefix: str, parsed_args: Any, **kwargs: Any) -> list[str]:
|
|||||||
def create_parser() -> argparse.ArgumentParser:
|
def create_parser() -> argparse.ArgumentParser:
|
||||||
parser = argparse.ArgumentParser(
|
parser = argparse.ArgumentParser(
|
||||||
description="ChatMastermind is a Python application that automates conversation with AI")
|
description="ChatMastermind is a Python application that automates conversation with AI")
|
||||||
parser.add_argument('-C', '--config', help='Config file name.', default=default_config_path)
|
parser.add_argument('-C', '--config', help='Config file name.', default=default_config_file)
|
||||||
|
|
||||||
# subcommand-parser
|
# subcommand-parser
|
||||||
cmdparser = parser.add_subparsers(dest='command',
|
cmdparser = parser.add_subparsers(dest='command',
|
||||||
|
|||||||
+1
-1
@@ -241,7 +241,7 @@ class TestChatDB(unittest.TestCase):
|
|||||||
self.assertEqual(chat_db.get_next_fid(), 5)
|
self.assertEqual(chat_db.get_next_fid(), 5)
|
||||||
self.assertEqual(chat_db.get_next_fid(), 6)
|
self.assertEqual(chat_db.get_next_fid(), 6)
|
||||||
self.assertEqual(chat_db.get_next_fid(), 7)
|
self.assertEqual(chat_db.get_next_fid(), 7)
|
||||||
with open(chat_db.next_fname, 'r') as f:
|
with open(chat_db.next_path, 'r') as f:
|
||||||
self.assertEqual(f.read(), '7')
|
self.assertEqual(f.read(), '7')
|
||||||
|
|
||||||
def test_chat_db_write(self) -> None:
|
def test_chat_db_write(self) -> None:
|
||||||
|
|||||||
Reference in New Issue
Block a user