Add --glob and --location flags to hist and question commands, to be able to specify the location and files they should use.
This commit is contained in:
@@ -325,7 +325,12 @@ class ChatDB(Chat):
|
||||
* 'mfilter': use with 'Message.from_file()' to filter messages
|
||||
when reading them.
|
||||
"""
|
||||
messages = read_dir(db_path, glob, mfilter)
|
||||
messages: list[Message] = []
|
||||
if loc in [msg_location.DB, msg_location.DISK, msg_location.ALL]:
|
||||
messages.extend(read_dir(db_path, glob, mfilter))
|
||||
if loc in [msg_location.CACHE, msg_location.DISK, msg_location.ALL]:
|
||||
messages.extend(read_dir(cache_path, glob, mfilter))
|
||||
messages.sort(key=lambda x: x.msg_id())
|
||||
return cls(messages, cache_path, db_path, mfilter, glob)
|
||||
|
||||
@classmethod
|
||||
|
||||
@@ -15,7 +15,8 @@ def convert_messages(args: argparse.Namespace, config: Config) -> None:
|
||||
('.txt', '.yaml') to the latest default message file suffix ('.msg').
|
||||
"""
|
||||
chat = ChatDB.from_dir(Path(config.cache),
|
||||
Path(config.db))
|
||||
Path(config.db),
|
||||
glob='*')
|
||||
# read all known message files
|
||||
msgs = chat.msg_gather(loc=msg_location.DISK, glob='*.*')
|
||||
# make a set of all message IDs
|
||||
@@ -55,7 +56,9 @@ def print_chat(args: argparse.Namespace, config: Config) -> None:
|
||||
answer_contains=args.answer)
|
||||
chat = ChatDB.from_dir(Path(config.cache),
|
||||
Path(config.db),
|
||||
mfilter=mfilter)
|
||||
mfilter=mfilter,
|
||||
loc=msg_location(args.location),
|
||||
glob=args.glob)
|
||||
chat.print(args.source_code_only,
|
||||
args.with_metadata,
|
||||
paged=not args.no_paging,
|
||||
|
||||
@@ -186,7 +186,9 @@ def question_cmd(args: argparse.Namespace, config: Config) -> None:
|
||||
tags_not=args.exclude_tags)
|
||||
chat = ChatDB.from_dir(cache_path=Path(config.cache),
|
||||
db_path=Path(config.db),
|
||||
mfilter=mfilter)
|
||||
mfilter=mfilter,
|
||||
glob=args.glob,
|
||||
loc=msg_location(args.location))
|
||||
# if it's a new question, create and store it immediately
|
||||
if args.ask or args.create:
|
||||
message = create_message(chat, args)
|
||||
|
||||
@@ -14,6 +14,7 @@ from .commands.tags import tags_cmd
|
||||
from .commands.config import config_cmd
|
||||
from .commands.hist import hist_cmd
|
||||
from .commands.print import print_cmd
|
||||
from .chat import msg_location
|
||||
|
||||
|
||||
def tags_completer(prefix: str, parsed_args: Any, **kwargs: Any) -> list[str]:
|
||||
@@ -65,6 +66,11 @@ def create_parser() -> argparse.ArgumentParser:
|
||||
question_group.add_argument('-c', '--create', nargs='+', help='Create a question', metavar='QUESTION')
|
||||
question_group.add_argument('-r', '--repeat', nargs='*', help='Repeat a question', metavar='MESSAGE')
|
||||
question_group.add_argument('-p', '--process', nargs='*', help='Process existing questions', metavar='MESSAGE')
|
||||
question_cmd_parser.add_argument('-l', '--location',
|
||||
choices=[x.value for x in msg_location],
|
||||
default='db',
|
||||
help='Select message location, default is \'db\'')
|
||||
question_cmd_parser.add_argument('-g', '--glob', help='Glob for message file names')
|
||||
question_cmd_parser.add_argument('-O', '--overwrite', help='Overwrite existing messages when repeating them',
|
||||
action='store_true')
|
||||
question_cmd_parser.add_argument('-s', '--source-text', nargs='+', help='Add content of a file to the query', metavar='FILE')
|
||||
@@ -87,6 +93,11 @@ def create_parser() -> argparse.ArgumentParser:
|
||||
hist_cmd_parser.add_argument('-Q', '--question', help='Print only questions with given substring', metavar='SUBSTRING')
|
||||
hist_cmd_parser.add_argument('-d', '--tight', help='Print without message separators', action='store_true')
|
||||
hist_cmd_parser.add_argument('-P', '--no-paging', help='Print without paging', action='store_true')
|
||||
hist_cmd_parser.add_argument('-l', '--location',
|
||||
choices=[x.value for x in msg_location],
|
||||
default='db',
|
||||
help='Select message location, default is \'db\'')
|
||||
hist_cmd_parser.add_argument('-g', '--glob', help='Glob for message file names')
|
||||
|
||||
# 'tags' command parser
|
||||
tags_cmd_parser = cmdparser.add_parser('tags',
|
||||
|
||||
+2
-1
@@ -240,7 +240,8 @@ class TestChatDB(TestChatBase):
|
||||
msg_to_file_force_suffix(duplicate_message)
|
||||
with self.assertRaises(ChatError) as cm:
|
||||
ChatDB.from_dir(pathlib.Path(self.cache_path.name),
|
||||
pathlib.Path(self.db_path.name))
|
||||
pathlib.Path(self.db_path.name),
|
||||
glob='*')
|
||||
self.assertEqual(str(cm.exception), "Validation failed")
|
||||
|
||||
def test_file_path_ID_exists(self) -> None:
|
||||
|
||||
@@ -234,6 +234,8 @@ class TestQuestionCmd(TestWithFakeAI):
|
||||
# create a mock argparse.Namespace
|
||||
self.args = argparse.Namespace(
|
||||
ask=['What is the meaning of life?'],
|
||||
glob=None,
|
||||
location='db',
|
||||
num_answers=1,
|
||||
output_tags=['science'],
|
||||
AI='FakeAI',
|
||||
|
||||
Reference in New Issue
Block a user