169 Commits

Author SHA1 Message Date
juk0de f17e76203a glossary cmd / main: added --print option 2024-02-26 16:34:30 +01:00
juk0de 6a77ec1d3b glossary test: added testcase for to_str() without description 2024-02-26 16:27:00 +01:00
juk0de f298a68140 glossary cmd test: added test for listing glossaries 2024-02-26 16:27:00 +01:00
juk0de 9bbf67af67 glossary: fixed printing of empty description 2024-02-26 16:27:00 +01:00
juk0de 284dd13201 glossary cmd: fixed globbing in 'list_glossaries' 2024-02-26 16:27:00 +01:00
juk0de 1932f8f6e9 configuration: improved error message when config file is missing 2024-02-26 16:27:00 +01:00
juk0de 15e8f8fd6b main: resolved conflicting short parameters 2024-02-26 16:27:00 +01:00
juk0de 9c683be994 glossary test: fixed cleanup of temporary files 2024-02-26 16:27:00 +01:00
juk0de 92fb2bbe15 glossary: added '__post_init__' 2024-02-26 16:27:00 +01:00
juk0de 2e0da31150 added test module for the 'glossary' command 2024-02-26 16:27:00 +01:00
juk0de ff6d4ded33 added 'glossary' command 2024-02-26 16:27:00 +01:00
juk0de 5377dc0784 main: missing directories are now created if user agrees 2024-02-26 16:27:00 +01:00
juk0de 3def4cb668 configuration: added 'glossaries' directory 2024-02-26 16:27:00 +01:00
juk0de 580c506483 glossary: now supports quoted and unquoted entries (incl. tests) 2024-02-26 16:27:00 +01:00
juk0de a1a090bcae glossary test: added testcases for 'to_str()' 2024-02-26 16:27:00 +01:00
juk0de 3cca32a40b glossary: added 'to_str()' function 2024-02-26 16:27:00 +01:00
juk0de 1b39fb1ac5 glossary test: added description test 2024-02-26 16:27:00 +01:00
juk0de b4ef2e43ca glossary: added description and removed useless input stripping 2024-02-26 16:27:00 +01:00
juk0de ff1e405991 glossary test: added suffix testcases 2024-02-26 16:27:00 +01:00
juk0de 4afd6d4e94 glossary: added suffix check 2024-02-26 16:27:00 +01:00
juk0de 94b812c31e glossary: added test module for glossaries 2024-02-26 16:27:00 +01:00
juk0de be873867ea added module 'glossary.py' 2024-02-26 16:27:00 +01:00
juk0de 82ad697b68 translation: added check for valid document format when using OpenAI 2024-02-26 16:27:00 +01:00
juk0de a185c0db7b translation: speficied / implemented the question format for OpenAI based translations 2024-02-26 16:27:00 +01:00
juk0de c1dc152f48 translation: some small required refactoring 2024-02-26 16:27:00 +01:00
juk0de f0129f7060 added new command 'translation' 2024-02-26 16:27:00 +01:00
Oleksandr Kozachuk 5d1bb1f9e4 Fix some of the commands. 2023-11-10 10:42:46 +01:00
Oleksandr Kozachuk 75a123eb72 Fix usage of the dynamic answer is some cases. 2023-10-24 12:59:13 +02:00
juk0de 7c1c67f8ff Merge pull request 'Dynamic Answer class and OpenAI streaming API' (#19) from dynamic_answer into main
Introduces several changes with the main objective of enabling OpenAI's streaming API in the chatmastermind application. This allows for the retrieval of AI responses gradually as a stream, which can significantly improve the user experience in interactions that involve large result sets.

* Added tiktoken import in 'openai.py' and modifications to the OpenAI class to support streaming. This includes the addition of a new class OpenAIAnswer to handle streaming API responses.
* Modified request function in the OpenAI class: the stream=True flag is added to the openai.ChatCompletion.create method to enable streaming API.
* Modified 'question.py' to print the answer parts as they are streamed.
* Replaced the Answer class's string data type with a generator which supports str and Generator[str, None, None] data types. Modifications are made to the Answer class methods to handle both data types accordingly.
* Updated the tests in 'test_ais_openai.py' and 'test_message.py' to reflect and validate these changes.
2023-10-21 15:50:45 +02:00
Oleksandr Kozachuk dbe72ff11c Activate and use OpenAI streaming API. 2023-10-21 14:21:48 +02:00
Oleksandr Kozachuk bbc1ab5a0a Fix source_code function with the dynamic answer class. 2023-10-20 14:02:09 +02:00
Oleksandr Kozachuk 2aee018708 Refactor message.Answer class in a way, that it can be constructed dynamically step by step, in preparation of using streaming API. 2023-10-20 13:43:31 +02:00
ok 17c6fa2453 Merge pull request 'Configurable glob and location on question and hist commands' (#18) from cust_loc_glob into main
Reviewed-on: #18
2023-10-20 09:47:03 +02:00
juk0de 5774278fb7 README: added new 'question' command parameters 2023-10-20 09:16:03 +02:00
juk0de 40d0de50de cmm: limited the message locations for the new cmm parameters to those that make sense 2023-10-20 09:16:03 +02:00
juk0de 72d31c26e9 main: improved parameter descriptions 2023-10-20 09:16:03 +02:00
juk0de 980e5ac51f chat: changed default glob to '*.msg' in all ChatDB functions 2023-10-20 09:00:58 +02:00
Oleksandr Kozachuk 114282dfd8 Add --glob and --location flags to hist and question commands, to be able to specify the location and files they should use. 2023-10-19 16:03:51 +02:00
Oleksandr Kozachuk 9a493b57da Per default use only files with .msg suffix ignoring other files. 2023-10-19 16:02:40 +02:00
Oleksandr Kozachuk 9b0951cb3f Change type msg_location to an Enum instead of Literal to be able to get all values easy and improve type checks. 2023-10-19 16:00:44 +02:00
Oleksandr Kozachuk 5f29f60168 Add .old/ to git ignore, I use that dir ofter for old files, I do not want to delete. 2023-10-17 11:53:49 +02:00
juk0de 3ea1f49027 cmm: added options '--tight' and '--no-paging' to the 'hist --print' cmd 2023-10-02 08:35:19 +02:00
juk0de 8f56399844 cmm: replaced options '--with-tags' and '--with-file' with '--with-metadata' 2023-10-01 10:11:16 +02:00
juk0de e4cb6eb22b README: updated 'hist' command description 2023-10-01 09:27:40 +02:00
juk0de e19c6bb1ea hist_cmd: added module 'test_hist_cmd.py' 2023-09-30 08:25:33 +02:00
juk0de 811b2e6830 hist_cmd: implemented '--convert' option 2023-09-29 18:53:12 +02:00
juk0de 2a8f01aee4 chat: 'msg_gather()' now supports globbing 2023-09-29 07:16:20 +02:00
juk0de efdb3cae2f question: moved around some code 2023-09-29 07:16:20 +02:00
juk0de aecfd1088d chat: added message file format as ChatDB class member 2023-09-29 07:16:20 +02:00
juk0de 140dbed809 message: added function 'rm_file()' and test 2023-09-29 07:16:20 +02:00
juk0de 01860ace2c test_question_cmd: modified tests to use '.msg' file suffix 2023-09-29 07:16:20 +02:00
juk0de df42bcee09 test_chat: added test for file_path collision detection 2023-09-29 07:16:20 +02:00
juk0de e34eab6519 test_chat: changed all tests to use the new '.msg' suffix 2023-09-29 07:16:20 +02:00
juk0de d07fd13e8e test_message: changed all tests to use the new '.msg' suffix 2023-09-29 07:16:20 +02:00
juk0de b8681e8274 message: fixed tag matching for YAML file format 2023-09-29 07:16:20 +02:00
juk0de d2be53aeab chat: switched to new message suffix and formats
- no longer using file suffix to choose the format
- added 'mformat' argument to 'write_xxx()' functions
- file suffix is now set by 'Message.to_file()' per default
2023-09-29 07:16:20 +02:00
juk0de 9ca9a23569 message: introduced file suffix '.msg'
- '.msg' suffix is always used for writing
- 'Message.to_file()' will set the file suffix if the given file_path has none
- added 'mformat' argument to 'Message.to_file()' for choosing the file format
- '.txt' and '.yaml' suffixes are only supported for reading
2023-09-29 07:16:20 +02:00
juk0de 6f3758e12e question_cmd: fixed '--create' option 2023-09-29 07:15:46 +02:00
juk0de dd836cd72d Merge pull request 'cmm question --repeat supports multiple questions, added tests and fixes' (#15) from repeat_multi into main
This PR primarily modifies the `cmm question --repeat` command to allow repeating multiple questions, instead of only the last one.

Additionally, this PR includes the following changes:

- In `ai_factory.py`, added optional parameters 'def_ai' and 'def_model' to the `create_ai` function which allows specifying a default AI and model.
- In `openai.py`, a potential bug was fixed where the 'tags' attribute was updated to ensure it is always a set, even when 'otags' is None.
- In `question.py`, a significant amount of new code was added to facilitate the 'repeat' functionality. This includes functions to create modified args based on an existing message (`create_msg_args`), to repeat a given list of messages (`repeat_messages`), and to invert the semantics of the INPUT tags for this command (`invert_input_tag_args`).
- In `main.py`, the 'nargs' parameter was changed from `+` to `*` in the 'or-tags', 'and-tags', and 'exclude-tags' arguments to accommodate the updated handling of tags in `question.py`.
- A new `test_common.py` file was added which includes a `FakeAI` class for testing purposes, and a `TestWithFakeAI` class which includes a number of methods for asserting various conditions about messages.

This PR also includes additional tests to verify the correct operation of the new 'repeat' functionality.
2023-09-26 18:04:27 +02:00
juk0de 601ebe731a test_question_cmd: added a new testcase and made the old cases more explicit (easier to read) 2023-09-24 08:53:37 +02:00
juk0de 87b25993be tests: moved 'FakeAI' and common functions to 'test_common.py' 2023-09-24 08:38:52 +02:00
juk0de a478408449 test_question_cmd: test fixes and cleanup 2023-09-23 08:53:26 +02:00
juk0de b83b396c7b question_cmd: fixed msg specific argument creation 2023-09-23 08:11:11 +02:00
juk0de 3c932aa88e openai: fixed assignment of output tags 2023-09-23 08:11:11 +02:00
juk0de b50caa345c test_question_cmd: introduced 'FakeAI' class 2023-09-23 08:11:11 +02:00
juk0de 80c5dcc801 question_cmd: input tag options without a tag (e. g. '-t') now select ALL tags 2023-09-23 08:11:11 +02:00
juk0de 33df84beaa ai_factory: added optional 'def_ai' and 'def_model' arguments to 'create_ai' 2023-09-22 13:43:31 +02:00
juk0de 0657a1bab8 question_cmd: fixed AI and model arguments when repeating messages 2023-09-22 13:43:31 +02:00
juk0de e9175aface test_question_cmd: added testcase for --repeat with multiple messages 2023-09-22 13:43:31 +02:00
juk0de 21f81f3569 question_cmd: implemented repetition of multiple messages 2023-09-22 13:43:31 +02:00
juk0de 4538624247 Merge pull request 'Implemented the 'question --repeat' command and other improvements' (#14) from repeat into main
Reviewed-on: #14
2023-09-21 07:25:47 +02:00
juk0de ac3c19739d README: updates and fixes 2023-09-20 10:18:06 +02:00
juk0de ed379ed535 print_cmd: added option to print latest message 2023-09-20 10:18:06 +02:00
juk0de c43bafe47a main: improved metavar names and descriptions 2023-09-20 10:18:06 +02:00
juk0de 7dd83428fb test_question_cmd: added more testcases for '--repeat' 2023-09-20 10:18:06 +02:00
juk0de 3ad4b96b8f test_question_cmd: added testclass for the 'question_cmd()' function 2023-09-20 10:17:59 +02:00
juk0de 561003aabe question_cmd: implemented repeating of the latest message 2023-09-20 10:17:59 +02:00
juk0de 59eb45a3ca chat: improved message equality checks 2023-09-20 10:17:59 +02:00
juk0de 29a20bd2d8 message: added 'equals()' function and improved robustness and debugging 2023-09-20 10:17:59 +02:00
juk0de 80a1457dd1 configuration: the cache folder can now be specified in the configuration file 2023-09-20 10:17:59 +02:00
juk0de f964c5471e Merge pull request 'Refactoring, fixes and new features for the 'chat.py' module' (#12) from chat_refactoring into main
Reviewed-on: #12
2023-09-18 14:23:51 +02:00
juk0de 25fffb6fea chat: db_read() and cache_read() now also support globbing and filtering 2023-09-17 10:59:29 +02:00
juk0de cf572e1882 chat: added functions db_move() and chat_move() (and tests) 2023-09-17 10:59:29 +02:00
juk0de 2fb7410b43 chat: added functions msg_in_cache() and msg_in_db(), also tests 2023-09-17 10:59:29 +02:00
juk0de 33ae27f00e chat: msg_remove() now supports multiple locations 2023-09-17 10:59:29 +02:00
juk0de f6a6e6036b chat: added validation during initialization 2023-09-17 10:59:29 +02:00
juk0de 525cdb92a1 message / chat: 'msg_id()' now returns 'file_path.stem' (removed suffix) 2023-09-17 10:59:29 +02:00
juk0de fc82f85b7c chat: added new functions: msg_unique_id(), msg_unique_content() and tests 2023-09-17 10:59:24 +02:00
juk0de d90845b58b chat: added new functions to ChatDB: msg_gather(), msg_find(), msg_remove() 2023-09-17 10:58:26 +02:00
juk0de 98777295d6 refactor: renamed (almost) all Chat/ChatDB functions 2023-09-17 10:58:26 +02:00
juk0de f6109949c8 chat: ChatDB now correctly ignores files that contain no valid messages 2023-09-17 10:58:10 +02:00
juk0de 071871f929 chat et al: '.next' and '.config.yaml' are now ignored by ChatDB 2023-09-14 16:06:00 +02:00
juk0de 5cb88dad1b chat: implemented special version of 'latest_message()' for the ChatDB class 2023-09-14 16:05:49 +02:00
juk0de 17a0264025 question_cmd: now also accepts Messages as source files 2023-09-13 17:44:39 +02:00
Oleksandr Kozachuk 7f4a16894e Add pre-commit checks into push webhook. 2023-09-13 11:08:02 +02:00
Oleksandr Kozachuk 26e3d38afb Add the Gitea web hooks. 2023-09-13 10:53:12 +02:00
juk0de b5af751193 openai: added test module 2023-09-13 09:01:00 +02:00
juk0de a7345cbc41 ai_factory: fixed argument parsing bug 2023-09-13 07:52:05 +02:00
juk0de 310cb9421e Merge pull request 'Cleanup after merge of restructurings #8' (#10) from cleanup into main
Reviewed-on: #10
2023-09-12 20:23:08 +02:00
Oleksandr Kozachuk 1ec3d6fcda Make it possible to specify the AI in config command. 2023-09-12 16:37:50 +02:00
Oleksandr Kozachuk 544bf0bf06 Improve README.md 2023-09-12 16:34:39 +02:00
Oleksandr Kozachuk f96e82bdd7 Implement the config -l and config -m commands. 2023-09-12 16:34:17 +02:00
Oleksandr Kozachuk 2b62cb8c4b Remove the -*terminal_width() to save space on screen. 2023-09-12 13:48:28 +02:00
juk0de a895c1fc6a Merge pull request 'ChatMasterMind Application Refactor and Enhancement' (#8) from restructurings into main
Reviewed-on: #8
2023-09-12 07:36:04 +02:00
Oleksandr Kozachuk ddfcc71510 Merge branch 'restructurings.main' into restructurings 2023-09-11 13:28:56 +02:00
Oleksandr Kozachuk 17de0b9967 Remove old code. 2023-09-11 13:17:59 +02:00
juk0de 33023d29f9 configuration: made 'default' AI ID optional 2023-09-11 13:09:45 +02:00
juk0de 481f9ecf7c configuration: improved config file format 2023-09-11 13:09:45 +02:00
juk0de 22fa187e5f question_cmd: when no tags are specified, no tags are selected 2023-09-11 13:09:45 +02:00
juk0de b840ebd792 message: to_file() now uses intermediate temporary file 2023-09-11 13:09:45 +02:00
juk0de 66908f5fed message: fixed matching with empty tag sets 2023-09-11 13:09:45 +02:00
juk0de 2e08ccf606 openai: stores AI.ID instead of AI.name in message 2023-09-11 13:09:44 +02:00
juk0de 595ff8e294 question_cmd: added message filtering by tags 2023-09-11 13:09:44 +02:00
juk0de faac42d3c2 question_cmd: fixed '--ask' command 2023-09-11 13:09:44 +02:00
juk0de 864ab7aeb1 chat: added check for existing files when creating new filenames 2023-09-11 13:09:44 +02:00
juk0de cc76da2ab3 chat: added 'update_messages()' function and test 2023-09-11 13:09:44 +02:00
juk0de f99cd3ed41 question_cmd: fixed source code extraction and added a testcase 2023-09-11 13:09:44 +02:00
Oleksandr Kozachuk 6f3ea98425 Small fixes. 2023-09-11 13:09:44 +02:00
Oleksandr Kozachuk 54ece6efeb Port print arguments -q/-a/-S from main to restructuring. 2023-09-11 13:09:44 +02:00
Oleksandr Kozachuk 86eebc39ea Allow in question -s for just sourcing file and -S to source file with ``` encapsulation. 2023-09-11 13:09:44 +02:00
juk0de 3eca53998b question cmd: added tests 2023-09-11 13:09:44 +02:00
juk0de c4f7bcc94e question_cmd: fixes 2023-09-11 13:09:44 +02:00
juk0de c52713c833 configuration: added tests 2023-09-11 13:09:44 +02:00
juk0de ecb6994783 configuration et al: implemented new Config format 2023-09-11 13:09:44 +02:00
juk0de 61e710a4b1 cmm: splitted commands into separate modules (and more cleanup) 2023-09-11 13:09:41 +02:00
juk0de 21d39c6c66 cmm: removed all the old code and modules 2023-09-11 13:08:45 +02:00
juk0de 6a4cc7a65d setup: added 'ais' subfolder 2023-09-11 13:07:46 +02:00
juk0de d6bb5800b1 test_main: temporarily disabled all testcases 2023-09-11 13:07:46 +02:00
juk0de 034e4093f1 cmm: added 'question' command 2023-09-11 13:07:46 +02:00
juk0de 7d15452242 added new module 'ai_factory' 2023-09-11 13:07:46 +02:00
juk0de 823d3bf7dc added new module 'openai.py' 2023-09-11 13:07:46 +02:00
juk0de 4bd144c4d7 added new module 'ai.py' 2023-09-11 13:07:46 +02:00
juk0de e186afbef0 cmm: the 'print' command now uses 'Message.from_file()' 2023-09-11 13:07:43 +02:00
juk0de 5e4ec70072 cmm: tags completion now uses 'Message.tags_from_dir' (fixes tag completion for me) 2023-09-11 13:06:22 +02:00
juk0de 4c378dde85 cmm: the 'hist' command now uses the new 'ChatDB' 2023-09-11 13:05:33 +02:00
juk0de 8923a13352 cmm: the 'tags' command now uses the new 'ChatDB' 2023-09-11 13:04:08 +02:00
juk0de e1414835c8 chat: added functions for finding and deleting messages 2023-09-11 13:04:08 +02:00
juk0de abb7fdacb6 message / chat: output improvements 2023-09-11 13:04:08 +02:00
juk0de 2e2228bd60 chat: new possibilites for adding messages and better tests 2023-09-11 13:04:08 +02:00
juk0de 713b55482a message: added rename_tags() function and test 2023-09-11 13:04:08 +02:00
juk0de d35de86c67 message: fixed Answer header for TXT format 2023-09-11 13:04:08 +02:00
juk0de aba3eb783d message: improved robustness of Question and Answer content checks and tests 2023-09-11 13:04:08 +02:00
juk0de 8e63831701 chat: added clear_cache() function and test 2023-09-11 13:04:08 +02:00
juk0de c318b99671 chat: improved history printing 2023-09-11 13:04:08 +02:00
juk0de 48c8e951e1 chat: fixed handling of unsupported files in DB and chache dir 2023-09-11 13:04:08 +02:00
juk0de b22a4b07ed chat: added tags_frequency() function and test 2023-09-11 13:04:08 +02:00
juk0de 33565d351d configuration: added AIConfig class 2023-09-11 13:04:08 +02:00
juk0de 6737fa98c7 added tokens() function to Message and Chat 2023-09-11 13:04:08 +02:00
juk0de 815a21893c added tests for 'chat.py' 2023-09-11 13:04:08 +02:00
juk0de 64893949a4 added new module 'chat.py' 2023-09-11 13:04:08 +02:00
juk0de a093f9b867 tags: some clarification and new tests 2023-09-11 13:04:08 +02:00
juk0de dc3f3dc168 added 'message_in()' function and test 2023-09-11 13:04:08 +02:00
juk0de 74c39070d6 fixed Message.filter_tags 2023-09-11 13:04:08 +02:00
juk0de fde0ae4652 fixed test case file cleanup 2023-09-11 13:04:08 +02:00
juk0de 238dbbee60 fixed handling empty tags in TXT file 2023-09-11 13:04:08 +02:00
juk0de 17f7b2fb45 Added tags filtering (prefix and contained string) to TagLine and Message 2023-09-11 13:04:08 +02:00
juk0de 9c2598a4b8 tests: added testcases for Message.from/to_file() and others 2023-09-11 13:04:08 +02:00
juk0de acec5f1d55 tests: splitted 'test_main.py' into 3 modules 2023-09-11 13:04:08 +02:00
juk0de c0f50bace5 gitignore: added vim session file 2023-09-11 13:04:08 +02:00
juk0de 30ccec2462 tags: TagLine constructor now supports multiline taglines and multiple spaces 2023-09-11 13:04:08 +02:00
juk0de 09da312657 configuration: added 'as_dict()' as an instance function 2023-09-11 13:04:08 +02:00
juk0de 33567df15f added testcases for messages.py 2023-09-11 13:04:08 +02:00
juk0de 264979a60d added new module 'message.py' 2023-09-11 13:04:08 +02:00
juk0de 061e5f8682 tags.py: converted most TagLine functions to module functions 2023-09-11 13:04:08 +02:00
juk0de 2d456e68f1 added testcases for Tag and TagLine classes 2023-09-11 13:04:08 +02:00
juk0de 8bd659e888 added new module 'tags.py' with classes 'Tag' and 'TagLine' 2023-09-11 13:04:08 +02:00
Oleksandr Kozachuk 3ef1339cc0 Fix extracting source file with type specification. 2023-09-09 11:53:32 +02:00
Oleksandr Kozachuk ed567afbea Make it possible to print just question or answer on printing files. 2023-09-08 15:54:29 +02:00
Oleksandr Kozachuk 6e447018d5 Fix tags_completter. 2023-09-07 18:11:32 +02:00
30 changed files with 3226 additions and 677 deletions
+1
View File
@@ -106,6 +106,7 @@ celerybeat.pid
.venv .venv
env/ env/
venv/ venv/
.old/
ENV/ ENV/
env.bak/ env.bak/
venv.bak/ venv.bak/
+109 -64
View File
@@ -37,63 +37,101 @@ cmm [global options] command [command options]
### Global Options ### Global Options
- `-c`, `--config`: Config file name (defaults to `.config.yaml`). - `-C`, `--config`: Config file name (defaults to `.config.yaml`).
### Commands
- `ask`: Ask a question.
- `hist`: Print chat history.
- `tag`: Manage tags.
- `config`: Manage configuration.
- `print`: Print files.
### Command Options ### Command Options
#### `ask` Command Options #### Question
- `-q`, `--question`: Question to ask (required). The `question` command is used to ask, create, and process questions.
- `-m`, `--max-tokens`: Max tokens to use.
- `-T`, `--temperature`: Temperature to use.
- `-M`, `--model`: Model to use.
- `-n`, `--number`: Number of answers to produce (default is 3).
- `-s`, `--source`: Add content of a file to the query.
- `-S`, `--only-source-code`: Add pure source code to the chat history.
- `-t`, `--tags`: List of tag names.
- `-e`, `--extags`: List of tag names to exclude.
- `-o`, `--output-tags`: List of output tag names (default is the input tags).
- `-a`, `--match-all-tags`: All given tags must match when selecting chat history entries.
#### `hist` Command Options ```bash
cmm question [-t OTAGS]... [-k ATAGS]... [-x XTAGS]... [-o OUTTAGS]... [-A AI_ID] [-M MODEL] [-n NUM] [-m MAX] [-T TEMP] (-a QUESTION | -c QUESTION | -r [MESSAGE ...] | -p [MESSAGE ...]) [-O] [-s FILE]... [-S FILE]...
```
- `-d`, `--dump`: Print chat history as Python structure. * `-t, --or-tags OTAGS`: List of tags (one must match)
- `-w`, `--with-tags`: Print chat history with tags. * `-k, --and-tags ATAGS`: List of tags (all must match)
- `-W`, `--with-files`: Print chat history with filenames. * `-x, --exclude-tags XTAGS`: List of tags to exclude
- `-S`, `--only-source-code`: Print only source code. * `-o, --output-tags OUTTAGS`: List of output tags (default: use input tags)
- `-t`, `--tags`: List of tag names. * `-A, --AI AI_ID`: AI ID to use
- `-e`, `--extags`: List of tag names to exclude. * `-M, --model MODEL`: Model to use
- `-a`, `--match-all-tags`: All given tags must match when selecting chat history entries. * `-n, --num-answers NUM`: Number of answers to request
* `-m, --max-tokens MAX`: Max. number of tokens
* `-T, --temperature TEMP`: Temperature value
* `-a, --ask QUESTION`: Ask a question
* `-c, --create QUESTION`: Create a question
* `-r, --repeat [MESSAGE ...]`: Repeat a question
* `-p, --process [MESSAGE ...]`: Process existing questions
* `-O, --overwrite`: Overwrite existing messages when repeating them
* `-s, --source-text FILE`: Add content of a file to the query
* `-S, --source-code FILE`: Add source code file content to the chat history
* `-l, --location {cache,db,all}`: Use given location when building the chat history (default: 'db')
* `-g, --glob GLOB`: Filter message files using the given glob pattern
#### `tag` Command Options #### Hist
- `-l`, `--list`: List all tags and their frequency. The `hist` command is used to print and manage the chat history.
#### `config` Command Options ```bash
cmm hist [--print | --convert FORMAT] [-t OTAGS]... [-k ATAGS]... [-x XTAGS]... [-w] [-W] [-S] [-A SUBSTRING] [-Q SUBSTRING]
```
- `-l`, `--list-models`: List all available models. * `-p, --print`: Print the DB chat history
- `-m`, `--print-model`: Print the currently configured model. * `-c, --convert FORMAT`: Convert all messages to the given format
- `-M`, `--model`: Set model in the config file. * `-t, --or-tags OTAGS`: List of tags (one must match)
* `-k, --and-tags ATAGS`: List of tags (all must match)
* `-x, --exclude-tags XTAGS`: List of tags to exclude
* `-w, --with-metadata`: Print chat history with metadata (tags, filenames, AI, etc.)
* `-S, --source-code-only`: Only print embedded source code
* `-A, --answer SUBSTRING`: Filter for answer substring
* `-Q, --question SUBSTRING`: Filter for question substring
* `-l, --location {cache,db,all}`: Use given location when building the chat history (default: 'db')
* `-g, --glob GLOB`: Filter message files using the given glob pattern
#### `print` Command Options #### Tags
- `-f`, `--file`: File to print (required). The `tags` command is used to manage tags.
- `-S`, `--only-source-code`: Print only source code.
```bash
cmm tags (-l | -p PREFIX | -c SUBSTRING)
```
* `-l, --list`: List all tags and their frequency
* `-p, --prefix PREFIX`: Filter tags by prefix
* `-c, --contain SUBSTRING`: Filter tags by contained substring
#### Config
The `config` command is used to manage the configuration.
```bash
cmm config (-l | -m | -c FILE)
```
* `-l, --list-models`: List all available models
* `-m, --print-model`: Print the currently configured model
* `-c, --create FILE`: Create config with default settings in the given file
#### Print
The `print` command is used to print message files.
```bash
cmm print (-f FILE | -l) [-q | -a | -S]
```
* `-f, --file FILE`: Print given file
* `-l, --latest`: Print latest message
* `-q, --question`: Only print the question
* `-a, --answer`: Only print the answer
* `-S, --only-source-code`: Only print embedded source code
### Examples ### Examples
1. Ask a question: 1. Ask a question:
```bash ```bash
cmm ask -q "What is the meaning of life?" -t philosophy -e religion cmm question -a "What is the meaning of life?" -t philosophy -x religion
``` ```
2. Display the chat history: 2. Display the chat history:
@@ -105,19 +143,19 @@ cmm hist
3. Filter chat history by tags: 3. Filter chat history by tags:
```bash ```bash
cmm hist -t tag1 tag2 cmm hist --or-tags tag1 tag2
``` ```
4. Exclude chat history by tags: 4. Exclude chat history by tags:
```bash ```bash
cmm hist -e tag3 tag4 cmm hist --exclude-tags tag3 tag4
``` ```
5. List all tags and their frequency: 5. List all tags and their frequency:
```bash ```bash
cmm tag -l cmm tags -l
``` ```
6. Print the contents of a file: 6. Print the contents of a file:
@@ -128,18 +166,27 @@ cmm print -f example.yaml
## Configuration ## Configuration
The configuration file (`.config.yaml`) should contain the following fields: The default configuration filename is `.config.yaml` (it is searched in the current working directory).
Use the command `cmm config --create <FILENAME>` to create a default configuration:
- `openai`: ```
- `api_key`: Your OpenAI API key. cache: .
- `model`: The name of the OpenAI model to use (e.g. "text-davinci-002"). db: ./db/
- `temperature`: The temperature value for the model. ais:
- `max_tokens`: The maximum number of tokens for the model. myopenai:
- `top_p`: The top P value for the model. name: openai
- `frequency_penalty`: The frequency penalty value. model: gpt-3.5-turbo-16k
- `presence_penalty`: The presence penalty value. api_key: 0123456789
- `system`: The system message used to set the behavior of the AI. temperature: 1.0
- `db`: The directory where the question-answer pairs are stored in YAML files. max_tokens: 4000
top_p: 1.0
frequency_penalty: 0.0
presence_penalty: 0.0
system: You are an assistant
```
Each AI has its own section and the name of that section is called the 'AI ID' (in the example above it is `myopenai`).
The AI ID can be any string, as long as it's unique within the `ais` section. The AI ID is used for all commands that support the `AI` parameter and it's also stored within each message file.
## Autocompletion ## Autocompletion
@@ -154,33 +201,33 @@ After adding this line, restart your shell or run `source <your-shell-config-fil
## Contributing ## Contributing
### Enable commit hooks ### Enable commit hooks
``` ```bash
pip install pre-commit pip install pre-commit
pre-commit install pre-commit install
``` ```
### Execute tests before opening a PR ### Execute tests before opening a PR
``` ```bash
pytest pytest
``` ```
### Consider using `pyenv` / `pyenv-virtualenv` ### Consider using `pyenv` / `pyenv-virtualenv`
Short installation instructions: Short installation instructions:
* install `pyenv`: * install `pyenv`:
``` ```bash
cd ~ cd ~
git clone https://github.com/pyenv/pyenv .pyenv git clone https://github.com/pyenv/pyenv .pyenv
cd ~/.pyenv && src/configure && make -C src cd ~/.pyenv && src/configure && make -C src
``` ```
* make sure that `~/.pyenv/shims` and `~/.pyenv/bin` are the first entries in your `PATH`, e. g. by setting it in `~/.bashrc` * make sure that `~/.pyenv/shims` and `~/.pyenv/bin` are the first entries in your `PATH`, e.g., by setting it in `~/.bashrc`
* add the following to your `~/.bashrc` (after setting `PATH`): `eval "$(pyenv init -)"` * add the following to your `~/.bashrc` (after setting `PATH`): `eval "$(pyenv init -)"`
* create a new terminal or source the changes (e. g. `source ~/.bashrc`) * create a new terminal or source the changes (e.g., `source ~/.bashrc`)
* install `virtualenv` * install `virtualenv`
``` ```bash
git clone https://github.com/pyenv/pyenv-virtualenv.git $(pyenv root)/plugins/pyenv-virtualenv git clone https://github.com/pyenv/pyenv-virtualenv.git $(pyenv root)/plugins/pyenv-virtualenv
``` ```
* add the following to your `~/.bashrc` (after the commands above): `eval "$(pyenv virtualenv-init -)` * add the following to your `~/.bashrc` (after the commands above): `eval "$(pyenv virtualenv-init -)`
* create a new terminal or source the changes (e. g. `source ~/.bashrc`) * create a new terminal or source the changes (e.g., `source ~/.bashrc`)
* go back to the `ChatMasterMind` repo and create a virtual environment with the latest `Python`, e. g. `3.11.4`: * go back to the `ChatMasterMind` repo and create a virtual environment with the latest `Python`, e.g., `3.11.4`:
``` ```bash
cd <CMM_REPO_PATH> cd <CMM_REPO_PATH>
pyenv install 3.11.4 pyenv install 3.11.4
pyenv virtualenv 3.11.4 py311 pyenv virtualenv 3.11.4 py311
@@ -191,5 +238,3 @@ pyenv activate py311
## License ## License
This project is licensed under the terms of the WTFPL License. This project is licensed under the terms of the WTFPL License.
+6
View File
@@ -59,6 +59,12 @@ class AI(Protocol):
""" """
raise NotImplementedError raise NotImplementedError
def print_models(self) -> None:
"""
Print all models supported by this AI.
"""
raise NotImplementedError
def tokens(self, data: Union[Message, Chat]) -> int: def tokens(self, data: Union[Message, Chat]) -> int:
""" """
Computes the nr. of AI language tokens for the given message Computes the nr. of AI language tokens for the given message
+16 -10
View File
@@ -3,25 +3,29 @@ Creates different AI instances, based on the given configuration.
""" """
import argparse import argparse
from typing import cast from typing import cast, Optional
from .configuration import Config, AIConfig, OpenAIConfig from .configuration import Config, AIConfig, OpenAIConfig
from .ai import AI, AIError from .ai import AI, AIError
from .ais.openai import OpenAI from .ais.openai import OpenAI
def create_ai(args: argparse.Namespace, config: Config) -> AI: # noqa: 11 def create_ai(args: argparse.Namespace, config: Config, # noqa: 11
def_ai: Optional[str] = None,
def_model: Optional[str] = None) -> AI:
""" """
Creates an AI subclass instance from the given arguments Creates an AI subclass instance from the given arguments and configuration file.
and configuration file. If AI has not been set in the If AI has not been set in the arguments, it searches for the ID 'default'. If
arguments, it searches for the ID 'default'. If that that is not found, it uses the first AI in the list. It's also possible to
is not found, it uses the first AI in the list. specify a default AI and model using 'def_ai' and 'def_model'.
""" """
ai_conf: AIConfig ai_conf: AIConfig
if args.AI: if hasattr(args, 'AI') and args.AI:
try: try:
ai_conf = config.ais[args.AI] ai_conf = config.ais[args.AI]
except KeyError: except KeyError:
raise AIError(f"AI ID '{args.AI}' does not exist in this configuration") raise AIError(f"AI ID '{args.AI}' does not exist in this configuration")
elif def_ai:
ai_conf = config.ais[def_ai]
elif 'default' in config.ais: elif 'default' in config.ais:
ai_conf = config.ais['default'] ai_conf = config.ais['default']
else: else:
@@ -32,11 +36,13 @@ def create_ai(args: argparse.Namespace, config: Config) -> AI: # noqa: 11
if ai_conf.name == 'openai': if ai_conf.name == 'openai':
ai = OpenAI(cast(OpenAIConfig, ai_conf)) ai = OpenAI(cast(OpenAIConfig, ai_conf))
if args.model: if hasattr(args, 'model') and args.model:
ai.config.model = args.model ai.config.model = args.model
if args.max_tokens: elif def_model:
ai.config.model = def_model
if hasattr(args, 'max_tokens') and args.max_tokens:
ai.config.max_tokens = args.max_tokens ai.config.max_tokens = args.max_tokens
if args.temperature: if hasattr(args, 'temperature') and args.temperature:
ai.config.temperature = args.temperature ai.config.temperature = args.temperature
return ai return ai
else: else:
+76 -18
View File
@@ -2,7 +2,8 @@
Implements the OpenAI client classes and functions. Implements the OpenAI client classes and functions.
""" """
import openai import openai
from typing import Optional, Union import tiktoken
from typing import Optional, Union, Generator
from ..tags import Tag from ..tags import Tag
from ..message import Message, Answer from ..message import Message, Answer
from ..chat import Chat from ..chat import Chat
@@ -12,6 +13,52 @@ from ..configuration import OpenAIConfig
ChatType = list[dict[str, str]] ChatType = list[dict[str, str]]
class OpenAIAnswer:
def __init__(self,
idx: int,
streams: dict[int, 'OpenAIAnswer'],
response: openai.ChatCompletion,
tokens: Tokens,
encoding: tiktoken.core.Encoding) -> None:
self.idx = idx
self.streams = streams
self.response = response
self.position: int = 0
self.encoding = encoding
self.data: list[str] = []
self.finished: bool = False
self.tokens = tokens
def stream(self) -> Generator[str, None, None]:
while True:
if not self.next():
continue
if len(self.data) <= self.position:
break
yield self.data[self.position]
self.position += 1
def next(self) -> bool:
if self.finished:
return True
try:
chunk = next(self.response)
except StopIteration:
self.finished = True
if not self.finished:
found_choice = False
for choice in chunk['choices']:
if not choice['finish_reason']:
self.streams[choice['index']].data.append(choice['delta']['content'])
self.tokens.completion += len(self.encoding.encode(choice['delta']['content']))
self.tokens.total = self.tokens.prompt + self.tokens.completion
if choice['index'] == self.idx:
found_choice = True
if not found_choice:
return False
return True
class OpenAI(AI): class OpenAI(AI):
""" """
The OpenAI AI client. The OpenAI AI client.
@@ -21,7 +68,7 @@ class OpenAI(AI):
self.ID = config.ID self.ID = config.ID
self.name = config.name self.name = config.name
self.config = config self.config = config
openai.api_key = config.api_key openai.api_key = self.config.api_key
def request(self, def request(self,
question: Message, question: Message,
@@ -33,7 +80,9 @@ class OpenAI(AI):
chat history. The nr. of requested answers corresponds to the chat history. The nr. of requested answers corresponds to the
nr. of messages in the 'AIResponse'. nr. of messages in the 'AIResponse'.
""" """
oai_chat = self.openai_chat(chat, self.config.system, question) self.encoding = tiktoken.encoding_for_model(self.config.model)
oai_chat, prompt_tokens = self.openai_chat(chat, self.config.system, question)
tokens: Tokens = Tokens(prompt_tokens, 0, prompt_tokens)
response = openai.ChatCompletion.create( response = openai.ChatCompletion.create(
model=self.config.model, model=self.config.model,
messages=oai_chat, messages=oai_chat,
@@ -41,28 +90,35 @@ class OpenAI(AI):
max_tokens=self.config.max_tokens, max_tokens=self.config.max_tokens,
top_p=self.config.top_p, top_p=self.config.top_p,
n=num_answers, n=num_answers,
stream=True,
frequency_penalty=self.config.frequency_penalty, frequency_penalty=self.config.frequency_penalty,
presence_penalty=self.config.presence_penalty) presence_penalty=self.config.presence_penalty)
question.answer = Answer(response['choices'][0]['message']['content']) streams: dict[int, OpenAIAnswer] = {}
question.tags = otags for n in range(num_answers):
streams[n] = OpenAIAnswer(n, streams, response, tokens, self.encoding)
question.answer = Answer(streams[0].stream())
question.tags = set(otags) if otags is not None else None
question.ai = self.ID question.ai = self.ID
question.model = self.config.model question.model = self.config.model
answers: list[Message] = [question] answers: list[Message] = [question]
for choice in response['choices'][1:]: # type: ignore for idx in range(1, num_answers):
answers.append(Message(question=question.question, answers.append(Message(question=question.question,
answer=Answer(choice['message']['content']), answer=Answer(streams[idx].stream()),
tags=otags, tags=otags,
ai=self.ID, ai=self.ID,
model=self.config.model)) model=self.config.model))
return AIResponse(answers, Tokens(response['usage']['prompt_tokens'], return AIResponse(answers, tokens)
response['usage']['completion_tokens'],
response['usage']['total_tokens']))
def models(self) -> list[str]: def models(self) -> list[str]:
""" """
Return all models supported by this AI. Return all models supported by this AI.
""" """
raise NotImplementedError ret = []
for engine in sorted(openai.Engine.list()['data'], key=lambda x: x['id']):
if engine['ready']:
ret.append(engine['id'])
ret.sort()
return ret
def print_models(self) -> None: def print_models(self) -> None:
""" """
@@ -78,24 +134,26 @@ class OpenAI(AI):
print('\nNot ready: ' + ', '.join(not_ready)) print('\nNot ready: ' + ', '.join(not_ready))
def openai_chat(self, chat: Chat, system: str, def openai_chat(self, chat: Chat, system: str,
question: Optional[Message] = None) -> ChatType: question: Optional[Message] = None) -> tuple[ChatType, int]:
""" """
Create a chat history with system message in OpenAI format. Create a chat history with system message in OpenAI format.
Optionally append a new question. Optionally append a new question.
""" """
oai_chat: ChatType = [] oai_chat: ChatType = []
prompt_tokens: int = 0
def append(role: str, content: str) -> None: def append(role: str, content: str) -> int:
oai_chat.append({'role': role, 'content': content.replace("''", "'")}) oai_chat.append({'role': role, 'content': content.replace("''", "'")})
return len(self.encoding.encode(', '.join(['role:', oai_chat[-1]['role'], 'content:', oai_chat[-1]['content']])))
append('system', system) prompt_tokens += append('system', system)
for message in chat.messages: for message in chat.messages:
if message.answer: if message.answer:
append('user', message.question) prompt_tokens += append('user', message.question)
append('assistant', message.answer) prompt_tokens += append('assistant', str(message.answer))
if question: if question:
append('user', question.question) prompt_tokens += append('user', question.question)
return oai_chat return oai_chat, prompt_tokens
def tokens(self, data: Union[Message, Chat]) -> int: def tokens(self, data: Union[Message, Chat]) -> int:
raise NotImplementedError raise NotImplementedError
+395 -153
View File
@@ -6,13 +6,27 @@ from pathlib import Path
from pprint import PrettyPrinter from pprint import PrettyPrinter
from pydoc import pager from pydoc import pager
from dataclasses import dataclass from dataclasses import dataclass
from typing import TypeVar, Type, Optional, ClassVar, Any, Callable from enum import Enum
from .message import Message, MessageFilter, MessageError, message_in from typing import TypeVar, Type, Optional, Any, Callable, Union
from .configuration import default_config_file
from .message import Message, MessageFilter, MessageError, MessageFormat, message_in, message_valid_formats
from .tags import Tag from .tags import Tag
ChatInst = TypeVar('ChatInst', bound='Chat') ChatInst = TypeVar('ChatInst', bound='Chat')
ChatDBInst = TypeVar('ChatDBInst', bound='ChatDB') ChatDBInst = TypeVar('ChatDBInst', bound='ChatDB')
db_next_file = '.next'
ignored_files = [db_next_file, default_config_file]
msg_suffix = Message.file_suffix_write
class msg_location(Enum):
MEM = 'mem'
DISK = 'disk'
CACHE = 'cache'
DB = 'db'
ALL = 'all'
class ChatError(Exception): class ChatError(Exception):
pass pass
@@ -38,40 +52,40 @@ def read_dir(dir_path: Path,
Parameters: Parameters:
* 'dir_path': source directory * 'dir_path': source directory
* 'glob': if specified, files will be filtered using 'path.glob()', * 'glob': if specified, files will be filtered using 'path.glob()',
otherwise it uses 'path.iterdir()'. otherwise it reads all files with the default message suffix
* 'mfilter': use with 'Message.from_file()' to filter messages * 'mfilter': use with 'Message.from_file()' to filter messages
when reading them. when reading them.
""" """
messages: list[Message] = [] messages: list[Message] = []
file_iter = dir_path.glob(glob) if glob else dir_path.iterdir() file_iter = dir_path.glob(glob) if glob else dir_path.glob(f'*{msg_suffix}')
for file_path in sorted(file_iter): for file_path in sorted(file_iter):
if file_path.is_file() and file_path.suffix in Message.file_suffixes: if (file_path.is_file()
and file_path.name not in ignored_files # noqa: W503
and file_path.suffix in Message.file_suffixes_read): # noqa: W503
try: try:
message = Message.from_file(file_path, mfilter) message = Message.from_file(file_path, mfilter)
if message: if message:
messages.append(message) messages.append(message)
except MessageError as e: except MessageError as e:
print(f"Error processing message in '{file_path}': {str(e)}") print(f"WARNING: Skipping message in '{file_path}': {str(e)}")
return messages return messages
def make_file_path(dir_path: Path, def make_file_path(dir_path: Path,
file_suffix: str,
next_fid: Callable[[], int]) -> Path: next_fid: Callable[[], int]) -> Path:
""" """
Create a file_path for the given directory using the Create a file_path for the given directory using the given ID generator function.
given file_suffix and ID generator function.
""" """
file_path = dir_path / f"{next_fid():04d}{file_suffix}" file_path = dir_path / f"{next_fid():04d}{msg_suffix}"
while file_path.exists(): while file_path.exists():
file_path = dir_path / f"{next_fid():04d}{file_suffix}" file_path = dir_path / f"{next_fid():04d}{msg_suffix}"
return file_path return file_path
def write_dir(dir_path: Path, def write_dir(dir_path: Path,
messages: list[Message], messages: list[Message],
file_suffix: str, next_fid: Callable[[], int],
next_fid: Callable[[], int]) -> None: mformat: MessageFormat = Message.default_format) -> None:
""" """
Write all messages to the given directory. If a message has no file_path, Write all messages to the given directory. If a message has no file_path,
a new one will be created. If message.file_path exists, it will be modified a new one will be created. If message.file_path exists, it will be modified
@@ -79,18 +93,17 @@ def write_dir(dir_path: Path,
Parameters: Parameters:
* 'dir_path': destination directory * 'dir_path': destination directory
* 'messages': list of messages to write * 'messages': list of messages to write
* 'file_suffix': suffix for the message files ['.txt'|'.yaml']
* 'next_fid': callable that returns the next file ID * 'next_fid': callable that returns the next file ID
""" """
for message in messages: for message in messages:
file_path = message.file_path file_path = message.file_path
# message has no file_path: create one # message has no file_path: create one
if not file_path: if not file_path:
file_path = make_file_path(dir_path, file_suffix, next_fid) file_path = make_file_path(dir_path, next_fid)
# file_path does not point to given directory: modify it # file_path does not point to given directory: modify it
elif not file_path.parent.samefile(dir_path): elif not file_path.parent.samefile(dir_path):
file_path = dir_path / file_path.name file_path = dir_path / file_path.name
message.to_file(file_path) message.to_file(file_path, mformat=mformat)
def clear_dir(dir_path: Path, def clear_dir(dir_path: Path,
@@ -100,7 +113,9 @@ def clear_dir(dir_path: Path,
""" """
file_iter = dir_path.glob(glob) if glob else dir_path.iterdir() file_iter = dir_path.glob(glob) if glob else dir_path.iterdir()
for file_path in file_iter: for file_path in file_iter:
if file_path.is_file() and file_path.suffix in Message.file_suffixes: if (file_path.is_file()
and file_path.name not in ignored_files # noqa: W503
and file_path.suffix in Message.file_suffixes_read): # noqa: W503
file_path.unlink(missing_ok=True) file_path.unlink(missing_ok=True)
@@ -112,14 +127,43 @@ class Chat:
messages: list[Message] messages: list[Message]
def filter(self, mfilter: MessageFilter) -> None: def __post_init__(self) -> None:
self.validate()
def validate(self) -> None:
"""
Validate this Chat instance.
"""
def msg_paths(stem: str) -> list[str]:
return [str(fp) for fp in file_paths if fp.stem == stem]
file_paths: set[Path] = {m.file_path for m in self.messages if m.file_path is not None}
file_stems = [m.file_path.stem for m in self.messages if m.file_path is not None]
error = False
for fp in file_paths:
if file_stems.count(fp.stem) > 1:
print(f"ERROR: Found multiple copies of message '{fp.stem}': {msg_paths(fp.stem)}")
error = True
if error:
raise ChatError("Validation failed")
def msg_name_matches(self, file_path: Path, name: str) -> bool:
"""
Return True if the given name matches the given file_path.
Matching is True if:
* 'name' matches the full 'file_path'
* 'name' matches 'file_path.name' (i. e. including the suffix)
* 'name' matches 'file_path.stem' (i. e. without the suffix)
"""
return Path(name) == file_path or name == file_path.name or name == file_path.stem
def msg_filter(self, mfilter: MessageFilter) -> None:
""" """
Use 'Message.match(mfilter) to remove all messages that Use 'Message.match(mfilter) to remove all messages that
don't fulfill the filter requirements. don't fulfill the filter requirements.
""" """
self.messages = [m for m in self.messages if m.match(mfilter)] self.messages = [m for m in self.messages if m.match(mfilter)]
def sort(self, reverse: bool = False) -> None: def msg_sort(self, reverse: bool = False) -> None:
""" """
Sort the messages according to 'Message.msg_id()'. Sort the messages according to 'Message.msg_id()'.
""" """
@@ -129,48 +173,71 @@ class Chat:
except MessageError: except MessageError:
pass pass
def clear(self) -> None: def msg_unique_id(self) -> None:
"""
Remove duplicates from the internal messages, based on the msg_id (i. e. file_path).
Messages without a file_path are kept.
"""
old_msgs = self.messages.copy()
self.messages = []
for m in old_msgs:
if not message_in(m, self.messages):
self.messages.append(m)
self.msg_sort()
def msg_unique_content(self) -> None:
"""
Remove duplicates from the internal messages, based on the content (i. e. question + answer).
"""
self.messages = list(set(self.messages))
self.msg_sort()
def msg_clear(self) -> None:
""" """
Delete all messages. Delete all messages.
""" """
self.messages = [] self.messages = []
def add_messages(self, messages: list[Message]) -> None: def msg_add(self, messages: list[Message]) -> None:
""" """
Add new messages and sort them if possible. Add new messages and sort them if possible.
""" """
self.messages += messages self.messages += messages
self.sort() self.msg_sort()
def latest_message(self) -> Optional[Message]: def msg_latest(self, mfilter: Optional[MessageFilter] = None) -> Optional[Message]:
""" """
Returns the last added message (according to the file ID). Return the last added message (according to the file ID) that matches the given filter.
When containing messages without a valid file_path, it returns the latest message in
the internal list.
""" """
if len(self.messages) > 0: if len(self.messages) > 0:
self.sort() self.msg_sort()
return self.messages[-1] for m in reversed(self.messages):
else: if mfilter is None or m.match(mfilter):
return m
return None return None
def find_messages(self, msg_names: list[str]) -> list[Message]: def msg_find(self, msg_names: list[str]) -> list[Message]:
""" """
Search and return the messages with the given names. Names can either be filenames Search and return the messages with the given names. Names can either be filenames
(incl. suffixes) or full paths. Messages that can't be found are ignored (i. e. the (with or without suffix), full paths or Message.msg_id(). Messages that can't be
caller should check the result if he requires all messages). found are ignored (i. e. the caller should check the result if they require all
messages).
""" """
return [m for m in self.messages return [m for m in self.messages
if any((m.file_path and (m.file_path == Path(mn) or m.file_path.name == mn)) for mn in msg_names)] if any((m.file_path and self.msg_name_matches(m.file_path, mn)) for mn in msg_names)]
def remove_messages(self, msg_names: list[str]) -> None: def msg_remove(self, msg_names: list[str]) -> None:
""" """
Remove the messages with the given names. Names can either be filenames Remove the messages with the given names. Names can either be filenames
(incl. the suffix) or full paths. (with or without suffix), full paths or Message.msg_id().
""" """
self.messages = [m for m in self.messages self.messages = [m for m in self.messages
if not any((m.file_path and (m.file_path == Path(mn) or m.file_path.name == mn)) for mn in msg_names)] if not any((m.file_path and self.msg_name_matches(m.file_path, mn)) for mn in msg_names)]
self.sort() self.msg_sort()
def tags(self, prefix: Optional[str] = None, contain: Optional[str] = None) -> set[Tag]: def msg_tags(self, prefix: Optional[str] = None, contain: Optional[str] = None) -> set[Tag]:
""" """
Get the tags of all messages, optionally filtered by prefix or substring. Get the tags of all messages, optionally filtered by prefix or substring.
""" """
@@ -179,7 +246,7 @@ class Chat:
tags |= m.filter_tags(prefix, contain) tags |= m.filter_tags(prefix, contain)
return set(sorted(tags)) return set(sorted(tags))
def tags_frequency(self, prefix: Optional[str] = None, contain: Optional[str] = None) -> dict[Tag, int]: def msg_tags_frequency(self, prefix: Optional[str] = None, contain: Optional[str] = None) -> dict[Tag, int]:
""" """
Get the frequency of all tags of all messages, optionally filtered by prefix or substring. Get the frequency of all tags of all messages, optionally filtered by prefix or substring.
""" """
@@ -196,14 +263,16 @@ class Chat:
return sum(m.tokens() for m in self.messages) return sum(m.tokens() for m in self.messages)
def print(self, source_code_only: bool = False, def print(self, source_code_only: bool = False,
with_tags: bool = False, with_files: bool = False, with_metadata: bool = False,
paged: bool = True) -> None: paged: bool = True,
tight: bool = False) -> None:
output: list[str] = [] output: list[str] = []
for message in self.messages: for message in self.messages:
if source_code_only: if source_code_only:
output.append(message.to_str(source_code_only=True)) output.append(message.to_str(source_code_only=True))
continue continue
output.append(message.to_str(with_tags, with_files)) output.append(message.to_str(with_metadata))
if not tight:
output.append('\n' + ('-' * terminal_width()) + '\n') output.append('\n' + ('-' * terminal_width()) + '\n')
if paged: if paged:
print_paged('\n'.join(output)) print_paged('\n'.join(output))
@@ -221,43 +290,50 @@ class ChatDB(Chat):
persistently. persistently.
""" """
default_file_suffix: ClassVar[str] = '.txt'
cache_path: Path cache_path: Path
db_path: Path db_path: Path
# a MessageFilter that all messages must match (if given) # a MessageFilter that all messages must match (if given)
mfilter: Optional[MessageFilter] = None mfilter: Optional[MessageFilter] = None
file_suffix: str = default_file_suffix
# the glob pattern for all messages # the glob pattern for all messages
glob: Optional[str] = None glob: str = f'*{msg_suffix}'
# message format (for writing)
mformat: MessageFormat = Message.default_format
def __post_init__(self) -> None: def __post_init__(self) -> None:
# contains the latest message ID # contains the latest message ID
self.next_fname = self.db_path / '.next' self.next_path = self.db_path / db_next_file
# make all paths absolute # make all paths absolute
self.cache_path = self.cache_path.absolute() self.cache_path = self.cache_path.absolute()
self.db_path = self.db_path.absolute() self.db_path = self.db_path.absolute()
self.validate()
@classmethod @classmethod
def from_dir(cls: Type[ChatDBInst], def from_dir(cls: Type[ChatDBInst],
cache_path: Path, cache_path: Path,
db_path: Path, db_path: Path,
glob: Optional[str] = None, glob: str = f'*{msg_suffix}',
mfilter: Optional[MessageFilter] = None) -> ChatDBInst: mfilter: Optional[MessageFilter] = None,
loc: msg_location = msg_location.DB) -> ChatDBInst:
""" """
Create a 'ChatDB' instance from the given directory structure. Create a 'ChatDB' instance from the given directory structure.
Reads all messages from 'db_path' into the local message list. Reads all messages from 'db_path' into the local message list.
Parameters: Parameters:
* 'cache_path': path to the directory for temporary messages * 'cache_path': path to the directory for temporary messages
* 'db_path': path to the directory for persistent messages * 'db_path': path to the directory for persistent messages
* 'glob': if specified, files will be filtered using 'path.glob()', * 'glob': if specified, files will be filtered using 'path.glob()'
otherwise it uses 'path.iterdir()'.
* 'mfilter': use with 'Message.from_file()' to filter messages * 'mfilter': use with 'Message.from_file()' to filter messages
when reading them. when reading them.
* 'loc': read messages from given location instead of 'db_path'
""" """
messages = read_dir(db_path, glob, mfilter) if loc == msg_location.MEM:
return cls(messages, cache_path, db_path, mfilter, raise ChatError(f"Can't build ChatDB from message location '{loc}'")
cls.default_file_suffix, glob) messages: list[Message] = []
if loc in [msg_location.DB, msg_location.DISK, msg_location.ALL]:
messages.extend(read_dir(db_path, glob, mfilter))
if loc in [msg_location.CACHE, msg_location.DISK, msg_location.ALL]:
messages.extend(read_dir(cache_path, glob, mfilter))
messages.sort(key=lambda x: x.msg_id())
return cls(messages, cache_path, db_path, mfilter, glob)
@classmethod @classmethod
def from_messages(cls: Type[ChatDBInst], def from_messages(cls: Type[ChatDBInst],
@@ -272,7 +348,7 @@ class ChatDB(Chat):
def get_next_fid(self) -> int: def get_next_fid(self) -> int:
try: try:
with open(self.next_fname, 'r') as f: with open(self.next_path, 'r') as f:
next_fid = int(f.read()) + 1 next_fid = int(f.read()) + 1
self.set_next_fid(next_fid) self.set_next_fid(next_fid)
return next_fid return next_fid
@@ -281,105 +357,22 @@ class ChatDB(Chat):
return 1 return 1
def set_next_fid(self, fid: int) -> None: def set_next_fid(self, fid: int) -> None:
with open(self.next_fname, 'w') as f: with open(self.next_path, 'w') as f:
f.write(f'{fid}') f.write(f'{fid}')
def read_db(self) -> None: def set_msg_format(self, mformat: MessageFormat) -> None:
""" """
Reads new messages from the DB directory. New ones are added to the internal list, Set message format for writing messages.
existing ones are replaced. A message is determined as 'existing' if a message with
the same base filename (i. e. 'file_path.name') is already in the list.
""" """
new_messages = read_dir(self.db_path, self.glob, self.mfilter) if mformat not in message_valid_formats:
# remove all messages from self.messages that are in the new list raise ChatError(f"Message format '{mformat}' is not supported")
self.messages = [m for m in self.messages if not message_in(m, new_messages)] self.mformat = mformat
# copy the messages from the temporary list to self.messages and sort them
self.messages += new_messages
self.sort()
def read_cache(self) -> None: def msg_write(self,
messages: Optional[list[Message]] = None,
mformat: Optional[MessageFormat] = None) -> None:
""" """
Reads new messages from the cache directory. New ones are added to the internal list, Write either the given messages or the internal ones to their CURRENT file_path.
existing ones are replaced. A message is determined as 'existing' if a message with
the same base filename (i. e. 'file_path.name') is already in the list.
"""
new_messages = read_dir(self.cache_path, self.glob, self.mfilter)
# remove all messages from self.messages that are in the new list
self.messages = [m for m in self.messages if not message_in(m, new_messages)]
# copy the messages from the temporary list to self.messages and sort them
self.messages += new_messages
self.sort()
def write_db(self, messages: Optional[list[Message]] = None) -> None:
"""
Write messages to the DB directory. If a message has no file_path, a new one
will be created. If message.file_path exists, it will be modified to point
to the DB directory.
"""
write_dir(self.db_path,
messages if messages else self.messages,
self.file_suffix,
self.get_next_fid)
def write_cache(self, messages: Optional[list[Message]] = None) -> None:
"""
Write messages to the cache directory. If a message has no file_path, a new one
will be created. If message.file_path exists, it will be modified to point to
the cache directory.
"""
write_dir(self.cache_path,
messages if messages else self.messages,
self.file_suffix,
self.get_next_fid)
def clear_cache(self) -> None:
"""
Deletes all Message files from the cache dir and removes those messages from
the internal list.
"""
clear_dir(self.cache_path, self.glob)
# only keep messages from DB dir (or those that have not yet been written)
self.messages = [m for m in self.messages if not m.file_path or m.file_path.parent.samefile(self.db_path)]
def add_to_db(self, messages: list[Message], write: bool = True) -> None:
"""
Add the given new messages and set the file_path to the DB directory.
Only accepts messages without a file_path.
"""
if any(m.file_path is not None for m in messages):
raise ChatError("Can't add new messages with existing file_path")
if write:
write_dir(self.db_path,
messages,
self.file_suffix,
self.get_next_fid)
else:
for m in messages:
m.file_path = make_file_path(self.db_path, self.default_file_suffix, self.get_next_fid)
self.messages += messages
self.sort()
def add_to_cache(self, messages: list[Message], write: bool = True) -> None:
"""
Add the given new messages and set the file_path to the cache directory.
Only accepts messages without a file_path.
"""
if any(m.file_path is not None for m in messages):
raise ChatError("Can't add new messages with existing file_path")
if write:
write_dir(self.cache_path,
messages,
self.file_suffix,
self.get_next_fid)
else:
for m in messages:
m.file_path = make_file_path(self.cache_path, self.default_file_suffix, self.get_next_fid)
self.messages += messages
self.sort()
def write_messages(self, messages: Optional[list[Message]] = None) -> None:
"""
Write either the given messages or the internal ones to their current file_path.
If messages are given, they all must have a valid file_path. When writing the If messages are given, they all must have a valid file_path. When writing the
internal messages, the ones with a valid file_path are written, the others internal messages, the ones with a valid file_path are written, the others
are ignored. are ignored.
@@ -388,20 +381,269 @@ class ChatDB(Chat):
raise ChatError("Can't write files without a valid file_path") raise ChatError("Can't write files without a valid file_path")
msgs = iter(messages if messages else self.messages) msgs = iter(messages if messages else self.messages)
while (m := next(msgs, None)): while (m := next(msgs, None)):
m.to_file() m.to_file(mformat=mformat if mformat else self.mformat)
def update_messages(self, messages: list[Message], write: bool = True) -> None: def msg_update(self, messages: list[Message], write: bool = True) -> None:
""" """
Update existing messages. A message is determined as 'existing' if a message with Update EXISTING messages. A message is determined as 'existing' if a message with
the same base filename (i. e. 'file_path.name') is already in the list. Only accepts the same base filename (i. e. 'file_path.name') is already in the list.
existing messages. Only accepts existing messages.
""" """
if any(not message_in(m, self.messages) for m in messages): if any(not message_in(m, self.messages) for m in messages):
raise ChatError("Can't update messages that are not in the internal list") raise ChatError("Can't update messages that are not in the internal list")
# remove old versions and add new ones # remove old versions and add new ones
self.messages = [m for m in self.messages if not message_in(m, messages)] self.messages = [m for m in self.messages if not message_in(m, messages)]
self.messages += messages self.messages += messages
self.sort() self.msg_sort()
# write the UPDATED messages if requested # write the UPDATED messages if requested
if write: if write:
self.write_messages(messages) self.msg_write(messages)
def msg_gather(self,
loc: msg_location,
require_file_path: bool = False,
glob: str = f'*{msg_suffix}',
mfilter: Optional[MessageFilter] = None) -> list[Message]:
"""
Gather and return messages from the given locations:
* 'mem' : messages currently in memory
* 'disk' : messages on disk (cache + DB directory), but not in memory
* 'cache': messages in the cache directory
* 'db' : messages in the DB directory
* 'all' : all messages ('mem' + 'disk')
If 'require_file_path' is True, return only files with a valid file_path.
"""
loc_messages: list[Message] = []
if loc in [msg_location.MEM, msg_location.ALL]:
if require_file_path:
loc_messages += [m for m in self.messages if (m.file_path is not None and (mfilter is None or m.match(mfilter)))]
else:
loc_messages += [m for m in self.messages if (mfilter is None or m.match(mfilter))]
if loc in [msg_location.CACHE, msg_location.DISK, msg_location.ALL]:
loc_messages += read_dir(self.cache_path, glob=glob, mfilter=mfilter)
if loc in [msg_location.DB, msg_location.DISK, msg_location.ALL]:
loc_messages += read_dir(self.db_path, glob=glob, mfilter=mfilter)
# remove_duplicates and sort the list
unique_messages: list[Message] = []
for m in loc_messages:
if not message_in(m, unique_messages):
unique_messages.append(m)
try:
unique_messages.sort(key=lambda m: m.msg_id())
# messages in 'mem' can have an empty file_path
except MessageError:
pass
return unique_messages
def msg_find(self,
msg_names: list[str],
loc: msg_location = msg_location.MEM,
) -> list[Message]:
"""
Search and return the messages with the given names. Names can either be filenames
(with or without suffix), full paths or Message.msg_id(). Messages that can't be
found are ignored (i. e. the caller should check the result if they require all
messages).
Searches one of the following locations:
* 'mem' : messages currently in memory
* 'disk' : messages on disk (cache + DB directory), but not in memory
* 'cache': messages in the cache directory
* 'db' : messages in the DB directory
* 'all' : all messages ('mem' + 'disk')
"""
loc_messages = self.msg_gather(loc, require_file_path=True)
return [m for m in loc_messages
if any((m.file_path and self.msg_name_matches(m.file_path, mn)) for mn in msg_names)]
def msg_remove(self, msg_names: list[str], loc: msg_location = msg_location.MEM) -> None:
"""
Remove the messages with the given names. Names can either be filenames
(with or without suffix), full paths or Message.msg_id(). Also deletes the
files of all given messages with a valid file_path.
Delete files from one of the following locations:
* 'mem' : messages currently in memory
* 'disk' : messages on disk (cache + DB directory), but not in memory
* 'cache': messages in the cache directory
* 'db' : messages in the DB directory
* 'all' : all messages ('mem' + 'disk')
"""
if loc != msg_location.MEM:
# delete the message files first
rm_messages = self.msg_find(msg_names, loc=loc)
for m in rm_messages:
if (m.file_path):
m.file_path.unlink()
# then remove them from the internal list
super().msg_remove(msg_names)
def msg_latest(self,
mfilter: Optional[MessageFilter] = None,
loc: msg_location = msg_location.MEM) -> Optional[Message]:
"""
Return the last added message (according to the file ID) that matches the given filter.
Only consider messages with a valid file_path (except if loc is 'mem').
Searches one of the following locations:
* 'mem' : messages currently in memory
* 'disk' : messages on disk (cache + DB directory), but not in memory
* 'cache': messages in the cache directory
* 'db' : messages in the DB directory
* 'all' : all messages ('mem' + 'disk')
"""
# only consider messages with a valid file_path so they can be sorted
loc_messages = self.msg_gather(loc, require_file_path=True)
loc_messages.sort(key=lambda m: m.msg_id(), reverse=True)
for m in loc_messages:
if mfilter is None or m.match(mfilter):
return m
return None
def msg_in_cache(self, message: Union[Message, str]) -> bool:
"""
Return true if the given Message (or filename or Message.msg_id())
is located in the cache directory. False otherwise.
"""
if isinstance(message, Message):
return (message.file_path is not None
and message.file_path.parent.samefile(self.cache_path) # noqa: W503
and message.file_path.exists()) # noqa: W503
else:
return len(self.msg_find([message], loc=msg_location.CACHE)) > 0
def msg_in_db(self, message: Union[Message, str]) -> bool:
"""
Return true if the given Message (or filename or Message.msg_id())
is located in the DB directory. False otherwise.
"""
if isinstance(message, Message):
return (message.file_path is not None
and message.file_path.parent.samefile(self.db_path) # noqa: W503
and message.file_path.exists()) # noqa: W503
else:
return len(self.msg_find([message], loc=msg_location.DB)) > 0
def cache_read(self, glob: str = f'*{msg_suffix}', mfilter: Optional[MessageFilter] = None) -> None:
"""
Read messages from the cache directory. New ones are added to the internal list,
existing ones are replaced. A message is determined as 'existing' if a message
with the same base filename (i. e. 'file_path.name') is already in the list.
"""
new_messages = read_dir(self.cache_path, glob, mfilter)
# remove all messages from self.messages that are in the new list
self.messages = [m for m in self.messages if not message_in(m, new_messages)]
# copy the messages from the temporary list to self.messages and sort them
self.messages += new_messages
self.msg_sort()
def cache_write(self, messages: Optional[list[Message]] = None) -> None:
"""
Write messages to the cache directory. If a message has no file_path, a new one
will be created. If message.file_path exists, it will be modified to point to
the cache directory.
Does NOT add the messages to the internal list (use 'cache_add()' for that)!
"""
write_dir(self.cache_path,
messages if messages else self.messages,
self.get_next_fid,
self.mformat)
def cache_add(self, messages: list[Message], write: bool = True) -> None:
"""
Add NEW messages and set the file_path to the cache directory.
Only accepts messages without a file_path.
"""
if any(m.file_path is not None for m in messages):
raise ChatError("Can't add new messages with existing file_path")
if write:
write_dir(self.cache_path,
messages,
self.get_next_fid,
self.mformat)
else:
for m in messages:
m.file_path = make_file_path(self.cache_path, self.get_next_fid)
self.messages += messages
self.msg_sort()
def cache_clear(self, glob: str = f'*{msg_suffix}') -> None:
"""
Delete all message files from the cache dir and remove them from the internal list.
"""
clear_dir(self.cache_path, glob)
# only keep messages from DB dir (or those that have not yet been written)
self.messages = [m for m in self.messages if not m.file_path or m.file_path.parent.samefile(self.db_path)]
def cache_move(self, message: Message) -> None:
"""
Moves the given messages to the cache directory.
"""
# remember the old path (if any)
old_path: Optional[Path] = None
if message.file_path:
old_path = message.file_path
# write message to the new destination
self.cache_write([message])
# remove the old one (if any)
if old_path:
self.msg_remove([str(old_path)], loc=msg_location.DB)
# (re)add it to the internal list
self.msg_add([message])
def db_read(self, glob: str = f'*{msg_suffix}', mfilter: Optional[MessageFilter] = None) -> None:
"""
Read messages from the DB directory. New ones are added to the internal list,
existing ones are replaced. A message is determined as 'existing' if a message
with the same base filename (i. e. 'file_path.name') is already in the list.
"""
new_messages = read_dir(self.db_path, self.glob, self.mfilter)
# remove all messages from self.messages that are in the new list
self.messages = [m for m in self.messages if not message_in(m, new_messages)]
# copy the messages from the temporary list to self.messages and sort them
self.messages += new_messages
self.msg_sort()
def db_write(self, messages: Optional[list[Message]] = None) -> None:
"""
Write messages to the DB directory. If a message has no file_path, a new one
will be created. If message.file_path exists, it will be modified to point
to the DB directory.
Does NOT add the messages to the internal list (use 'db_add()' for that)!
"""
write_dir(self.db_path,
messages if messages else self.messages,
self.get_next_fid,
self.mformat)
def db_add(self, messages: list[Message], write: bool = True) -> None:
"""
Add NEW messages and set the file_path to the DB directory.
Only accepts messages without a file_path.
"""
if any(m.file_path is not None for m in messages):
raise ChatError("Can't add new messages with existing file_path")
if write:
write_dir(self.db_path,
messages,
self.get_next_fid,
self.mformat)
else:
for m in messages:
m.file_path = make_file_path(self.db_path, self.get_next_fid)
self.messages += messages
self.msg_sort()
def db_move(self, message: Message) -> None:
"""
Moves the given messages to the db directory.
"""
# remember the old path (if any)
old_path: Optional[Path] = None
if message.file_path:
old_path = message.file_path
# write message to the new destination
self.db_write([message])
# remove the old one (if any)
if old_path:
self.msg_remove([str(old_path)], loc=msg_location.CACHE)
# (re)add it to the internal list
self.msg_add([message])
+69
View File
@@ -0,0 +1,69 @@
"""
Contains shared functions for the various CMM subcommands.
"""
import argparse
from pathlib import Path
from ..message import Message, MessageError, source_code
def read_text_file(file: Path) -> str:
with open(file) as r:
content = r.read().strip()
return content
def add_file_as_text(question_parts: list[str], file: str) -> None:
"""
Add the given file as plain text to the question part list.
If the file is a Message, add the answer.
"""
file_path = Path(file)
content: str
try:
message = Message.from_file(file_path)
if message and message.answer:
content = message.answer
except MessageError:
content = read_text_file(Path(file))
if len(content) > 0:
question_parts.append(content)
def add_file_as_code(question_parts: list[str], file: str) -> None:
"""
Add all source code from the given file. If no code segments can be extracted,
the whole content is added as source code segment. If the file is a Message,
extract the source code from the answer.
"""
file_path = Path(file)
content: str
try:
message = Message.from_file(file_path)
if message and message.answer:
content = message.answer
except MessageError:
with open(file) as r:
content = r.read().strip()
# extract and add source code
code_parts = source_code(content, include_delims=True)
if len(code_parts) > 0:
question_parts += code_parts
else:
question_parts.append(f"```\n{content}\n```")
def invert_input_tag_args(args: argparse.Namespace) -> None:
"""
Changes the semantics of the INPUT tags for this command:
* not tags specified on the CLI -> no tags are selected
* empty tags specified on the CLI -> all tags are selected
"""
if args.or_tags is None:
args.or_tags = set()
elif len(args.or_tags) == 0:
args.or_tags = None
if args.and_tags is None:
args.and_tags = set()
elif len(args.and_tags) == 0:
args.and_tags = None
+9
View File
@@ -1,6 +1,8 @@
import argparse import argparse
from pathlib import Path from pathlib import Path
from ..configuration import Config from ..configuration import Config
from ..ai import AI
from ..ai_factory import create_ai
def config_cmd(args: argparse.Namespace) -> None: def config_cmd(args: argparse.Namespace) -> None:
@@ -9,3 +11,10 @@ def config_cmd(args: argparse.Namespace) -> None:
""" """
if args.create: if args.create:
Config.create_default(Path(args.create)) Config.create_default(Path(args.create))
elif args.list_models or args.print_model:
config: Config = Config.from_file(args.config)
ai: AI = create_ai(args, config)
if args.list_models:
ai.print_models()
else:
print(ai.config.model)
+95
View File
@@ -0,0 +1,95 @@
import sys
import argparse
from pathlib import Path
from pydoc import pager
from ..configuration import Config
from ..glossary import Glossary
class GlossaryCmdError(Exception):
pass
def print_paged(text: str) -> None:
pager(text)
def get_glossary_file_path(name: str, config: Config) -> Path:
"""
Get the complete filename for a glossary with the given path.
"""
if not config.glossaries:
raise GlossaryCmdError("Can't create glossary name without a glossary directory")
return Path(config.glossaries, name).with_suffix(Glossary.file_suffix).absolute()
def list_glossaries(args: argparse.Namespace, config: Config) -> None:
"""
List existing glossaries in the 'glossaries' directory.
"""
if not config.glossaries:
raise GlossaryCmdError("Glossaries directory missing in the configuration file")
glossaries = Path(config.glossaries).glob(f'*{Glossary.file_suffix}')
for glo in sorted(glossaries):
print(Glossary.from_file(glo).to_str())
def print_glossary(args: argparse.Namespace, config: Config) -> None:
"""
Print an existing glossary.
"""
# sanity checks
if args.name is None:
raise GlossaryCmdError("Missing glossary name")
if config.glossaries is None and args.file is None:
raise GlossaryCmdError("Glossaries directory missing in the configuration file")
# create file path or use the given one
glo_file = Path(args.file) if args.file else get_glossary_file_path(args.name, config)
if not glo_file.exists():
raise GlossaryCmdError(f"Glossary '{glo_file}' does not exist")
# read glossary
glo = Glossary.from_file(glo_file)
print_paged(glo.to_str(with_entries=True))
def create_glossary(args: argparse.Namespace, config: Config) -> None:
"""
Create a new glossary and write it either to the glossaries directory
or the given file.
"""
# sanity checks
if args.name is None:
raise GlossaryCmdError("Missing glossary name")
if args.source_lang is None:
raise GlossaryCmdError("Missing source language")
if args.target_lang is None:
raise GlossaryCmdError("Missing target language")
if config.glossaries is None and args.file is None:
raise GlossaryCmdError("Glossaries directory missing in the configuration file")
# create file path or use the given one
glo_file = Path(args.file) if args.file else get_glossary_file_path(args.name, config)
if glo_file.exists():
raise GlossaryCmdError(f"Glossary '{glo_file}' already exists")
glo = Glossary(name=args.name,
source_lang=args.source_lang,
target_lang=args.target_lang,
desc=args.description,
file_path=glo_file)
glo.to_file()
print(f"Successfully created new glossary '{glo_file}'.")
def glossary_cmd(args: argparse.Namespace, config: Config) -> None:
"""
Handler for the 'glossary' command.
"""
try:
if args.create:
create_glossary(args, config)
elif args.list:
list_glossaries(args, config)
elif args.print:
print_glossary(args, config)
except GlossaryCmdError as err:
print(f"Error: {err}")
sys.exit(1)
+60 -8
View File
@@ -1,13 +1,52 @@
import sys
import argparse import argparse
from pathlib import Path from pathlib import Path
from ..configuration import Config from ..configuration import Config
from ..chat import ChatDB from ..chat import ChatDB, msg_location
from ..message import MessageFilter from ..message import MessageFilter, Message
def hist_cmd(args: argparse.Namespace, config: Config) -> None: msg_suffix = Message.file_suffix_write # currently '.msg'
def convert_messages(args: argparse.Namespace, config: Config) -> None:
""" """
Handler for the 'hist' command. Convert messages to a new format. Also used to change old suffixes
('.txt', '.yaml') to the latest default message file suffix ('.msg').
"""
chat = ChatDB.from_dir(Path(config.cache),
Path(config.db),
glob='*')
# read all known message files
msgs = chat.msg_gather(loc=msg_location.DISK, glob='*.*')
# make a set of all message IDs
msg_ids = set([m.msg_id() for m in msgs])
# set requested format and write all messages
chat.set_msg_format(args.convert)
# delete the current suffix
# -> a new one will automatically be created
for m in msgs:
if m.file_path:
m.file_path = m.file_path.with_suffix('')
chat.msg_write(msgs)
# read all messages with the current default suffix
msgs = chat.msg_gather(loc=msg_location.DISK, glob=f'*{msg_suffix}')
# make sure we converted all of the original messages
for mid in msg_ids:
if not any(mid == m.msg_id() for m in msgs):
print(f"Message '{mid}' has not been found after conversion. Aborting.")
sys.exit(1)
# delete messages with old suffixes
msgs = chat.msg_gather(loc=msg_location.DISK, glob='*.*')
for m in msgs:
if m.file_path and m.file_path.suffix != msg_suffix:
m.rm_file()
print(f"Successfully converted {len(msg_ids)} messages.")
def print_chat(args: argparse.Namespace, config: Config) -> None:
"""
Print the DB chat history.
""" """
mfilter = MessageFilter(tags_or=args.or_tags, mfilter = MessageFilter(tags_or=args.or_tags,
@@ -15,9 +54,22 @@ def hist_cmd(args: argparse.Namespace, config: Config) -> None:
tags_not=args.exclude_tags, tags_not=args.exclude_tags,
question_contains=args.question, question_contains=args.question,
answer_contains=args.answer) answer_contains=args.answer)
chat = ChatDB.from_dir(Path('.'), chat = ChatDB.from_dir(Path(config.cache),
Path(config.db), Path(config.db),
mfilter=mfilter) mfilter=mfilter,
loc=msg_location(args.location),
glob=args.glob)
chat.print(args.source_code_only, chat.print(args.source_code_only,
args.with_tags, args.with_metadata,
args.with_files) paged=not args.no_paging,
tight=args.tight)
def hist_cmd(args: argparse.Namespace, config: Config) -> None:
"""
Handler for the 'hist' command.
"""
if args.print:
print_chat(args, config)
elif args.convert:
convert_messages(args, config)
+24 -6
View File
@@ -3,16 +3,13 @@ import argparse
from pathlib import Path from pathlib import Path
from ..configuration import Config from ..configuration import Config
from ..message import Message, MessageError from ..message import Message, MessageError
from ..chat import ChatDB, msg_location
def print_cmd(args: argparse.Namespace, config: Config) -> None: def print_message(message: Message, args: argparse.Namespace) -> None:
""" """
Handler for the 'print' command. Print given message according to give arguments.
""" """
fname = Path(args.file)
try:
message = Message.from_file(fname)
if message:
if args.question: if args.question:
print(message.question) print(message.question)
elif args.answer: elif args.answer:
@@ -22,6 +19,27 @@ def print_cmd(args: argparse.Namespace, config: Config) -> None:
print(code) print(code)
else: else:
print(message.to_str()) print(message.to_str())
def print_cmd(args: argparse.Namespace, config: Config) -> None:
"""
Handler for the 'print' command.
"""
# print given file
if args.file is not None:
fname = Path(args.file)
try:
message = Message.from_file(fname)
if message:
print_message(message, args)
except MessageError: except MessageError:
print(f"File is not a valid message: {args.file}") print(f"File is not a valid message: {args.file}")
sys.exit(1) sys.exit(1)
# print latest message
elif args.latest:
chat = ChatDB.from_dir(Path(config.cache), Path(config.db))
latest = chat.msg_latest(loc=msg_location.DISK)
if not latest:
print("No message found!")
sys.exit(1)
print_message(latest, args)
+125 -52
View File
@@ -1,93 +1,166 @@
import sys
import argparse import argparse
from pathlib import Path from pathlib import Path
from itertools import zip_longest from itertools import zip_longest
from copy import deepcopy
from .common import invert_input_tag_args, add_file_as_code, add_file_as_text
from ..configuration import Config from ..configuration import Config
from ..chat import ChatDB from ..chat import ChatDB, msg_location
from ..message import Message, MessageFilter, Question, source_code from ..message import Message, MessageFilter, Question
from ..ai_factory import create_ai from ..ai_factory import create_ai
from ..ai import AI, AIResponse from ..ai import AI, AIResponse
class QuestionCmdError(Exception):
pass
def create_msg_args(msg: Message, args: argparse.Namespace) -> argparse.Namespace:
"""
Takes an existing message and CLI arguments, and returns modified args based
on the members of the given message. Used e.g. when repeating messages, where
it's necessary to determine the correct AI, module and output tags to use
(either from the existing message or the given args).
"""
msg_args = args
# if AI, model or output tags have not been specified,
# use those from the original message
if (args.AI is None
or args.model is None # noqa: W503
or args.output_tags is None): # noqa: W503
msg_args = deepcopy(args)
if args.AI is None and msg.ai is not None:
msg_args.AI = msg.ai
if args.model is None and msg.model is not None:
msg_args.model = msg.model
if args.output_tags is None and msg.tags is not None:
msg_args.output_tags = msg.tags
return msg_args
def create_message(chat: ChatDB, args: argparse.Namespace) -> Message: def create_message(chat: ChatDB, args: argparse.Namespace) -> Message:
""" """
Creates (and writes) a new message from the given arguments. Create a new message from the given arguments and write it
to the cache directory.
""" """
question_parts = [] question_parts = []
question_list = args.ask if args.ask is not None else [] if args.create is not None:
question_list = args.create
elif args.ask is not None:
question_list = args.ask
else:
raise QuestionCmdError("No question found")
text_files = args.source_text if args.source_text is not None else [] text_files = args.source_text if args.source_text is not None else []
code_files = args.source_code if args.source_code is not None else [] code_files = args.source_code if args.source_code is not None else []
for question, source, code in zip_longest(question_list, text_files, code_files, fillvalue=None): for question, text_file, code_file in zip_longest(question_list, text_files, code_files, fillvalue=None):
if question is not None and len(question.strip()) > 0: if question is not None and len(question.strip()) > 0:
question_parts.append(question) question_parts.append(question)
if source is not None and len(source) > 0: if text_file is not None and len(text_file) > 0:
with open(source) as r: add_file_as_text(question_parts, text_file)
content = r.read().strip() if code_file is not None and len(code_file) > 0:
if len(content) > 0: add_file_as_code(question_parts, code_file)
question_parts.append(content)
if code is not None and len(code) > 0:
with open(code) as r:
content = r.read().strip()
if len(content) == 0:
continue
# try to extract and add source code
code_parts = source_code(content, include_delims=True)
if len(code_parts) > 0:
question_parts += code_parts
# if there's none, add the whole file
else:
question_parts.append(f"```\n{content}\n```")
full_question = '\n\n'.join(question_parts) full_question = '\n\n'.join([str(s) for s in question_parts])
message = Message(question=Question(full_question), message = Message(question=Question(full_question),
tags=args.output_tags, # FIXME tags=args.output_tags,
ai=args.AI, ai=args.AI,
model=args.model) model=args.model)
chat.add_to_cache([message]) # only write the new message to the cache,
# don't add it to the internal list
chat.cache_write([message])
return message return message
def make_request(ai: AI, chat: ChatDB, message: Message, args: argparse.Namespace) -> None:
"""
Make an AI request with the given AI, chat history, message and arguments.
Write the response(s) to the cache directory, without appending it to the
given chat history. Then print the response(s).
"""
# print history and message question before making the request
ai.print()
chat.print(paged=False)
print(message.to_str())
response: AIResponse = ai.request(message,
chat,
args.num_answers,
args.output_tags)
# only write the response messages to the cache,
# don't add them to the internal list
for idx, msg in enumerate(response.messages):
print(f"=== ANSWER {idx+1} ===", flush=True)
if msg.answer:
for piece in msg.answer:
print(piece, end='', flush=True)
print()
if response.tokens:
print("===============")
print(response.tokens)
chat.cache_write(response.messages)
def repeat_messages(messages: list[Message], chat: ChatDB, args: argparse.Namespace, config: Config) -> None:
"""
Repeat the given messages using the given arguments.
"""
ai: AI
for msg in messages:
msg_args = create_msg_args(msg, args)
ai = create_ai(msg_args, config)
print(f"--------- Repeating message '{msg.msg_id()}': ---------")
# overwrite the latest message if requested or empty
# -> but not if it's in the DB!
if ((msg.answer is None or msg_args.overwrite is True)
and (not chat.msg_in_db(msg))): # noqa: W503
msg.clear_answer()
make_request(ai, chat, msg, msg_args)
# otherwise create a new one
else:
msg_args.ask = [msg.question]
message = create_message(chat, msg_args)
make_request(ai, chat, message, msg_args)
def question_cmd(args: argparse.Namespace, config: Config) -> None: def question_cmd(args: argparse.Namespace, config: Config) -> None:
""" """
Handler for the 'question' command. Handler for the 'question' command.
""" """
mfilter = MessageFilter(tags_or=args.or_tags if args.or_tags is not None else set(), invert_input_tag_args(args)
tags_and=args.and_tags if args.and_tags is not None else set(), mfilter = MessageFilter(tags_or=args.or_tags,
tags_not=args.exclude_tags if args.exclude_tags is not None else set()) tags_and=args.and_tags,
chat = ChatDB.from_dir(cache_path=Path('.'), tags_not=args.exclude_tags)
chat = ChatDB.from_dir(cache_path=Path(config.cache),
db_path=Path(config.db), db_path=Path(config.db),
mfilter=mfilter) mfilter=mfilter,
glob=args.glob,
loc=msg_location(args.location))
# if it's a new question, create and store it immediately # if it's a new question, create and store it immediately
if args.ask or args.create: if args.ask or args.create:
message = create_message(chat, args) message = create_message(chat, args)
if args.create: if args.create:
return return
# create the correct AI instance # === ASK ===
ai: AI = create_ai(args, config)
if args.ask: if args.ask:
ai.print() ai: AI = create_ai(args, config)
chat.print(paged=False) make_request(ai, chat, message, args)
response: AIResponse = ai.request(message, # === REPEAT ===
chat,
args.num_answers, # FIXME
args.output_tags) # FIXME
chat.update_messages([response.messages[0]])
chat.add_to_cache(response.messages[1:])
for idx, msg in enumerate(response.messages):
print(f"=== ANSWER {idx+1} ===")
print(msg.answer)
if response.tokens:
print("===============")
print(response.tokens)
elif args.repeat is not None: elif args.repeat is not None:
lmessage = chat.latest_message() repeat_msgs: list[Message] = []
assert lmessage # repeat latest message
# TODO: repeat either the last question or the if len(args.repeat) == 0:
# one(s) given in 'args.repeat' (overwrite lmessage = chat.msg_latest(loc=msg_location.CACHE)
# existing ones if 'args.overwrite' is True) if lmessage is None:
pass print("No message found to repeat!")
sys.exit(1)
repeat_msgs.append(lmessage)
# repeat given message(s)
else:
repeat_msgs = chat.msg_find(args.repeat, loc=msg_location.DISK)
repeat_messages(repeat_msgs, chat, args, config)
# === PROCESS ===
elif args.process is not None: elif args.process is not None:
# TODO: process either all questions without an # TODO: process either all questions without an
# answer or the one(s) given in 'args.process' # answer or the one(s) given in 'args.process'
+2 -2
View File
@@ -8,10 +8,10 @@ def tags_cmd(args: argparse.Namespace, config: Config) -> None:
""" """
Handler for the 'tags' command. Handler for the 'tags' command.
""" """
chat = ChatDB.from_dir(cache_path=Path('.'), chat = ChatDB.from_dir(cache_path=Path(config.cache),
db_path=Path(config.db)) db_path=Path(config.db))
if args.list: if args.list:
tags_freq = chat.tags_frequency(args.prefix, args.contain) tags_freq = chat.msg_tags_frequency(args.prefix, args.contain)
for tag, freq in tags_freq.items(): for tag, freq in tags_freq.items():
print(f"- {tag}: {freq}") print(f"- {tag}: {freq}")
# TODO: add renaming # TODO: add renaming
+105
View File
@@ -0,0 +1,105 @@
import argparse
import mimetypes
from pathlib import Path
from .common import invert_input_tag_args, read_text_file
from ..configuration import Config
from ..message import MessageFilter, Message, Question
from ..chat import ChatDB, msg_location
class TranslationCmdError(Exception):
pass
text_separator: str = 'TEXT:'
def assert_document_type_supported_openai(document_file: Path) -> None:
doctype = mimetypes.guess_type(document_file)
if doctype != 'text/plain':
raise TranslationCmdError("AI 'OpenAI' only supports document type 'text/plain''")
def translation_prompt_openai(source_lang: str, target_lang: str) -> str:
"""
Return the prompt for GPT that tells it to do the translation.
"""
return f"Translate the text below the line {text_separator} from {source_lang} to {target_lang}."
def create_message_openai(chat: ChatDB, args: argparse.Namespace) -> Message:
"""
Create a new message from the given arguments and write it to the cache directory.
Message format
1. Translation prompt (tells GPT to do a translation)
2. Glossary (if specified as an argument)
3. User provided prompt enhancements
4. Translation separator
5. User provided text to be translated
The text to be translated is determined as a follows:
- if a document is provided in the arguments, translate its content
- if no document is provided, translate the last text argument
The other text arguments will be put into the "header" and can be used
to improve the translation prompt.
"""
text_args: list[str] = []
if args.create is not None:
text_args = args.create
elif args.ask is not None:
text_args = args.ask
else:
raise TranslationCmdError("No input text found")
# extract user prompt and user text to be translated
user_text: str
user_prompt: str
if args.input_document is not None:
assert_document_type_supported_openai(Path(args.input_document))
user_text = read_text_file(Path(args.input_document))
user_prompt = '\n\n'.join([str(s) for s in text_args])
else:
user_text = text_args[-1]
user_prompt = '\n\n'.join([str(s) for s in text_args[:-1]])
# build full question string
# FIXME: add glossaries if given
question_text: str = '\n\n'.join([translation_prompt_openai(args.source_lang, args.target_lang),
user_prompt,
text_separator,
user_text])
# create and write the message
message = Message(question=Question(question_text),
tags=args.output_tags,
ai=args.AI,
model=args.model)
# only write the new message to the cache,
# don't add it to the internal list
chat.cache_write([message])
return message
def translation_cmd(args: argparse.Namespace, config: Config) -> None:
"""
Handler for the 'translation' command. Creates and executes translation
requests based on the input and selected AI. Depending on the AI, the
whole process may be significantly different (e.g. DeepL vs OpenAI).
"""
invert_input_tag_args(args)
mfilter = MessageFilter(tags_or=args.or_tags,
tags_and=args.and_tags,
tags_not=args.exclude_tags)
chat = ChatDB.from_dir(cache_path=Path(config.cache),
db_path=Path(config.db),
mfilter=mfilter,
glob=args.glob,
loc=msg_location(args.location))
# if it's a new translation, create and store it immediately
# FIXME: check AI type
if args.ask or args.create:
# message = create_message(chat, args)
create_message_openai(chat, args)
if args.create:
return
+9 -2
View File
@@ -9,7 +9,7 @@ OpenAIConfigInst = TypeVar('OpenAIConfigInst', bound='OpenAIConfig')
supported_ais: list[str] = ['openai'] supported_ais: list[str] = ['openai']
default_config_path = '.config.yaml' default_config_file = '.config.yaml'
class ConfigError(Exception): class ConfigError(Exception):
@@ -39,6 +39,7 @@ class AIConfig:
name: ClassVar[str] name: ClassVar[str]
# a user-defined ID for an AI configuration entry # a user-defined ID for an AI configuration entry
ID: str ID: str
model: str = 'n/a'
# the name must not be changed # the name must not be changed
def __setattr__(self, name: str, value: Any) -> None: def __setattr__(self, name: str, value: Any) -> None:
@@ -115,7 +116,9 @@ class Config:
""" """
# all members have default values, so we can easily create # all members have default values, so we can easily create
# a default configuration # a default configuration
cache: str = '.'
db: str = './db/' db: str = './db/'
glossaries: str | None = './glossaries/'
ais: dict[str, AIConfig] = field(default_factory=create_default_ai_configs) ais: dict[str, AIConfig] = field(default_factory=create_default_ai_configs)
@classmethod @classmethod
@@ -131,8 +134,10 @@ class Config:
ai_conf = ai_config_instance(conf['name'], conf) ai_conf = ai_config_instance(conf['name'], conf)
ais[ID] = ai_conf ais[ID] = ai_conf
return cls( return cls(
cache=str(source['cache']) if 'cache' in source else '.',
db=str(source['db']), db=str(source['db']),
ais=ais ais=ais,
glossaries=str(source['glossaries']) if 'glossaries' in source else None
) )
@classmethod @classmethod
@@ -145,6 +150,8 @@ class Config:
@classmethod @classmethod
def from_file(cls: Type[ConfigInst], path: str) -> ConfigInst: def from_file(cls: Type[ConfigInst], path: str) -> ConfigInst:
if not Path(path).exists():
raise ConfigError(f"Configuration file '{path}' not found. Use 'cmm config --create' to create one.")
with open(path, 'r') as f: with open(path, 'r') as f:
source = yaml.load(f, Loader=yaml.FullLoader) source = yaml.load(f, Loader=yaml.FullLoader)
return cls.from_dict(source) return cls.from_dict(source)
+165
View File
@@ -0,0 +1,165 @@
"""
Module implementing glossaries for translations.
"""
import yaml
import tempfile
import shutil
import csv
from pathlib import Path
from dataclasses import dataclass, field
from typing import Type, TypeVar, ClassVar
GlossaryInst = TypeVar('GlossaryInst', bound='Glossary')
class GlossaryError(Exception):
pass
def str_presenter(dumper: yaml.Dumper, data: str) -> yaml.ScalarNode:
"""
Changes the YAML dump style to multiline syntax for multiline strings.
"""
if len(data.splitlines()) > 1:
return dumper.represent_scalar('tag:yaml.org,2002:str', data, style='|')
return dumper.represent_scalar('tag:yaml.org,2002:str', data)
@dataclass
class Glossary:
"""
A glossary consists of the following parameters:
- Name (freely selectable)
- Path (full file path, suffix is automatically generated)
- Source language
- Target language
- Description (optional)
- Entries (pairs of source lang and target lang terms)
- ID (automatically generated / modified, required by DeepL)
"""
name: str
source_lang: str
target_lang: str
file_path: Path | None = None
desc: str | None = None
entries: dict[str, str] = field(default_factory=lambda: dict())
ID: str | None = None
file_suffix: ClassVar[str] = '.glo'
def __post_init__(self) -> None:
# FIXME: check for valid languages
pass
@classmethod
def from_file(cls: Type[GlossaryInst], file_path: Path) -> GlossaryInst:
"""
Create a glossary from the given file.
"""
if not file_path.exists():
raise GlossaryError(f"Glossary file '{file_path}' does not exist")
if file_path.suffix != cls.file_suffix:
raise GlossaryError(f"File type '{file_path.suffix}' is not supported")
with open(file_path, "r") as fd:
try:
# use BaseLoader so every entry is read as a string
# - disables automatic conversions
# - makes it possible to omit quoting for YAML keywords in entries (e. g. 'yes')
# - also correctly reads quoted entries
data = yaml.load(fd, Loader=yaml.BaseLoader)
clean_entries = data['Entries']
return cls(name=data['Name'],
source_lang=data['SourceLang'],
target_lang=data['TargetLang'],
file_path=file_path,
desc=data['Description'],
entries=clean_entries,
ID=data['ID'] if data['ID'] != 'None' else None)
except Exception:
raise GlossaryError(f"'{file_path}' does not contain a valid glossary")
def to_file(self, file_path: Path | None = None) -> None:
"""
Write glossary to given file.
"""
if file_path:
self.file_path = file_path
if not self.file_path:
raise GlossaryError("Got no valid path to write glossary")
# check / add valid suffix
if not self.file_path.suffix:
self.file_path = self.file_path.with_suffix(self.file_suffix)
elif self.file_path.suffix != self.file_suffix:
raise GlossaryError(f"File suffix '{self.file_path.suffix}' is not supported")
# write YAML
with tempfile.NamedTemporaryFile(dir=self.file_path.parent, prefix=self.file_path.name, mode="w", delete=False) as temp_fd:
temp_file_path = Path(temp_fd.name)
data = {'Name': self.name,
'Description': str(self.desc),
'ID': str(self.ID),
'SourceLang': self.source_lang,
'TargetLang': self.target_lang,
'Entries': self.entries}
yaml.dump(data, temp_fd, sort_keys=False)
shutil.move(temp_file_path, self.file_path)
def export_csv(self, dictionary: dict[str, str], file_path: Path) -> None:
"""
Export the 'entries' of this glossary to a file in CSV format (compatible with DeepL).
"""
with open(file_path, 'w', newline='', encoding='utf-8') as csvfile:
writer = csv.writer(csvfile, delimiter=',', quotechar='"', quoting=csv.QUOTE_ALL)
for source_entry, target_entry in self.entries.items():
writer.writerow([source_entry, target_entry])
def export_tsv(self, entries: dict[str, str], file_path: Path) -> None:
"""
Export the 'entries' of this glossary to a file in TSV format (compatible with DeepL).
"""
with open(file_path, 'w', encoding='utf-8') as file:
for source_entry, target_entry in self.entries.items():
file.write(f"{source_entry}\t{target_entry}\n")
def import_csv(self, file_path: Path) -> None:
"""
Import the entries from the given CSV file to those of the current glossary.
Existing entries are overwritten.
"""
try:
with open(file_path, mode='r', encoding='utf-8') as csvfile:
reader = csv.reader(csvfile, delimiter=',', quotechar='"')
self.entries = {rows[0]: rows[1] for rows in reader if len(rows) >= 2}
except Exception as e:
raise GlossaryError(f"Error importing CSV: {e}")
def import_tsv(self, file_path: Path) -> None:
"""
Import the entries from the given CSV file to those of the current glossary.
Existing entries are overwritten.
"""
try:
with open(file_path, mode='r', encoding='utf-8') as tsvfile:
self.entries = {}
for line in tsvfile:
parts = line.strip().split('\t')
if len(parts) == 2:
self.entries[parts[0]] = parts[1]
except Exception as e:
raise GlossaryError(f"Error importing TSV: {e}")
def to_str(self, with_entries: bool = False) -> str:
"""
Return the current glossary as a string.
"""
output: list[str] = []
output.append(f'{self.name} (ID: {self.ID}):')
if self.desc and self.desc != 'None':
output.append('- ' + self.desc)
output.append(f'- Languages: {self.source_lang} -> {self.target_lang}')
if with_entries:
output.append('- Entries:')
for source, target in self.entries.items():
output.append(f' {source} : {target}')
else:
output.append(f'- Entries: {len(self.entries)}')
return '\n'.join(output)
+126 -30
View File
@@ -3,17 +3,21 @@
# vim: set fileencoding=utf-8 : # vim: set fileencoding=utf-8 :
import sys import sys
import os
import argcomplete import argcomplete
import argparse import argparse
from pathlib import Path from pathlib import Path
from typing import Any from typing import Any
from .configuration import Config, default_config_path from .configuration import Config, default_config_file, ConfigError
from .message import Message from .message import Message
from .commands.question import question_cmd from .commands.question import question_cmd
from .commands.tags import tags_cmd from .commands.tags import tags_cmd
from .commands.config import config_cmd from .commands.config import config_cmd
from .commands.hist import hist_cmd from .commands.hist import hist_cmd
from .commands.print import print_cmd from .commands.print import print_cmd
from .commands.translation import translation_cmd
from .commands.glossary import glossary_cmd
from .chat import msg_location
def tags_completer(prefix: str, parsed_args: Any, **kwargs: Any) -> list[str]: def tags_completer(prefix: str, parsed_args: Any, **kwargs: Any) -> list[str]:
@@ -24,7 +28,7 @@ def tags_completer(prefix: str, parsed_args: Any, **kwargs: Any) -> list[str]:
def create_parser() -> argparse.ArgumentParser: def create_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
description="ChatMastermind is a Python application that automates conversation with AI") description="ChatMastermind is a Python application that automates conversation with AI")
parser.add_argument('-C', '--config', help='Config file name.', default=default_config_path) parser.add_argument('-C', '--config', help='Config file name.', default=default_config_file)
# subcommand-parser # subcommand-parser
cmdparser = parser.add_subparsers(dest='command', cmdparser = parser.add_subparsers(dest='command',
@@ -34,24 +38,24 @@ def create_parser() -> argparse.ArgumentParser:
# a parent parser for all commands that support tag selection # a parent parser for all commands that support tag selection
tag_parser = argparse.ArgumentParser(add_help=False) tag_parser = argparse.ArgumentParser(add_help=False)
tag_arg = tag_parser.add_argument('-t', '--or-tags', nargs='+', tag_arg = tag_parser.add_argument('-t', '--or-tags', nargs='*',
help='List of tags (one must match)', metavar='OTAGS') help='List of tags (one must match)', metavar='OTAGS')
tag_arg.completer = tags_completer # type: ignore tag_arg.completer = tags_completer # type: ignore
atag_arg = tag_parser.add_argument('-k', '--and-tags', nargs='+', atag_arg = tag_parser.add_argument('-k', '--and-tags', nargs='*',
help='List of tags (all must match)', metavar='ATAGS') help='List of tags (all must match)', metavar='ATAGS')
atag_arg.completer = tags_completer # type: ignore atag_arg.completer = tags_completer # type: ignore
etag_arg = tag_parser.add_argument('-x', '--exclude-tags', nargs='+', etag_arg = tag_parser.add_argument('-x', '--exclude-tags', nargs='*',
help='List of tags to exclude', metavar='XTAGS') help='List of tags to exclude', metavar='XTAGS')
etag_arg.completer = tags_completer # type: ignore etag_arg.completer = tags_completer # type: ignore
otag_arg = tag_parser.add_argument('-o', '--output-tags', nargs='+', otag_arg = tag_parser.add_argument('-o', '--output-tags', nargs='+',
help='List of output tags (default: use input tags)', metavar='OUTTAGS') help='List of output tags (default: use input tags)', metavar='OUTAGS')
otag_arg.completer = tags_completer # type: ignore otag_arg.completer = tags_completer # type: ignore
# a parent parser for all commands that support AI configuration # a parent parser for all commands that support AI configuration
ai_parser = argparse.ArgumentParser(add_help=False) ai_parser = argparse.ArgumentParser(add_help=False)
ai_parser.add_argument('-A', '--AI', help='AI ID to use') ai_parser.add_argument('-A', '--AI', help='AI ID to use', metavar='AI_ID')
ai_parser.add_argument('-M', '--model', help='Model to use') ai_parser.add_argument('-M', '--model', help='Model to use', metavar='MODEL')
ai_parser.add_argument('-n', '--num-answers', help='Number of answers to request', type=int, default=1) ai_parser.add_argument('-N', '--num-answers', help='Number of answers to request', type=int, default=1)
ai_parser.add_argument('-m', '--max-tokens', help='Max. nr. of tokens', type=int) ai_parser.add_argument('-m', '--max-tokens', help='Max. nr. of tokens', type=int)
ai_parser.add_argument('-T', '--temperature', help='Temperature value', type=float) ai_parser.add_argument('-T', '--temperature', help='Temperature value', type=float)
@@ -61,67 +65,154 @@ def create_parser() -> argparse.ArgumentParser:
aliases=['q']) aliases=['q'])
question_cmd_parser.set_defaults(func=question_cmd) question_cmd_parser.set_defaults(func=question_cmd)
question_group = question_cmd_parser.add_mutually_exclusive_group(required=True) question_group = question_cmd_parser.add_mutually_exclusive_group(required=True)
question_group.add_argument('-a', '--ask', nargs='+', help='Ask a question') question_group.add_argument('-a', '--ask', nargs='+', help='Ask a question', metavar='QUESTION')
question_group.add_argument('-c', '--create', nargs='+', help='Create a question') question_group.add_argument('-c', '--create', nargs='+', help='Create a question', metavar='QUESTION')
question_group.add_argument('-r', '--repeat', nargs='*', help='Repeat a question') question_group.add_argument('-r', '--repeat', nargs='*', help='Repeat a question', metavar='MESSAGE')
question_group.add_argument('-p', '--process', nargs='*', help='Process existing questions') question_group.add_argument('-p', '--process', nargs='*', help='Process existing questions', metavar='MESSAGE')
question_cmd_parser.add_argument('-l', '--location',
choices=[x.value for x in msg_location if x not in [msg_location.MEM, msg_location.DISK]],
default='db',
help='Use given location when building the chat history (default: \'db\')')
question_cmd_parser.add_argument('-g', '--glob', help='Filter message files using the given glob pattern')
question_cmd_parser.add_argument('-O', '--overwrite', help='Overwrite existing messages when repeating them', question_cmd_parser.add_argument('-O', '--overwrite', help='Overwrite existing messages when repeating them',
action='store_true') action='store_true')
question_cmd_parser.add_argument('-s', '--source-text', nargs='+', help='Add content of a file to the query') question_cmd_parser.add_argument('-s', '--source-text', nargs='+', help='Add content of a file to the query', metavar='FILE')
question_cmd_parser.add_argument('-S', '--source-code', nargs='+', help='Add source code file content to the chat history') question_cmd_parser.add_argument('-S', '--source-code', nargs='+', help='Add source code file content to the chat history',
metavar='FILE')
# 'hist' command parser # 'hist' command parser
hist_cmd_parser = cmdparser.add_parser('hist', parents=[tag_parser], hist_cmd_parser = cmdparser.add_parser('hist', parents=[tag_parser],
help="Print chat history.", help="Print and manage chat history.",
aliases=['h']) aliases=['h'])
hist_cmd_parser.set_defaults(func=hist_cmd) hist_cmd_parser.set_defaults(func=hist_cmd)
hist_cmd_parser.add_argument('-w', '--with-tags', help="Print chat history with tags.", hist_group = hist_cmd_parser.add_mutually_exclusive_group(required=True)
hist_group.add_argument('-p', '--print', help='Print the DB chat history', action='store_true')
hist_group.add_argument('-c', '--convert', help='Convert all message files to the given format [txt|yaml]', metavar='FORMAT')
hist_cmd_parser.add_argument('-w', '--with-metadata', help="Print chat history with metadata (tags, filename, AI, etc.).",
action='store_true') action='store_true')
hist_cmd_parser.add_argument('-W', '--with-files', help="Print chat history with filenames.", hist_cmd_parser.add_argument('-S', '--source-code-only', help='Only print embedded source code',
action='store_true') action='store_true')
hist_cmd_parser.add_argument('-S', '--source-code-only', help='Print only source code', hist_cmd_parser.add_argument('-A', '--answer', help='Print only answers with given substring', metavar='SUBSTRING')
action='store_true') hist_cmd_parser.add_argument('-Q', '--question', help='Print only questions with given substring', metavar='SUBSTRING')
hist_cmd_parser.add_argument('-A', '--answer', help='Search for answer substring') hist_cmd_parser.add_argument('-d', '--tight', help='Print without message separators', action='store_true')
hist_cmd_parser.add_argument('-Q', '--question', help='Search for question substring') hist_cmd_parser.add_argument('-P', '--no-paging', help='Print without paging', action='store_true')
hist_cmd_parser.add_argument('-l', '--location',
choices=[x.value for x in msg_location if x not in [msg_location.MEM, msg_location.DISK]],
default='db',
help='Use given location when building the chat history (default: \'db\')')
hist_cmd_parser.add_argument('-g', '--glob', help='Filter message files using the given glob pattern')
# 'tags' command parser # 'tags' command parser
tags_cmd_parser = cmdparser.add_parser('tags', tags_cmd_parser = cmdparser.add_parser('tags',
help="Manage tags.", help="Manage tags.",
aliases=['t']) aliases=['T'])
tags_cmd_parser.set_defaults(func=tags_cmd) tags_cmd_parser.set_defaults(func=tags_cmd)
tags_group = tags_cmd_parser.add_mutually_exclusive_group(required=True) tags_group = tags_cmd_parser.add_mutually_exclusive_group(required=True)
tags_group.add_argument('-l', '--list', help="List all tags and their frequency", tags_group.add_argument('-l', '--list', help="List all tags and their frequency",
action='store_true') action='store_true')
tags_cmd_parser.add_argument('-p', '--prefix', help="Filter tags by prefix") tags_cmd_parser.add_argument('-p', '--prefix', help="Filter tags by prefix", metavar='PREFIX')
tags_cmd_parser.add_argument('-c', '--contain', help="Filter tags by contained substring") tags_cmd_parser.add_argument('-c', '--contain', help="Filter tags by contained substring", metavar='SUBSTRING')
# 'config' command parser # 'config' command parser
config_cmd_parser = cmdparser.add_parser('config', config_cmd_parser = cmdparser.add_parser('config',
help="Manage configuration", help="Manage configuration",
aliases=['c']) aliases=['c'])
config_cmd_parser.set_defaults(func=config_cmd) config_cmd_parser.set_defaults(func=config_cmd)
config_cmd_parser.add_argument('-A', '--AI', help='AI ID to use')
config_group = config_cmd_parser.add_mutually_exclusive_group(required=True) config_group = config_cmd_parser.add_mutually_exclusive_group(required=True)
config_group.add_argument('-l', '--list-models', help="List all available models", config_group.add_argument('-l', '--list-models', help="List all available models",
action='store_true') action='store_true')
config_group.add_argument('-m', '--print-model', help="Print the currently configured model", config_group.add_argument('-m', '--print-model', help="Print the currently configured model",
action='store_true') action='store_true')
config_group.add_argument('-c', '--create', help="Create config with default settings in the given file") config_group.add_argument('-c', '--create', help="Create config with default settings in the given file", metavar='FILE')
# 'print' command parser # 'print' command parser
print_cmd_parser = cmdparser.add_parser('print', print_cmd_parser = cmdparser.add_parser('print',
help="Print message files.", help="Print message files.",
aliases=['p']) aliases=['p'])
print_cmd_parser.set_defaults(func=print_cmd) print_cmd_parser.set_defaults(func=print_cmd)
print_cmd_parser.add_argument('-f', '--file', help='File to print', required=True) print_group = print_cmd_parser.add_mutually_exclusive_group(required=True)
print_group.add_argument('-f', '--file', help='Print given message file', metavar='FILE')
print_group.add_argument('-l', '--latest', help='Print latest message', action='store_true')
print_cmd_modes = print_cmd_parser.add_mutually_exclusive_group() print_cmd_modes = print_cmd_parser.add_mutually_exclusive_group()
print_cmd_modes.add_argument('-q', '--question', help='Print only question', action='store_true') print_cmd_modes.add_argument('-q', '--question', help='Only print the question', action='store_true')
print_cmd_modes.add_argument('-a', '--answer', help='Print only answer', action='store_true') print_cmd_modes.add_argument('-a', '--answer', help='Only print the answer', action='store_true')
print_cmd_modes.add_argument('-S', '--only-source-code', help='Print only source code', action='store_true') print_cmd_modes.add_argument('-S', '--only-source-code', help='Only print embedded source code', action='store_true')
# 'translation' command parser
translation_cmd_parser = cmdparser.add_parser('translation', parents=[ai_parser, tag_parser],
help="Ask, create and repeat translations.",
aliases=['t'])
translation_cmd_parser.set_defaults(func=translation_cmd)
translation_group = translation_cmd_parser.add_mutually_exclusive_group(required=True)
translation_group.add_argument('-a', '--ask', nargs='+', help='Ask to translate the given text', metavar='TEXT')
translation_group.add_argument('-c', '--create', nargs='+', help='Create a translation', metavar='TEXT')
translation_group.add_argument('-r', '--repeat', nargs='*', help='Repeat a translation', metavar='MESSAGE')
translation_cmd_parser.add_argument('-l', '--source-lang', help="Source language", metavar="LANGUAGE", required=True)
translation_cmd_parser.add_argument('-L', '--target-lang', help="Target language", metavar="LANGUAGE", required=True)
translation_cmd_parser.add_argument('-G', '--glossaries', nargs='+', help="List of glossary names", metavar="GLOSSARY")
translation_cmd_parser.add_argument('-d', '--input-document', help="Document to translate", metavar="FILE")
translation_cmd_parser.add_argument('-D', '--output-document', help="Path for the translated document", metavar="FILE")
# 'glossary' command parser
glossary_cmd_parser = cmdparser.add_parser('glossary', parents=[ai_parser],
help="Manage glossaries.",
aliases=['g'])
glossary_cmd_parser.set_defaults(func=glossary_cmd)
glossary_group = glossary_cmd_parser.add_mutually_exclusive_group(required=True)
glossary_group.add_argument('-c', '--create', help='Create a glossary', action='store_true')
glossary_cmd_parser.add_argument('-n', '--name', help="Glossary name (not ID)", metavar="NAME")
glossary_cmd_parser.add_argument('-l', '--source-lang', help="Source language", metavar="LANGUAGE")
glossary_cmd_parser.add_argument('-L', '--target-lang', help="Target language", metavar="LANGUAGE")
glossary_cmd_parser.add_argument('-f', '--file', help='File path of the goven glossary', metavar='GLOSSARY_FILE')
glossary_cmd_parser.add_argument('-D', '--description', help="Glossary description", metavar="DESCRIPTION")
glossary_group.add_argument('-i', '--list', help='List existing glossaries', action='store_true')
glossary_group.add_argument('-p', '--print', help='Print an existing glossary', action='store_true')
argcomplete.autocomplete(parser) argcomplete.autocomplete(parser)
return parser return parser
def create_directories(config: Config) -> None: # noqa: 11
"""
Create the directories in the given configuration if they don't exist.
"""
def make_dir(path: Path) -> None:
try:
os.makedirs(path.absolute())
except Exception as e:
print(f"Creating directory '{path.absolute()}' failed with: {e}")
sys.exit(1)
# Cache
cache_path = Path(config.cache)
if not cache_path.exists():
answer = input(f"Cache directory '{cache_path}' does not exist. Create it? [y/n]")
if answer.lower() in ['y', 'yes']:
make_dir(cache_path.absolute())
else:
print("Can't continue without a valid cache directory!")
sys.exit(1)
# DB
db_path = Path(config.db)
if not db_path.exists():
answer = input(f"DB directory '{db_path}' does not exist. Create it? [y/n]")
if answer.lower() in ['y', 'yes']:
make_dir(db_path.absolute())
else:
print("Can't continue without a valid DB directory!")
sys.exit(1)
# Glossaries
if config.glossaries:
glossaries_path = Path(config.glossaries)
if not glossaries_path.exists():
answer = input(f"Glossaries directory '{glossaries_path}' does not exist. Create it? [y/n]")
if answer.lower() in ['y', 'yes']:
make_dir(glossaries_path.absolute())
else:
print("Can't continue without a valid glossaries directory. Create it or remove it from the configuration.")
sys.exit(1)
def main() -> int: def main() -> int:
parser = create_parser() parser = create_parser()
args = parser.parse_args() args = parser.parse_args()
@@ -130,7 +221,12 @@ def main() -> int:
if command.func == config_cmd: if command.func == config_cmd:
command.func(command) command.func(command)
else: else:
try:
config = Config.from_file(args.config) config = Config.from_file(args.config)
except ConfigError as err:
print(f"{err}")
return 1
create_directories(config)
command.func(command, config) command.func(command, config)
return 0 return 0
+172 -55
View File
@@ -5,7 +5,10 @@ import pathlib
import yaml import yaml
import tempfile import tempfile
import shutil import shutil
from typing import Type, TypeVar, ClassVar, Optional, Any, Union, Final, Literal, Iterable import io
from typing import Type, TypeVar, ClassVar, Optional, Any, Union, Final, Literal, Iterable, Tuple
from typing import Generator, Iterator
from typing import get_args as typing_get_args
from dataclasses import dataclass, asdict, field from dataclasses import dataclass, asdict, field
from .tags import Tag, TagLine, TagError, match_tags, rename_tags from .tags import Tag, TagLine, TagError, match_tags, rename_tags
@@ -15,6 +18,9 @@ MessageInst = TypeVar('MessageInst', bound='Message')
AILineInst = TypeVar('AILineInst', bound='AILine') AILineInst = TypeVar('AILineInst', bound='AILine')
ModelLineInst = TypeVar('ModelLineInst', bound='ModelLine') ModelLineInst = TypeVar('ModelLineInst', bound='ModelLine')
YamlDict = dict[str, Union[QuestionInst, AnswerInst, set[Tag]]] YamlDict = dict[str, Union[QuestionInst, AnswerInst, set[Tag]]]
MessageFormat = Literal['txt', 'yaml']
message_valid_formats: Final[Tuple[MessageFormat, ...]] = typing_get_args(MessageFormat)
message_default_format: Final[MessageFormat] = 'txt'
class MessageError(Exception): class MessageError(Exception):
@@ -45,7 +51,7 @@ def source_code(text: str, include_delims: bool = False) -> list[str]:
code_lines: list[str] = [] code_lines: list[str] = []
in_code_block = False in_code_block = False
for line in text.split('\n'): for line in str(text).split('\n'):
if line.strip().startswith('```'): if line.strip().startswith('```'):
if include_delims: if include_delims:
code_lines.append(line) code_lines.append(line)
@@ -92,7 +98,7 @@ class MessageFilter:
class AILine(str): class AILine(str):
""" """
A line that represents the AI name in a '.txt' file.. A line that represents the AI name in the 'txt' format.
""" """
prefix: Final[str] = 'AI:' prefix: Final[str] = 'AI:'
@@ -112,7 +118,7 @@ class AILine(str):
class ModelLine(str): class ModelLine(str):
""" """
A line that represents the model name in a '.txt' file.. A line that represents the model name in the 'txt' format.
""" """
prefix: Final[str] = 'MODEL:' prefix: Final[str] = 'MODEL:'
@@ -138,30 +144,100 @@ class Answer(str):
txt_header: ClassVar[str] = '==== ANSWER ====' txt_header: ClassVar[str] = '==== ANSWER ===='
yaml_key: ClassVar[str] = 'answer' yaml_key: ClassVar[str] = 'answer'
def __new__(cls: Type[AnswerInst], string: str) -> AnswerInst: def __init__(self, data: Union[str, Generator[str, None, None]]) -> None:
# Indicator of whether all of data has been processed
self.is_exhausted: bool = False
# Initialize data
self.iterator: Iterator[str] = self._init_data(data)
# Set up the buffer to hold the 'Answer' content
self.buffer: io.StringIO = io.StringIO()
def _init_data(self, data: Union[str, Generator[str, None, None]]) -> Iterator[str]:
""" """
Make sure the answer string does not contain the header as a whole line. Process input data (either a string or a string generator)
""" """
if cls.txt_header in string.split('\n'): if isinstance(data, str):
raise MessageError(f"Answer '{string}' contains the header '{cls.txt_header}'") yield data
instance = super().__new__(cls, string) else:
return instance yield from data
def __str__(self) -> str:
"""
Output all content when converted into a string
"""
# Ensure all data has been processed
for _ in self:
pass
# Return the 'Answer' content
return self.buffer.getvalue()
def __repr__(self) -> str:
return repr(str(self))
def __iter__(self) -> Generator[str, None, None]:
"""
Allows the object to be iterable
"""
# Generate content if not all data has been processed
if not self.is_exhausted:
yield from self.generator_iter()
else:
yield self.buffer.getvalue()
def generator_iter(self) -> Generator[str, None, None]:
"""
Main generator method to process data
"""
for piece in self.iterator:
# Write to buffer and yield piece for the iterator
self.buffer.write(piece)
yield piece
self.is_exhausted = True # Set the flag that all data has been processed
# If the header occurs in the 'Answer' content, raise an error
if f'\n{self.txt_header}' in self.buffer.getvalue() or self.buffer.getvalue().startswith(self.txt_header):
raise MessageError(f"Answer {repr(self.buffer.getvalue())} contains the header {repr(Answer.txt_header)}")
def __eq__(self, other: object) -> bool:
"""
Comparing the object to a string or another object
"""
if isinstance(other, str):
return str(self) == other # Compare the string value of this object to the other string
# Default behavior for comparing non-string objects
return super().__eq__(other)
def __hash__(self) -> int:
"""
Generate a hash for the object based on its string representation.
"""
return hash(str(self))
def __format__(self, format_spec: str) -> str:
"""
Return a formatted version of the string as per the format specification.
"""
return str(self).__format__(format_spec)
@classmethod @classmethod
def from_list(cls: Type[AnswerInst], strings: list[str]) -> AnswerInst: def from_list(cls: Type[AnswerInst], strings: list[str]) -> AnswerInst:
""" """
Build Question from a list of strings. Make sure strings do not contain the header. Build Answer from a list of strings. Make sure strings do not contain the header.
""" """
if cls.txt_header in strings: def _gen() -> Generator[str, None, None]:
raise MessageError(f"Question contains the header '{cls.txt_header}'") if len(strings) > 0:
instance = super().__new__(cls, '\n'.join(strings).strip()) yield strings[0]
return instance for s in strings[1:]:
yield '\n'
yield s
return cls(_gen())
def source_code(self, include_delims: bool = False) -> list[str]: def source_code(self, include_delims: bool = False) -> list[str]:
""" """
Extract and return all source code sections. Extract and return all source code sections.
""" """
return source_code(self, include_delims) return source_code(str(self), include_delims)
class Question(str): class Question(str):
@@ -216,18 +292,44 @@ class Message():
model: Optional[str] = field(default=None, compare=False) model: Optional[str] = field(default=None, compare=False)
file_path: Optional[pathlib.Path] = field(default=None, compare=False) file_path: Optional[pathlib.Path] = field(default=None, compare=False)
# class variables # class variables
file_suffixes: ClassVar[list[str]] = ['.txt', '.yaml'] file_suffixes_read: ClassVar[list[str]] = ['.msg', '.txt', '.yaml']
file_suffix_write: ClassVar[str] = '.msg'
default_format: ClassVar[MessageFormat] = message_default_format
tags_yaml_key: ClassVar[str] = 'tags' tags_yaml_key: ClassVar[str] = 'tags'
file_yaml_key: ClassVar[str] = 'file_path' file_yaml_key: ClassVar[str] = 'file_path'
ai_yaml_key: ClassVar[str] = 'ai' ai_yaml_key: ClassVar[str] = 'ai'
model_yaml_key: ClassVar[str] = 'model' model_yaml_key: ClassVar[str] = 'model'
def __post_init__(self) -> None:
# convert some types that are often set wrong
if self.tags is not None and not isinstance(self.tags, set):
self.tags = set(self.tags)
if self.file_path is not None and not isinstance(self.file_path, pathlib.Path):
self.file_path = pathlib.Path(self.file_path)
def __hash__(self) -> int: def __hash__(self) -> int:
""" """
The hash value is computed based on immutable members. The hash value is computed based on immutable members.
""" """
return hash((self.question, self.answer)) return hash((self.question, self.answer))
def equals(self, other: MessageInst, tags: bool = True, ai: bool = True,
model: bool = True, file_path: bool = True, verbose: bool = False) -> bool:
"""
Compare this message with another one, including the metadata.
Return True if everything is identical, False otherwise.
"""
equal: bool = ((not tags or (self.tags == other.tags))
and (not ai or (self.ai == other.ai)) # noqa: W503
and (not model or (self.model == other.model)) # noqa: W503
and (not file_path or (self.file_path == other.file_path)) # noqa: W503
and (self == other)) # noqa: W503
if not equal and verbose:
print("Messages not equal:")
print(self)
print(other)
return equal
@classmethod @classmethod
def from_dict(cls: Type[MessageInst], data: dict[str, Any]) -> MessageInst: def from_dict(cls: Type[MessageInst], data: dict[str, Any]) -> MessageInst:
""" """
@@ -252,16 +354,8 @@ class Message():
tags: set[Tag] = set() tags: set[Tag] = set()
if not file_path.exists(): if not file_path.exists():
raise MessageError(f"Message file '{file_path}' does not exist") raise MessageError(f"Message file '{file_path}' does not exist")
if file_path.suffix not in cls.file_suffixes: if file_path.suffix not in cls.file_suffixes_read:
raise MessageError(f"File type '{file_path.suffix}' is not supported") raise MessageError(f"File type '{file_path.suffix}' is not supported")
# for TXT, it's enough to read the TagLine
if file_path.suffix == '.txt':
with open(file_path, "r") as fd:
try:
tags = TagLine(fd.readline()).tags(prefix, contain)
except TagError:
pass # message without tags
else: # '.yaml'
try: try:
message = cls.from_file(file_path) message = cls.from_file(file_path)
if message: if message:
@@ -304,15 +398,16 @@ class Message():
""" """
if not file_path.exists(): if not file_path.exists():
raise MessageError(f"Message file '{file_path}' does not exist") raise MessageError(f"Message file '{file_path}' does not exist")
if file_path.suffix not in cls.file_suffixes: if file_path.suffix not in cls.file_suffixes_read:
raise MessageError(f"File type '{file_path.suffix}' is not supported") raise MessageError(f"File type '{file_path.suffix}' is not supported")
# try TXT first
if file_path.suffix == '.txt': try:
message = cls.__from_file_txt(file_path, message = cls.__from_file_txt(file_path,
mfilter.tags_or if mfilter else None, mfilter.tags_or if mfilter else None,
mfilter.tags_and if mfilter else None, mfilter.tags_and if mfilter else None,
mfilter.tags_not if mfilter else None) mfilter.tags_not if mfilter else None)
else: # then YAML
except MessageError:
message = cls.__from_file_yaml(file_path) message = cls.__from_file_yaml(file_path)
if message and (mfilter is None or message.match(mfilter)): if message and (mfilter is None or message.match(mfilter)):
return message return message
@@ -349,10 +444,6 @@ class Message():
tags = TagLine(fd.readline()).tags() tags = TagLine(fd.readline()).tags()
except TagError: except TagError:
fd.seek(pos) fd.seek(pos)
if tags_or or tags_and or tags_not:
# match with an empty set if the file has no tags
if not match_tags(tags, tags_or, tags_and, tags_not):
return None
# AILine (Optional) # AILine (Optional)
try: try:
pos = fd.tell() pos = fd.tell()
@@ -370,13 +461,19 @@ class Message():
try: try:
question_idx = text.index(Question.txt_header) + 1 question_idx = text.index(Question.txt_header) + 1
except ValueError: except ValueError:
raise MessageError(f"Question header '{Question.txt_header}' not found in '{file_path}'") raise MessageError(f"'{file_path}' does not contain a valid message")
try: try:
answer_idx = text.index(Answer.txt_header) answer_idx = text.index(Answer.txt_header)
question = Question.from_list(text[question_idx:answer_idx]) question = Question.from_list(text[question_idx:answer_idx])
answer = Answer.from_list(text[answer_idx + 1:]) answer = Answer.from_list(text[answer_idx + 1:])
except ValueError: except ValueError:
question = Question.from_list(text[question_idx:]) question = Question.from_list(text[question_idx:])
# match tags AFTER reading the whole file
# -> make sure it's a valid 'txt' file format
if tags_or or tags_and or tags_not:
# match with an empty set if the file has no tags
if not match_tags(tags, tags_or, tags_and, tags_not):
return None
return cls(question, answer, tags, ai, model, file_path) return cls(question, answer, tags, ai, model, file_path)
@classmethod @classmethod
@@ -390,11 +487,14 @@ class Message():
* Message.model_yaml_key: str [Optional] * Message.model_yaml_key: str [Optional]
""" """
with open(file_path, "r") as fd: with open(file_path, "r") as fd:
try:
data = yaml.load(fd, Loader=yaml.FullLoader) data = yaml.load(fd, Loader=yaml.FullLoader)
data[cls.file_yaml_key] = file_path data[cls.file_yaml_key] = file_path
return cls.from_dict(data) return cls.from_dict(data)
except Exception:
raise MessageError(f"'{file_path}' does not contain a valid message")
def to_str(self, with_tags: bool = False, with_file: bool = False, source_code_only: bool = False) -> str: def to_str(self, with_metadata: bool = False, source_code_only: bool = False) -> str:
""" """
Return the current Message as a string. Return the current Message as a string.
""" """
@@ -404,35 +504,41 @@ class Message():
if self.answer: if self.answer:
output.extend(self.answer.source_code(include_delims=True)) output.extend(self.answer.source_code(include_delims=True))
return '\n'.join(output) if len(output) > 0 else '' return '\n'.join(output) if len(output) > 0 else ''
if with_tags: if with_metadata:
output.append(self.tags_str()) output.append(self.tags_str())
if with_file:
output.append('FILE: ' + str(self.file_path)) output.append('FILE: ' + str(self.file_path))
output.append('AI: ' + str(self.ai))
output.append('MODEL: ' + str(self.model))
output.append(Question.txt_header) output.append(Question.txt_header)
output.append(self.question) output.append(self.question)
if self.answer: if self.answer:
output.append(Answer.txt_header) output.append(Answer.txt_header)
output.append(self.answer) output.append(str(self.answer))
return '\n'.join(output) return '\n'.join(output)
def __str__(self) -> str: def to_file(self, file_path: Optional[pathlib.Path]=None, mformat: MessageFormat = message_default_format) -> None: # noqa: 11
return self.to_str(True, True, False)
def to_file(self, file_path: Optional[pathlib.Path]=None) -> None: # noqa: 11
""" """
Write a Message to the given file. Type is determined based on the suffix. Write a Message to the given file. Supported message file formats are 'txt' and 'yaml'.
Currently supported suffixes: ['.txt', '.yaml'] Suffix is always '.msg'.
""" """
if file_path: if file_path:
self.file_path = file_path self.file_path = file_path
if not self.file_path: if not self.file_path:
raise MessageError("Got no valid path to write message") raise MessageError("Got no valid path to write message")
if self.file_path.suffix not in self.file_suffixes: if mformat not in message_valid_formats:
raise MessageError(f"File type '{self.file_path.suffix}' is not supported") raise MessageError(f"File format '{mformat}' is not supported")
# check for valid suffix
# -> add one if it's empty
# -> refuse old or otherwise unsupported suffixes
if not self.file_path.suffix:
self.file_path = self.file_path.with_suffix(self.file_suffix_write)
elif self.file_path.suffix != self.file_suffix_write:
raise MessageError(f"File suffix '{self.file_path.suffix}' is not supported")
# TXT # TXT
if self.file_path.suffix == '.txt': if mformat == 'txt':
return self.__to_file_txt(self.file_path) return self.__to_file_txt(self.file_path)
elif self.file_path.suffix == '.yaml': # YAML
elif mformat == 'yaml':
return self.__to_file_yaml(self.file_path) return self.__to_file_yaml(self.file_path)
def __to_file_txt(self, file_path: pathlib.Path) -> None: def __to_file_txt(self, file_path: pathlib.Path) -> None:
@@ -444,8 +550,8 @@ class Message():
* Model [Optional] * Model [Optional]
* Question.txt_header * Question.txt_header
* Question * Question
* Answer.txt_header * Answer.txt_header [Optional]
* Answer * Answer [Optional]
""" """
with tempfile.NamedTemporaryFile(dir=file_path.parent, prefix=file_path.name, mode="w", delete=False) as temp_fd: with tempfile.NamedTemporaryFile(dir=file_path.parent, prefix=file_path.name, mode="w", delete=False) as temp_fd:
temp_file_path = pathlib.Path(temp_fd.name) temp_file_path = pathlib.Path(temp_fd.name)
@@ -457,7 +563,7 @@ class Message():
temp_fd.write(f'{ModelLine.from_model(self.model)}\n') temp_fd.write(f'{ModelLine.from_model(self.model)}\n')
temp_fd.write(f'{Question.txt_header}\n{self.question}\n') temp_fd.write(f'{Question.txt_header}\n{self.question}\n')
if self.answer: if self.answer:
temp_fd.write(f'{Answer.txt_header}\n{self.answer}\n') temp_fd.write(f'{Answer.txt_header}\n{str(self.answer)}\n')
shutil.move(temp_file_path, file_path) shutil.move(temp_file_path, file_path)
def __to_file_yaml(self, file_path: pathlib.Path) -> None: def __to_file_yaml(self, file_path: pathlib.Path) -> None:
@@ -484,6 +590,13 @@ class Message():
yaml.dump(data, temp_fd, sort_keys=False) yaml.dump(data, temp_fd, sort_keys=False)
shutil.move(temp_file_path, file_path) shutil.move(temp_file_path, file_path)
def rm_file(self) -> None:
"""
Delete the message file. Ignore empty file_path and not existing files.
"""
if self.file_path is not None:
self.file_path.unlink(missing_ok=True)
def filter_tags(self, prefix: Optional[str] = None, contain: Optional[str] = None) -> set[Tag]: def filter_tags(self, prefix: Optional[str] = None, contain: Optional[str] = None) -> set[Tag]:
""" """
Filter tags based on their prefix (i. e. the tag starts with a given string) Filter tags based on their prefix (i. e. the tag starts with a given string)
@@ -519,7 +632,7 @@ class Message():
or (mfilter.ai and (not self.ai or mfilter.ai != self.ai)) # noqa: W503 or (mfilter.ai and (not self.ai or mfilter.ai != self.ai)) # noqa: W503
or (mfilter.model and (not self.model or mfilter.model != self.model)) # noqa: W503 or (mfilter.model and (not self.model or mfilter.model != self.model)) # noqa: W503
or (mfilter.question_contains and mfilter.question_contains not in self.question) # noqa: W503 or (mfilter.question_contains and mfilter.question_contains not in self.question) # noqa: W503
or (mfilter.answer_contains and (not self.answer or mfilter.answer_contains not in self.answer)) # noqa: W503 or (mfilter.answer_contains and (not self.answer or mfilter.answer_contains not in str(self.answer))) # noqa: W503
or (mfilter.answer_state == 'available' and not self.answer) # noqa: W503 or (mfilter.answer_state == 'available' and not self.answer) # noqa: W503
or (mfilter.ai_state == 'available' and not self.ai) # noqa: W503 or (mfilter.ai_state == 'available' and not self.ai) # noqa: W503
or (mfilter.model_state == 'available' and not self.model) # noqa: W503 or (mfilter.model_state == 'available' and not self.model) # noqa: W503
@@ -537,13 +650,17 @@ class Message():
if self.tags: if self.tags:
self.tags = rename_tags(self.tags, tags_rename) self.tags = rename_tags(self.tags, tags_rename)
def clear_answer(self) -> None:
self.answer = None
def msg_id(self) -> str: def msg_id(self) -> str:
""" """
Returns an ID that is unique throughout all messages in the same (DB) directory. Returns an ID that is unique throughout all messages in the same (DB) directory.
Currently this is the file name. The ID is also used for sorting messages. Currently this is the file name without suffix. The ID is also used for sorting
messages.
""" """
if self.file_path: if self.file_path:
return self.file_path.name return self.file_path.stem
else: else:
raise MessageError("Can't create file ID without a file path") raise MessageError("Can't create file ID without a file path")
+56
View File
@@ -0,0 +1,56 @@
<?php
$secret_key = '123';
// check for POST request
if ($_SERVER['REQUEST_METHOD'] != 'POST') {
error_log('FAILED - not POST - '. $_SERVER['REQUEST_METHOD']);
exit();
}
// get content type
$content_type = isset($_SERVER['CONTENT_TYPE']) ? strtolower(trim($_SERVER['CONTENT_TYPE'])) : '';
if ($content_type != 'application/json') {
error_log('FAILED - not application/json - '. $content_type);
exit();
}
// get payload
$payload = trim(file_get_contents("php://input"));
if (empty($payload)) {
error_log('FAILED - no payload');
exit();
}
// get header signature
$header_signature = isset($_SERVER['HTTP_X_GITEA_SIGNATURE']) ? $_SERVER['HTTP_X_GITEA_SIGNATURE'] : '';
if (empty($header_signature)) {
error_log('FAILED - header signature missing');
exit();
}
// calculate payload signature
$payload_signature = hash_hmac('sha256', $payload, $secret_key, false);
// check payload signature against header signature
if ($header_signature !== $payload_signature) {
error_log('FAILED - payload signature');
exit();
}
// convert json to array
$decoded = json_decode($payload, true);
// check for json decode errors
if (json_last_error() !== JSON_ERROR_NONE) {
error_log('FAILED - json decode - '. json_last_error());
exit();
}
// success, do something
$output = shell_exec('/home/kaizen/repos/ChatMastermind/hooks/push_hook.sh');
echo "$output";
?>
+8
View File
@@ -0,0 +1,8 @@
#!/usr/bin/bash
. /home/kaizen/.bashrc
set -e
cd /home/kaizen/repos/ChatMastermind
git pull
pre-commit run -a
pytest
+1
View File
@@ -2,3 +2,4 @@ openai
PyYAML PyYAML
argcomplete argcomplete
pytest pytest
tiktoken
+93
View File
@@ -0,0 +1,93 @@
import unittest
from unittest import mock
from chatmastermind.ais.openai import OpenAI
from chatmastermind.message import Message, Question, Answer
from chatmastermind.chat import Chat
from chatmastermind.ai import AIResponse, Tokens
from chatmastermind.configuration import OpenAIConfig
class OpenAITest(unittest.TestCase):
@mock.patch('openai.ChatCompletion.create')
def test_request(self, mock_create: mock.MagicMock) -> None:
# Create a test instance of OpenAI
config = OpenAIConfig()
openai = OpenAI(config)
# Set up the mock response from openai.ChatCompletion.create
mock_chunk1 = {
'choices': [
{
'index': 0,
'delta': {
'content': 'Answer 1'
},
'finish_reason': None
},
{
'index': 1,
'delta': {
'content': 'Answer 2'
},
'finish_reason': None
}
],
}
mock_chunk2 = {
'choices': [
{
'index': 0,
'finish_reason': 'stop'
},
{
'index': 1,
'finish_reason': 'stop'
}
],
}
mock_create.return_value = iter([mock_chunk1, mock_chunk2])
# Create test data
question = Message(Question('Question'))
chat = Chat([
Message(Question('Question 1'), answer=Answer('Answer 1')),
Message(Question('Question 2'), answer=Answer('Answer 2')),
# add message without an answer -> expect to be skipped
Message(Question('Question 3'))
])
# Make the request
response = openai.request(question, chat, num_answers=2)
# Assert the AIResponse
self.assertIsInstance(response, AIResponse)
self.assertEqual(len(response.messages), 2)
self.assertEqual(response.messages[0].answer, 'Answer 1')
self.assertEqual(response.messages[1].answer, 'Answer 2')
self.assertIsNotNone(response.tokens)
self.assertIsInstance(response.tokens, Tokens)
assert response.tokens
self.assertEqual(response.tokens.prompt, 53)
self.assertEqual(response.tokens.completion, 6)
self.assertEqual(response.tokens.total, 59)
# Assert the mock call to openai.ChatCompletion.create
mock_create.assert_called_once_with(
model=f'{config.model}',
messages=[
{'role': 'system', 'content': f'{config.system}'},
{'role': 'user', 'content': 'Question 1'},
{'role': 'assistant', 'content': 'Answer 1'},
{'role': 'user', 'content': 'Question 2'},
{'role': 'assistant', 'content': 'Answer 2'},
{'role': 'user', 'content': 'Question'}
],
temperature=config.temperature,
max_tokens=config.max_tokens,
top_p=config.top_p,
n=2,
stream=True,
frequency_penalty=config.frequency_penalty,
presence_penalty=config.presence_penalty
)
+354 -155
View File
@@ -2,179 +2,267 @@ import unittest
import pathlib import pathlib
import tempfile import tempfile
import time import time
import yaml
from io import StringIO from io import StringIO
from unittest.mock import patch from unittest.mock import patch
from chatmastermind.tags import TagLine from chatmastermind.tags import TagLine
from chatmastermind.message import Message, Question, Answer, Tag, MessageFilter from chatmastermind.message import Message, Question, Answer, Tag, MessageFilter
from chatmastermind.chat import Chat, ChatDB, terminal_width, ChatError from chatmastermind.chat import Chat, ChatDB, ChatError, msg_location
class TestChat(unittest.TestCase): msg_suffix: str = Message.file_suffix_write
def msg_to_file_force_suffix(msg: Message) -> None:
"""
Force writing a message file with illegal suffixes.
"""
def_suffix = Message.file_suffix_write
assert msg.file_path
Message.file_suffix_write = msg.file_path.suffix
msg.to_file()
Message.file_suffix_write = def_suffix
class TestChatBase(unittest.TestCase):
def assert_messages_equal(self, msg1: list[Message], msg2: list[Message]) -> None:
"""
Compare messages using more than just Question and Answer.
"""
self.assertEqual(len(msg1), len(msg2))
for m1, m2 in zip(msg1, msg2):
# exclude the file_path, compare only Q, A and metadata
self.assertTrue(m1.equals(m2, file_path=False, verbose=True))
class TestChat(TestChatBase):
def setUp(self) -> None: def setUp(self) -> None:
self.chat = Chat([]) self.chat = Chat([])
self.message1 = Message(Question('Question 1'), self.message1 = Message(Question('Question 1'),
Answer('Answer 1'), Answer('Answer 1'),
{Tag('atag1'), Tag('btag2')}, {Tag('atag1'), Tag('btag2')},
file_path=pathlib.Path('0001.txt')) ai='FakeAI',
model='FakeModel',
file_path=pathlib.Path(f'0001{msg_suffix}'))
self.message2 = Message(Question('Question 2'), self.message2 = Message(Question('Question 2'),
Answer('Answer 2'), Answer('Answer 2'),
{Tag('btag2')}, {Tag('btag2')},
file_path=pathlib.Path('0002.txt')) ai='FakeAI',
model='FakeModel',
file_path=pathlib.Path(f'0002{msg_suffix}'))
self.maxDiff = None
def test_unique_id(self) -> None:
# test with two identical messages
self.chat.msg_add([self.message1, self.message1])
self.assert_messages_equal(self.chat.messages, [self.message1, self.message1])
self.chat.msg_unique_id()
self.assert_messages_equal(self.chat.messages, [self.message1])
# test with two different messages
self.chat.msg_add([self.message2])
self.chat.msg_unique_id()
self.assert_messages_equal(self.chat.messages, [self.message1, self.message2])
def test_unique_content(self) -> None:
# test with two identical messages
self.chat.msg_add([self.message1, self.message1])
self.assert_messages_equal(self.chat.messages, [self.message1, self.message1])
self.chat.msg_unique_content()
self.assert_messages_equal(self.chat.messages, [self.message1])
# test with two different messages
self.chat.msg_add([self.message2])
self.chat.msg_unique_content()
self.assert_messages_equal(self.chat.messages, [self.message1, self.message2])
def test_filter(self) -> None: def test_filter(self) -> None:
self.chat.add_messages([self.message1, self.message2]) self.chat.msg_add([self.message1, self.message2])
self.chat.filter(MessageFilter(answer_contains='Answer 1')) self.chat.msg_filter(MessageFilter(answer_contains='Answer 1'))
self.assertEqual(len(self.chat.messages), 1) self.assertEqual(len(self.chat.messages), 1)
self.assertEqual(self.chat.messages[0].question, 'Question 1') self.assertEqual(self.chat.messages[0].question, 'Question 1')
def test_sort(self) -> None: def test_sort(self) -> None:
self.chat.add_messages([self.message2, self.message1]) self.chat.msg_add([self.message2, self.message1])
self.chat.sort() self.chat.msg_sort()
self.assertEqual(self.chat.messages[0].question, 'Question 1') self.assertEqual(self.chat.messages[0].question, 'Question 1')
self.assertEqual(self.chat.messages[1].question, 'Question 2') self.assertEqual(self.chat.messages[1].question, 'Question 2')
self.chat.sort(reverse=True) self.chat.msg_sort(reverse=True)
self.assertEqual(self.chat.messages[0].question, 'Question 2') self.assertEqual(self.chat.messages[0].question, 'Question 2')
self.assertEqual(self.chat.messages[1].question, 'Question 1') self.assertEqual(self.chat.messages[1].question, 'Question 1')
def test_clear(self) -> None: def test_clear(self) -> None:
self.chat.add_messages([self.message1]) self.chat.msg_add([self.message1])
self.chat.clear() self.chat.msg_clear()
self.assertEqual(len(self.chat.messages), 0) self.assertEqual(len(self.chat.messages), 0)
def test_add_messages(self) -> None: def test_add_messages(self) -> None:
self.chat.add_messages([self.message1, self.message2]) self.chat.msg_add([self.message1, self.message2])
self.assertEqual(len(self.chat.messages), 2) self.assertEqual(len(self.chat.messages), 2)
self.assertEqual(self.chat.messages[0].question, 'Question 1') self.assertEqual(self.chat.messages[0].question, 'Question 1')
self.assertEqual(self.chat.messages[1].question, 'Question 2') self.assertEqual(self.chat.messages[1].question, 'Question 2')
def test_tags(self) -> None: def test_tags(self) -> None:
self.chat.add_messages([self.message1, self.message2]) self.chat.msg_add([self.message1, self.message2])
tags_all = self.chat.tags() tags_all = self.chat.msg_tags()
self.assertSetEqual(tags_all, {Tag('atag1'), Tag('btag2')}) self.assertSetEqual(tags_all, {Tag('atag1'), Tag('btag2')})
tags_pref = self.chat.tags(prefix='a') tags_pref = self.chat.msg_tags(prefix='a')
self.assertSetEqual(tags_pref, {Tag('atag1')}) self.assertSetEqual(tags_pref, {Tag('atag1')})
tags_cont = self.chat.tags(contain='2') tags_cont = self.chat.msg_tags(contain='2')
self.assertSetEqual(tags_cont, {Tag('btag2')}) self.assertSetEqual(tags_cont, {Tag('btag2')})
def test_tags_frequency(self) -> None: def test_tags_frequency(self) -> None:
self.chat.add_messages([self.message1, self.message2]) self.chat.msg_add([self.message1, self.message2])
tags_freq = self.chat.tags_frequency() tags_freq = self.chat.msg_tags_frequency()
self.assertDictEqual(tags_freq, {'atag1': 1, 'btag2': 2}) self.assertDictEqual(tags_freq, {'atag1': 1, 'btag2': 2})
def test_find_remove_messages(self) -> None: def test_find_remove_messages(self) -> None:
self.chat.add_messages([self.message1, self.message2]) self.chat.msg_add([self.message1, self.message2])
msgs = self.chat.find_messages(['0001.txt']) msgs = self.chat.msg_find(['0001'])
self.assertListEqual(msgs, [self.message1]) self.assertListEqual(msgs, [self.message1])
msgs = self.chat.find_messages(['0001.txt', '0002.txt']) msgs = self.chat.msg_find(['0001', '0002'])
self.assertListEqual(msgs, [self.message1, self.message2]) self.assertListEqual(msgs, [self.message1, self.message2])
# add new Message with full path # add new Message with full path
message3 = Message(Question('Question 2'), message3 = Message(Question('Question 2'),
Answer('Answer 2'), Answer('Answer 2'),
{Tag('btag2')}, {Tag('btag2')},
file_path=pathlib.Path('/foo/bla/0003.txt')) file_path=pathlib.Path(f'/foo/bla/0003{msg_suffix}'))
self.chat.add_messages([message3]) self.chat.msg_add([message3])
# find new Message by full path # find new Message by full path
msgs = self.chat.find_messages(['/foo/bla/0003.txt']) msgs = self.chat.msg_find([f'/foo/bla/0003{msg_suffix}'])
self.assertListEqual(msgs, [message3]) self.assertListEqual(msgs, [message3])
# find Message with full path only by filename # find Message with full path only by filename
msgs = self.chat.find_messages(['0003.txt']) msgs = self.chat.msg_find([f'0003{msg_suffix}'])
self.assertListEqual(msgs, [message3]) self.assertListEqual(msgs, [message3])
# remove last message # remove last message
self.chat.remove_messages(['0003.txt']) self.chat.msg_remove(['0003'])
self.assertListEqual(self.chat.messages, [self.message1, self.message2]) self.assertListEqual(self.chat.messages, [self.message1, self.message2])
def test_latest_message(self) -> None:
self.assertIsNone(self.chat.msg_latest())
self.chat.msg_add([self.message1])
self.assertEqual(self.chat.msg_latest(), self.message1)
self.chat.msg_add([self.message2])
self.assertEqual(self.chat.msg_latest(), self.message2)
@patch('sys.stdout', new_callable=StringIO) @patch('sys.stdout', new_callable=StringIO)
def test_print(self, mock_stdout: StringIO) -> None: def test_print(self, mock_stdout: StringIO) -> None:
self.chat.add_messages([self.message1, self.message2]) self.chat.msg_add([self.message1, self.message2])
self.chat.print(paged=False) self.chat.print(paged=False, tight=True)
expected_output = f"""{Question.txt_header} expected_output = f"""{Question.txt_header}
Question 1 Question 1
{Answer.txt_header} {Answer.txt_header}
Answer 1 Answer 1
{'-'*terminal_width()}
{Question.txt_header} {Question.txt_header}
Question 2 Question 2
{Answer.txt_header} {Answer.txt_header}
Answer 2 Answer 2
{'-'*terminal_width()}
""" """
self.assertEqual(mock_stdout.getvalue(), expected_output) self.assertEqual(mock_stdout.getvalue(), expected_output)
@patch('sys.stdout', new_callable=StringIO) @patch('sys.stdout', new_callable=StringIO)
def test_print_with_tags_and_file(self, mock_stdout: StringIO) -> None: def test_print_with_metadata(self, mock_stdout: StringIO) -> None:
self.chat.add_messages([self.message1, self.message2]) self.chat.msg_add([self.message1, self.message2])
self.chat.print(paged=False, with_tags=True, with_files=True) self.chat.print(paged=False, with_metadata=True, tight=True)
expected_output = f"""{TagLine.prefix} atag1 btag2 expected_output = f"""{TagLine.prefix} atag1 btag2
FILE: 0001.txt FILE: 0001{msg_suffix}
AI: FakeAI
MODEL: FakeModel
{Question.txt_header} {Question.txt_header}
Question 1 Question 1
{Answer.txt_header} {Answer.txt_header}
Answer 1 Answer 1
{'-'*terminal_width()}
{TagLine.prefix} btag2 {TagLine.prefix} btag2
FILE: 0002.txt FILE: 0002{msg_suffix}
AI: FakeAI
MODEL: FakeModel
{Question.txt_header} {Question.txt_header}
Question 2 Question 2
{Answer.txt_header} {Answer.txt_header}
Answer 2 Answer 2
{'-'*terminal_width()}
""" """
self.assertEqual(mock_stdout.getvalue(), expected_output) self.assertEqual(mock_stdout.getvalue(), expected_output)
class TestChatDB(unittest.TestCase): class TestChatDB(TestChatBase):
def setUp(self) -> None: def setUp(self) -> None:
self.db_path = tempfile.TemporaryDirectory() self.db_path = tempfile.TemporaryDirectory()
self.cache_path = tempfile.TemporaryDirectory() self.cache_path = tempfile.TemporaryDirectory()
self.message1 = Message(Question('Question 1'), self.message1 = Message(Question('Question 1'),
Answer('Answer 1'), Answer('Answer 1'),
{Tag('tag1')}, {Tag('tag1')})
file_path=pathlib.Path('0001.txt'))
self.message2 = Message(Question('Question 2'), self.message2 = Message(Question('Question 2'),
Answer('Answer 2'), Answer('Answer 2'),
{Tag('tag2')}, {Tag('tag2')})
file_path=pathlib.Path('0002.yaml'))
self.message3 = Message(Question('Question 3'), self.message3 = Message(Question('Question 3'),
Answer('Answer 3'), Answer('Answer 3'),
{Tag('tag3')}, {Tag('tag3')})
file_path=pathlib.Path('0003.txt'))
self.message4 = Message(Question('Question 4'), self.message4 = Message(Question('Question 4'),
Answer('Answer 4'), Answer('Answer 4'),
{Tag('tag4')}, {Tag('tag4')})
file_path=pathlib.Path('0004.yaml'))
self.message1.to_file(pathlib.Path(self.db_path.name, '0001.txt')) self.message1.to_file(pathlib.Path(self.db_path.name, '0001'), mformat='txt')
self.message2.to_file(pathlib.Path(self.db_path.name, '0002.yaml')) self.message2.to_file(pathlib.Path(self.db_path.name, '0002'), mformat='yaml')
self.message3.to_file(pathlib.Path(self.db_path.name, '0003.txt')) self.message3.to_file(pathlib.Path(self.db_path.name, '0003'), mformat='txt')
self.message4.to_file(pathlib.Path(self.db_path.name, '0004.yaml')) self.message4.to_file(pathlib.Path(self.db_path.name, '0004'), mformat='yaml')
# make the next FID match the current state # make the next FID match the current state
next_fname = pathlib.Path(self.db_path.name) / '.next' next_fname = pathlib.Path(self.db_path.name) / '.next'
with open(next_fname, 'w') as f: with open(next_fname, 'w') as f:
f.write('4') f.write('4')
# add some "trash" in order to test if it's correctly handled / ignored
self.trash_files = ['.config.yaml', 'foo.yaml', 'bla.txt', 'fubar.msg']
for file in self.trash_files:
with open(pathlib.Path(self.db_path.name) / file, 'w') as f:
f.write('test trash')
# also create a file with actual yaml content
with open(pathlib.Path(self.db_path.name) / 'content.yaml', 'w') as f:
yaml.dump({'key': 'value'}, f)
self.trash_files.append('content.yaml')
self.maxDiff = None
def message_list(self, tmp_dir: tempfile.TemporaryDirectory) -> list[pathlib.Path]: def message_list(self, tmp_dir: tempfile.TemporaryDirectory) -> list[pathlib.Path]:
""" """
List all Message files in the given TemporaryDirectory. List all Message files in the given TemporaryDirectory.
""" """
# exclude '.next' # exclude '.next'
return list(pathlib.Path(tmp_dir.name).glob('*.[ty]*')) return [f for f in pathlib.Path(tmp_dir.name).glob('*.[tym]*') if f.name not in self.trash_files]
def tearDown(self) -> None: def tearDown(self) -> None:
self.db_path.cleanup() self.db_path.cleanup()
self.cache_path.cleanup() self.cache_path.cleanup()
pass pass
def test_chat_db_from_dir(self) -> None: def test_validate(self) -> None:
duplicate_message = Message(Question('Question 4'),
Answer('Answer 4'),
{Tag('tag4')},
file_path=pathlib.Path(self.db_path.name, '0004.txt'))
msg_to_file_force_suffix(duplicate_message)
with self.assertRaises(ChatError) as cm:
ChatDB.from_dir(pathlib.Path(self.cache_path.name),
pathlib.Path(self.db_path.name),
glob='*')
self.assertEqual(str(cm.exception), "Validation failed")
def test_file_path_ID_exists(self) -> None:
"""
Tests if the CacheDB chooses another ID if a file path with
the given one exists.
"""
# create a new and empty CacheDB
db_path = tempfile.TemporaryDirectory()
cache_path = tempfile.TemporaryDirectory()
chat_db = ChatDB.from_dir(pathlib.Path(cache_path.name),
pathlib.Path(db_path.name))
# add a message file
message = Message(Question('What?'),
file_path=pathlib.Path(cache_path.name) / f'0001{msg_suffix}')
message.to_file()
message1 = Message(Question('Where?'))
chat_db.cache_write([message1])
self.assertEqual(message1.msg_id(), '0002')
def test_from_dir(self) -> None:
chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name), chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name),
pathlib.Path(self.db_path.name)) pathlib.Path(self.db_path.name))
self.assertEqual(len(chat_db.messages), 4) self.assertEqual(len(chat_db.messages), 4)
@@ -182,27 +270,25 @@ class TestChatDB(unittest.TestCase):
self.assertEqual(chat_db.db_path, pathlib.Path(self.db_path.name)) self.assertEqual(chat_db.db_path, pathlib.Path(self.db_path.name))
# check that the files are sorted # check that the files are sorted
self.assertEqual(chat_db.messages[0].file_path, self.assertEqual(chat_db.messages[0].file_path,
pathlib.Path(self.db_path.name, '0001.txt')) pathlib.Path(self.db_path.name, f'0001{msg_suffix}'))
self.assertEqual(chat_db.messages[1].file_path, self.assertEqual(chat_db.messages[1].file_path,
pathlib.Path(self.db_path.name, '0002.yaml')) pathlib.Path(self.db_path.name, f'0002{msg_suffix}'))
self.assertEqual(chat_db.messages[2].file_path, self.assertEqual(chat_db.messages[2].file_path,
pathlib.Path(self.db_path.name, '0003.txt')) pathlib.Path(self.db_path.name, f'0003{msg_suffix}'))
self.assertEqual(chat_db.messages[3].file_path, self.assertEqual(chat_db.messages[3].file_path,
pathlib.Path(self.db_path.name, '0004.yaml')) pathlib.Path(self.db_path.name, f'0004{msg_suffix}'))
def test_chat_db_from_dir_glob(self) -> None: def test_from_dir_glob(self) -> None:
chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name), chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name),
pathlib.Path(self.db_path.name), pathlib.Path(self.db_path.name),
glob='*.txt') glob='*1.*')
self.assertEqual(len(chat_db.messages), 2) self.assertEqual(len(chat_db.messages), 1)
self.assertEqual(chat_db.cache_path, pathlib.Path(self.cache_path.name)) self.assertEqual(chat_db.cache_path, pathlib.Path(self.cache_path.name))
self.assertEqual(chat_db.db_path, pathlib.Path(self.db_path.name)) self.assertEqual(chat_db.db_path, pathlib.Path(self.db_path.name))
self.assertEqual(chat_db.messages[0].file_path, self.assertEqual(chat_db.messages[0].file_path,
pathlib.Path(self.db_path.name, '0001.txt')) pathlib.Path(self.db_path.name, f'0001{msg_suffix}'))
self.assertEqual(chat_db.messages[1].file_path,
pathlib.Path(self.db_path.name, '0003.txt'))
def test_chat_db_from_dir_filter_tags(self) -> None: def test_from_dir_filter_tags(self) -> None:
chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name), chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name),
pathlib.Path(self.db_path.name), pathlib.Path(self.db_path.name),
mfilter=MessageFilter(tags_or={Tag('tag1')})) mfilter=MessageFilter(tags_or={Tag('tag1')}))
@@ -210,9 +296,9 @@ class TestChatDB(unittest.TestCase):
self.assertEqual(chat_db.cache_path, pathlib.Path(self.cache_path.name)) self.assertEqual(chat_db.cache_path, pathlib.Path(self.cache_path.name))
self.assertEqual(chat_db.db_path, pathlib.Path(self.db_path.name)) self.assertEqual(chat_db.db_path, pathlib.Path(self.db_path.name))
self.assertEqual(chat_db.messages[0].file_path, self.assertEqual(chat_db.messages[0].file_path,
pathlib.Path(self.db_path.name, '0001.txt')) pathlib.Path(self.db_path.name, f'0001{msg_suffix}'))
def test_chat_db_from_dir_filter_tags_empty(self) -> None: def test_from_dir_filter_tags_empty(self) -> None:
chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name), chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name),
pathlib.Path(self.db_path.name), pathlib.Path(self.db_path.name),
mfilter=MessageFilter(tags_or=set(), mfilter=MessageFilter(tags_or=set(),
@@ -220,7 +306,7 @@ class TestChatDB(unittest.TestCase):
tags_not=set())) tags_not=set()))
self.assertEqual(len(chat_db.messages), 0) self.assertEqual(len(chat_db.messages), 0)
def test_chat_db_from_dir_filter_answer(self) -> None: def test_from_dir_filter_answer(self) -> None:
chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name), chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name),
pathlib.Path(self.db_path.name), pathlib.Path(self.db_path.name),
mfilter=MessageFilter(answer_contains='Answer 2')) mfilter=MessageFilter(answer_contains='Answer 2'))
@@ -228,10 +314,10 @@ class TestChatDB(unittest.TestCase):
self.assertEqual(chat_db.cache_path, pathlib.Path(self.cache_path.name)) self.assertEqual(chat_db.cache_path, pathlib.Path(self.cache_path.name))
self.assertEqual(chat_db.db_path, pathlib.Path(self.db_path.name)) self.assertEqual(chat_db.db_path, pathlib.Path(self.db_path.name))
self.assertEqual(chat_db.messages[0].file_path, self.assertEqual(chat_db.messages[0].file_path,
pathlib.Path(self.db_path.name, '0002.yaml')) pathlib.Path(self.db_path.name, f'0002{msg_suffix}'))
self.assertEqual(chat_db.messages[0].answer, 'Answer 2') self.assertEqual(chat_db.messages[0].answer, 'Answer 2')
def test_chat_db_from_messages(self) -> None: def test_from_messages(self) -> None:
chat_db = ChatDB.from_messages(pathlib.Path(self.cache_path.name), chat_db = ChatDB.from_messages(pathlib.Path(self.cache_path.name),
pathlib.Path(self.db_path.name), pathlib.Path(self.db_path.name),
messages=[self.message1, self.message2, messages=[self.message1, self.message2,
@@ -240,39 +326,58 @@ class TestChatDB(unittest.TestCase):
self.assertEqual(chat_db.cache_path, pathlib.Path(self.cache_path.name)) self.assertEqual(chat_db.cache_path, pathlib.Path(self.cache_path.name))
self.assertEqual(chat_db.db_path, pathlib.Path(self.db_path.name)) self.assertEqual(chat_db.db_path, pathlib.Path(self.db_path.name))
def test_chat_db_fids(self) -> None: def test_fids(self) -> None:
chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name), chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name),
pathlib.Path(self.db_path.name)) pathlib.Path(self.db_path.name))
self.assertEqual(chat_db.get_next_fid(), 5) self.assertEqual(chat_db.get_next_fid(), 5)
self.assertEqual(chat_db.get_next_fid(), 6) self.assertEqual(chat_db.get_next_fid(), 6)
self.assertEqual(chat_db.get_next_fid(), 7) self.assertEqual(chat_db.get_next_fid(), 7)
with open(chat_db.next_fname, 'r') as f: with open(chat_db.next_path, 'r') as f:
self.assertEqual(f.read(), '7') self.assertEqual(f.read(), '7')
def test_chat_db_write(self) -> None: def test_msg_in_db_or_cache(self) -> None:
# create a new ChatDB instance
chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name),
pathlib.Path(self.db_path.name))
self.assertTrue(chat_db.msg_in_db(self.message1))
self.assertTrue(chat_db.msg_in_db(str(self.message1.file_path)))
self.assertTrue(chat_db.msg_in_db(self.message1.msg_id()))
self.assertFalse(chat_db.msg_in_cache(self.message1))
self.assertFalse(chat_db.msg_in_cache(str(self.message1.file_path)))
self.assertFalse(chat_db.msg_in_cache(self.message1.msg_id()))
# add new message to the cache dir
cache_message = Message(question=Question("Question 1"),
answer=Answer("Answer 1"))
chat_db.cache_add([cache_message])
self.assertTrue(chat_db.msg_in_cache(cache_message))
self.assertTrue(chat_db.msg_in_cache(cache_message.msg_id()))
self.assertFalse(chat_db.msg_in_db(cache_message))
self.assertFalse(chat_db.msg_in_db(str(cache_message.file_path)))
def test_db_write(self) -> None:
# create a new ChatDB instance # create a new ChatDB instance
chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name), chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name),
pathlib.Path(self.db_path.name)) pathlib.Path(self.db_path.name))
# check that Message.file_path is correct # check that Message.file_path is correct
self.assertEqual(chat_db.messages[0].file_path, pathlib.Path(self.db_path.name, '0001.txt')) self.assertEqual(chat_db.messages[0].file_path, pathlib.Path(self.db_path.name, f'0001{msg_suffix}'))
self.assertEqual(chat_db.messages[1].file_path, pathlib.Path(self.db_path.name, '0002.yaml')) self.assertEqual(chat_db.messages[1].file_path, pathlib.Path(self.db_path.name, f'0002{msg_suffix}'))
self.assertEqual(chat_db.messages[2].file_path, pathlib.Path(self.db_path.name, '0003.txt')) self.assertEqual(chat_db.messages[2].file_path, pathlib.Path(self.db_path.name, f'0003{msg_suffix}'))
self.assertEqual(chat_db.messages[3].file_path, pathlib.Path(self.db_path.name, '0004.yaml')) self.assertEqual(chat_db.messages[3].file_path, pathlib.Path(self.db_path.name, f'0004{msg_suffix}'))
# write the messages to the cache directory # write the messages to the cache directory
chat_db.write_cache() chat_db.cache_write()
# check if the written files are in the cache directory # check if the written files are in the cache directory
cache_dir_files = self.message_list(self.cache_path) cache_dir_files = self.message_list(self.cache_path)
self.assertEqual(len(cache_dir_files), 4) self.assertEqual(len(cache_dir_files), 4)
self.assertIn(pathlib.Path(self.cache_path.name, '0001.txt'), cache_dir_files) self.assertIn(pathlib.Path(self.cache_path.name, f'0001{msg_suffix}'), cache_dir_files)
self.assertIn(pathlib.Path(self.cache_path.name, '0002.yaml'), cache_dir_files) self.assertIn(pathlib.Path(self.cache_path.name, f'0002{msg_suffix}'), cache_dir_files)
self.assertIn(pathlib.Path(self.cache_path.name, '0003.txt'), cache_dir_files) self.assertIn(pathlib.Path(self.cache_path.name, f'0003{msg_suffix}'), cache_dir_files)
self.assertIn(pathlib.Path(self.cache_path.name, '0004.yaml'), cache_dir_files) self.assertIn(pathlib.Path(self.cache_path.name, f'0004{msg_suffix}'), cache_dir_files)
# check that Message.file_path has been correctly updated # check that Message.file_path has been correctly updated
self.assertEqual(chat_db.messages[0].file_path, pathlib.Path(self.cache_path.name, '0001.txt')) self.assertEqual(chat_db.messages[0].file_path, pathlib.Path(self.cache_path.name, f'0001{msg_suffix}'))
self.assertEqual(chat_db.messages[1].file_path, pathlib.Path(self.cache_path.name, '0002.yaml')) self.assertEqual(chat_db.messages[1].file_path, pathlib.Path(self.cache_path.name, f'0002{msg_suffix}'))
self.assertEqual(chat_db.messages[2].file_path, pathlib.Path(self.cache_path.name, '0003.txt')) self.assertEqual(chat_db.messages[2].file_path, pathlib.Path(self.cache_path.name, f'0003{msg_suffix}'))
self.assertEqual(chat_db.messages[3].file_path, pathlib.Path(self.cache_path.name, '0004.yaml')) self.assertEqual(chat_db.messages[3].file_path, pathlib.Path(self.cache_path.name, f'0004{msg_suffix}'))
# check the timestamp of the files in the DB directory # check the timestamp of the files in the DB directory
db_dir_files = self.message_list(self.db_path) db_dir_files = self.message_list(self.db_path)
@@ -280,24 +385,24 @@ class TestChatDB(unittest.TestCase):
old_timestamps = {file: file.stat().st_mtime for file in db_dir_files} old_timestamps = {file: file.stat().st_mtime for file in db_dir_files}
# overwrite the messages in the db directory # overwrite the messages in the db directory
time.sleep(0.05) time.sleep(0.05)
chat_db.write_db() chat_db.db_write()
# check if the written files are in the DB directory # check if the written files are in the DB directory
db_dir_files = self.message_list(self.db_path) db_dir_files = self.message_list(self.db_path)
self.assertEqual(len(db_dir_files), 4) self.assertEqual(len(db_dir_files), 4)
self.assertIn(pathlib.Path(self.db_path.name, '0001.txt'), db_dir_files) self.assertIn(pathlib.Path(self.db_path.name, f'0001{msg_suffix}'), db_dir_files)
self.assertIn(pathlib.Path(self.db_path.name, '0002.yaml'), db_dir_files) self.assertIn(pathlib.Path(self.db_path.name, f'0002{msg_suffix}'), db_dir_files)
self.assertIn(pathlib.Path(self.db_path.name, '0003.txt'), db_dir_files) self.assertIn(pathlib.Path(self.db_path.name, f'0003{msg_suffix}'), db_dir_files)
self.assertIn(pathlib.Path(self.db_path.name, '0004.yaml'), db_dir_files) self.assertIn(pathlib.Path(self.db_path.name, f'0004{msg_suffix}'), db_dir_files)
# check if all files in the DB dir have actually been overwritten # check if all files in the DB dir have actually been overwritten
for file in db_dir_files: for file in db_dir_files:
self.assertGreater(file.stat().st_mtime, old_timestamps[file]) self.assertGreater(file.stat().st_mtime, old_timestamps[file])
# check that Message.file_path has been correctly updated (again) # check that Message.file_path has been correctly updated (again)
self.assertEqual(chat_db.messages[0].file_path, pathlib.Path(self.db_path.name, '0001.txt')) self.assertEqual(chat_db.messages[0].file_path, pathlib.Path(self.db_path.name, f'0001{msg_suffix}'))
self.assertEqual(chat_db.messages[1].file_path, pathlib.Path(self.db_path.name, '0002.yaml')) self.assertEqual(chat_db.messages[1].file_path, pathlib.Path(self.db_path.name, f'0002{msg_suffix}'))
self.assertEqual(chat_db.messages[2].file_path, pathlib.Path(self.db_path.name, '0003.txt')) self.assertEqual(chat_db.messages[2].file_path, pathlib.Path(self.db_path.name, f'0003{msg_suffix}'))
self.assertEqual(chat_db.messages[3].file_path, pathlib.Path(self.db_path.name, '0004.yaml')) self.assertEqual(chat_db.messages[3].file_path, pathlib.Path(self.db_path.name, f'0004{msg_suffix}'))
def test_chat_db_read(self) -> None: def test_db_read(self) -> None:
# create a new ChatDB instance # create a new ChatDB instance
chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name), chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name),
pathlib.Path(self.db_path.name)) pathlib.Path(self.db_path.name))
@@ -310,80 +415,80 @@ class TestChatDB(unittest.TestCase):
new_message2 = Message(Question('Question 6'), new_message2 = Message(Question('Question 6'),
Answer('Answer 6'), Answer('Answer 6'),
{Tag('tag6')}) {Tag('tag6')})
new_message1.to_file(pathlib.Path(self.db_path.name, '0005.txt')) new_message1.to_file(pathlib.Path(self.db_path.name, f'0005{msg_suffix}'), mformat='txt')
new_message2.to_file(pathlib.Path(self.db_path.name, '0006.yaml')) new_message2.to_file(pathlib.Path(self.db_path.name, f'0006{msg_suffix}'), mformat='yaml')
# read and check them # read and check them
chat_db.read_db() chat_db.db_read()
self.assertEqual(len(chat_db.messages), 6) self.assertEqual(len(chat_db.messages), 6)
self.assertEqual(chat_db.messages[4].file_path, pathlib.Path(self.db_path.name, '0005.txt')) self.assertEqual(chat_db.messages[4].file_path, pathlib.Path(self.db_path.name, f'0005{msg_suffix}'))
self.assertEqual(chat_db.messages[5].file_path, pathlib.Path(self.db_path.name, '0006.yaml')) self.assertEqual(chat_db.messages[5].file_path, pathlib.Path(self.db_path.name, f'0006{msg_suffix}'))
# create 2 new files in the cache directory # create 2 new files in the cache directory
new_message3 = Message(Question('Question 7'), new_message3 = Message(Question('Question 7'),
Answer('Answer 5'), Answer('Answer 7'),
{Tag('tag7')}) {Tag('tag7')})
new_message4 = Message(Question('Question 8'), new_message4 = Message(Question('Question 8'),
Answer('Answer 6'), Answer('Answer 8'),
{Tag('tag8')}) {Tag('tag8')})
new_message3.to_file(pathlib.Path(self.cache_path.name, '0007.txt')) new_message3.to_file(pathlib.Path(self.cache_path.name, f'0007{msg_suffix}'), mformat='txt')
new_message4.to_file(pathlib.Path(self.cache_path.name, '0008.yaml')) new_message4.to_file(pathlib.Path(self.cache_path.name, f'0008{msg_suffix}'), mformat='yaml')
# read and check them # read and check them
chat_db.read_cache() chat_db.cache_read()
self.assertEqual(len(chat_db.messages), 8) self.assertEqual(len(chat_db.messages), 8)
# check that the new message have the cache dir path # check that the new message have the cache dir path
self.assertEqual(chat_db.messages[6].file_path, pathlib.Path(self.cache_path.name, '0007.txt')) self.assertEqual(chat_db.messages[6].file_path, pathlib.Path(self.cache_path.name, f'0007{msg_suffix}'))
self.assertEqual(chat_db.messages[7].file_path, pathlib.Path(self.cache_path.name, '0008.yaml')) self.assertEqual(chat_db.messages[7].file_path, pathlib.Path(self.cache_path.name, f'0008{msg_suffix}'))
# an the old ones keep their path (since they have not been replaced) # an the old ones keep their path (since they have not been replaced)
self.assertEqual(chat_db.messages[4].file_path, pathlib.Path(self.db_path.name, '0005.txt')) self.assertEqual(chat_db.messages[4].file_path, pathlib.Path(self.db_path.name, f'0005{msg_suffix}'))
self.assertEqual(chat_db.messages[5].file_path, pathlib.Path(self.db_path.name, '0006.yaml')) self.assertEqual(chat_db.messages[5].file_path, pathlib.Path(self.db_path.name, f'0006{msg_suffix}'))
# now overwrite two messages in the DB directory # now overwrite two messages in the DB directory
new_message1.question = Question('New Question 1') new_message1.question = Question('New Question 1')
new_message2.question = Question('New Question 2') new_message2.question = Question('New Question 2')
new_message1.to_file(pathlib.Path(self.db_path.name, '0005.txt')) new_message1.to_file(pathlib.Path(self.db_path.name, f'0005{msg_suffix}'), mformat='txt')
new_message2.to_file(pathlib.Path(self.db_path.name, '0006.yaml')) new_message2.to_file(pathlib.Path(self.db_path.name, f'0006{msg_suffix}'), mformat='yaml')
# read from the DB dir and check if the modified messages have been updated # read from the DB dir and check if the modified messages have been updated
chat_db.read_db() chat_db.db_read()
self.assertEqual(len(chat_db.messages), 8) self.assertEqual(len(chat_db.messages), 8)
self.assertEqual(chat_db.messages[4].question, 'New Question 1') self.assertEqual(chat_db.messages[4].question, 'New Question 1')
self.assertEqual(chat_db.messages[5].question, 'New Question 2') self.assertEqual(chat_db.messages[5].question, 'New Question 2')
self.assertEqual(chat_db.messages[4].file_path, pathlib.Path(self.db_path.name, '0005.txt')) self.assertEqual(chat_db.messages[4].file_path, pathlib.Path(self.db_path.name, f'0005{msg_suffix}'))
self.assertEqual(chat_db.messages[5].file_path, pathlib.Path(self.db_path.name, '0006.yaml')) self.assertEqual(chat_db.messages[5].file_path, pathlib.Path(self.db_path.name, f'0006{msg_suffix}'))
# now write the messages from the cache to the DB directory # now write the messages from the cache to the DB directory
new_message3.to_file(pathlib.Path(self.db_path.name, '0007.txt')) new_message3.to_file(pathlib.Path(self.db_path.name, f'0007{msg_suffix}'))
new_message4.to_file(pathlib.Path(self.db_path.name, '0008.yaml')) new_message4.to_file(pathlib.Path(self.db_path.name, f'0008{msg_suffix}'))
# read and check them # read and check them
chat_db.read_db() chat_db.db_read()
self.assertEqual(len(chat_db.messages), 8) self.assertEqual(len(chat_db.messages), 8)
# check that they now have the DB path # check that they now have the DB path
self.assertEqual(chat_db.messages[6].file_path, pathlib.Path(self.db_path.name, '0007.txt')) self.assertEqual(chat_db.messages[6].file_path, pathlib.Path(self.db_path.name, f'0007{msg_suffix}'))
self.assertEqual(chat_db.messages[7].file_path, pathlib.Path(self.db_path.name, '0008.yaml')) self.assertEqual(chat_db.messages[7].file_path, pathlib.Path(self.db_path.name, f'0008{msg_suffix}'))
def test_chat_db_clear(self) -> None: def test_cache_clear(self) -> None:
# create a new ChatDB instance # create a new ChatDB instance
chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name), chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name),
pathlib.Path(self.db_path.name)) pathlib.Path(self.db_path.name))
# check that Message.file_path is correct # check that Message.file_path is correct
self.assertEqual(chat_db.messages[0].file_path, pathlib.Path(self.db_path.name, '0001.txt')) self.assertEqual(chat_db.messages[0].file_path, pathlib.Path(self.db_path.name, f'0001{msg_suffix}'))
self.assertEqual(chat_db.messages[1].file_path, pathlib.Path(self.db_path.name, '0002.yaml')) self.assertEqual(chat_db.messages[1].file_path, pathlib.Path(self.db_path.name, f'0002{msg_suffix}'))
self.assertEqual(chat_db.messages[2].file_path, pathlib.Path(self.db_path.name, '0003.txt')) self.assertEqual(chat_db.messages[2].file_path, pathlib.Path(self.db_path.name, f'0003{msg_suffix}'))
self.assertEqual(chat_db.messages[3].file_path, pathlib.Path(self.db_path.name, '0004.yaml')) self.assertEqual(chat_db.messages[3].file_path, pathlib.Path(self.db_path.name, f'0004{msg_suffix}'))
# write the messages to the cache directory # write the messages to the cache directory
chat_db.write_cache() chat_db.cache_write()
# check if the written files are in the cache directory # check if the written files are in the cache directory
cache_dir_files = self.message_list(self.cache_path) cache_dir_files = self.message_list(self.cache_path)
self.assertEqual(len(cache_dir_files), 4) self.assertEqual(len(cache_dir_files), 4)
# now rewrite them to the DB dir and check for modified paths # now rewrite them to the DB dir and check for modified paths
chat_db.write_db() chat_db.db_write()
db_dir_files = self.message_list(self.db_path) db_dir_files = self.message_list(self.db_path)
self.assertEqual(len(db_dir_files), 4) self.assertEqual(len(db_dir_files), 4)
self.assertIn(pathlib.Path(self.db_path.name, '0001.txt'), db_dir_files) self.assertIn(pathlib.Path(self.db_path.name, f'0001{msg_suffix}'), db_dir_files)
self.assertIn(pathlib.Path(self.db_path.name, '0002.yaml'), db_dir_files) self.assertIn(pathlib.Path(self.db_path.name, f'0002{msg_suffix}'), db_dir_files)
self.assertIn(pathlib.Path(self.db_path.name, '0003.txt'), db_dir_files) self.assertIn(pathlib.Path(self.db_path.name, f'0003{msg_suffix}'), db_dir_files)
self.assertIn(pathlib.Path(self.db_path.name, '0004.yaml'), db_dir_files) self.assertIn(pathlib.Path(self.db_path.name, f'0004{msg_suffix}'), db_dir_files)
# add a new message with empty file_path # add a new message with empty file_path
message_empty = Message(question=Question("What the hell am I doing here?"), message_empty = Message(question=Question("What the hell am I doing here?"),
@@ -391,11 +496,11 @@ class TestChatDB(unittest.TestCase):
# and one for the cache dir # and one for the cache dir
message_cache = Message(question=Question("What the hell am I doing here?"), message_cache = Message(question=Question("What the hell am I doing here?"),
answer=Answer("You're a creep!"), answer=Answer("You're a creep!"),
file_path=pathlib.Path(self.cache_path.name, '0005.txt')) file_path=pathlib.Path(self.cache_path.name, '0005'))
chat_db.add_messages([message_empty, message_cache]) chat_db.msg_add([message_empty, message_cache])
# clear the cache and check the cache dir # clear the cache and check the cache dir
chat_db.clear_cache() chat_db.cache_clear()
cache_dir_files = self.message_list(self.cache_path) cache_dir_files = self.message_list(self.cache_path)
self.assertEqual(len(cache_dir_files), 0) self.assertEqual(len(cache_dir_files), 0)
# make sure that the DB messages (and the new message) are still there # make sure that the DB messages (and the new message) are still there
@@ -405,7 +510,7 @@ class TestChatDB(unittest.TestCase):
# but not the message with the cache dir path # but not the message with the cache dir path
self.assertFalse(any(m.file_path == message_cache.file_path for m in chat_db.messages)) self.assertFalse(any(m.file_path == message_cache.file_path for m in chat_db.messages))
def test_chat_db_add(self) -> None: def test_add(self) -> None:
# create a new ChatDB instance # create a new ChatDB instance
chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name), chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name),
pathlib.Path(self.db_path.name)) pathlib.Path(self.db_path.name))
@@ -416,7 +521,7 @@ class TestChatDB(unittest.TestCase):
# add new messages to the cache dir # add new messages to the cache dir
message1 = Message(question=Question("Question 1"), message1 = Message(question=Question("Question 1"),
answer=Answer("Answer 1")) answer=Answer("Answer 1"))
chat_db.add_to_cache([message1]) chat_db.cache_add([message1])
# check if the file_path has been correctly set # check if the file_path has been correctly set
self.assertIsNotNone(message1.file_path) self.assertIsNotNone(message1.file_path)
self.assertEqual(message1.file_path.parent, pathlib.Path(self.cache_path.name)) # type: ignore [union-attr] self.assertEqual(message1.file_path.parent, pathlib.Path(self.cache_path.name)) # type: ignore [union-attr]
@@ -426,7 +531,7 @@ class TestChatDB(unittest.TestCase):
# add new messages to the DB dir # add new messages to the DB dir
message2 = Message(question=Question("Question 2"), message2 = Message(question=Question("Question 2"),
answer=Answer("Answer 2")) answer=Answer("Answer 2"))
chat_db.add_to_db([message2]) chat_db.db_add([message2])
# check if the file_path has been correctly set # check if the file_path has been correctly set
self.assertIsNotNone(message2.file_path) self.assertIsNotNone(message2.file_path)
self.assertEqual(message2.file_path.parent, pathlib.Path(self.db_path.name)) # type: ignore [union-attr] self.assertEqual(message2.file_path.parent, pathlib.Path(self.db_path.name)) # type: ignore [union-attr]
@@ -434,9 +539,9 @@ class TestChatDB(unittest.TestCase):
self.assertEqual(len(db_dir_files), 5) self.assertEqual(len(db_dir_files), 5)
with self.assertRaises(ChatError): with self.assertRaises(ChatError):
chat_db.add_to_cache([Message(Question("?"), file_path=pathlib.Path("foo"))]) chat_db.cache_add([Message(Question("?"), file_path=pathlib.Path("foo"))])
def test_chat_db_write_messages(self) -> None: def test_msg_write(self) -> None:
# create a new ChatDB instance # create a new ChatDB instance
chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name), chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name),
pathlib.Path(self.db_path.name)) pathlib.Path(self.db_path.name))
@@ -450,16 +555,16 @@ class TestChatDB(unittest.TestCase):
message = Message(question=Question("Question 1"), message = Message(question=Question("Question 1"),
answer=Answer("Answer 1")) answer=Answer("Answer 1"))
with self.assertRaises(ChatError): with self.assertRaises(ChatError):
chat_db.write_messages([message]) chat_db.msg_write([message])
# write a message with a valid file_path # write a message with a valid file_path
message.file_path = pathlib.Path(self.cache_path.name) / '123456.txt' message.file_path = pathlib.Path(self.cache_path.name) / '123456'
chat_db.write_messages([message]) chat_db.msg_write([message])
cache_dir_files = self.message_list(self.cache_path) cache_dir_files = self.message_list(self.cache_path)
self.assertEqual(len(cache_dir_files), 1) self.assertEqual(len(cache_dir_files), 1)
self.assertIn(pathlib.Path(self.cache_path.name, '123456.txt'), cache_dir_files) self.assertIn(pathlib.Path(self.cache_path.name, f'123456{msg_suffix}'), cache_dir_files)
def test_chat_db_update_messages(self) -> None: def test_msg_update(self) -> None:
# create a new ChatDB instance # create a new ChatDB instance
chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name), chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name),
pathlib.Path(self.db_path.name)) pathlib.Path(self.db_path.name))
@@ -472,17 +577,111 @@ class TestChatDB(unittest.TestCase):
message = chat_db.messages[0] message = chat_db.messages[0]
message.answer = Answer("New answer") message.answer = Answer("New answer")
# update message without writing # update message without writing
chat_db.update_messages([message], write=False) chat_db.msg_update([message], write=False)
self.assertEqual(chat_db.messages[0].answer, Answer("New answer")) self.assertEqual(chat_db.messages[0].answer, Answer("New answer"))
# re-read the message and check for old content # re-read the message and check for old content
chat_db.read_db() chat_db.db_read()
self.assertEqual(chat_db.messages[0].answer, Answer("Answer 1")) self.assertEqual(chat_db.messages[0].answer, Answer("Answer 1"))
# now check with writing (message should be overwritten) # now check with writing (message should be overwritten)
chat_db.update_messages([message], write=True) chat_db.msg_update([message], write=True)
chat_db.read_db() chat_db.db_read()
self.assertEqual(chat_db.messages[0].answer, Answer("New answer")) self.assertEqual(chat_db.messages[0].answer, Answer("New answer"))
# test without file_path -> expect error # test without file_path -> expect error
message1 = Message(question=Question("Question 1"), message1 = Message(question=Question("Question 1"),
answer=Answer("Answer 1")) answer=Answer("Answer 1"))
with self.assertRaises(ChatError): with self.assertRaises(ChatError):
chat_db.update_messages([message1]) chat_db.msg_update([message1])
def test_msg_find(self) -> None:
chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name),
pathlib.Path(self.db_path.name))
# search for a DB file in memory
self.assertEqual(chat_db.msg_find([str(self.message1.file_path)], loc=msg_location.MEM), [self.message1])
self.assertEqual(chat_db.msg_find([self.message1.msg_id()], loc=msg_location.MEM), [self.message1])
self.assertEqual(chat_db.msg_find(['0001.msg'], loc=msg_location.MEM), [self.message1])
self.assertEqual(chat_db.msg_find(['0001'], loc=msg_location.MEM), [self.message1])
# and on disk
self.assertEqual(chat_db.msg_find([str(self.message2.file_path)], loc=msg_location.DB), [self.message2])
self.assertEqual(chat_db.msg_find([self.message2.msg_id()], loc=msg_location.DB), [self.message2])
self.assertEqual(chat_db.msg_find(['0002.msg'], loc=msg_location.DB), [self.message2])
self.assertEqual(chat_db.msg_find(['0002'], loc=msg_location.DB), [self.message2])
# now search the cache -> expect empty result
self.assertEqual(chat_db.msg_find([str(self.message3.file_path)], loc=msg_location.CACHE), [])
self.assertEqual(chat_db.msg_find([self.message3.msg_id()], loc=msg_location.CACHE), [])
self.assertEqual(chat_db.msg_find(['0003.msg'], loc=msg_location.CACHE), [])
self.assertEqual(chat_db.msg_find(['0003'], loc=msg_location.CACHE), [])
# search for multiple messages
# -> search one twice, expect result to be unique
search_names = ['0001', '0002.msg', self.message3.msg_id(), str(self.message3.file_path)]
expected_result = [self.message1, self.message2, self.message3]
result = chat_db.msg_find(search_names, loc=msg_location.ALL)
self.assert_messages_equal(result, expected_result)
def test_msg_latest(self) -> None:
chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name),
pathlib.Path(self.db_path.name))
self.assertEqual(chat_db.msg_latest(loc=msg_location.MEM), self.message4)
self.assertEqual(chat_db.msg_latest(loc=msg_location.DB), self.message4)
self.assertEqual(chat_db.msg_latest(loc=msg_location.DISK), self.message4)
self.assertEqual(chat_db.msg_latest(loc=msg_location.ALL), self.message4)
# the cache is currently empty:
self.assertIsNone(chat_db.msg_latest(loc=msg_location.CACHE))
# add new messages to the cache dir
new_message = Message(question=Question("New Question"),
answer=Answer("New Answer"))
chat_db.cache_add([new_message])
self.assertEqual(chat_db.msg_latest(loc=msg_location.CACHE), new_message)
self.assertEqual(chat_db.msg_latest(loc=msg_location.MEM), new_message)
self.assertEqual(chat_db.msg_latest(loc=msg_location.DISK), new_message)
self.assertEqual(chat_db.msg_latest(loc=msg_location.ALL), new_message)
# the DB does not contain the new message
self.assertEqual(chat_db.msg_latest(loc=msg_location.DB), self.message4)
def test_msg_gather(self) -> None:
chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name),
pathlib.Path(self.db_path.name))
all_messages = [self.message1, self.message2, self.message3, self.message4]
self.assert_messages_equal(chat_db.msg_gather(loc=msg_location.ALL), all_messages)
self.assert_messages_equal(chat_db.msg_gather(loc=msg_location.DB), all_messages)
self.assert_messages_equal(chat_db.msg_gather(loc=msg_location.MEM), all_messages)
self.assert_messages_equal(chat_db.msg_gather(loc=msg_location.DISK), all_messages)
self.assert_messages_equal(chat_db.msg_gather(loc=msg_location.CACHE), [])
# add a new message, but only to the internal list
new_message = Message(Question("What?"))
all_messages_mem = all_messages + [new_message]
chat_db.msg_add([new_message])
self.assert_messages_equal(chat_db.msg_gather(loc=msg_location.MEM), all_messages_mem)
self.assert_messages_equal(chat_db.msg_gather(loc=msg_location.ALL), all_messages_mem)
# the nr. of messages on disk did not change -> expect old result
self.assert_messages_equal(chat_db.msg_gather(loc=msg_location.DB), all_messages)
self.assert_messages_equal(chat_db.msg_gather(loc=msg_location.DISK), all_messages)
self.assert_messages_equal(chat_db.msg_gather(loc=msg_location.CACHE), [])
# test with MessageFilter
self.assert_messages_equal(chat_db.msg_gather(loc=msg_location.ALL, mfilter=MessageFilter(tags_or={Tag('tag1')})),
[self.message1])
self.assert_messages_equal(chat_db.msg_gather(loc=msg_location.DISK, mfilter=MessageFilter(tags_or={Tag('tag2')})),
[self.message2])
self.assert_messages_equal(chat_db.msg_gather(loc=msg_location.CACHE, mfilter=MessageFilter(tags_or={Tag('tag3')})),
[])
self.assert_messages_equal(chat_db.msg_gather(loc=msg_location.MEM, mfilter=MessageFilter(question_contains="What")),
[new_message])
def test_msg_move_and_gather(self) -> None:
chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name),
pathlib.Path(self.db_path.name))
all_messages = [self.message1, self.message2, self.message3, self.message4]
self.assert_messages_equal(chat_db.msg_gather(loc=msg_location.DB), all_messages)
self.assert_messages_equal(chat_db.msg_gather(loc=msg_location.CACHE), [])
# move first message to the cache
chat_db.cache_move(self.message1)
self.assert_messages_equal(chat_db.msg_gather(loc=msg_location.CACHE), [self.message1])
self.assertEqual(self.message1.file_path.parent, pathlib.Path(self.cache_path.name)) # type: ignore [union-attr]
self.assert_messages_equal(chat_db.msg_gather(loc=msg_location.DB), [self.message2, self.message3, self.message4])
self.assert_messages_equal(chat_db.msg_gather(loc=msg_location.ALL), all_messages)
self.assert_messages_equal(chat_db.msg_gather(loc=msg_location.DISK), all_messages)
self.assert_messages_equal(chat_db.msg_gather(loc=msg_location.MEM), all_messages)
# now move first message back to the DB
chat_db.db_move(self.message1)
self.assert_messages_equal(chat_db.msg_gather(loc=msg_location.CACHE), [])
self.assertEqual(self.message1.file_path.parent, pathlib.Path(self.db_path.name)) # type: ignore [union-attr]
self.assert_messages_equal(chat_db.msg_gather(loc=msg_location.DB), all_messages)
+100
View File
@@ -0,0 +1,100 @@
import unittest
import argparse
from typing import Union, Optional
from chatmastermind.configuration import Config, AIConfig
from chatmastermind.tags import Tag
from chatmastermind.message import Message, Answer
from chatmastermind.chat import Chat
from chatmastermind.ai import AI, AIResponse, Tokens, AIError
class FakeAI(AI):
"""
A mocked version of the 'AI' class.
"""
ID: str
name: str
config: AIConfig
def models(self) -> list[str]:
raise NotImplementedError
def tokens(self, data: Union[Message, Chat]) -> int:
return 123
def print(self) -> None:
pass
def print_models(self) -> None:
pass
def __init__(self, ID: str, model: str, error: bool = False):
self.ID = ID
self.model = model
self.error = error
def request(self,
question: Message,
chat: Chat,
num_answers: int = 1,
otags: Optional[set[Tag]] = None) -> AIResponse:
"""
Mock the 'ai.request()' function by either returning fake
answers or raising an exception.
"""
if self.error:
raise AIError
question.answer = Answer("Answer 0")
question.tags = set(otags) if otags is not None else None
question.ai = self.ID
question.model = self.model
answers: list[Message] = [question]
for n in range(1, num_answers):
answers.append(Message(question=question.question,
answer=Answer(f"Answer {n}"),
tags=otags,
ai=self.ID,
model=self.model))
return AIResponse(answers, Tokens(10, 10, 20))
class TestWithFakeAI(unittest.TestCase):
"""
Base class for all tests that need to use the FakeAI.
"""
def assert_msgs_equal_except_file_path(self, msg1: list[Message], msg2: list[Message]) -> None:
"""
Compare messages using Question, Answer and all metadata excecot for the file_path.
"""
self.assertEqual(len(msg1), len(msg2))
for m1, m2 in zip(msg1, msg2):
# exclude the file_path, compare only Q, A and metadata
self.assertTrue(m1.equals(m2, file_path=False, verbose=True))
def assert_msgs_all_equal(self, msg1: list[Message], msg2: list[Message]) -> None:
"""
Compare messages using Question, Answer and ALL metadata.
"""
self.assertEqual(len(msg1), len(msg2))
for m1, m2 in zip(msg1, msg2):
self.assertTrue(m1.equals(m2, verbose=True))
def assert_msgs_content_equal(self, msg1: list[Message], msg2: list[Message]) -> None:
"""
Compare messages using only Question and Answer.
"""
self.assertEqual(len(msg1), len(msg2))
for m1, m2 in zip(msg1, msg2):
self.assertEqual(m1, m2)
def mock_create_ai(self, args: argparse.Namespace, config: Config) -> AI:
"""
Mocked 'create_ai' that returns a 'FakeAI' instance.
"""
return FakeAI(args.AI, args.model)
def mock_create_ai_with_error(self, args: argparse.Namespace, config: Config) -> AI:
"""
Mocked 'create_ai' that returns a 'FakeAI' instance.
"""
return FakeAI(args.AI, args.model, error=True)
+13 -1
View File
@@ -57,6 +57,7 @@ class TestConfig(unittest.TestCase):
def test_from_dict_should_create_config_from_dict(self) -> None: def test_from_dict_should_create_config_from_dict(self) -> None:
source_dict = { source_dict = {
'cache': '.',
'db': './test_db/', 'db': './test_db/',
'ais': { 'ais': {
'myopenai': { 'myopenai': {
@@ -70,10 +71,13 @@ class TestConfig(unittest.TestCase):
'frequency_penalty': 0.7, 'frequency_penalty': 0.7,
'presence_penalty': 0.2 'presence_penalty': 0.2
} }
} },
'glossaries': './glossaries/'
} }
config = Config.from_dict(source_dict) config = Config.from_dict(source_dict)
self.assertEqual(config.cache, '.')
self.assertEqual(config.db, './test_db/') self.assertEqual(config.db, './test_db/')
self.assertEqual(config.glossaries, './glossaries/')
self.assertEqual(len(config.ais), 1) self.assertEqual(len(config.ais), 1)
self.assertEqual(config.ais['myopenai'].name, 'openai') self.assertEqual(config.ais['myopenai'].name, 'openai')
self.assertEqual(cast(OpenAIConfig, config.ais['myopenai']).system, 'Custom system') self.assertEqual(cast(OpenAIConfig, config.ais['myopenai']).system, 'Custom system')
@@ -89,6 +93,7 @@ class TestConfig(unittest.TestCase):
def test_from_file_should_load_config_from_file(self) -> None: def test_from_file_should_load_config_from_file(self) -> None:
source_dict = { source_dict = {
'cache': './test_cache/',
'db': './test_db/', 'db': './test_db/',
'ais': { 'ais': {
'default': { 'default': {
@@ -102,19 +107,24 @@ class TestConfig(unittest.TestCase):
'frequency_penalty': 0.7, 'frequency_penalty': 0.7,
'presence_penalty': 0.2 'presence_penalty': 0.2
} }
# omit glossaries, since it's optional
} }
} }
with open(self.test_file.name, 'w') as f: with open(self.test_file.name, 'w') as f:
yaml.dump(source_dict, f) yaml.dump(source_dict, f)
config = Config.from_file(self.test_file.name) config = Config.from_file(self.test_file.name)
self.assertIsInstance(config, Config) self.assertIsInstance(config, Config)
self.assertEqual(config.cache, './test_cache/')
self.assertEqual(config.db, './test_db/') self.assertEqual(config.db, './test_db/')
# missing 'glossaries' should result in 'None'
self.assertEqual(config.glossaries, None)
self.assertEqual(len(config.ais), 1) self.assertEqual(len(config.ais), 1)
self.assertIsInstance(config.ais['default'], AIConfig) self.assertIsInstance(config.ais['default'], AIConfig)
self.assertEqual(cast(OpenAIConfig, config.ais['default']).system, 'Custom system') self.assertEqual(cast(OpenAIConfig, config.ais['default']).system, 'Custom system')
def test_to_file_should_save_config_to_file(self) -> None: def test_to_file_should_save_config_to_file(self) -> None:
config = Config( config = Config(
cache='./test_cache/',
db='./test_db/', db='./test_db/',
ais={ ais={
'myopenai': OpenAIConfig( 'myopenai': OpenAIConfig(
@@ -133,12 +143,14 @@ class TestConfig(unittest.TestCase):
config.to_file(Path(self.test_file.name)) config.to_file(Path(self.test_file.name))
with open(self.test_file.name, 'r') as f: with open(self.test_file.name, 'r') as f:
saved_config = yaml.load(f, Loader=yaml.FullLoader) saved_config = yaml.load(f, Loader=yaml.FullLoader)
self.assertEqual(saved_config['cache'], './test_cache/')
self.assertEqual(saved_config['db'], './test_db/') self.assertEqual(saved_config['db'], './test_db/')
self.assertEqual(len(saved_config['ais']), 1) self.assertEqual(len(saved_config['ais']), 1)
self.assertEqual(saved_config['ais']['myopenai']['system'], 'Custom system') self.assertEqual(saved_config['ais']['myopenai']['system'], 'Custom system')
def test_from_file_error_unknown_ai(self) -> None: def test_from_file_error_unknown_ai(self) -> None:
source_dict = { source_dict = {
'cache': './test_cache/',
'db': './test_db/', 'db': './test_db/',
'ais': { 'ais': {
'default': { 'default': {
+209
View File
@@ -0,0 +1,209 @@
import unittest
import tempfile
from pathlib import Path
from chatmastermind.glossary import Glossary, GlossaryError
glossary_suffix: str = Glossary.file_suffix
class TestGlossary(unittest.TestCase):
def test_from_file_yaml_unquoted(self) -> None:
"""
Test glossary creatiom from YAML with unquoted entries.
"""
with tempfile.NamedTemporaryFile('w', delete=False, suffix=glossary_suffix) as yaml_file:
yaml_file.write("Name: Sample\n"
"Description: A brief description\n"
"ID: '123'\n"
"SourceLang: en\n"
"TargetLang: es\n"
"Entries:\n"
" hello: hola\n"
" goodbye: adiós\n"
# 'yes' is a YAML keyword and would normally be quoted
" yes: sí\n"
" I'm going home: me voy a casa\n")
yaml_file_path = Path(yaml_file.name)
# create and check valid glossary
glossary = Glossary.from_file(yaml_file_path)
self.assertEqual(glossary.name, "Sample")
self.assertEqual(glossary.desc, "A brief description")
self.assertEqual(glossary.ID, "123")
self.assertEqual(glossary.source_lang, "en")
self.assertEqual(glossary.target_lang, "es")
self.assertEqual(glossary.entries, {"hello": "hola",
"goodbye": "adiós",
"yes": "",
"I'm going home": "me voy a casa"})
yaml_file_path.unlink() # Remove the temporary file
def test_from_file_yaml_quoted(self) -> None:
"""
Test glossary creatiom from YAML with quoted entries.
"""
with tempfile.NamedTemporaryFile('w', delete=False, suffix=glossary_suffix) as yaml_file:
yaml_file.write("Name: Sample\n"
"Description: A brief description\n"
"ID: '123'\n"
"SourceLang: en\n"
"TargetLang: es\n"
"Entries:\n"
" 'hello': 'hola'\n"
" 'goodbye': 'adiós'\n"
" 'yes': ''\n"
" \"I'm going home\": 'me voy a casa'\n")
yaml_file_path = Path(yaml_file.name)
# create and check valid glossary
glossary = Glossary.from_file(yaml_file_path)
self.assertEqual(glossary.name, "Sample")
self.assertEqual(glossary.desc, "A brief description")
self.assertEqual(glossary.ID, "123")
self.assertEqual(glossary.source_lang, "en")
self.assertEqual(glossary.target_lang, "es")
self.assertEqual(glossary.entries, {"hello": "hola",
"goodbye": "adiós",
"yes": "",
"I'm going home": "me voy a casa"})
yaml_file_path.unlink() # Remove the temporary file
def test_to_file_writes_yaml(self) -> None:
# Create glossary instance
glossary = Glossary(name="Test",
desc="Test description",
ID="666",
source_lang="en",
target_lang="fr",
entries={"yes": "oui"})
with tempfile.NamedTemporaryFile('w', suffix=glossary_suffix) as tmp_file:
file_path = Path(tmp_file.name)
glossary.to_file(file_path)
# read and check valid YAML
with open(file_path, 'r') as file:
content = file.read()
self.assertIn("Name: Test", content)
self.assertIn("Description: Test description", content)
self.assertIn("ID: '666'", content)
self.assertIn("SourceLang: en", content)
self.assertIn("TargetLang: fr", content)
self.assertIn("Entries", content)
# 'yes' is a YAML keyword and therefore quoted
self.assertIn("'yes': oui", content)
def test_write_read_glossary(self) -> None:
# Create glossary instance
# -> use 'yes' in order to test if the YAML quoting is correctly removed when reading the file
glossary_write = Glossary(name="Test", source_lang="en", target_lang="fr", entries={"yes": "oui"})
with tempfile.NamedTemporaryFile('w', suffix=glossary_suffix) as tmp_file:
file_path = Path(tmp_file.name)
glossary_write.to_file(file_path)
# create new instance from glossary file
glossary_read = Glossary.from_file(file_path)
self.assertEqual(glossary_write.name, glossary_read.name)
self.assertEqual(glossary_write.source_lang, glossary_read.source_lang)
self.assertEqual(glossary_write.target_lang, glossary_read.target_lang)
self.assertDictEqual(glossary_write.entries, glossary_read.entries)
def test_import_export_csv(self) -> None:
glossary = Glossary(name="Test", source_lang="en", target_lang="fr", entries={})
# First export to CSV
with tempfile.NamedTemporaryFile('w', suffix=glossary_suffix) as csvfile:
csv_file_path = Path(csvfile.name)
glossary.entries = {"hello": "salut", "goodbye": "au revoir"}
glossary.export_csv(glossary.entries, csv_file_path)
# Now import CSV
glossary.import_csv(csv_file_path)
self.assertEqual(glossary.entries, {"hello": "salut", "goodbye": "au revoir"})
def test_import_export_tsv(self) -> None:
glossary = Glossary(name="Test", source_lang="en", target_lang="fr", entries={})
# First export to TSV
with tempfile.NamedTemporaryFile('w', suffix=glossary_suffix) as tsvfile:
tsv_file_path = Path(tsvfile.name)
glossary.entries = {"hello": "salut", "goodbye": "au revoir"}
glossary.export_tsv(glossary.entries, tsv_file_path)
# Now import TSV
glossary.import_tsv(tsv_file_path)
self.assertEqual(glossary.entries, {"hello": "salut", "goodbye": "au revoir"})
def test_to_file_wrong_suffix(self) -> None:
"""
Test for exception if suffix is wrong.
"""
glossary = Glossary(name="Test", source_lang="en", target_lang="fr", entries={"yes": "oui"})
with tempfile.NamedTemporaryFile('w', suffix='.wrong') as tmp_file:
file_path = Path(tmp_file.name)
with self.assertRaises(GlossaryError) as err:
glossary.to_file(file_path)
self.assertEqual(str(err.exception), "File suffix '.wrong' is not supported")
def test_to_file_auto_suffix(self) -> None:
"""
Test if suffix is auto-generated if omitted.
"""
glossary = Glossary(name="Test", source_lang="en", target_lang="fr", entries={"yes": "oui"})
with tempfile.NamedTemporaryFile('w', suffix='') as tmp_file:
file_path = Path(tmp_file.name)
glossary.to_file(file_path)
assert glossary.file_path is not None
self.assertEqual(glossary.file_path.suffix, glossary_suffix)
# remove glossary file (differs from 'tmp_file' because of the added suffix
glossary.file_path.unlink()
def test_to_str_with_id(self) -> None:
# Create a Glossary instance with an ID
glossary_with_id = Glossary(name="TestGlossary", source_lang="en", target_lang="fr",
desc="A simple test glossary", ID="1001", entries={"one": "un"})
glossary_str = glossary_with_id.to_str()
self.assertIn("TestGlossary (ID: 1001):", glossary_str)
self.assertIn("- A simple test glossary", glossary_str)
self.assertIn("- Languages: en -> fr", glossary_str)
self.assertIn("- Entries: 1", glossary_str)
def test_to_str_with_id_and_entries(self) -> None:
# Create a Glossary instance with an ID and include entries
glossary_with_entries = Glossary(name="TestGlossaryWithEntries", source_lang="en", target_lang="fr",
desc="Another test glossary", ID="2002",
entries={"hello": "salut", "goodbye": "au revoir"})
glossary_str_with_entries = glossary_with_entries.to_str(with_entries=True)
self.assertIn("TestGlossaryWithEntries (ID: 2002):", glossary_str_with_entries)
self.assertIn("- Entries:", glossary_str_with_entries)
self.assertIn(" hello : salut", glossary_str_with_entries)
self.assertIn(" goodbye : au revoir", glossary_str_with_entries)
def test_to_str_without_id(self) -> None:
# Create a Glossary instance without an ID
glossary_without_id = Glossary(name="TestGlossaryNoID", source_lang="en", target_lang="fr",
desc="A test glossary without an ID", ID=None, entries={"yes": "oui"})
glossary_str_no_id = glossary_without_id.to_str()
self.assertIn("TestGlossaryNoID (ID: None):", glossary_str_no_id)
self.assertIn("- A test glossary without an ID", glossary_str_no_id)
self.assertIn("- Languages: en -> fr", glossary_str_no_id)
self.assertIn("- Entries: 1", glossary_str_no_id)
def test_to_str_without_id_and_no_entries(self) -> None:
# Create a Glossary instance without an ID and no entries
glossary_no_id_no_entries = Glossary(name="EmptyGlossary", source_lang="en", target_lang="fr",
desc="An empty test glossary", ID=None, entries={})
glossary_str_no_id_no_entries = glossary_no_id_no_entries.to_str()
self.assertIn("EmptyGlossary (ID: None):", glossary_str_no_id_no_entries)
self.assertIn("- An empty test glossary", glossary_str_no_id_no_entries)
self.assertIn("- Languages: en -> fr", glossary_str_no_id_no_entries)
self.assertIn("- Entries: 0", glossary_str_no_id_no_entries)
def test_to_str_no_description(self) -> None:
# Create a Glossary instance with an ID
glossary_with_id = Glossary(name="TestGlossary", source_lang="en", target_lang="fr",
ID="1001", entries={"one": "un"})
glossary_str = glossary_with_id.to_str()
expected_str = """TestGlossary (ID: 1001):
- Languages: en -> fr
- Entries: 1"""
self.assertEqual(expected_str, glossary_str)
+149
View File
@@ -0,0 +1,149 @@
import unittest
import argparse
import tempfile
import io
from contextlib import redirect_stdout
from chatmastermind.configuration import Config
from chatmastermind.commands.glossary import (
Glossary,
GlossaryCmdError,
glossary_cmd,
get_glossary_file_path,
create_glossary,
print_glossary,
list_glossaries
)
class TestGlossaryCmdNoGlossaries(unittest.TestCase):
def setUp(self) -> None:
# create DB and cache
self.db_dir = tempfile.TemporaryDirectory()
self.cache_dir = tempfile.TemporaryDirectory()
self.glossaries_dir = tempfile.TemporaryDirectory()
# create configuration
self.config = Config()
self.config.cache = self.cache_dir.name
self.config.db = self.db_dir.name
self.config.glossaries = self.glossaries_dir.name
# create a mock argparse.Namespace
self.args = argparse.Namespace(
create=True,
list=False,
print=False,
name='new_glossary',
file=None,
source_lang='en',
target_lang='de',
description=False,
)
def test_glossary_create_no_glossaries_err(self) -> None:
self.config.glossaries = None
with self.assertRaises(GlossaryCmdError) as err:
create_glossary(self.args, self.config)
self.assertIn(str(err.exception).lower(), "glossaries directory missing")
def test_glossary_create_no_name_err(self) -> None:
self.args.name = None
with self.assertRaises(GlossaryCmdError) as err:
create_glossary(self.args, self.config)
self.assertIn(str(err.exception).lower(), "missing glossary name")
def test_glossary_create_no_source_lang_err(self) -> None:
self.args.source_lang = None
with self.assertRaises(GlossaryCmdError) as err:
create_glossary(self.args, self.config)
self.assertIn(str(err.exception).lower(), "missing source language")
def test_glossary_create_no_target_lang_err(self) -> None:
self.args.target_lang = None
with self.assertRaises(GlossaryCmdError) as err:
create_glossary(self.args, self.config)
self.assertIn(str(err.exception).lower(), "missing target language")
def test_glossary_print_no_name_err(self) -> None:
self.args.name = None
with self.assertRaises(GlossaryCmdError) as err:
print_glossary(self.args, self.config)
self.assertIn(str(err.exception).lower(), "missing glossary name")
def test_glossary_list_no_glossaries_err(self) -> None:
self.config.glossaries = None
with self.assertRaises(GlossaryCmdError) as err:
list_glossaries(self.args, self.config)
self.assertIn(str(err.exception).lower(), "glossaries directory missing")
def test_glossary_create(self) -> None:
self.args.create = True
self.args.list = False
self.args.print = False
glossary_cmd(self.args, self.config)
expected_path = get_glossary_file_path(self.args.name, self.config)
glo = Glossary.from_file(expected_path)
self.assertEqual(glo.name, self.args.name)
expected_path.unlink()
def test_glossary_create_twice_err(self) -> None:
self.args.create = True
self.args.list = False
self.args.print = False
glossary_cmd(self.args, self.config)
expected_path = get_glossary_file_path(self.args.name, self.config)
glo = Glossary.from_file(expected_path)
self.assertEqual(glo.name, self.args.name)
# create glossary with the same name again
with self.assertRaises(GlossaryCmdError) as err:
create_glossary(self.args, self.config)
self.assertIn(str(err.exception).lower(), "already exists")
expected_path.unlink()
class TestGlossaryCmdWithGlossaries(unittest.TestCase):
def setUp(self) -> None:
# create DB and cache
self.db_dir = tempfile.TemporaryDirectory()
self.cache_dir = tempfile.TemporaryDirectory()
self.glossaries_dir = tempfile.TemporaryDirectory()
# create configuration
self.config = Config()
self.config.cache = self.cache_dir.name
self.config.db = self.db_dir.name
self.config.glossaries = self.glossaries_dir.name
# create a mock argparse.Namespace
self.args = argparse.Namespace(
create=True,
list=False,
print=False,
name='Glossary1',
file=None,
source_lang='en',
target_lang='de',
description=False,
)
# create Glossary1
glossary_cmd(self.args, self.config)
self.Glossary1_path = get_glossary_file_path('Glossary1', self.config)
# create Glossary2
self.args.name = 'Glossary2'
glossary_cmd(self.args, self.config)
self.Glossary2_path = get_glossary_file_path('Glossary2', self.config)
def test_glossaries_exist(self) -> None:
"""
Test if the default glossaries created in setUp exist.
"""
glo = Glossary.from_file(self.Glossary1_path)
self.assertEqual(glo.name, 'Glossary1')
glo = Glossary.from_file(self.Glossary2_path)
self.assertEqual(glo.name, 'Glossary2')
def test_glossaries_list(self) -> None:
self.args.create = False
self.args.list = True
with redirect_stdout(io.StringIO()) as list_output:
glossary_cmd(self.args, self.config)
self.assertIn('Glossary1', list_output.getvalue())
self.assertIn('Glossary2', list_output.getvalue())
+62
View File
@@ -0,0 +1,62 @@
import unittest
import argparse
import tempfile
import yaml
from pathlib import Path
from chatmastermind.message import Message, Question
from chatmastermind.chat import ChatDB, ChatError, msg_location
from chatmastermind.configuration import Config
from chatmastermind.commands.hist import convert_messages
msg_suffix = Message.file_suffix_write
class TestConvertMessages(unittest.TestCase):
def setUp(self) -> None:
self.db_dir = tempfile.TemporaryDirectory()
self.cache_dir = tempfile.TemporaryDirectory()
self.db_path = Path(self.db_dir.name)
self.cache_path = Path(self.cache_dir.name)
self.args = argparse.Namespace()
self.config = Config()
self.config.cache = self.cache_dir.name
self.config.db = self.db_dir.name
# Prepare some messages
self.chat = ChatDB.from_dir(Path(self.cache_path),
Path(self.db_path))
self.messages = [Message(Question(f'Question {i}')) for i in range(0, 6)]
self.chat.db_write(self.messages[0:2])
self.chat.cache_write(self.messages[2:])
# Change some of the suffixes
assert self.messages[0].file_path
assert self.messages[1].file_path
self.messages[0].file_path.rename(self.messages[0].file_path.with_suffix('.txt'))
self.messages[1].file_path.rename(self.messages[1].file_path.with_suffix('.yaml'))
def tearDown(self) -> None:
self.db_dir.cleanup()
self.cache_dir.cleanup()
def test_convert_messages(self) -> None:
self.args.convert = 'yaml'
convert_messages(self.args, self.config)
msgs = self.chat.msg_gather(loc=msg_location.DISK, glob='*.*')
# Check if the number of messages is the same as before
self.assertEqual(len(msgs), len(self.messages))
# Check if all messages have the requested suffix
for msg in msgs:
assert msg.file_path
self.assertEqual(msg.file_path.suffix, msg_suffix)
# Check if the message IDs are correctly maintained
for m_new, m_old in zip(msgs, self.messages):
self.assertEqual(m_new.msg_id(), m_old.msg_id())
# check if all messages have the new format
for m in msgs:
with open(str(m.file_path), "r") as fd:
yaml.load(fd, Loader=yaml.FullLoader)
def test_convert_messages_wrong_format(self) -> None:
self.args.convert = 'foo'
with self.assertRaises(ChatError):
convert_messages(self.args, self.config)
+104 -41
View File
@@ -1,11 +1,16 @@
import unittest import unittest
import pathlib import pathlib
import tempfile import tempfile
import itertools
from typing import cast from typing import cast
from chatmastermind.message import source_code, Message, MessageError, Question, Answer, AILine, ModelLine, MessageFilter, message_in from chatmastermind.message import source_code, Message, MessageError, Question, Answer, AILine, ModelLine,\
MessageFilter, message_in, message_valid_formats
from chatmastermind.tags import Tag, TagLine from chatmastermind.tags import Tag, TagLine
msg_suffix: str = Message.file_suffix_write
class SourceCodeTestCase(unittest.TestCase): class SourceCodeTestCase(unittest.TestCase):
def test_source_code_with_include_delims(self) -> None: def test_source_code_with_include_delims(self) -> None:
text = """ text = """
@@ -86,7 +91,7 @@ class QuestionTestCase(unittest.TestCase):
class AnswerTestCase(unittest.TestCase): class AnswerTestCase(unittest.TestCase):
def test_answer_with_header(self) -> None: def test_answer_with_header(self) -> None:
with self.assertRaises(MessageError): with self.assertRaises(MessageError):
Answer(f"{Answer.txt_header}\nno") str(Answer(f"{Answer.txt_header}\nno"))
def test_answer_with_legal_header(self) -> None: def test_answer_with_legal_header(self) -> None:
answer = Answer(f"This is a line contaning '{Answer.txt_header}'\nIt is what it is.") answer = Answer(f"This is a line contaning '{Answer.txt_header}'\nIt is what it is.")
@@ -101,7 +106,7 @@ class AnswerTestCase(unittest.TestCase):
class MessageToFileTxtTestCase(unittest.TestCase): class MessageToFileTxtTestCase(unittest.TestCase):
def setUp(self) -> None: def setUp(self) -> None:
self.file = tempfile.NamedTemporaryFile(delete=False, suffix='.txt') self.file = tempfile.NamedTemporaryFile(delete=False, suffix=msg_suffix)
self.file_path = pathlib.Path(self.file.name) self.file_path = pathlib.Path(self.file.name)
self.message_complete = Message(Question('This is a question.'), self.message_complete = Message(Question('This is a question.'),
Answer('This is an answer.'), Answer('This is an answer.'),
@@ -117,7 +122,7 @@ class MessageToFileTxtTestCase(unittest.TestCase):
self.file_path.unlink() self.file_path.unlink()
def test_to_file_txt_complete(self) -> None: def test_to_file_txt_complete(self) -> None:
self.message_complete.to_file(self.file_path) self.message_complete.to_file(self.file_path, mformat='txt')
with open(self.file_path, "r") as fd: with open(self.file_path, "r") as fd:
content = fd.read() content = fd.read()
@@ -132,7 +137,7 @@ This is an answer.
self.assertEqual(content, expected_content) self.assertEqual(content, expected_content)
def test_to_file_txt_min(self) -> None: def test_to_file_txt_min(self) -> None:
self.message_min.to_file(self.file_path) self.message_min.to_file(self.file_path, mformat='txt')
with open(self.file_path, "r") as fd: with open(self.file_path, "r") as fd:
content = fd.read() content = fd.read()
@@ -141,11 +146,17 @@ This is a question.
""" """
self.assertEqual(content, expected_content) self.assertEqual(content, expected_content)
def test_to_file_unsupported_file_type(self) -> None: def test_to_file_unsupported_file_suffix(self) -> None:
unsupported_file_path = pathlib.Path("example.doc") unsupported_file_path = pathlib.Path("example.doc")
with self.assertRaises(MessageError) as cm: with self.assertRaises(MessageError) as cm:
self.message_complete.to_file(unsupported_file_path) self.message_complete.to_file(unsupported_file_path)
self.assertEqual(str(cm.exception), "File type '.doc' is not supported") self.assertEqual(str(cm.exception), "File suffix '.doc' is not supported")
def test_to_file_unsupported_file_format(self) -> None:
unsupported_file_format = pathlib.Path(f"example{msg_suffix}")
with self.assertRaises(MessageError) as cm:
self.message_complete.to_file(unsupported_file_format, mformat='doc') # type: ignore [arg-type]
self.assertEqual(str(cm.exception), "File format 'doc' is not supported")
def test_to_file_no_file_path(self) -> None: def test_to_file_no_file_path(self) -> None:
""" """
@@ -159,10 +170,24 @@ This is a question.
# reset the internal file_path # reset the internal file_path
self.message_complete.file_path = self.file_path self.message_complete.file_path = self.file_path
def test_to_file_txt_auto_suffix(self) -> None:
"""
Test if suffix is auto-generated if omitted.
"""
file_path_no_suffix = self.file_path.with_suffix('')
# test with file_path member
self.message_min.file_path = file_path_no_suffix
self.message_min.to_file(mformat='txt')
self.assertEqual(self.message_min.file_path.suffix, msg_suffix)
# test with explicit file_path
self.message_min.file_path = file_path_no_suffix
self.message_min.to_file(file_path=file_path_no_suffix, mformat='txt')
self.assertEqual(self.message_min.file_path.suffix, msg_suffix)
class MessageToFileYamlTestCase(unittest.TestCase): class MessageToFileYamlTestCase(unittest.TestCase):
def setUp(self) -> None: def setUp(self) -> None:
self.file = tempfile.NamedTemporaryFile(delete=False, suffix='.yaml') self.file = tempfile.NamedTemporaryFile(delete=False, suffix=msg_suffix)
self.file_path = pathlib.Path(self.file.name) self.file_path = pathlib.Path(self.file.name)
self.message_complete = Message(Question('This is a question.'), self.message_complete = Message(Question('This is a question.'),
Answer('This is an answer.'), Answer('This is an answer.'),
@@ -184,7 +209,7 @@ class MessageToFileYamlTestCase(unittest.TestCase):
self.file_path.unlink() self.file_path.unlink()
def test_to_file_yaml_complete(self) -> None: def test_to_file_yaml_complete(self) -> None:
self.message_complete.to_file(self.file_path) self.message_complete.to_file(self.file_path, mformat='yaml')
with open(self.file_path, "r") as fd: with open(self.file_path, "r") as fd:
content = fd.read() content = fd.read()
@@ -199,7 +224,7 @@ class MessageToFileYamlTestCase(unittest.TestCase):
self.assertEqual(content, expected_content) self.assertEqual(content, expected_content)
def test_to_file_yaml_multiline(self) -> None: def test_to_file_yaml_multiline(self) -> None:
self.message_multiline.to_file(self.file_path) self.message_multiline.to_file(self.file_path, mformat='yaml')
with open(self.file_path, "r") as fd: with open(self.file_path, "r") as fd:
content = fd.read() content = fd.read()
@@ -218,17 +243,31 @@ class MessageToFileYamlTestCase(unittest.TestCase):
self.assertEqual(content, expected_content) self.assertEqual(content, expected_content)
def test_to_file_yaml_min(self) -> None: def test_to_file_yaml_min(self) -> None:
self.message_min.to_file(self.file_path) self.message_min.to_file(self.file_path, mformat='yaml')
with open(self.file_path, "r") as fd: with open(self.file_path, "r") as fd:
content = fd.read() content = fd.read()
expected_content = f"{Question.yaml_key}: This is a question.\n" expected_content = f"{Question.yaml_key}: This is a question.\n"
self.assertEqual(content, expected_content) self.assertEqual(content, expected_content)
def test_to_file_yaml_auto_suffix(self) -> None:
"""
Test if suffix is auto-generated if omitted.
"""
file_path_no_suffix = self.file_path.with_suffix('')
# test with file_path member
self.message_min.file_path = file_path_no_suffix
self.message_min.to_file(mformat='yaml')
self.assertEqual(self.message_min.file_path.suffix, msg_suffix)
# test with explicit file_path
self.message_min.file_path = file_path_no_suffix
self.message_min.to_file(file_path=file_path_no_suffix, mformat='yaml')
self.assertEqual(self.message_min.file_path.suffix, msg_suffix)
class MessageFromFileTxtTestCase(unittest.TestCase): class MessageFromFileTxtTestCase(unittest.TestCase):
def setUp(self) -> None: def setUp(self) -> None:
self.file = tempfile.NamedTemporaryFile(delete=False, suffix='.txt') self.file = tempfile.NamedTemporaryFile(delete=False, suffix=msg_suffix)
self.file_path = pathlib.Path(self.file.name) self.file_path = pathlib.Path(self.file.name)
with open(self.file_path, "w") as fd: with open(self.file_path, "w") as fd:
fd.write(f"""{TagLine.prefix} tag1 tag2 fd.write(f"""{TagLine.prefix} tag1 tag2
@@ -239,7 +278,7 @@ This is a question.
{Answer.txt_header} {Answer.txt_header}
This is an answer. This is an answer.
""") """)
self.file_min = tempfile.NamedTemporaryFile(delete=False, suffix='.txt') self.file_min = tempfile.NamedTemporaryFile(delete=False, suffix=msg_suffix)
self.file_path_min = pathlib.Path(self.file_min.name) self.file_path_min = pathlib.Path(self.file_min.name)
with open(self.file_path_min, "w") as fd: with open(self.file_path_min, "w") as fd:
fd.write(f"""{Question.txt_header} fd.write(f"""{Question.txt_header}
@@ -259,7 +298,7 @@ This is a question.
message = Message.from_file(self.file_path) message = Message.from_file(self.file_path)
self.assertIsNotNone(message) self.assertIsNotNone(message)
self.assertIsInstance(message, Message) self.assertIsInstance(message, Message)
if message: # mypy bug assert message
self.assertEqual(message.question, 'This is a question.') self.assertEqual(message.question, 'This is a question.')
self.assertEqual(message.answer, 'This is an answer.') self.assertEqual(message.answer, 'This is an answer.')
self.assertSetEqual(cast(set[Tag], message.tags), {Tag('tag1'), Tag('tag2')}) self.assertSetEqual(cast(set[Tag], message.tags), {Tag('tag1'), Tag('tag2')})
@@ -274,7 +313,7 @@ This is a question.
message = Message.from_file(self.file_path_min) message = Message.from_file(self.file_path_min)
self.assertIsNotNone(message) self.assertIsNotNone(message)
self.assertIsInstance(message, Message) self.assertIsInstance(message, Message)
if message: # mypy bug assert message
self.assertEqual(message.question, 'This is a question.') self.assertEqual(message.question, 'This is a question.')
self.assertEqual(message.file_path, self.file_path_min) self.assertEqual(message.file_path, self.file_path_min)
self.assertIsNone(message.answer) self.assertIsNone(message.answer)
@@ -284,7 +323,7 @@ This is a question.
MessageFilter(tags_or={Tag('tag1')})) MessageFilter(tags_or={Tag('tag1')}))
self.assertIsNotNone(message) self.assertIsNotNone(message)
self.assertIsInstance(message, Message) self.assertIsInstance(message, Message)
if message: # mypy bug assert message
self.assertEqual(message.question, 'This is a question.') self.assertEqual(message.question, 'This is a question.')
self.assertEqual(message.answer, 'This is an answer.') self.assertEqual(message.answer, 'This is an answer.')
self.assertSetEqual(cast(set[Tag], message.tags), {Tag('tag1'), Tag('tag2')}) self.assertSetEqual(cast(set[Tag], message.tags), {Tag('tag1'), Tag('tag2')})
@@ -311,13 +350,13 @@ This is a question.
MessageFilter(tags_not={Tag('tag1')})) MessageFilter(tags_not={Tag('tag1')}))
self.assertIsNotNone(message) self.assertIsNotNone(message)
self.assertIsInstance(message, Message) self.assertIsInstance(message, Message)
if message: # mypy bug assert message
self.assertEqual(message.question, 'This is a question.') self.assertEqual(message.question, 'This is a question.')
self.assertSetEqual(cast(set[Tag], message.tags), set()) self.assertSetEqual(cast(set[Tag], message.tags), set())
self.assertEqual(message.file_path, self.file_path_min) self.assertEqual(message.file_path, self.file_path_min)
def test_from_file_not_exists(self) -> None: def test_from_file_not_exists(self) -> None:
file_not_exists = pathlib.Path("example.txt") file_not_exists = pathlib.Path(f"example{msg_suffix}")
with self.assertRaises(MessageError) as cm: with self.assertRaises(MessageError) as cm:
Message.from_file(file_not_exists) Message.from_file(file_not_exists)
self.assertEqual(str(cm.exception), f"Message file '{file_not_exists}' does not exist") self.assertEqual(str(cm.exception), f"Message file '{file_not_exists}' does not exist")
@@ -396,7 +435,7 @@ This is a question.
class MessageFromFileYamlTestCase(unittest.TestCase): class MessageFromFileYamlTestCase(unittest.TestCase):
def setUp(self) -> None: def setUp(self) -> None:
self.file = tempfile.NamedTemporaryFile(delete=False, suffix='.yaml') self.file = tempfile.NamedTemporaryFile(delete=False, suffix=msg_suffix)
self.file_path = pathlib.Path(self.file.name) self.file_path = pathlib.Path(self.file.name)
with open(self.file_path, "w") as fd: with open(self.file_path, "w") as fd:
fd.write(f""" fd.write(f"""
@@ -410,7 +449,7 @@ class MessageFromFileYamlTestCase(unittest.TestCase):
- tag1 - tag1
- tag2 - tag2
""") """)
self.file_min = tempfile.NamedTemporaryFile(delete=False, suffix='.yaml') self.file_min = tempfile.NamedTemporaryFile(delete=False, suffix=msg_suffix)
self.file_path_min = pathlib.Path(self.file_min.name) self.file_path_min = pathlib.Path(self.file_min.name)
with open(self.file_path_min, "w") as fd: with open(self.file_path_min, "w") as fd:
fd.write(f""" fd.write(f"""
@@ -431,7 +470,7 @@ class MessageFromFileYamlTestCase(unittest.TestCase):
message = Message.from_file(self.file_path) message = Message.from_file(self.file_path)
self.assertIsInstance(message, Message) self.assertIsInstance(message, Message)
self.assertIsNotNone(message) self.assertIsNotNone(message)
if message: # mypy bug assert message
self.assertEqual(message.question, 'This is a question.') self.assertEqual(message.question, 'This is a question.')
self.assertEqual(message.answer, 'This is an answer.') self.assertEqual(message.answer, 'This is an answer.')
self.assertSetEqual(cast(set[Tag], message.tags), {Tag('tag1'), Tag('tag2')}) self.assertSetEqual(cast(set[Tag], message.tags), {Tag('tag1'), Tag('tag2')})
@@ -446,14 +485,14 @@ class MessageFromFileYamlTestCase(unittest.TestCase):
message = Message.from_file(self.file_path_min) message = Message.from_file(self.file_path_min)
self.assertIsInstance(message, Message) self.assertIsInstance(message, Message)
self.assertIsNotNone(message) self.assertIsNotNone(message)
if message: # mypy bug assert message
self.assertEqual(message.question, 'This is a question.') self.assertEqual(message.question, 'This is a question.')
self.assertSetEqual(cast(set[Tag], message.tags), set()) self.assertSetEqual(cast(set[Tag], message.tags), set())
self.assertEqual(message.file_path, self.file_path_min) self.assertEqual(message.file_path, self.file_path_min)
self.assertIsNone(message.answer) self.assertIsNone(message.answer)
def test_from_file_not_exists(self) -> None: def test_from_file_not_exists(self) -> None:
file_not_exists = pathlib.Path("example.yaml") file_not_exists = pathlib.Path(f"example{msg_suffix}")
with self.assertRaises(MessageError) as cm: with self.assertRaises(MessageError) as cm:
Message.from_file(file_not_exists) Message.from_file(file_not_exists)
self.assertEqual(str(cm.exception), f"Message file '{file_not_exists}' does not exist") self.assertEqual(str(cm.exception), f"Message file '{file_not_exists}' does not exist")
@@ -463,7 +502,7 @@ class MessageFromFileYamlTestCase(unittest.TestCase):
MessageFilter(tags_or={Tag('tag1')})) MessageFilter(tags_or={Tag('tag1')}))
self.assertIsNotNone(message) self.assertIsNotNone(message)
self.assertIsInstance(message, Message) self.assertIsInstance(message, Message)
if message: # mypy bug assert message
self.assertEqual(message.question, 'This is a question.') self.assertEqual(message.question, 'This is a question.')
self.assertEqual(message.answer, 'This is an answer.') self.assertEqual(message.answer, 'This is an answer.')
self.assertSetEqual(cast(set[Tag], message.tags), {Tag('tag1'), Tag('tag2')}) self.assertSetEqual(cast(set[Tag], message.tags), {Tag('tag1'), Tag('tag2')})
@@ -484,7 +523,7 @@ class MessageFromFileYamlTestCase(unittest.TestCase):
MessageFilter(tags_not={Tag('tag1')})) MessageFilter(tags_not={Tag('tag1')}))
self.assertIsNotNone(message) self.assertIsNotNone(message)
self.assertIsInstance(message, Message) self.assertIsInstance(message, Message)
if message: # mypy bug assert message
self.assertEqual(message.question, 'This is a question.') self.assertEqual(message.question, 'This is a question.')
self.assertSetEqual(cast(set[Tag], message.tags), set()) self.assertSetEqual(cast(set[Tag], message.tags), set())
self.assertEqual(message.file_path, self.file_path_min) self.assertEqual(message.file_path, self.file_path_min)
@@ -563,7 +602,7 @@ class MessageFromFileYamlTestCase(unittest.TestCase):
class TagsFromFileTestCase(unittest.TestCase): class TagsFromFileTestCase(unittest.TestCase):
def setUp(self) -> None: def setUp(self) -> None:
self.file_txt = tempfile.NamedTemporaryFile(delete=False, suffix='.txt') self.file_txt = tempfile.NamedTemporaryFile(delete=False, suffix=msg_suffix)
self.file_path_txt = pathlib.Path(self.file_txt.name) self.file_path_txt = pathlib.Path(self.file_txt.name)
with open(self.file_path_txt, "w") as fd: with open(self.file_path_txt, "w") as fd:
fd.write(f"""{TagLine.prefix} tag1 tag2 ptag3 fd.write(f"""{TagLine.prefix} tag1 tag2 ptag3
@@ -572,7 +611,7 @@ This is a question.
{Answer.txt_header} {Answer.txt_header}
This is an answer. This is an answer.
""") """)
self.file_txt_no_tags = tempfile.NamedTemporaryFile(delete=False, suffix='.txt') self.file_txt_no_tags = tempfile.NamedTemporaryFile(delete=False, suffix=msg_suffix)
self.file_path_txt_no_tags = pathlib.Path(self.file_txt_no_tags.name) self.file_path_txt_no_tags = pathlib.Path(self.file_txt_no_tags.name)
with open(self.file_path_txt_no_tags, "w") as fd: with open(self.file_path_txt_no_tags, "w") as fd:
fd.write(f"""{Question.txt_header} fd.write(f"""{Question.txt_header}
@@ -580,7 +619,7 @@ This is a question.
{Answer.txt_header} {Answer.txt_header}
This is an answer. This is an answer.
""") """)
self.file_txt_tags_empty = tempfile.NamedTemporaryFile(delete=False, suffix='.txt') self.file_txt_tags_empty = tempfile.NamedTemporaryFile(delete=False, suffix=msg_suffix)
self.file_path_txt_tags_empty = pathlib.Path(self.file_txt_tags_empty.name) self.file_path_txt_tags_empty = pathlib.Path(self.file_txt_tags_empty.name)
with open(self.file_path_txt_tags_empty, "w") as fd: with open(self.file_path_txt_tags_empty, "w") as fd:
fd.write(f"""TAGS: fd.write(f"""TAGS:
@@ -589,7 +628,7 @@ This is a question.
{Answer.txt_header} {Answer.txt_header}
This is an answer. This is an answer.
""") """)
self.file_yaml = tempfile.NamedTemporaryFile(delete=False, suffix='.yaml') self.file_yaml = tempfile.NamedTemporaryFile(delete=False, suffix=msg_suffix)
self.file_path_yaml = pathlib.Path(self.file_yaml.name) self.file_path_yaml = pathlib.Path(self.file_yaml.name)
with open(self.file_path_yaml, "w") as fd: with open(self.file_path_yaml, "w") as fd:
fd.write(f""" fd.write(f"""
@@ -602,7 +641,7 @@ This is an answer.
- tag2 - tag2
- ptag3 - ptag3
""") """)
self.file_yaml_no_tags = tempfile.NamedTemporaryFile(delete=False, suffix='.yaml') self.file_yaml_no_tags = tempfile.NamedTemporaryFile(delete=False, suffix=msg_suffix)
self.file_path_yaml_no_tags = pathlib.Path(self.file_yaml_no_tags.name) self.file_path_yaml_no_tags = pathlib.Path(self.file_yaml_no_tags.name)
with open(self.file_path_yaml_no_tags, "w") as fd: with open(self.file_path_yaml_no_tags, "w") as fd:
fd.write(f""" fd.write(f"""
@@ -679,24 +718,25 @@ class TagsFromDirTestCase(unittest.TestCase):
{Tag('ctag5'), Tag('ctag6')} {Tag('ctag5'), Tag('ctag6')}
] ]
self.files = [ self.files = [
pathlib.Path(self.temp_dir.name, 'file1.txt'), pathlib.Path(self.temp_dir.name, f'file1{msg_suffix}'),
pathlib.Path(self.temp_dir.name, 'file2.yaml'), pathlib.Path(self.temp_dir.name, f'file2{msg_suffix}'),
pathlib.Path(self.temp_dir.name, 'file3.txt') pathlib.Path(self.temp_dir.name, f'file3{msg_suffix}')
] ]
self.files_no_tags = [ self.files_no_tags = [
pathlib.Path(self.temp_dir_no_tags.name, 'file4.txt'), pathlib.Path(self.temp_dir_no_tags.name, f'file4{msg_suffix}'),
pathlib.Path(self.temp_dir_no_tags.name, 'file5.yaml'), pathlib.Path(self.temp_dir_no_tags.name, f'file5{msg_suffix}'),
pathlib.Path(self.temp_dir_no_tags.name, 'file6.txt') pathlib.Path(self.temp_dir_no_tags.name, f'file6{msg_suffix}')
] ]
mformats = itertools.cycle(message_valid_formats)
for file, tags in zip(self.files, self.tag_sets): for file, tags in zip(self.files, self.tag_sets):
message = Message(Question('This is a question.'), message = Message(Question('This is a question.'),
Answer('This is an answer.'), Answer('This is an answer.'),
tags) tags)
message.to_file(file) message.to_file(file, next(mformats))
for file in self.files_no_tags: for file in self.files_no_tags:
message = Message(Question('This is a question.'), message = Message(Question('This is a question.'),
Answer('This is an answer.')) Answer('This is an answer.'))
message.to_file(file) message.to_file(file, next(mformats))
def tearDown(self) -> None: def tearDown(self) -> None:
self.temp_dir.cleanup() self.temp_dir.cleanup()
@@ -719,7 +759,7 @@ class TagsFromDirTestCase(unittest.TestCase):
class MessageIDTestCase(unittest.TestCase): class MessageIDTestCase(unittest.TestCase):
def setUp(self) -> None: def setUp(self) -> None:
self.file = tempfile.NamedTemporaryFile(delete=False, suffix='.txt') self.file = tempfile.NamedTemporaryFile(delete=False, suffix=msg_suffix)
self.file_path = pathlib.Path(self.file.name) self.file_path = pathlib.Path(self.file.name)
self.message = Message(Question('This is a question.'), self.message = Message(Question('This is a question.'),
file_path=self.file_path) file_path=self.file_path)
@@ -730,7 +770,7 @@ class MessageIDTestCase(unittest.TestCase):
self.file_path.unlink() self.file_path.unlink()
def test_msg_id_txt(self) -> None: def test_msg_id_txt(self) -> None:
self.assertEqual(self.message.msg_id(), self.file_path.name) self.assertEqual(self.message.msg_id(), self.file_path.stem)
def test_msg_id_txt_exception(self) -> None: def test_msg_id_txt_exception(self) -> None:
with self.assertRaises(MessageError): with self.assertRaises(MessageError):
@@ -816,6 +856,8 @@ class MessageToStrTestCase(unittest.TestCase):
def setUp(self) -> None: def setUp(self) -> None:
self.message = Message(Question('This is a question.'), self.message = Message(Question('This is a question.'),
Answer('This is an answer.'), Answer('This is an answer.'),
ai=('FakeAI'),
model=('FakeModel'),
tags={Tag('atag1'), Tag('btag2')}, tags={Tag('atag1'), Tag('btag2')},
file_path=pathlib.Path('/tmp/foo/bla')) file_path=pathlib.Path('/tmp/foo/bla'))
@@ -829,8 +871,29 @@ This is an answer."""
def test_to_str_with_tags_and_file(self) -> None: def test_to_str_with_tags_and_file(self) -> None:
expected_output = f"""{TagLine.prefix} atag1 btag2 expected_output = f"""{TagLine.prefix} atag1 btag2
FILE: /tmp/foo/bla FILE: /tmp/foo/bla
AI: FakeAI
MODEL: FakeModel
{Question.txt_header} {Question.txt_header}
This is a question. This is a question.
{Answer.txt_header} {Answer.txt_header}
This is an answer.""" This is an answer."""
self.assertEqual(self.message.to_str(with_tags=True, with_file=True), expected_output) self.assertEqual(self.message.to_str(with_metadata=True), expected_output)
class MessageRmFileTestCase(unittest.TestCase):
def setUp(self) -> None:
self.file = tempfile.NamedTemporaryFile(delete=False, suffix=msg_suffix)
self.file_path = pathlib.Path(self.file.name)
self.message = Message(Question('This is a question.'),
file_path=self.file_path)
self.message.to_file()
def tearDown(self) -> None:
self.file.close()
self.file_path.unlink(missing_ok=True)
def test_rm_file(self) -> None:
assert self.message.file_path
self.assertTrue(self.message.file_path.exists())
self.message.rm_file()
self.assertFalse(self.message.file_path.exists())
+446 -13
View File
@@ -1,25 +1,39 @@
import os import os
import unittest
import argparse import argparse
import tempfile import tempfile
from copy import copy
from pathlib import Path from pathlib import Path
from unittest.mock import MagicMock from unittest import mock
from chatmastermind.commands.question import create_message from unittest.mock import MagicMock, call
from chatmastermind.message import Message, Question from chatmastermind.configuration import Config
from chatmastermind.chat import ChatDB from chatmastermind.commands.question import create_message, question_cmd
from chatmastermind.tags import Tag
from chatmastermind.message import Message, Question, Answer
from chatmastermind.chat import Chat, ChatDB, msg_location
from chatmastermind.ai import AIError
from .test_common import TestWithFakeAI
class TestMessageCreate(unittest.TestCase): msg_suffix = Message.file_suffix_write
class TestMessageCreate(TestWithFakeAI):
""" """
Test if messages created by the 'question' command have Test if messages created by the 'question' command have
the correct format. the correct format.
""" """
def setUp(self) -> None: def setUp(self) -> None:
# create ChatDB structure # create ChatDB structure
self.db_path = tempfile.TemporaryDirectory() self.db_dir = tempfile.TemporaryDirectory()
self.cache_path = tempfile.TemporaryDirectory() self.cache_dir = tempfile.TemporaryDirectory()
self.chat = ChatDB.from_dir(cache_path=Path(self.cache_path.name), self.chat = ChatDB.from_dir(cache_path=Path(self.cache_dir.name),
db_path=Path(self.db_path.name)) db_path=Path(self.db_dir.name))
# create some messages
self.message_text = Message(Question("What is this?"),
Answer("It is pure text"))
self.message_code = Message(Question("What is this?"),
Answer("Text\n```\nIt is embedded code\n```\ntext"))
self.chat.db_add([self.message_text, self.message_code])
# create arguments mock # create arguments mock
self.args = MagicMock(spec=argparse.Namespace) self.args = MagicMock(spec=argparse.Namespace)
self.args.source_text = None self.args.source_text = None
@@ -27,6 +41,8 @@ class TestMessageCreate(unittest.TestCase):
self.args.AI = None self.args.AI = None
self.args.model = None self.args.model = None
self.args.output_tags = None self.args.output_tags = None
self.args.ask = None
self.args.create = None
# File 1 : no source code block, only text # File 1 : no source code block, only text
self.source_file1 = tempfile.NamedTemporaryFile(delete=False) self.source_file1 = tempfile.NamedTemporaryFile(delete=False)
self.source_file1_content = """This is just text. self.source_file1_content = """This is just text.
@@ -68,17 +84,18 @@ Aaaand again some text."""
os.remove(self.source_file1.name) os.remove(self.source_file1.name)
os.remove(self.source_file2.name) os.remove(self.source_file2.name)
os.remove(self.source_file3.name) os.remove(self.source_file3.name)
os.remove(self.source_file4.name)
def message_list(self, tmp_dir: tempfile.TemporaryDirectory) -> list[Path]: def message_list(self, tmp_dir: tempfile.TemporaryDirectory) -> list[Path]:
# exclude '.next' # exclude '.next'
return list(Path(tmp_dir.name).glob('*.[ty]*')) return list(Path(tmp_dir.name).glob(f'*{msg_suffix}'))
def test_message_file_created(self) -> None: def test_message_file_created(self) -> None:
self.args.ask = ["What is this?"] self.args.ask = ["What is this?"]
cache_dir_files = self.message_list(self.cache_path) cache_dir_files = self.message_list(self.cache_dir)
self.assertEqual(len(cache_dir_files), 0) self.assertEqual(len(cache_dir_files), 0)
create_message(self.chat, self.args) create_message(self.chat, self.args)
cache_dir_files = self.message_list(self.cache_path) cache_dir_files = self.message_list(self.cache_dir)
self.assertEqual(len(cache_dir_files), 1) self.assertEqual(len(cache_dir_files), 1)
message = Message.from_file(cache_dir_files[0]) message = Message.from_file(cache_dir_files[0])
self.assertIsInstance(message, Message) self.assertIsInstance(message, Message)
@@ -160,3 +177,419 @@ This is embedded source code.
This is embedded source code. This is embedded source code.
``` ```
""")) """))
def test_single_question_with_text_only_message(self) -> None:
self.args.ask = ["What is this?"]
self.args.source_text = [f"{self.chat.messages[0].file_path}"]
message = create_message(self.chat, self.args)
self.assertIsInstance(message, Message)
# file contains no source code (only text)
# -> don't expect any in the question
self.assertEqual(len(message.question.source_code()), 0)
self.assertEqual(message.question, Question(f"""What is this?
{self.message_text.answer}"""))
def test_single_question_with_message_and_embedded_code(self) -> None:
self.args.ask = ["What is this?"]
self.args.source_code = [f"{self.chat.messages[1].file_path}"]
message = create_message(self.chat, self.args)
self.assertIsInstance(message, Message)
# answer contains 1 source code block
# -> expect it in the question
self.assertEqual(len(message.question.source_code()), 1)
self.assertEqual(message.question, Question("""What is this?
```
It is embedded code
```
"""))
class TestCreateOption(TestMessageCreate):
def test_message_file_created(self) -> None:
self.args.create = ["How does question --create work?"]
self.args.ask = None
cache_dir_files = self.message_list(self.cache_dir)
self.assertEqual(len(cache_dir_files), 0)
create_message(self.chat, self.args)
cache_dir_files = self.message_list(self.cache_dir)
self.assertEqual(len(cache_dir_files), 1)
message = Message.from_file(cache_dir_files[0])
self.assertIsInstance(message, Message)
self.assertEqual(message.question, Question("How does question --create work?")) # type: ignore [union-attr]
class TestQuestionCmd(TestWithFakeAI):
def setUp(self) -> None:
# create DB and cache
self.db_dir = tempfile.TemporaryDirectory()
self.cache_dir = tempfile.TemporaryDirectory()
# create configuration
self.config = Config()
self.config.cache = self.cache_dir.name
self.config.db = self.db_dir.name
# create a mock argparse.Namespace
self.args = argparse.Namespace(
ask=['What is the meaning of life?'],
glob=None,
location='db',
num_answers=1,
output_tags=['science'],
AI='FakeAI',
model='FakeModel',
or_tags=None,
and_tags=None,
exclude_tags=None,
source_text=None,
source_code=None,
create=None,
repeat=None,
process=None,
overwrite=None
)
def message_list(self, tmp_dir: tempfile.TemporaryDirectory) -> list[Path]:
# exclude '.next'
return sorted([f for f in Path(tmp_dir.name).glob(f'*{msg_suffix}')])
class TestQuestionCmdAsk(TestQuestionCmd):
@mock.patch('chatmastermind.commands.question.create_ai')
def test_ask_single_answer(self, mock_create_ai: MagicMock) -> None:
"""
Test single answer with no errors.
"""
mock_create_ai.side_effect = self.mock_create_ai
expected_question = Message(Question(self.args.ask[0]),
tags=set(self.args.output_tags),
ai=self.args.AI,
model=self.args.model,
file_path=Path('<NOT COMPARED>'))
fake_ai = self.mock_create_ai(self.args, self.config)
expected_responses = fake_ai.request(expected_question,
Chat([]),
self.args.num_answers,
self.args.output_tags).messages
# execute the command
question_cmd(self.args, self.config)
# check for the expected message files
chat = ChatDB.from_dir(Path(self.cache_dir.name),
Path(self.db_dir.name))
cached_msg = chat.msg_gather(loc=msg_location.CACHE)
self.assertEqual(len(self.message_list(self.cache_dir)), 1)
self.assert_msgs_equal_except_file_path(cached_msg, expected_responses)
@mock.patch('chatmastermind.commands.question.ChatDB.from_dir')
@mock.patch('chatmastermind.commands.question.create_ai')
def test_ask_single_answer_mocked(self, mock_create_ai: MagicMock, mock_from_dir: MagicMock) -> None:
"""
Test single answer with no errors (mocked ChatDB version).
"""
chat = MagicMock(spec=ChatDB)
mock_from_dir.return_value = chat
mock_create_ai.side_effect = self.mock_create_ai
expected_question = Message(Question(self.args.ask[0]),
tags=set(self.args.output_tags),
ai=self.args.AI,
model=self.args.model,
file_path=Path('<NOT COMPARED>'))
fake_ai = self.mock_create_ai(self.args, self.config)
expected_responses = fake_ai.request(expected_question,
Chat([]),
self.args.num_answers,
self.args.output_tags).messages
# execute the command
question_cmd(self.args, self.config)
# check for the correct ChatDB calls:
# - initial question has been written (prior to the actual request)
# - responses have been written (after the request)
chat.cache_write.assert_has_calls([call([expected_question]),
call(expected_responses)],
any_order=False)
# check that the messages have not been added to the internal message list
chat.cache_add.assert_not_called()
@mock.patch('chatmastermind.commands.question.create_ai')
def test_ask_with_error(self, mock_create_ai: MagicMock) -> None:
"""
Provoke an error during the AI request and verify that the question
has been correctly stored in the cache.
"""
mock_create_ai.side_effect = self.mock_create_ai_with_error
expected_question = Message(Question(self.args.ask[0]),
tags=set(self.args.output_tags),
ai=self.args.AI,
model=self.args.model,
file_path=Path('<NOT COMPARED>'))
# execute the command
with self.assertRaises(AIError):
question_cmd(self.args, self.config)
# check for the expected message files
chat = ChatDB.from_dir(Path(self.cache_dir.name),
Path(self.db_dir.name))
cached_msg = chat.msg_gather(loc=msg_location.CACHE)
self.assertEqual(len(self.message_list(self.cache_dir)), 1)
self.assert_msgs_equal_except_file_path(cached_msg, [expected_question])
class TestQuestionCmdRepeat(TestQuestionCmd):
@mock.patch('chatmastermind.commands.question.create_ai')
def test_repeat_single_question(self, mock_create_ai: MagicMock) -> None:
"""
Repeat a single question.
"""
mock_create_ai.side_effect = self.mock_create_ai
chat = ChatDB.from_dir(Path(self.cache_dir.name),
Path(self.db_dir.name))
# create a message
message = Message(Question(self.args.ask[0]),
Answer('Old Answer'),
tags=set(self.args.output_tags),
ai=self.args.AI,
model=self.args.model,
file_path=Path(self.cache_dir.name) / f'0001{msg_suffix}')
chat.msg_write([message])
# repeat the last question (without overwriting)
# -> expect two identical messages (except for the file_path)
self.args.ask = None
self.args.repeat = []
self.args.overwrite = False
expected_response = Message(Question(message.question),
Answer('Answer 0'),
ai=message.ai,
model=message.model,
tags=message.tags,
file_path=Path('<NOT COMPARED>'))
# we expect the original message + the one with the new response
expected_responses = [message] + [expected_response]
question_cmd(self.args, self.config)
cached_msg = chat.msg_gather(loc=msg_location.CACHE)
print(self.message_list(self.cache_dir))
self.assertEqual(len(self.message_list(self.cache_dir)), 2)
self.assert_msgs_equal_except_file_path(cached_msg, expected_responses)
@mock.patch('chatmastermind.commands.question.create_ai')
def test_repeat_single_question_overwrite(self, mock_create_ai: MagicMock) -> None:
"""
Repeat a single question and overwrite the old one.
"""
mock_create_ai.side_effect = self.mock_create_ai
chat = ChatDB.from_dir(Path(self.cache_dir.name),
Path(self.db_dir.name))
# create a message
message = Message(Question(self.args.ask[0]),
Answer('Old Answer'),
tags=set(self.args.output_tags),
ai=self.args.AI,
model=self.args.model,
file_path=Path(self.cache_dir.name) / f'0001{msg_suffix}')
chat.msg_write([message])
cached_msg = chat.msg_gather(loc=msg_location.CACHE)
assert cached_msg[0].file_path
cached_msg_file_id = cached_msg[0].file_path.stem
# repeat the last question (WITH overwriting)
# -> expect a single message afterwards (with a new answer)
self.args.ask = None
self.args.repeat = []
self.args.overwrite = True
expected_response = Message(Question(message.question),
Answer('Answer 0'),
ai=message.ai,
model=message.model,
tags=message.tags,
file_path=Path('<NOT COMPARED>'))
question_cmd(self.args, self.config)
cached_msg = chat.msg_gather(loc=msg_location.CACHE)
self.assertEqual(len(self.message_list(self.cache_dir)), 1)
self.assert_msgs_equal_except_file_path(cached_msg, [expected_response])
# also check that the file ID has not been changed
assert cached_msg[0].file_path
self.assertEqual(cached_msg_file_id, cached_msg[0].file_path.stem)
@mock.patch('chatmastermind.commands.question.create_ai')
def test_repeat_single_question_after_error(self, mock_create_ai: MagicMock) -> None:
"""
Repeat a single question after an error.
"""
mock_create_ai.side_effect = self.mock_create_ai
chat = ChatDB.from_dir(Path(self.cache_dir.name),
Path(self.db_dir.name))
# create a question WITHOUT an answer
# -> just like after an error, which is tested above
message = Message(Question(self.args.ask[0]),
tags=set(self.args.output_tags),
ai=self.args.AI,
model=self.args.model,
file_path=Path(self.cache_dir.name) / f'0001{msg_suffix}')
chat.msg_write([message])
cached_msg = chat.msg_gather(loc=msg_location.CACHE)
assert cached_msg[0].file_path
cached_msg_file_id = cached_msg[0].file_path.stem
# repeat the last question (without overwriting)
# -> expect a single message because if the original has
# no answer, it should be overwritten by default
self.args.ask = None
self.args.repeat = []
self.args.overwrite = False
expected_response = Message(Question(message.question),
Answer('Answer 0'),
ai=message.ai,
model=message.model,
tags=message.tags,
file_path=Path('<NOT COMPARED>'))
question_cmd(self.args, self.config)
cached_msg = chat.msg_gather(loc=msg_location.CACHE)
self.assertEqual(len(self.message_list(self.cache_dir)), 1)
self.assert_msgs_equal_except_file_path(cached_msg, [expected_response])
# also check that the file ID has not been changed
assert cached_msg[0].file_path
self.assertEqual(cached_msg_file_id, cached_msg[0].file_path.stem)
@mock.patch('chatmastermind.commands.question.create_ai')
def test_repeat_single_question_new_args(self, mock_create_ai: MagicMock) -> None:
"""
Repeat a single question with new arguments.
"""
mock_create_ai.side_effect = self.mock_create_ai
chat = ChatDB.from_dir(Path(self.cache_dir.name),
Path(self.db_dir.name))
# create a message
message = Message(Question(self.args.ask[0]),
Answer('Old Answer'),
tags=set(self.args.output_tags),
ai=self.args.AI,
model=self.args.model,
file_path=Path(self.cache_dir.name) / f'0001{msg_suffix}')
chat.msg_write([message])
cached_msg = chat.msg_gather(loc=msg_location.CACHE)
assert cached_msg[0].file_path
# repeat the last question with new arguments (without overwriting)
# -> expect two messages with identical question but different metadata and new answer
self.args.ask = None
self.args.repeat = []
self.args.overwrite = False
self.args.output_tags = ['newtag']
self.args.AI = 'newai'
self.args.model = 'newmodel'
new_expected_response = Message(Question(message.question),
Answer('Answer 0'),
ai='newai',
model='newmodel',
tags={Tag('newtag')},
file_path=Path('<NOT COMPARED>'))
question_cmd(self.args, self.config)
cached_msg = chat.msg_gather(loc=msg_location.CACHE)
self.assertEqual(len(self.message_list(self.cache_dir)), 2)
self.assert_msgs_equal_except_file_path(cached_msg, [message] + [new_expected_response])
@mock.patch('chatmastermind.commands.question.create_ai')
def test_repeat_single_question_new_args_overwrite(self, mock_create_ai: MagicMock) -> None:
"""
Repeat a single question with new arguments, overwriting the old one.
"""
mock_create_ai.side_effect = self.mock_create_ai
chat = ChatDB.from_dir(Path(self.cache_dir.name),
Path(self.db_dir.name))
# create a message
message = Message(Question(self.args.ask[0]),
Answer('Old Answer'),
tags=set(self.args.output_tags),
ai=self.args.AI,
model=self.args.model,
file_path=Path(self.cache_dir.name) / f'0001{msg_suffix}')
chat.msg_write([message])
cached_msg = chat.msg_gather(loc=msg_location.CACHE)
assert cached_msg[0].file_path
# repeat the last question with new arguments
self.args.ask = None
self.args.repeat = []
self.args.overwrite = True
self.args.output_tags = ['newtag']
self.args.AI = 'newai'
self.args.model = 'newmodel'
new_expected_response = Message(Question(message.question),
Answer('Answer 0'),
ai='newai',
model='newmodel',
tags={Tag('newtag')},
file_path=Path('<NOT COMPARED>'))
question_cmd(self.args, self.config)
cached_msg = chat.msg_gather(loc=msg_location.CACHE)
self.assertEqual(len(self.message_list(self.cache_dir)), 1)
self.assert_msgs_equal_except_file_path(cached_msg, [new_expected_response])
@mock.patch('chatmastermind.commands.question.create_ai')
def test_repeat_multiple_questions(self, mock_create_ai: MagicMock) -> None:
"""
Repeat multiple questions.
"""
mock_create_ai.side_effect = self.mock_create_ai
chat = ChatDB.from_dir(Path(self.cache_dir.name),
Path(self.db_dir.name))
# 1. === create three questions ===
# cached message without an answer
message1 = Message(Question(self.args.ask[0]),
tags=self.args.output_tags,
ai=self.args.AI,
model=self.args.model,
file_path=Path(self.cache_dir.name) / f'0001{msg_suffix}')
# cached message with an answer
message2 = Message(Question(self.args.ask[0]),
Answer('Old Answer'),
tags=self.args.output_tags,
ai=self.args.AI,
model=self.args.model,
file_path=Path(self.cache_dir.name) / f'0002{msg_suffix}')
# DB message without an answer
message3 = Message(Question(self.args.ask[0]),
tags=self.args.output_tags,
ai=self.args.AI,
model=self.args.model,
file_path=Path(self.db_dir.name) / f'0003{msg_suffix}')
chat.msg_write([message1, message2, message3])
questions = [message1, message2, message3]
expected_responses: list[Message] = []
fake_ai = self.mock_create_ai(self.args, self.config)
for question in questions:
# since the message's answer is modified, we use a copy
# -> the original is used for comparison below
expected_responses += fake_ai.request(copy(question),
Chat([]),
self.args.num_answers,
set(self.args.output_tags)).messages
# 2. === repeat all three questions (without overwriting) ===
self.args.ask = None
self.args.repeat = ['0001', '0002', '0003']
self.args.overwrite = False
question_cmd(self.args, self.config)
# two new files should be in the cache directory
# * the repeated cached message with answer
# * the repeated DB message
# -> the cached message without answer should be overwritten
self.assertEqual(len(self.message_list(self.cache_dir)), 4)
self.assertEqual(len(self.message_list(self.db_dir)), 1)
expected_cache_messages = [expected_responses[0], message2, expected_responses[1], expected_responses[2]]
cached_msg = chat.msg_gather(loc=msg_location.CACHE)
self.assert_msgs_equal_except_file_path(cached_msg, expected_cache_messages)
# check that the DB message has not been modified at all
db_msg = chat.msg_gather(loc=msg_location.DB)
self.assert_msgs_all_equal(db_msg, [message3])