Change storage to use text format per default, instead of yaml, but still support yaml.

This commit is contained in:
Oleksandr Kozachuk
2023-05-07 14:31:17 +02:00
parent 57caba5360
commit c5fd466dda
2 changed files with 57 additions and 40 deletions
+47 -28
View File
@@ -5,6 +5,32 @@ from .utils import terminal_width, append_message, message_to_chat
from typing import List, Dict, Any, Optional
def read_file(fname: str) -> Dict[str, Any]:
with open(fname, "r") as fd:
text = fd.read().split('\n')
tags = [x.strip() for x in text.pop(0).split(':')[1].strip().split(',')]
question_idx = text.index("=== QUESTION ===") + 1
answer_idx = text.index("==== ANSWER ====")
question = "\n".join(text[question_idx:answer_idx]).strip()
answer = "\n".join(text[answer_idx + 1:]).strip()
return {"question": question, "answer": answer, "tags": tags}
def dump_data(data: Dict[str, Any]) -> str:
with io.StringIO() as fd:
fd.write(f'TAGS: {", ".join(data["tags"])}\n')
fd.write(f'=== QUESTION ===\n{data["question"]}\n')
fd.write(f'==== ANSWER ====\n{data["answer"]}\n')
return fd.getvalue()
def write_file(fname: str, data: Dict[str, Any]) -> None:
with open(fname, "w") as fd:
fd.write(f'TAGS: {", ".join(data["tags"])}\n')
fd.write(f'=== QUESTION ===\n{data["question"]}\n')
fd.write(f'==== ANSWER ====\n{data["answer"]}\n')
def save_answers(question: str,
answers: list[str],
tags: list[str],
@@ -26,22 +52,7 @@ def save_answers(question: str,
title_end = '-' * (terminal_width() - len(title))
print(f'{title}{title_end}')
print(answer)
with open(f"{num:04d}.yaml", "w") as fd:
with io.StringIO() as f:
yaml.dump({'question': question},
f,
default_style="|",
default_flow_style=False)
fd.write(f.getvalue().replace('"question":', "question:", 1))
with io.StringIO() as f:
yaml.dump({'answer': answer},
f,
default_style="|",
default_flow_style=False)
fd.write(f.getvalue().replace('"answer":', "answer:", 1))
yaml.dump({'tags': wtags},
fd,
default_flow_style=False)
write_file(f"{num:04d}.txt", {"question": question, "answer": answer, "tags": wtags})
with open(next_fname, 'w') as f:
f.write(f'{num}')
@@ -57,13 +68,17 @@ def create_chat(question: Optional[str],
if file.suffix == '.yaml':
with open(file, 'r') as f:
data = yaml.load(f, Loader=yaml.FullLoader)
data_tags = set(data.get('tags', []))
tags_match = \
not tags or data_tags.intersection(tags)
extags_do_not_match = \
not extags or not data_tags.intersection(extags)
if tags_match and extags_do_not_match:
message_to_chat(data, chat)
elif file.suffix == '.txt':
data = read_file(file)
else:
continue
data_tags = set(data.get('tags', []))
tags_match = \
not tags or data_tags.intersection(tags)
extags_do_not_match = \
not extags or not data_tags.intersection(extags)
if tags_match and extags_do_not_match:
message_to_chat(data, chat)
if question:
append_message(chat, 'user', question)
return chat
@@ -75,10 +90,14 @@ def get_tags(config: Dict[str, Any], prefix: Optional[str]) -> List[str]:
if file.suffix == '.yaml':
with open(file, 'r') as f:
data = yaml.load(f, Loader=yaml.FullLoader)
for tag in data.get('tags', []):
if prefix and len(prefix) > 0:
if tag.startswith(prefix):
result.append(tag)
else:
elif file.suffix == '.txt':
data = read_file(file)
else:
continue
for tag in data.get('tags', []):
if prefix and len(prefix) > 0:
if tag.startswith(prefix):
result.append(tag)
else:
result.append(tag)
return list(set(result))