added typ hints for all functions in 'main.py', 'utils.py', 'storage.py' and 'api_client.py'

This commit is contained in:
2023-08-15 23:36:45 +02:00
parent ba41794f4e
commit 4303fb414f
4 changed files with 39 additions and 26 deletions
+11 -11
View File
@@ -1,11 +1,11 @@
import yaml
import io
import pathlib
from .utils import terminal_width, append_message, message_to_chat, ConfigType
from typing import List, Dict, Any, Optional
from .utils import terminal_width, append_message, message_to_chat, ConfigType, ChatType
from typing import Any, Optional
def read_file(fname: pathlib.Path, tags_only: bool = False) -> Dict[str, Any]:
def read_file(fname: pathlib.Path, tags_only: bool = False) -> dict[str, Any]:
with open(fname, "r") as fd:
tagline = fd.readline().strip().split(':', maxsplit=1)[1].strip()
# also support tags separated by ',' (old format)
@@ -33,7 +33,7 @@ def write_config(path: str, config: ConfigType) -> None:
yaml.dump(config, f)
def dump_data(data: Dict[str, Any]) -> str:
def dump_data(data: dict[str, Any]) -> str:
with io.StringIO() as fd:
fd.write(f'TAGS: {" ".join(data["tags"])}\n')
fd.write(f'=== QUESTION ===\n{data["question"]}\n')
@@ -41,7 +41,7 @@ def dump_data(data: Dict[str, Any]) -> str:
return fd.getvalue()
def write_file(fname: str, data: Dict[str, Any]) -> None:
def write_file(fname: str, data: dict[str, Any]) -> None:
with open(fname, "w") as fd:
fd.write(f'TAGS: {" ".join(data["tags"])}\n')
fd.write(f'=== QUESTION ===\n{data["question"]}\n')
@@ -75,14 +75,14 @@ def save_answers(question: str,
def create_chat_hist(question: Optional[str],
tags: Optional[List[str]],
extags: Optional[List[str]],
tags: Optional[list[str]],
extags: Optional[list[str]],
config: ConfigType,
match_all_tags: bool = False,
with_tags: bool = False,
with_file: bool = False
) -> List[Dict[str, str]]:
chat: List[Dict[str, str]] = []
) -> ChatType:
chat: ChatType = []
append_message(chat, 'system', str(config['system']).strip())
for file in sorted(pathlib.Path(str(config['db'])).iterdir()):
if file.suffix == '.yaml':
@@ -108,7 +108,7 @@ def create_chat_hist(question: Optional[str],
return chat
def get_tags(config: ConfigType, prefix: Optional[str]) -> List[str]:
def get_tags(config: ConfigType, prefix: Optional[str]) -> list[str]:
result = []
for file in sorted(pathlib.Path(str(config['db'])).iterdir()):
if file.suffix == '.yaml':
@@ -127,5 +127,5 @@ def get_tags(config: ConfigType, prefix: Optional[str]) -> List[str]:
return result
def get_tags_unique(config: ConfigType, prefix: Optional[str]) -> List[str]:
def get_tags_unique(config: ConfigType, prefix: Optional[str]) -> list[str]:
return list(set(get_tags(config, prefix)))