Compare commits

...

327 Commits

Author SHA1 Message Date
sabaimran
d607ad7a27 Release Khoj version 1.29.1 2024-11-12 10:32:56 -08:00
sabaimran
8ec1764e42 Handle size calculation more gracefully for converted documents, depending on type 2024-11-12 02:00:29 -08:00
sabaimran
b6714c202f Increase the title character limit to 500 for conversations 2024-11-12 01:51:19 -08:00
sabaimran
f05e64cf8c Release Khoj version 1.29.0 2024-11-11 21:46:25 -08:00
sabaimran
47d3c8c235 Remove email query parameter from subscription patch api 2024-11-11 21:39:49 -08:00
sabaimran
d7027109a5 And null handling for response output_files in code output 2024-11-11 21:14:56 -08:00
sabaimran
d68243a3fb Revert clean_json logic temporarily. Eventually, we should do better validation here to extract markdown-formatted json. 2024-11-11 21:05:17 -08:00
sabaimran
1cab6c081f Add better error handling for diagram output, and fix chat history construct
- Make the `clean_json` method more robust as well
2024-11-11 20:44:19 -08:00
sabaimran
7bd2f83f97 Wrap test in suggestionCard 2024-11-11 20:12:46 -08:00
Debanjum
48862a8400 Enable Passing External Documents for Analysis in Code Sandbox (#960)
- Allow passing user files as input into code sandbox for analysis
- Update prompt to give more example of complex, multi-line code
- Simplify logic for model. Run one program at a time, 
  instead of allowing model to run multiple programs in parallel
- Show Code generated charts and docs in Reference pane of web app and make them downloaded
2024-11-11 19:37:17 -08:00
Debanjum
5078ac0ce2 Await on conversation save when generate conversation title via API 2024-11-11 19:17:39 -08:00
Debanjum
e1d0015248 Allow disabling Khoj telemetry via KHOJ_TELEMETRY_DISABLE env var 2024-11-11 19:17:39 -08:00
Debanjum
a52500d289 Show generated code artifacts before notes and online references 2024-11-11 18:00:22 -08:00
Debanjum
218eed83cd Show output file not code on hover. Remove reference card title border 2024-11-11 18:00:22 -08:00
Debanjum
b970cfd4b3 Align styling of reference panel card across code, docs, web results
- Add a border below heading
- Show code snippet in pre block
- Overflow-x when reference side panel open to allow seeing whole text
  via x-scroll
- Align header, body position of reference cards with each other
- Only show filename in doc reference cards at message bottom.
  Show full file path in hover and reference side panel
2024-11-11 18:00:22 -08:00
Debanjum
8e9f4262a9 Render code output files with code references in reference section
- Improve rendering code reference with better icons, smaller text and
  different line clamps for better visibility
- Show code output files as sub card of code card in reference section
- Allow downloading files generated by code instead of rendering it in
  chat message directly
- Show executed code before online references in reference panel
2024-11-11 18:00:22 -08:00
Debanjum
92c1efe6ee Fixes to render & save code context with non text based output modes
- Fix to render code generated chart with images, excalidraw diagrams
- Fix to save code context to chat history in image, diagram output modes
- Fix bug in image markdown being wrapped twice in markdown syntax
- Render newline in code references shown on chat page of web app
  Previously newlines weren't getting rendered. This made the code
  executed by Khoj hard to read in references. This changes fixes that.

  `dangerouslySetInnerHTML' usage is justified as rendered code
  snippet is being sanitized by DOMPurify before rendering.
2024-11-11 18:00:22 -08:00
Debanjum
af0215765c Decode code text output files from b64 to str to ease client processing 2024-11-11 18:00:22 -08:00
Debanjum
7b39f2014a Enable analysing user documents in code sandbox and other improvements
- Run one program at a time, instead of allowing model to pass
  multiple programs to run in parallel to simplify logic for model
- Update prompt to give more example of complex, multi-line code
- Allow passing user files as input into code sandbox for analysis

- Log code execution timer at info level to evaluate execution latencies
  in production

- Type the generated code for easier processing by caller functions
2024-11-11 17:59:37 -08:00
sabaimran
dc109559d4 Research mode gray when off, colored when on 2024-11-11 16:35:07 -08:00
sabaimran
cdda9c2e73 Improve text wrapping for attached files and preview context
For the research mode toggle, make it not fill when it's off
2024-11-11 13:32:10 -08:00
sabaimran
dd36303bb7 Fix sending file attachments in save_to_conversation method
- When files attached but upload fails, don't update the state variables
- Make removing null characters in pdf extraction more space efficient
2024-11-11 12:53:06 -08:00
Debanjum
ba2471dc02 Do not CRUD on entries, files & conversations in DB for null user (#958)
Increase defense-in-depth by reducing paths to create, read, update or delete entries, files and conversations in DB when user is unset.
2024-11-11 12:47:22 -08:00
Debanjum
536fe994be Remove unused db adapter methods, like for fact checker data store 2024-11-11 12:22:34 -08:00
Debanjum
10bca6fa8f Convert required user param check into decorator. Use with more adapters 2024-11-11 12:22:32 -08:00
Debanjum
ff5c10c221 Do not CRUD on entries, files & conversations in DB for null user
Increase defense-in-depth by reducing paths to create, read, update or
delete entries, files and conversations in DB when user is unset.
2024-11-11 12:20:07 -08:00
sabaimran
27fa39353e Make custom agent creation flow available to everyone
- For private agents, add guardrails to prevent against any misuse or violation of terms of service.
2024-11-11 11:54:59 -08:00
sabaimran
b563f46a2e Merge pull request #957 from khoj-ai/features/include-full-file-in-convo-with-filter
Support including file attachments in the chat message

Now that models have much larger context windows, we can reasonably include full texts of certain files in the messages. Do this when an explicit file filter is set in a conversation. Do so in a separate user message in order to mitigate any confusion in the operation.

Pipe the relevant attached_files context through all methods calling into models.

This breaks certain prior behaviors. We will no longer automatically be processing/generating embeddings on the backend and adding documents to the "brain". You'll have to go to settings and go through the upload documents flow there in order to add docs to the brain (i.e., have search include them during question / response).
2024-11-11 11:34:42 -08:00
sabaimran
2bb2ff27a4 Rename attached_files to query_files. Update relevant backend and client-side code. 2024-11-11 11:21:26 -08:00
sabaimran
47937d5148 Merge branch 'features/include-full-file-in-convo-with-filter' of github.com:khoj-ai/khoj into features/include-full-file-in-convo-with-filter 2024-11-11 09:34:08 -08:00
sabaimran
ae4eb96d48 Consolidate file name to icon mapping 2024-11-11 09:34:04 -08:00
Debanjum
7954f39633 Use accept param to file input to indicate supported file types in web app
Remove unused total size calculations in chat input
2024-11-11 04:06:17 -08:00
Debanjum
4223b355dc Use python stdlib methods to write pdf, docx to temp files for loaders
Use python standard method tempfile.NamedTemporaryFile to write,
delete temporary files safely.
2024-11-11 03:24:50 -08:00
Debanjum
fd15fc1e59 Move construct chat history back to it's original position in file
Keep function where it original was allows tracking diffs and change
history more easily
2024-11-11 03:24:50 -08:00
Debanjum
35d6c792e4 Show snippet of truncated messages in debug logs to avoid log flooding 2024-11-11 02:30:38 -08:00
sabaimran
8805e731fd Merge branch 'master' of github.com:khoj-ai/khoj into features/include-full-file-in-convo-with-filter 2024-11-10 19:24:11 -08:00
sabaimran
a5e2b9e745 Exit early when running an automation if the conversation for the automation does not exist. 2024-11-10 19:22:21 -08:00
sabaimran
55200be4fa Apply agent color fill to the toggle both in off and on states 2024-11-10 19:16:43 -08:00
Debanjum
7468f6a6ed Deduplicate online references returned by chat API to clients
This will ensure only unique online references are shown in all
clients.

The duplication issue was exacerbated in research mode as even with
different online search queries, you can get previously seen results.

This change does a global deduplication across all online results seen
across research iterations before returning them in client reponse.
2024-11-10 16:10:32 -08:00
Debanjum
137687ee49 Deduplicate searches in normal mode & across research iterations
- Deduplicate online, doc search queries across research iterations.
  This avoids running previously run online, doc searches again and
  dedupes online, doc context seen by model to generate response.
- Deduplicate online search queries generated by chat model for each
  user query.
- Do not pass online, docs, code context separately when generate
  response in research mode. These are already collected in the meta
  research passed with the user query
- Improve formatting of context passed to generate research response
  - Use xml tags to delimit context. Pass per iteration queries in each
    iteration result
  - Put user query before meta research results in user message passed
    for generating response

This deduplications will improve speed, cost & quality of research mode
2024-11-10 16:10:32 -08:00
Debanjum
306f7a2132 Show error in picking next tool to researcher llm in next iteration
Previously the whole research mode response would fail if the pick
next tool call to chat model failed. Now instead of it completely
failing, the researcher actor is told to try again in next iteration.

This allows for a more graceful degradation in answering a research
question even if a (few?) calls to the chat model fail.
2024-11-10 14:52:02 -08:00
Debanjum
eb492f3025 Only keep webpage content requested, even if Jina API gets more data
Jina search API returns content of all webpages in search results.
Previously code wouldn't remove content beyond max_webpages_to_read
limit set. Now, webpage content in organic results aree explicitly
removed beyond the requested max_webpage_to_read limit.

This should align behavior of online results from Jina with other
online search providers. And restrict llm context to a reasonable size
when using Jina for online search.
2024-11-10 14:51:16 -08:00
Debanjum
8ef7892c5e Exclude non-dictionary doc context from chat history sent to chat models
This fixes chat with old chat sessions. Fixes issue with old Whatsapp
users can't chat with Khoj because chat history doc context was
stored as a list earlier
2024-11-10 14:51:16 -08:00
Debanjum
d892ab3174 Fix handling of command rate limit and improve rate limit messages
Command rate limit wouldn't be shown to user as server wouldn't be
able to handle HTTP exception in the middle of streaming.

Catch exception and render it as LLM response message instead for
visibility into command rate limiting to user on client

Log rate limmit messages for all rate limit events on server as info
messages

Convert exception messages into first person responses by Khoj to
prevent breaking the third wall and provide more details on wht
happened and possible ways to resolve them.
2024-11-10 14:51:16 -08:00
Debanjum
80ee35b9b1 Wrap messages in web, obsidian UI to stay within screen when long links
Wrap long links etc. in chat messages and train of thought lists on
web app app and obsidian plugin by breaking them into newlines by word
2024-11-10 14:49:51 -08:00
Debanjum
f967bdf702 Show correct example index being currently processed in frames eval
Previously the batch start index wasn't being passed so all batches
started in parallel were showing the same processing example index

This change doesn't impact the evaluation itself, just the index shown
of the example currently being evaluated
2024-11-10 14:49:51 -08:00
Debanjum
84a8088c2b Only evaluate non-empty responses to reduce eval script latency, cost
Empty responses by Khoj will always be an incorrect response, so no
need to make call to an evaluator agent to check that
2024-11-10 14:49:51 -08:00
sabaimran
170d959feb Handle offline messages differently, as they don't respond well to the structured messages 2024-11-09 19:52:46 -08:00
sabaimran
2c543bedd7 Add typing to the constructed messages listed 2024-11-09 19:40:27 -08:00
sabaimran
79b15e4594 Only add images when they're present and vision enabled 2024-11-09 19:37:30 -08:00
sabaimran
bd55028115 Fix randint import from random when creating filenames for tmp 2024-11-09 19:17:18 -08:00
sabaimran
92b6b3ef7b Add attached files to latest structured message in chat ml format 2024-11-09 19:17:00 -08:00
sabaimran
835fa80a4b Allow docx conversion in the chatFunction.ts 2024-11-09 18:51:00 -08:00
sabaimran
459318be13 And random suffixes to decreases any clash probability when writing tmp files to disc 2024-11-09 18:46:34 -08:00
sabaimran
dbf0c26247 Remove _summary_ description in function descriptions 2024-11-09 18:42:42 -08:00
sabaimran
e5ac076fc4 Move construct_chat_history method back to conversation.utils.py 2024-11-09 18:27:46 -08:00
sabaimran
bc95a99fb4 Make tracer the last input parameter for all the relevant chat helper methods 2024-11-09 18:22:46 -08:00
sabaimran
ceb29eae74 Add phone number verification and remove telemetry update call from place where authentication middleware isn't yet installed (in the middleware itself). 2024-11-09 12:25:36 -08:00
sabaimran
3badb27744 Remove stored uploaded files after they're processed. 2024-11-08 23:28:02 -08:00
sabaimran
78630603f4 Delete the fact checker application 2024-11-08 17:27:42 -08:00
sabaimran
807687a0ac Automatically generate titles for conversations from history 2024-11-08 16:02:34 -08:00
sabaimran
7159b0b735 Enforce limits on file size when converting to text 2024-11-08 15:27:28 -08:00
sabaimran
4695174149 Add support for file preview in the chat input area (before message sent) 2024-11-08 15:12:48 -08:00
sabaimran
ad46b0e718 Label pages when extract text from pdf, docs content. Fix scroll area in doc preview. 2024-11-08 14:53:20 -08:00
sabaimran
ee062d1c48 Fix parsing for PDFs via content indexing API 2024-11-07 18:17:29 -08:00
sabaimran
623a97a9ee Merge branch 'master' of github.com:khoj-ai/khoj into features/include-full-file-in-convo-with-filter 2024-11-07 17:18:23 -08:00
sabaimran
33498d876b Simplify the share chat page. Don't need it to maintain its own conversation history
- When chatting on a shared page, fork and redirect to a new conversation page
2024-11-07 17:14:11 -08:00
sabaimran
4b8be55958 Convert UUID to string when forking a conversation 2024-11-07 17:13:04 -08:00
sabaimran
9bbe27fe36 Set default value of attached files to empty list 2024-11-07 17:12:45 -08:00
sabaimran
3a51996f64 Process attached files in the chat history and add them to the chat message 2024-11-07 16:06:58 -08:00
sabaimran
a89160e2f7 Add support for converting an attached doc and chatting with it
- Document is first converted in the chatinputarea, then sent to the chat component. From there, it's sent in the chat API body and then processed by the backend
- We couldn't directly use a UploadFile type in the backend API because we'd have to convert the api type to a multipart form. This would require other client side migrations without uniform benefit, which is why we do it in this two-phase process. This also gives us capacity to repurpose the moe generic interface down the road.
2024-11-07 16:06:37 -08:00
sabaimran
e521853895 Remove unnecessary console.log statements 2024-11-07 16:03:31 -08:00
sabaimran
92c3b9c502 Add function to get an icon from a file type 2024-11-07 16:02:53 -08:00
sabaimran
140c67f6b5 Remove focus ring from the text area component 2024-11-07 16:02:02 -08:00
sabaimran
b8ed98530f Accept attached files in the chat API
- weave through all subsequent subcalls to models, where relevant, and save to conversation log
2024-11-07 16:01:48 -08:00
sabaimran
ecc81e06a7 Add separate methods for docx and pdf files to just convert files to raw text, before further processing 2024-11-07 16:01:08 -08:00
sabaimran
394035136d Add an api that gets a document, and converts it to just text 2024-11-07 16:00:10 -08:00
sabaimran
3b1e8462cd Include attach files in calls to extract questions 2024-11-07 15:59:15 -08:00
sabaimran
de73cbc610 Add support for relaying attached files through backend calls to models 2024-11-07 15:58:52 -08:00
Debanjum
4cad96ded6 Add Script to Evaluate Khoj on Google's FRAMES benchmark (#955)
- Why
We need better, automated evals to measure performance shifts of Khoj
across prompt, model and capability changes.

Google's FRAMES benchmark evaluates multi-step retrieval and reasoning
capabilities of AI agents. It's a good starter benchmark to evaluate Khoj.

- Details
This PR adds an eval script to evaluate Khoj responses on the the FRAMES
benchmark prompts against the ground truth provided by it.

Script allows configuring sample size, batch size, sampling queries from the
eval dataset.

Gemini is used as an LLM Judge to auto grade Khoj responses vs ground truth 
data from the benchmark.
2024-11-06 17:52:01 -08:00
Debanjum
8679294bed Remove need to set server chat settings from use openai proxies docs
This was previously required, but now it's only usefuly for more
advanced settings, not typical for self-hosting users.

With recent updates, the user's selected chat model is used for both
Khoj's train of thought and response. This makes it easy to
switch your preferred chat model directly from the user settings
page and not have to update this in the admin panel as well.

Reflect these code changse in the docs, by removing the unnecessary
step for self-hosted users to create a server chat setting when using
an OpenAI proxy service like Ollama, LiteLLM etc.
2024-11-05 17:10:53 -08:00
Debanjum
05a93fcbed v-align attach, send buttons with chat input text area on web app
Otherwise, those buttons look off-center when images are attached to
the chat input area
2024-11-05 17:10:53 -08:00
sabaimran
a0480d5f6c use fill weight for the toggle right (enabled state) for research mode 2024-11-04 22:01:09 -08:00
sabaimran
dc26da0a12 Add uploaded files in the conversation file filter for a new convo 2024-11-04 22:00:47 -08:00
Debanjum
b51ee644aa Fix escaping filename when normalizing in org node parser 2024-11-04 20:24:57 -08:00
Debanjum
5724d16a6f Fix passing images to anthropic chat models to extract questions 2024-11-04 20:24:57 -08:00
sabaimran
cf0bcec0e7 Revert SKIP_TESTS flag in offline chat director tests 2024-11-04 19:06:54 -08:00
sabaimran
1f372bf2b1 Update file summarization unit tests now that multiple files are allowed 2024-11-04 17:45:54 -08:00
sabaimran
7543360210 Merge branch 'master' of github.com:khoj-ai/khoj into features/include-full-file-in-convo-with-filter 2024-11-04 16:55:48 -08:00
sabaimran
b6145df3be Handle file retrieval when agent is None 2024-11-04 16:55:22 -08:00
sabaimran
3dc9139cee Add additional handling for when file_object comes back empty 2024-11-04 16:53:07 -08:00
sabaimran
a27b8d3e54 Remove summarize condition for only 1 file filter 2024-11-04 16:51:37 -08:00
sabaimran
362bdebd02 Add methods for reading full files by name and including context
Now that models have much larger context windows, we can reasonably include full texts of certain files in the messages. Do this when an explicit file filter is set in a conversation. Do so in a separate user message in order to mitigate any confusion in the operation.

Pipe the relevant attached_files context through all methods calling into models.

We'll want to limit the file sizes for which this is used and provide more helpful UI indicators that this sort of behavior is taking place.
2024-11-04 16:37:13 -08:00
sabaimran
e3ca52b7cb Use .get() to get text accompanying image url, instead of subindexing 2024-11-04 16:09:16 -08:00
sabaimran
1e89baca7b Deprecate the UserSearchModelConfig and remove all references
- The server has moved to a model of standardization for the embeddings generation workflow. Remove references to the support for differentiated models.
- The migration script fo ra new model needs to be updated to accommodate full regeneration.
2024-11-04 12:24:41 -08:00
Debanjum
1ccbf72752 Use logger instead of print to track eval 2024-11-04 00:40:26 -08:00
sabaimran
99c1d2831a Release Khoj version 1.28.3 2024-11-02 12:23:11 -07:00
sabaimran
075b4ecf15 Call subscription_to_state with sync_to_async wrapper when getting user subscription state
- This is needed in case the renewal_date is not set and we need to reset it for the user
2024-11-02 12:22:35 -07:00
sabaimran
ec44cbe1e7 Release Khoj version 1.28.2 2024-11-02 07:53:51 -07:00
Debanjum
791eb205f6 Run prompt batches in parallel for faster eval runs 2024-11-02 04:58:03 -07:00
Debanjum
96904e0769 Add script to evaluate khoj on Google's FRAMES benchmark
Google's FRAMES benchmark evaluates multi-step retrieval and reasoning
capabilities of an agent.

The script uses Gemini as an LLM Judge to evaluate Khoj responses to
the FRAMES benchmark prompts against the ground truth provided by it.
2024-11-02 04:57:42 -07:00
Debanjum
31b5fde163 Only enable prompt tracer if git python is installed 2024-11-02 02:07:02 -07:00
sabaimran
5b18dc96e0 Release Khoj version 1.28.1 2024-11-01 22:51:51 -07:00
sabaimran
8d1b1bc78e Move the git python dependency into top level dependencies 2024-11-01 22:51:00 -07:00
Debanjum
e85dd59295 Release Khoj version 1.28.0 2024-11-01 19:06:59 -07:00
Debanjum
1f79a10541 Fix link to code execution feature in docs 2024-11-01 18:22:21 -07:00
Debanjum
cff8e02b60 Research Mode [Part 2]: Improve Prompts, Edit Chat Messages. Set LLM Seed for Reproducibility (#954)
- Improve chat actors and their prompts for research mode.
- Add documentation to enable the code tool when self-hosting Khoj
- Edit Chat Messages
  - Store Turn Id in each chat message. 
  - Expose API to delete chat message.
  - Expose delete chat message button to turn delete chat message from web app
- Set LLM Generation Seed for Reproducible Debugging and Testing
  - Setting seed for LLM generation is supported by Llama.cpp and OpenAI models. 
    This can (somewhat) restrain LLM output
  - Getting fixed responses for fixed inputs helps test, debug longer reasoning chains like used in advanced reasoning
2024-11-01 18:16:42 -07:00
Debanjum
14e453039d Add prompt tracing, agent personality to infer webpage urls chat actor 2024-11-01 18:12:50 -07:00
Debanjum
ab321dc518 Expect query before tool in response to give think space in research prompt 2024-11-01 17:51:41 -07:00
Debanjum
1a83bbcc94 Clean API chat router. Move FeedbackData response type to router helper 2024-11-01 17:51:41 -07:00
sabaimran
e6eb87bbb5 Merge branch 'improve-debug-reasoning-and-other-misc-fixes' of github.com:khoj-ai/khoj into improve-debug-reasoning-and-other-misc-fixes 2024-11-01 16:48:39 -07:00
sabaimran
a213b593e8 Limit the number of urls the webscraper can extract for scraping 2024-11-01 16:48:36 -07:00
sabaimran
327fcb8f62 create defiltered query after conversation command is extracted 2024-11-01 16:48:03 -07:00
sabaimran
b79a9ec36d Clarify description of the code evaluation environment: not for document creation 2024-11-01 16:47:27 -07:00
Debanjum
9c7b36dc69 Use standard per minute rate limits across user types 2024-11-01 16:16:06 -07:00
Debanjum
ac21b10dd5 Simplify logic to get default search model. Remove unused import 2024-11-01 15:14:00 -07:00
sabaimran
2b35790165 Merge branch 'master' of github.com:khoj-ai/khoj into improve-debug-reasoning-and-other-misc-fixes 2024-11-01 14:51:26 -07:00
Debanjum
22f3ed3f5d Research Mode: Give Khoj the ability to perform more advanced reasoning (#952)
## Overview
Khoj can now go into research mode and use a python code interpreter. These are experimental features that are being released early for feedback and testing.

- Research mode allows Khoj to dynamically select the tools it needs to best answer the question. It is also allowed more iterations to get to a satisfactory answer. Its more dynamic train of thought is shown to improve visibility into its thinking.
- Adding ability for Khoj to use a python code interpreter is an adjacent capability. It can help Khoj do some data analysis and generate charts for you. A sandboxed python to run code is provided using [cohere-terrarium](https://github.com/cohere-ai/cohere-terrarium?tab=readme-ov-file), [pyodide](https://pyodide.org/).

## Analysis
Research mode (significantly?) improves Khoj's information retrieval for more complex queries requiring multi-step lookups but takes longer to run. It can research for longer, requiring less back-n-forth with the user to find an answer.

Research mode gives most gains when used with more advanced chat models (like o1, 4o, new claude sonnet and gemini-pro-002). Smaller models improve their response quality but tend to get into repetitive loops more often. 

## Next Steps
- Get community feedback on research mode. What works, what fails, what is confusing, what'd be cool to have.
- Tune Khoj's capabilities for longer autonomous runs and to generalize across a larger range of model sizes

## Miscellaneous Improvements
- Khoj's train of thought is saved and shown for all messages, not just the latest one
- Render charts generated by Khoj and code running using the code tool on the web app
- Align chat input color to currently selected agent color
2024-11-01 14:46:29 -07:00
sabaimran
baa939f4ce When running code, strip any code delimiters. Disable application json type specification in Gemini request. 2024-11-01 13:47:39 -07:00
sabaimran
8fd2fe162f Determine if research mode is enabled by checking the conversation commands and 'linting' them in the selection phase 2024-11-01 13:12:34 -07:00
sabaimran
cead1598b9 Don't reset research mode after completing research execution 2024-11-01 13:00:11 -07:00
Debanjum
c1c779a7ef Do not yaml format raw code results in context for LLM. It's confusing 2024-11-01 12:45:26 -07:00
sabaimran
b3dad1f393 Standardize rate limits to 1/6 ratio 2024-11-01 12:21:09 -07:00
sabaimran
23a49b6b95 Add documentation for python code execution capability 2024-11-01 12:14:33 -07:00
Debanjum
cd75151431 Do not allow auto selecting research mode as tool for now.
You are required to manually turning it on. This takes longer and
should be a high intent activity initiated by user
2024-11-01 12:07:52 -07:00
Debanjum
0b0cfb35e6 Simplify in research mode check in api_chat.
- Dedent code for readability
- Use better name for in research mode check
- Continue to remove inferred summarize command when multiple files in
  file filter even when not in research mode
- Continue to show select information source train of thought.
  It was removed by mistake earlier
2024-11-01 12:07:08 -07:00
sabaimran
ffa7f95559 Add template for a code sandbox to the docker-compose configuration 2024-11-01 11:50:58 -07:00
Debanjum
73750ef286 Merge branch 'master' into features/advanced-reasoning 2024-11-01 11:42:01 -07:00
sabaimran
1fc280db35 Handle case where infer_webpage_url returns no valid urls 2024-11-01 11:41:32 -07:00
Debanjum
1c920273dd Add Prompt Tracer to Visualize, Analyze and Debug Khoj's Train of Thought (#951)
## Overview
Use git to capture prompt traces of khoj's train of thought. View, analyze and debug them using your favorite git client (e.g vscode, magit).

- Each commit captures an interaction with an LLM
  The commit writes the query, response and system message each to a separate file in the repo.
  The commit message captures the chat model, Khoj version and other metadata
- Each conversation turn can have multiple interactions with an LLM (e.g Khoj's train of thought)
- Each new conversation turn forks from and merges back into its conversation branch
- Each new conversation branches from the user branch
- Each new user branches from root commit on the main branch

## Usage
1. Set `KHOJ_DEBUG=true` or start khoj in very verbose mode with `khoj -vv` to turn on prompt tracing
2. Chat with Khoj as usual 
3. Open the promptrace git repo to view the generated prompt traces using your favorite git porcelain. 
   The Khoj prompt trace git repo is created at `/tmp/khoj_promptrace` by default. You can configure the prompt trace directory by setting the `PROMPTRACE_DIR`environment variable.

## Implementation
- Add utility functions to capture prompt traces using git (via `gitpython`)
- Make each model provider in Khoj commit their LLM interactions with promptrace
- Weave chat metadata from chat API through all chat actors and commit it to the prompt trace
2024-11-01 11:33:54 -07:00
sabaimran
33d36ee58c Add experimental notice to research mode tooltip 2024-11-01 11:00:27 -07:00
sabaimran
0145b2a366 Set usage limits on the research mode 2024-11-01 10:29:33 -07:00
sabaimran
3ea94ac972 Only include inferred-queries in chat history when present 2024-10-31 22:01:41 -07:00
sabaimran
149cbe1019 Use bottom anchor for the commandbar popover 2024-10-31 20:40:38 -07:00
sabaimran
21858acccc Remove conversation command always in query, filter out inferred queries that were not with selected tool when going through tool selection iterations 2024-10-31 20:27:38 -07:00
sabaimran
19241805ee Merge branch 'master' of github.com:khoj-ai/khoj into improve-debug-reasoning-and-other-misc-fixes 2024-10-31 18:20:23 -07:00
Debanjum
302bd51d17 Improve online chat actor prompt for research and normal mode
- Match the online query generator prompt to match the formatting of
  extract questions
- Separate iteration results by newline
- Improve webpage and online tool descriptions
2024-10-31 18:17:12 -07:00
Debanjum
52163fe299 Improve research planner prompt to reduce looping 2024-10-31 18:17:01 -07:00
sabaimran
7ebf999688 Merge branch 'master' of github.com:khoj-ai/khoj into features/advanced-reasoning 2024-10-31 18:15:13 -07:00
sabaimran
159ea44883 Remove frame references in the diagramming prompts 2024-10-31 18:14:51 -07:00
Debanjum
89597aefe9 Json dump contents in prompt tracer to make structure discernable 2024-10-31 18:08:42 -07:00
Debanjum
5b15176e20 Only add /research prefix in research mode if not already in user query 2024-10-31 18:08:42 -07:00
sabaimran
559601dd0a Do not exit if/else loop in research loop when notes not found 2024-10-31 13:51:10 -07:00
sabaimran
a13760640c Only show trash can when turnId is present 2024-10-31 13:19:16 -07:00
sabaimran
8d1ecb9bd8 Add optional brew steps for docker install 2024-10-31 12:41:53 -07:00
Debanjum
adca6cbe9d Merge branch 'master' into add-prompt-tracer-for-observability 2024-10-31 02:28:34 -07:00
Debanjum
e17dc9f7b5 Put train of thought ui before Khoj response on web app 2024-10-31 02:24:53 -07:00
Debanjum
e8e6ead39f Fix deleting new messages generated after conversation load 2024-10-30 20:56:38 -07:00
Debanjum
cb90abc660 Resolve train of thought component needs unique key id error on web app 2024-10-30 14:00:21 -07:00
Debanjum
ca5a6831b6 Add ability to delete messages from the web app 2024-10-30 14:00:21 -07:00
Debanjum
ba15686682 Store turn id with each chat message. Expose API to delete chat turn
Each chat turn is a user query, khoj response message pair
2024-10-30 14:00:21 -07:00
Debanjum
f64f5b3b6e Handle add/delete file filter operation on non-existent conversation 2024-10-30 14:00:21 -07:00
Debanjum
b3a63017b5 Support setting seed for reproducible LLM response generation
Anthropic models do not support seed. But offline, gemini and openai
models do. Use these to debug and test Khoj via KHOJ_LLM_SEED env var
2024-10-30 14:00:21 -07:00
Debanjum
d44e68ba01 Improve handling embedding model config from admin interface
- Allow server to start if loading embedding model fails with an error.
  This allows fixing the embedding model config via admin panel.

  Previously server failed to start if embedding model was configured
  incorrectly. This prevented fixing the model config via admin panel.

- Convert boolean string in config json to actual booleans when passed
  via admin panel as json before passing to model, query configs

- Only create default model if no search model configured by admin.
  Return first created search model if its been configured by admin.
2024-10-30 14:00:21 -07:00
Debanjum
358a6ce95d Defer turning cursor color to selected agents color for later
Capability exists but idea needs to be investigated further
2024-10-30 14:00:21 -07:00
Debanjum
2ac840e3f2 Make cursor in chat input take on selected agent color 2024-10-30 14:00:21 -07:00
Debanjum
1448b8b3fc Use 3rd person for user in research prompt to reduce person confusion
Models were getting a bit confused about who is search for who's
information. Using third person to explicitly call out on who's behalf
these searches are running seems to perform better across
models (gemini's, gpt etc.), even if the role of the message is user.
2024-10-30 13:49:48 -07:00
Debanjum
b8c6989677 Separate example from actual question in extract question prompt 2024-10-30 13:49:48 -07:00
Debanjum
86ffd7a7a2 Handle \n, dedupe json cleaning into single function for reusability
Use placeholder for newline in json object values until json parsed
and values extracted. This is useful when research mode models outputs
multi-line codeblocks in queries etc.
2024-10-30 13:49:48 -07:00
Debanjum
83ca820abe Encourage Anthropic models to output json object using { prefill
Anthropic API doesn't have ability to enforce response with valid json
object, unlike all the other model types.

While the model will usually adhere to json output instructions.

This step is meant to more strongly encourage it to just output json
object when response_type of json_object is requested.
2024-10-30 13:49:48 -07:00
Debanjum
dc8e89b5de Pass tool AIs iteration history as chat history for better context
Separate conversation history with user from the conversation history
between the tool AIs and the researcher AI.

Tools AIs don't need top level conversation history, that context is
meant for the researcher AI.

The invoked tool AIs need previous attempts at using the tool in this
research runs iteration history to better tune their next run.

Or at least that is the hypothesis to break the models looping.
2024-10-30 13:49:48 -07:00
Debanjum
d865994062 Rename code tool arg previous_iteration_history' to context' 2024-10-30 13:49:48 -07:00
Debanjum
06aeca2670 Make researcher, docs search AIs ask more diverse retrieval questions
Models weren't generating a diverse enough set of questions. They'd do
minor variations on the original query. What is required is asking
queries from a bunch of different lenses to retrieve the requisite
information.

This prompt updates shows the AIs the breadth of questions to by
example and instruction. Seem like performance improved based on vibes
2024-10-30 13:49:48 -07:00
Debanjum
01881dc7a2 Revert "Make extract question prompt in 1st person wrt user as its a user message"
This reverts commit 6d3602798aa1b95a30c557576fd4f93ddef2ae76.
2024-10-30 13:49:48 -07:00
Debanjum
3e695df198 Make extract question prompt in 1st person wrt user as its a user message
Divide Example from Actual chat history section in prompt
2024-10-30 13:49:48 -07:00
Debanjum
a3751d6a04 Make extract relevant information system prompt work for any document
Previously it was too strongly tuned for extracting information from
only webpages. This shouldn't be necessary
2024-10-30 13:49:48 -07:00
Debanjum
a39e747d07 Improve passing user name in pick next research tool prompt 2024-10-30 13:49:48 -07:00
Debanjum
deff512baa Improve research mode prompts to reduce looping, increase webpage reads 2024-10-30 13:49:48 -07:00
Debanjum
d3184ae39a Simplify storing and displaying document results in research mode
- Mention count of notes and files disovered
- Store query associated with each compiled reference retrieved for
  easier referencing
2024-10-30 13:49:48 -07:00
Debanjum
8bd94bf855 Do not use a message branch if no msg id provided to prompt tracer 2024-10-30 13:49:48 -07:00
sabaimran
b63fbc5345 Add a simple badget to the dropdown menu that shows subscription status 2024-10-30 13:00:16 -07:00
sabaimran
82f3d79064 Merge branch 'master' of github.com:khoj-ai/khoj into features/advanced-reasoning 2024-10-30 11:32:10 -07:00
sabaimran
2b2564257e Handle subscription case where it's set to trial, but renewal_date is not set. set the renewal_date for LENGTH_OF_FREE_TRIAL days from subscription creation. 2024-10-30 11:05:31 -07:00
Debanjum
9935d4db0b Do not use a message branch if no msg id provided to prompt tracer 2024-10-28 17:50:27 -07:00
Debanjum
d184498038 Pass context in separate message from user query to research chat actor 2024-10-28 15:37:28 -07:00
Debanjum
d75ce4a9e3 Format online, notes, code context with YAML to be legibile for LLM 2024-10-28 15:37:28 -07:00
sabaimran
5bea0c705b Use break-words in the train of thought for better formatting 2024-10-28 15:36:06 -07:00
sabaimran
1f1b182461 Automatically carry over research mode from home page to chat
- Improve mobile friendliness with new research mode toggle, since chat input area is now taking up more space
- Remove clunky title from the suggestion card
- Fix fk lookup error for agent.creator
2024-10-28 15:29:24 -07:00
sabaimran
ebaed53069 Merge branch 'master' of github.com:khoj-ai/khoj into features/advanced-reasoning 2024-10-28 12:39:00 -07:00
sabaimran
889dbd738a Add keyword diagram to diagram output mode description 2024-10-28 12:20:46 -07:00
Debanjum
50ffd7f199 Merge branch 'master' into features/advanced-reasoning 2024-10-28 04:10:59 -07:00
Debanjum
a5d0ca6e1c Use selected agent color to theme the chat input area on home page 2024-10-28 03:47:40 -07:00
Debanjum
aad7528d1b Render slash commands popup below chat input text area on home page 2024-10-28 02:06:04 -07:00
Debanjum
3e17ab438a Separate notes, online context from user message sent to chat models (#950)
Overview
---
- Put context into separate user message before sending to chat model.
  This should improve model response quality and truncation logic in code
- Pass online context from chat history to chat model for response.
  This should improve response speed when previous online context can be reused
- Improve format of notes, online context passed to chat models in prompt.
  This should improve model response quality

Details
---
The document, online search context are now passed as separate user
messages to chat model, instead of being added to the final user message.

This will improve
- Models ability to differentiate data from user query.
  That should improve response quality and reduce prompt injection
  probability
- Make truncation logic simpler and more robust
  When context window hit, can simply pop messages to auto truncate
  context in order of context, user, assistant message for each
  conversation turn in history until reach current user query

  The complex, brittle logic to extract user query from context in
  last user message isn't required.
2024-10-28 02:03:18 -07:00
Debanjum
8ddd70f3a9 Put context into separate message before sending to offline chat model
Align context passed to offline chat model with other chat models

- Pass context in separate message for better separation between user
  query and the shared context
- Pass filename in context
- Add online results for webpage conversation command
2024-10-28 00:22:21 -07:00
Debanjum
ee0789eb3d Mark context messages with user role as context role isn't being used
Context role was added to allow change message truncation order based
on context role as well.

Revert it for now since currently this is not currently being done.
2024-10-28 00:04:14 -07:00
Debanjum
4e39088f5b Make agent name in home page carousel not text wrap on mobile 2024-10-27 23:03:53 -07:00
Debanjum
94074b7007 Focus chat input on toggle research mode. v-align it with send button 2024-10-27 22:54:55 -07:00
sabaimran
a691ce4aa6 Batch entries into smaller groups to process 2024-10-27 20:43:41 -07:00
sabaimran
2924909692 Add a research mode toggle to the chat input area 2024-10-27 16:37:40 -07:00
sabaimran
68499e253b Auto-collapse train of thought, show after chat response in history 2024-10-27 15:48:13 -07:00
sabaimran
101ea6efb1 Add research mode as a slash command, remove from default path 2024-10-27 15:47:44 -07:00
sabaimran
0bd78791ca Let user exit from command mode with esc, click out, etc. 2024-10-27 15:01:49 -07:00
sabaimran
a121d67b10 Persist the train of thought in the conversation history 2024-10-26 23:46:15 -07:00
sabaimran
9e8ac7f89e Fix input/output mismatches in the /summarize command 2024-10-26 16:37:58 -07:00
sabaimran
e4285941d1 Use the advanced chat model if the user is subscribed 2024-10-26 16:00:54 -07:00
sabaimran
33e48aa27e Merge branch 'add-prompt-tracer-for-observability' of github.com:khoj-ai/khoj into features/advanced-reasoning 2024-10-26 14:09:00 -07:00
sabaimran
fd71a4b086 Add better exception handling in the prompt trace logic, use default value from parameters 2024-10-26 14:08:00 -07:00
Debanjum
3e5b5ec122 Encourage model to read webpages more often after online search
Previously model would rarely read webpages after webpage search. Need
the model to webpages more regularly for deeper research and to stop
getting stuck in repetitive online search loops
2024-10-26 10:49:09 -07:00
Debanjum
bf96d81943 Format online results as YAML to pass it in more readable form to model
Previous passing of online results as json dump in prompts was less
readable for humans, and I'm guessing less readable for
models (trained on human data) as well?
2024-10-26 10:49:09 -07:00
Debanjum
3e97ebf0c7 Unescape special characters in prompt traces for better readability 2024-10-26 10:49:09 -07:00
Debanjum
8af9dc3ee1 Unescape special characters in prompt traces for better readability 2024-10-26 10:45:42 -07:00
Debanjum Singh Solanky
0f3927e810 Send gathered references to client after code results calculated 2024-10-26 05:59:10 -07:00
Debanjum Singh Solanky
f04f871a72 Merge branch 'add-prompt-tracer-for-observability' of github.com:khoj-ai/khoj into features/advanced-reasoning
- Start from this branches src/khoj/routers/api_chat.py
    Add tracer to all old and new chat actors that don't have it set
    when they are called.
  - Update the new chat actors like apick next tool etc to use tracer too
2024-10-26 05:56:13 -07:00
Debanjum Singh Solanky
ddc6ccde2d Merge branch 'master' into features/advanced-reasoning
- Conflicts:
  Combine both sides of the conflict in all 3 files below
  - src/khoj/processor/conversation/utils.py
  - src/khoj/routers/helpers.py
  - src/khoj/utils/helpers.py
2024-10-26 05:15:51 -07:00
Debanjum Singh Solanky
ea0712424b Commit conversation traces using user, chat, message branch hierarchy
- Message train of thought forks and merges from its conversation branch
- Conversation branches from user branch
- User branches from root commit on the main branch

- Weave chat tracer metadata from api endpoint through all chat actors
  and commit it to the prompt trace
2024-10-26 05:08:47 -07:00
Debanjum Singh Solanky
a3022b7556 Allow Offline Chat model calling functions to save conversation traces 2024-10-26 05:08:47 -07:00
Debanjum Singh Solanky
eb6424f14d Allow Anthropic API calling functions to save conversation traces 2024-10-26 05:08:47 -07:00
Debanjum Singh Solanky
6fcd6a5659 Allow Gemini API calling functions to save conversation traces 2024-10-26 05:08:47 -07:00
Debanjum Singh Solanky
384f394336 Allow OpenAI API calling functions to save conversation traces 2024-10-26 04:59:21 -07:00
Debanjum Singh Solanky
10c8fd3b2a Save conversation traces to git for visualization 2024-10-26 04:59:19 -07:00
sabaimran
7e0a692d16 Release Khoj version 1.27.1 2024-10-25 15:23:07 -07:00
sabaimran
b257fa1884 Add a None check before doing a DT comparison when getting subscription type 2024-10-25 15:22:48 -07:00
sabaimran
0f6f282c30 Release Khoj version 1.27.0 2024-10-25 14:11:14 -07:00
sabaimran
479e156168 Add to the ConversationCommand.Image description to LLM 2024-10-25 09:14:32 -07:00
sabaimran
a11b5293fb Add uploaded images to research mode, code slash command, include code references 2024-10-24 23:56:24 -07:00
sabaimran
5acf40c440 Clean up summarization code paths
Use assumption of summarization response being a str
2024-10-24 23:56:24 -07:00
sabaimran
12b32a3d04 Resolve merge conflicts 2024-10-24 23:43:55 -07:00
Debanjum
adee5a3e20 Give Vision to Anthropic models in Khoj (#948)
### Major
- Give Vision to Anthropic models in Khoj

### Minor
- Reuse logic to format messages for chat with anthropic models
- Make the get image from url function more versatile and reusable
- Encourage output mode chat actor to output only json and nothing else
2024-10-24 18:02:38 -07:00
Debanjum Singh Solanky
01d740debd Return typed image from image_with_url function for readability 2024-10-24 17:58:46 -07:00
Debanjum Singh Solanky
37317e321d Dedupe user location passed in image, diagram generation prompts 2024-10-24 01:03:29 -07:00
Debanjum Singh Solanky
2a32836d1a Log more descriptive error when image gen fails with Replicate 2024-10-24 01:03:29 -07:00
sabaimran
30f9225021 Merge branch 'master' of github.com:khoj-ai/khoj into features/advanced-reasoning 2024-10-23 19:15:51 -07:00
sabaimran
5120597d4e Remove user customized search model (#946)
- Use a single standard search model across the server. There's diminishing benefits for having multiple user-customizable search models. 
- We may want to add server-level customization for specific tasks
- Store the search model used to generate a given entry on the `Entry` object
- Remove user-facing APIs and view
- Add a management command for migrating the default search model on the server

In a future PR (after running the migration), we'll also remove the `UserSearchModelConfig`
2024-10-23 17:38:37 -07:00
Debanjum Singh Solanky
8d588e0765 Encourage output mode chat actor to output only json and nothing else
Latest claude model wanted to say more than just give the json output.
The updated prompt encourages the model to ouput just json. This is
similar to what is already being done for other prompts
2024-10-23 17:19:21 -07:00
Debanjum Singh Solanky
abad5348a0 Give Vision to Anthropic models in Khoj 2024-10-23 17:19:21 -07:00
Debanjum Singh Solanky
6fd50a5956 Reuse logic to format messages for chat with anthropic models 2024-10-23 17:19:21 -07:00
Debanjum Singh Solanky
82eac5a043 Make the get image from url function more versatile and reusable
It was previously added under the google utils. Now it can be used by
other conversation processors as well.

The updated function
- can get both base64 encoded and PIL formatted images from url
- will return the media type of the image as well in response
2024-10-23 17:19:20 -07:00
sabaimran
f3ce47b445 Create explicit flow to enable the free trial (#944)
* Create explicit flow to enable the free trial

The current design is confusing. It obfuscates the fact that the user is on a free trial. This design will make the opt-in explicit and more intuitive.

* Use the Subscription Type enum instead of hardcoded strings everywhere

* Use length of free trial in the frontend code as well
2024-10-23 15:29:23 -07:00
Debanjum Singh Solanky
bc059eeb0b Merge branch 'master' into put-retrieved-context-in-separate-chatml-message 2024-10-23 12:55:18 -07:00
Debanjum Singh Solanky
3b978b9b67 Fix chat history construction when generating chatml msgs with context 2024-10-23 12:55:12 -07:00
sabaimran
c5e91c346a Fix Docker desktop link for Linux 2024-10-23 11:24:54 -07:00
Debanjum Singh Solanky
9f2c02d9f7 Chat with the default agent by default from web app home
Had temporarily updated the default selected agent to last used.
Revert for now as
1. The previous logic was buggy. It didn't select the default agent
   even when the last used agent was the default agent. Which would
   require more work.
2. It maybe too early anyway to set the default agent to last used.
2024-10-23 03:43:57 -07:00
Debanjum Singh Solanky
218946edda Fix copying message with user images on web app
Adding div elements to message to render degraded text copied to
clipboard for messages with user uploaded images.

This change fixes that by separating message to render from message
for clipboard. It ensures differently formatted forms of the user
images are added to the two to allow proper rendering while still
having decently formatted text copied to clipboard
2024-10-23 03:41:25 -07:00
Debanjum Singh Solanky
7d9a06c8ab Merge branch 'master' into put-retrieved-context-in-separate-chatml-message 2024-10-23 00:13:38 -07:00
sabaimran
7c29af9745 Add link to self-hosted admin page and add docs for building front-end assets. Close #901 2024-10-22 22:42:27 -07:00
Debanjum Singh Solanky
2a50694089 Allow typing multi-line queries from a phone with Enter key
Add newline instead of sending message when hit Enter key on mobile
displays. As on phones shift key doesn't exist and send button is easily
clickable.

Limit hitting Enter key to send message to computers = larger display
= expected to have full fledged keyboards.
2024-10-22 21:20:22 -07:00
Debanjum Singh Solanky
a134cd835c Focus on chat input area to enter text after file uploads on web app 2024-10-22 21:19:17 -07:00
Debanjum
c81e708833 Show all agents, smart sorted, in carousel on home screen of web app (#943)
## Overview
Allow quickly selecting, switching agents from agents pane on home page of web app

## Details
- Show all agents in carousel on home screen agent pane of web app
- Smart Sort
  1. Pin default agent as first for ease of access
  2. Show used agents by MRU for ease of access
  3. Shuffle unused agents for discoverability
- Select most recently used agent to chat with by default
- Push smart sort logic down to API
  - Common logic can be reused across clients
  - Agent sort was previously done in web app
- Focus on chat input on agent select
- Double click agent on home page to open edit agent card on agents page
2024-10-22 21:18:17 -07:00
Debanjum Singh Solanky
750fbce0c2 Merge branch 'master' into improve-agent-pane-on-home-screen 2024-10-22 20:05:29 -07:00
Debanjum Singh Solanky
3be505db48 Only show type of error when image generation fails to clients
Rather than showing raw error message from the underlying service as it
could contain sensitive information
2024-10-22 20:03:20 -07:00
Debanjum
c6f3253ebd Chat with Multiple Images. Support Vision with Gemini (#942)
## Overview
- Add vision support for Gemini models in Khoj
- Allow sharing multiple images as part of user query from the web app
- Handle multiple images shared in query to chat API
2024-10-22 19:59:18 -07:00
Debanjum Singh Solanky
b3fff43542 Sanitize user attached images. Constrain chat input width on home page
Set max combined images size to 20mb to allow multiple photos to be shared
2024-10-22 19:42:40 -07:00
Debanjum Singh Solanky
6c393800cc Merge branch 'master' into multi-image-chat-and-vision-for-gemini 2024-10-22 18:38:49 -07:00
Debanjum Singh Solanky
91bbd19333 Close the agent detail hover card when scroll on agent pane 2024-10-22 18:03:17 -07:00
Debanjum Singh Solanky
110c67f083 Improve agent pill, detail card styling. Handle null chatInputRef
- Remove border from agent detail hover card on home page
- Do not wrap long agent names in agent pills on home page
- Handle scenario where chatInputRef is null
2024-10-22 18:03:17 -07:00
Debanjum Singh Solanky
aca8bef024 Only use recent chat sessions for agent MRU. Handle null agent chats 2024-10-22 17:46:45 -07:00
sabaimran
0dad4212fa Generate dynamic diagrams (via Excalidraw) (#940)
Add support for generating dynamic diagrams in flow with Excalidraw (https://github.com/excalidraw/excalidraw). This happens in three steps:
1. Default information collection & intent determination step.
2. Improving the overall guidance of the prompt for generating a JSON, Excalidraw-compatible declaration.
3. Generation of the diagram to output to the final UI.

Add support in the web UI.
2024-10-22 16:13:46 -07:00
sabaimran
1e993d561b Release Khoj version 1.26.4 2024-10-22 13:50:08 -07:00
Debanjum Singh Solanky
e8fb79a369 Rate limit the count and total size of images shared via API 2024-10-22 04:37:54 -07:00
Debanjum Singh Solanky
39a613d3bc Fix up openai chat actor tests 2024-10-22 03:09:36 -07:00
Debanjum Singh Solanky
0847fb0102 Pass online context from chat history to chat model for response
Previously only notes context from chat history was included.
This change includes online context from chat history for model to use
for response generation.

This can reduce need for online lookups by reusing previous online
context for faster responses. But will increase overall response time
when not reusing past online context, as faster context buildup per
conversation.

Unsure if inclusion of context is preferrable. If not, both notes and
online context should be removed.
2024-10-22 03:09:36 -07:00
Debanjum Singh Solanky
0c52a1169a Put context into separate user message before sending to chat model
The document, online search context are now passed as separate user
messages to chat model, instead of being added to the final user message.

This will improve

- Models ability to differentiate data from user query.
  That should improve response quality and reduce prompt injection
  probability

- Make truncation logic simpler and more robust
  When context window hit, can simply pop messages to auto truncate
  context in order of context, user, assistant message for each
  conversation turn in history until reach current user query

  The complex, brittle logic to extract user query from context in
  last user message isn't required.

Marking the context message with assistant role doesn't translate well
across chat models. E.g
- Gemini can't handle consecutive messages by role = model well
- Claude will merge consecutive messages by same role. In current
  message ordering the context message will result get merged into the
  previous assistant response. And if move context message after user
  query. The truncation logic will have to hop and skip while doing
  deletions
- GPT seems to handle consecutive roles of any type fine

Using context role = user generalizes better across chat models for
now and aligns with previous behavior.
2024-10-22 03:09:36 -07:00
Debanjum Singh Solanky
7ac241b766 Improve format of notes, online context passed to chat models in prompt
Improve separation of note snippets and show its origin file in notes
prompt to have more readable, contextualized text shared with model.

Previously the references dict was being directly passed as a string.
The documents don't look well formatted and are less intelligible.

- Passing file path along with notes snippets will help contextualize
  the notes better.
- Better formatting should help with making notes more readable by the
  chat model.
2024-10-22 03:09:36 -07:00
sabaimran
892040972f Replace user_id with server_id in telemetry 2024-10-21 20:47:52 -07:00
sabaimran
db959a504d Fix the version of pymupdf to avert build errors 2024-10-21 12:56:51 -07:00
sabaimran
21e69b506d Release Khoj version 1.26.3 2024-10-21 08:19:05 -07:00
Debanjum Singh Solanky
9b554feb91 Show agent details card on hover on agent pill on web app home page
- Double click on agent to open edit agent card
- Focus on chat input pane when agent selected/clicked
  for quick, smooth agent switch and message flow
- Hover on agent to see agent detail card on non-mobile displays
  - Use debounce to only show when hover on card for a bit
2024-10-21 00:08:01 -07:00
Debanjum Singh Solanky
220ff1df62 Set chatInputArea forward ref from parent components for control 2024-10-21 00:02:48 -07:00
Debanjum Singh Solanky
54b92eaf73 Extract isUserSubscribed check from Agents page to make it resusable 2024-10-20 23:31:48 -07:00
Debanjum Singh Solanky
bdbe8f003e Move agent details and edit card out into reusable components on web app 2024-10-20 23:31:47 -07:00
sabaimran
ad197be70c Fix PDFs unit test, skip OCR 2024-10-20 22:25:41 -07:00
sabaimran
59fec37943 Improve agents management, and limit agents view to private and official agents
- Default to None for the input_tools and output_modes so that they can be managed in the admin panel
- Hold off on showing off all Public Agents until we have a better experience for user profiles etc.
2024-10-20 22:24:51 -07:00
sabaimran
a979457442 Add unit tests for agents
- Add permutations of testing for with, without knowledge base. Private, public, different users.
2024-10-20 20:04:50 -07:00
Debanjum Singh Solanky
5fca41cc29 Show agents sorted by mru, Select mru agent by default on web app
Have get agents API return agents ordered intelligently
- Put the default agent first
- Sort used agents by most recently chatted with agent for ease of access
- Randomly shuffle the remaining unused agents for discoverability
2024-10-20 15:21:25 -07:00
Debanjum Singh Solanky
a6bfdbdbfe Show all agents in carousel on home screen agent pane of web app
This change wraps the agent pane in a scroll area with all agents shown.
It allows selecting an agent to chat with directly from the home
screen without breaking flow and having to jump to the agents page.

The previous flow was not convenient to quickly and consistently start
chat with one of your standard agents.

This was because a random subet of agents were shown on the home page.
To start chat with an agent not shown on home screen load you had to
open the agents page and initiate the conversation from there.
2024-10-20 15:21:25 -07:00
Debanjum Singh Solanky
7646ac6779 Style user attached images as carousel on chat input area of web app 2024-10-20 00:40:08 -07:00
sabaimran
5d5bea6a5f Ensure images are reset after messages processed 2024-10-19 22:02:06 -07:00
sabaimran
1ad6e1749f Move window redirect to after relevant data is dropped in localStorage on the homage page
One limitation of this methodology is that localStorage has a limit in how much data it can take. Should add more graceful error handling here as well.
2024-10-19 20:36:13 -07:00
sabaimran
cb6b3ec1e9 Improve mode description given to LLM when determining how to respond.
Currently experiencing difficulty instruction following when an image is shared. It's more likely to try and output an image. Update to make a clearer distinction.
2024-10-19 20:35:32 -07:00
sabaimran
545259e308 Remove unused icons in chatInputArea 2024-10-19 16:54:21 -07:00
Debanjum Singh Solanky
3cc1426edf Style user attached images with fixed height, in a single row on web app 2024-10-19 16:48:36 -07:00
Debanjum Singh Solanky
58a331227d Display the attached images inside the chat input area on the web app
- Put the attached images display div inside the same parent div as
  the text area
- Keep the attachment, microphone/send message buttons aligned with
  the text area. So the attached images just show up at the top of the
  text area but everything else stays at the same horizontal height as
  before.

- This improves the UX by
  - Ensuring that the attached images do not obscure the agents pane
    above the chat input area
  - The attached images visually look like they are inside the actual
    input area, rather than floating above it. So the visual aligns
    with the semantics
2024-10-19 16:29:45 -07:00
Debanjum Singh Solanky
3e39fac455 Add vision support for Gemini models in Khoj 2024-10-19 15:47:03 -07:00
Debanjum Singh Solanky
0d6a54c10f Allow sharing multiple images as part of user query from the web app
Previously the web app only expected a single image to be shared by
the user as part of their query.

This change allows sharing multiple images from the web app.

Closes #921
2024-10-19 15:47:03 -07:00
Debanjum Singh Solanky
e2abc1a257 Handle multiple images shared in query to chat API
Previously Khoj could respond to a single shared image at a time.

This changes updates the chat API to accept multiple images shared by
the user and send it to the appropriate chat actors including the
openai response generation chat actor for getting an image aware
response
2024-10-19 14:53:33 -07:00
Debanjum Singh Solanky
c6c48cfc18 Fix arg to generate_summary_from_file and type of this_iteration 2024-10-17 13:38:48 -07:00
Debanjum Singh Solanky
feb6d65ef8 Merge branch 'master' into features/advanced-reasoning 2024-10-15 09:37:56 -07:00
Debanjum Singh Solanky
336c6c3689 Show tool to use decision for next iteration in train of thought 2024-10-15 01:12:18 -07:00
Debanjum Singh Solanky
81fb65fa0a Return data sources to use if exception in data source chat actor
Previously no value was returned if an exception got triggered when
collecting information sources to search.
2024-10-14 18:20:20 -07:00
Debanjum Singh Solanky
3c93f07b3f Try respond even if web search, webpage read fails during chat
Khoj shouldn't refuse to respond to user if web lookups fail.
It should transparently mention that online search etc. failed.
But try respond as best as it can without those references

This change ensures a response to the users query is attempted even
when web info retrieval fails.
2024-10-14 18:13:26 -07:00
Debanjum Singh Solanky
07ab7ebf07 Try respond even if document search via inference endpoint fails
The huggingface endpoint can be flaky. Khoj shouldn't refuse to
respond to user if document search fails.
It should transparently mention that document lookup failed.
But try respond as best as it can without the document references

This changes provides graceful failover when inference endpoint
requests fail either when encoding query or reranking retrieved docs
2024-10-14 18:13:26 -07:00
Debanjum Singh Solanky
d6206aa80c Remove deprecated GET chat API endpoint 2024-10-14 18:13:26 -07:00
Debanjum Singh Solanky
263eee4351 Fix default chat model to use user model if no server chat model set
- Advanced chat model should also fallback to user chat model if set
- Get conversation config should falback to user chat model if set

These assume no server chat model settings is configured
2024-10-14 18:13:26 -07:00
Debanjum Singh Solanky
abcd11cfc0 Merge branch 'master' into features/advanced-reasoning 2024-10-13 03:06:23 -07:00
Debanjum Singh Solanky
9356e66b94 Fix default chat model to use user model if no server chat model set
- Advanced chat model should also fallback to user chat model if set
- Get conversation config should falback to user chat model if set

These assume no server chat model settings is configured
2024-10-13 03:02:29 -07:00
Debanjum Singh Solanky
9314f0a398 Fix default chat configs to use user model if no server chat model set
Post merge cleanup in advanced reasoning to fallback to user chat
model if no server chat model defined for advanced and default
2024-10-13 02:59:10 -07:00
Debanjum Singh Solanky
a2200466b7 Merge branch 'master' into features/advanced-reasoning 2024-10-12 21:01:22 -07:00
Debanjum Singh Solanky
9daaae0fdb Render inline any image files output by code in message
Update regex to also include any links to code generated images that
aren't explicitly meant to be displayed inline. This allows folks to
download the image (unlike the fake link that doesn't work created by
model)
2024-10-12 10:34:57 -07:00
Debanjum Singh Solanky
20d495c43a Update the iterative chat director prompt to generalize across chat models
These prompts work across o1 and standard openai model. Works with
anthropic and google models as well
2024-10-12 10:34:57 -07:00
sabaimran
01a58b71a5 Skip image, code generation if in research mode 2024-10-10 18:06:29 -07:00
Debanjum Singh Solanky
1b13d069f5 Pass data collected from various sources to code tool in normal flow too 2024-10-10 05:19:27 -07:00
Debanjum Singh Solanky
f462d34547 Render images files output by code interpreter in message on web app 2024-10-10 05:17:53 -07:00
Debanjum Singh Solanky
564491e164 Extract date filters quoted with non-ascii quotes in query 2024-10-10 04:45:00 -07:00
Debanjum Singh Solanky
6a8fd9bf33 Reorder embeddings search arguments based on argument importance 2024-10-10 04:45:00 -07:00
Debanjum Singh Solanky
0eacc0b2b0 Use consistent name for user, planner to not miss current user query
Previously Khoj would start answering the previous query. This maybe
because the prompt uses User for prompt in chat history but was using
Q for current user prompt.
2024-10-10 04:45:00 -07:00
Debanjum Singh Solanky
284c8c331b Increase default max iterations for research chat director to 5 2024-10-10 04:45:00 -07:00
Debanjum Singh Solanky
1e390325d2 Let research chat director decide which webpage to read, if any
Make webpages to read automatically on search_online configurable via
a argument.

Set it to default to 1, so other callers of the function
are unaffected.

But iterative chat director can still decide which, if
any, webpages to read based on the online search it performs
2024-10-10 04:45:00 -07:00
Debanjum Singh Solanky
5a699a52d2 Improve webpage summarization prompt to better extract links, excerpts
This change allows the iterative director to dive deeper into its
research as the data extracted contains relevant links from the webpage

Previous summarization prompt didn't extract relevant links from the
webpage which limited further explorations from webpages
2024-10-10 04:45:00 -07:00
Debanjum Singh Solanky
61df1d5db8 Pass previous iteration results to code interpreter chat actors
This improves the code interpreter chat actors abilitiy to generate
code with data collected during the previous iterations
2024-10-10 04:45:00 -07:00
Debanjum Singh Solanky
9e7025b330 Set python interpret sandbox url via environment variable 2024-10-10 04:45:00 -07:00
Debanjum Singh Solanky
2dc5804571 Extract defilter query into conversation utils for reuse 2024-10-10 04:45:00 -07:00
sabaimran
e69a8382f2 Add a code icon for code-related train of thought 2024-10-09 23:56:57 -07:00
sabaimran
536422a40c Include code snippets in the reference panel 2024-10-09 23:54:11 -07:00
Debanjum Singh Solanky
8d33c764b7 Allow iterative chat director to use python interpreter as a tool 2024-10-09 23:38:20 -07:00
Debanjum Singh Solanky
b373073f47 Show executed code in web app chat message references 2024-10-09 22:13:18 -07:00
Debanjum Singh Solanky
a98f97ed5e Refactor Run Code tool into separate module and modularize code functions
Move construct_chat_history and ChatEvent enum into conversation.utils
and move send_message_to_model_wrapper to conversation.helper to
modularize code. And start thinning out the bloated routers.helper

- conversation.util components are shared functions that conversation
  child packages can use.
- conversation.helper components can't be imported by conversation
  packages but it can use these child packages

This division allows better modularity while avoiding circular
import dependencies
2024-10-09 22:13:17 -07:00
Debanjum Singh Solanky
8044733201 Give Khoj ability to run python code as a tool triggered via chat API
Create python code executing chat actor
- The chat actor generate python code within sandbox constraints
- Run the generated python code in the cohere terrarium, pyodide
  based sandbox accessible at sandbox url
2024-10-09 21:37:22 -07:00
Debanjum Singh Solanky
4d33239af6 Improve prompts for the iterative chat director 2024-10-09 21:23:18 -07:00
Debanjum Singh Solanky
6ad85e2275 Fix to continue showing retrieved documents in train of thought 2024-10-09 21:20:22 -07:00
sabaimran
a6f6e4f418 Fix notes references and passage of user query in the chat flow 2024-10-09 20:34:20 -07:00
Debanjum Singh Solanky
ec248efd31 Allow iterative chat director to do notes search 2024-10-09 19:04:59 -07:00
Debanjum Singh Solanky
a6905a9f0c Pass background context to iterating chat director 2024-10-09 19:04:59 -07:00
sabaimran
028b6e6379 Fix yield for scraping direct web page 2024-10-09 18:14:08 -07:00
sabaimran
717d9da8d8 Handle when summarize result is not present, rename variable in for loop from query 2024-10-09 17:57:08 -07:00
sabaimran
03544efde2 Ignore typing of the result dict for online, web page scrape 2024-10-09 17:48:24 -07:00
sabaimran
ab81b01fcb Fix typing of direct_web_pages and remove the deprecated chat API 2024-10-09 17:46:28 -07:00
sabaimran
5b8d663cf1 Add intermediate summarization of results when planning with o1 2024-10-09 17:40:56 -07:00
sabaimran
7b288a1179 Clean up the function planning prompt a little bit 2024-10-09 16:59:20 -07:00
sabaimran
f71e4969d3 Skip summarize while it's broken, and snip some other parts of the workflow while under construction 2024-10-09 16:40:06 -07:00
sabaimran
f7e6f99a32 add typing for extract document references 2024-10-09 16:05:34 -07:00
sabaimran
6960fb097c update types of prev iterations response 2024-10-09 16:04:39 -07:00
sabaimran
4978360852 Fix type of previous_iterations 2024-10-09 16:02:41 -07:00
sabaimran
46ef205a75 Add additional type annotations for compiled_references et al 2024-10-09 16:01:52 -07:00
sabaimran
4fbaef10e9 Correct usage of the summarize function 2024-10-09 15:58:05 -07:00
sabaimran
c91678078d Correct the usage of query passed to summarize function 2024-10-09 15:55:55 -07:00
sabaimran
f867d5ed72 Working prototype of meta-level chain of reasoning and execution
- Create a more dynamic reasoning agent that can evaluate information and understand what it doesn't know, making moves to get that information
- Lots of hacks and code that needs to be reversed later on before submission
2024-10-09 15:54:25 -07:00
124 changed files with 7538 additions and 3855 deletions

View File

@@ -14,6 +14,10 @@ services:
interval: 30s
timeout: 10s
retries: 5
sandbox:
image: ghcr.io/khoj-ai/terrarium:latest
ports:
- "8080:8080"
server:
depends_on:
database:
@@ -57,6 +61,10 @@ services:
# - KHOJ_NO_HTTPS=True
# - KHOJ_DOMAIN=192.168.0.104
# - KHOJ_DOMAIN=khoj.example.com
# Uncomment the line below to disable telemetry.
# Telemetry helps us prioritize feature development and understand how people are using Khoj
# Read more at https://docs.khoj.dev/miscellaneous/telemetry
# - KHOJ_TELEMETRY_DISABLE=True
command: --host="0.0.0.0" --port=42110 -vv --anonymous-mode --non-interactive

View File

@@ -1,6 +1,8 @@
# Admin Panel
> Describes the Khoj settings configurable via the admin panel
By default, you admin panel is available at `http://localhost:42110/server/admin/`. You can access the admin panel by logging in with your admin credentials (this would be your `KHOJ_ADMIN_EMAIL` and `KHOJ_ADMIN_PASSWORD`). The admin panel allows you to configure various settings for your Khoj server.
## App Settings
### Agents
Add all the agents you want to use for your different use-cases like Writer, Researcher, Therapist etc.

View File

@@ -31,7 +31,4 @@ Using LiteLLM with Khoj makes it possible to turn any LLM behind an API into you
- Openai Config: `<the proxy config you created in step 3>`
- Max prompt size: `20000` (replace with the max prompt size of your model)
- Tokenizer: *Do not set for OpenAI, Mistral, Llama3 based models*
5. Create a new [Server Chat Setting](http://localhost:42110/server/admin/database/serverchatsettings/add/) on your Khoj admin panel
- Default model: `<name of chat model option you created in step 4>`
- Summarizer model: `<name of chat model option you created in step 4>`
6. Go to [your config](http://localhost:42110/settings) and select the model you just created in the chat model dropdown.
5. Go to [your config](http://localhost:42110/settings) and select the model you just created in the chat model dropdown.

View File

@@ -24,7 +24,4 @@ LM Studio can expose an [OpenAI API compatible server](https://lmstudio.ai/docs/
- Openai Config: `<the proxy config you created in step 3>`
- Max prompt size: `20000` (replace with the max prompt size of your model)
- Tokenizer: *Do not set for OpenAI, mistral, llama3 based models*
5. Create a new [Server Chat Setting](http://localhost:42110/server/admin/database/serverchatsettings/add/) on your Khoj admin panel
- Default model: `<name of chat model option you created in step 4>`
- Summarizer model: `<name of chat model option you created in step 4>`
6. Go to [your config](http://localhost:42110/settings) and select the model you just created in the chat model dropdown.
5. Go to [your config](http://localhost:42110/settings) and select the model you just created in the chat model dropdown.

View File

@@ -28,9 +28,6 @@ Ollama exposes a local [OpenAI API compatible server](https://github.com/ollama/
- Model Type: `Openai`
- Openai Config: `<the ollama config you created in step 3>`
- Max prompt size: `20000` (replace with the max prompt size of your model)
5. Create a new [Server Chat Setting](http://localhost:42110/server/admin/database/serverchatsettings/add/) on your Khoj admin panel
- Default model: `<name of chat model option you created in step 4>`
- Summarizer model: `<name of chat model option you created in step 4>`
6. Go to [your config](http://localhost:42110/settings) and select the model you just created in the chat model dropdown.
5. Go to [your config](http://localhost:42110/settings) and select the model you just created in the chat model dropdown.
That's it! You should now be able to chat with your Ollama model from Khoj. If you want to add additional models running on Ollama, repeat step 6 for each model.

View File

@@ -31,7 +31,4 @@ For specific integrations, see our [Ollama](/advanced/ollama), [LMStudio](/advan
- Openai Config: `<the proxy config you created in step 2>`
- Max prompt size: `2000` (replace with the max prompt size of your model)
- Tokenizer: *Do not set for OpenAI, mistral, llama3 based models*
4. Create a new [Server Chat Setting](http://localhost:42110/server/admin/database/serverchatsettings/add/) on your Khoj admin panel
- Default model: `<name of chat model option you created in step 3>`
- Summarizer model: `<name of chat model option you created in step 3>`
5. Go to [your config](http://localhost:42110/settings) and select the model you just created in the chat model dropdown.
4. Go to [your config](http://localhost:42110/settings) and select the model you just created in the chat model dropdown.

View File

@@ -12,7 +12,7 @@ Without any desktop clients, you can start chatting with Khoj on WhatsApp. Bear
In order to use Khoj on WhatsApp with your own data, you need to setup a Khoj Cloud account and connect your WhatsApp account to it. This is a one time setup and you can do it from the [Khoj Cloud config page](https://app.khoj.dev/settings).
If you hit usage limits for the WhatsApp bot, upgrade to [a paid plan](https://khoj.dev/pricing) on Khoj Cloud.
If you hit usage limits for the WhatsApp bot, upgrade to [a paid plan](https://khoj.dev/#pricing) on Khoj Cloud.
<img src="https://khoj-web-bucket.s3.amazonaws.com/khojwhatsapp.png" alt="WhatsApp QR Code" width="300" height="300" />

View File

@@ -102,7 +102,19 @@ sudo -u postgres createdb khoj --password
</TabItem>
</Tabs>
#### 3. Run
#### 3. Build the front-end assets
```shell
cd src/interface/web/
yarn install
yarn export
```
You can optionally use `yarn dev` to start a development server for the front-end which will be available at http://localhost:3000. This is especially useful if you're making changes to the front-end code, but not necessary for running Khoj. Note that streaming does not work on the dev server due to how it is handled with SSR in Next.js.
Always run `yarn export` to test your front-end changes on http://localhost:42110 before creating a PR.
#### 4. Run
1. Start Khoj
```bash
khoj -vv

View File

@@ -43,3 +43,6 @@ Slash commands allows you to change what Khoj uses to respond to your query
- **/image**: Generate an image in response to your query.
- **/help**: Use /help to get all available commands and general information about Khoj
- **/summarize**: Can be used to summarize 1 selected file filter for that conversation. Refer to [File Summarization](summarization) for details.
- **/diagram**: Generate a diagram in response to your query. This is built on [Excalidraw](https://excalidraw.com/).
- **/code**: Generate and run very simple Python code snippets. Refer to [Code Execution](code_execution) for details.
- **/research**: Go deeper in a topic for more accurate, in-depth responses.

View File

@@ -0,0 +1,30 @@
---
---
# Code Execution
Khoj can generate and run very simple Python code snippets as well. This is useful if you want to generate a plot, run a simple calculation, or do some basic data manipulation. LLMs by default aren't skilled at complex quantitative tasks. Code generation & execution can come in handy for such tasks.
Just use `/code` in your chat command.
### Setup (Self-Hosting)
Run [Cohere's Terrarium](https://github.com/cohere-ai/cohere-terrarium) on your machine to enable code generation and execution.
Check the [instructions](https://github.com/cohere-ai/cohere-terrarium?tab=readme-ov-file#development) for running from source.
For running with Docker, you can use our [docker-compose.yml](https://github.com/khoj-ai/khoj/blob/master/docker-compose.yml), or start it manually like this:
```bash
docker pull ghcr.io/khoj-ai/terrarium:latest
docker run -d -p 8080:8080 ghcr.io/khoj-ai/terrarium:latest
```
#### Verify
Verify that it's running, by evaluating a simple Python expression:
```bash
curl -X POST -H "Content-Type: application/json" \
--url http://localhost:8080 \
--data-raw '{"code": "1 + 1"}' \
--no-buffer
```

View File

@@ -27,7 +27,28 @@ If you want to use the offline chat model and you have a GPU, you should use Ins
<Tabs groupId="operating-systems" queryString="os">
<TabItem value="macos" label="MacOS">
<h3>Prerequisites</h3>
Install [Docker Desktop](https://docs.docker.com/desktop/install/mac-install/)
<h4>Docker</h4>
(Option 1) Click here to install [Docker Desktop](https://docs.docker.com/desktop/install/mac-install/). Make sure you also install the [Docker Compose](https://docs.docker.com/desktop/install/mac-install/) tool.
(Option 2) Use [Homebrew](https://brew.sh/) to install Docker and Docker Compose.
```shell
brew install --cask docker
brew install docker-compose
```
<h3>Setup</h3>
1. Download the Khoj docker-compose.yml file [from Github](https://github.com/khoj-ai/khoj/blob/master/docker-compose.yml)
```shell
mkdir ~/.khoj && cd ~/.khoj
wget https://raw.githubusercontent.com/khoj-ai/khoj/master/docker-compose.yml
```
2. Configure the environment variables in the docker-compose.yml
- Set `KHOJ_ADMIN_PASSWORD`, `KHOJ_DJANGO_SECRET_KEY` (and optionally the `KHOJ_ADMIN_EMAIL`) to something secure. This allows you to customize Khoj later via the admin panel.
- Set `OPENAI_API_KEY`, `ANTHROPIC_API_KEY`, or `GEMINI_API_KEY` to your API key if you want to use OpenAI, Anthropic or Gemini chat models respectively.
3. Start Khoj by running the following command in the same directory as your docker-compose.yml file.
```shell
cd ~/.khoj
docker-compose up
```
</TabItem>
<TabItem value="windows" label="Windows">
<h3>Prerequisites</h3>
@@ -37,30 +58,47 @@ If you want to use the offline chat model and you have a GPU, you should use Ins
wsl --install
```
2. Install [Docker Desktop](https://docs.docker.com/desktop/install/windows-install/) with **[WSL2 backend](https://docs.docker.com/desktop/wsl/#turn-on-docker-desktop-wsl-2)** (default)
<h3>Setup</h3>
1. Download the Khoj docker-compose.yml file [from Github](https://github.com/khoj-ai/khoj/blob/master/docker-compose.yml)
```shell
# Windows users should use their WSL2 terminal to run these commands
mkdir ~/.khoj && cd ~/.khoj
wget https://raw.githubusercontent.com/khoj-ai/khoj/master/docker-compose.yml
```
2. Configure the environment variables in the docker-compose.yml
- Set `KHOJ_ADMIN_PASSWORD`, `KHOJ_DJANGO_SECRET_KEY` (and optionally the `KHOJ_ADMIN_EMAIL`) to something secure. This allows you to customize Khoj later via the admin panel.
- Set `OPENAI_API_KEY`, `ANTHROPIC_API_KEY`, or `GEMINI_API_KEY` to your API key if you want to use OpenAI, Anthropic or Gemini chat models respectively.
3. Start Khoj by running the following command in the same directory as your docker-compose.yml file.
```shell
# Windows users should use their WSL2 terminal to run these commands
cd ~/.khoj
docker-compose up
```
</TabItem>
<TabItem value="linux" label="Linux">
<h3>Prerequisites</h3>
Install [Docker Desktop](https://docs.docker.com/desktop/install/windows-install/).
Install [Docker Desktop](https://docs.docker.com/desktop/install/linux/).
You can also use your package manager to install Docker Engine & Docker Compose.
<h3>Setup</h3>
1. Download the Khoj docker-compose.yml file [from Github](https://github.com/khoj-ai/khoj/blob/master/docker-compose.yml)
```shell
mkdir ~/.khoj && cd ~/.khoj
wget https://raw.githubusercontent.com/khoj-ai/khoj/master/docker-compose.yml
```
2. Configure the environment variables in the docker-compose.yml
- Set `KHOJ_ADMIN_PASSWORD`, `KHOJ_DJANGO_SECRET_KEY` (and optionally the `KHOJ_ADMIN_EMAIL`) to something secure. This allows you to customize Khoj later via the admin panel.
- Set `OPENAI_API_KEY`, `ANTHROPIC_API_KEY`, or `GEMINI_API_KEY` to your API key if you want to use OpenAI, Anthropic or Gemini chat models respectively.
3. Start Khoj by running the following command in the same directory as your docker-compose.yml file.
```shell
cd ~/.khoj
docker-compose up
```
</TabItem>
</Tabs>
<h3>Setup</h3>
1. Download the Khoj docker-compose.yml file [from Github](https://github.com/khoj-ai/khoj/blob/master/docker-compose.yml)
```shell
# Windows users should use their WSL2 terminal to run these commands
mkdir ~/.khoj && cd ~/.khoj
wget https://raw.githubusercontent.com/khoj-ai/khoj/master/docker-compose.yml
```
2. Configure the environment variables in the docker-compose.yml
- Set `KHOJ_ADMIN_PASSWORD`, `KHOJ_DJANGO_SECRET_KEY` (and optionally the `KHOJ_ADMIN_EMAIL`) to something secure. This allows you to customize Khoj later via the admin panel.
- Set `OPENAI_API_KEY`, `ANTHROPIC_API_KEY`, or `GEMINI_API_KEY` to your API key if you want to use OpenAI, Anthropic or Gemini chat models respectively.
3. Start Khoj by running the following command in the same directory as your docker-compose.yml file.
```shell
# Windows users should use their WSL2 terminal to run these commands
cd ~/.khoj
docker-compose up
```
:::info[Remote Access]
By default Khoj is only accessible on the machine it is running. To access Khoj from a remote machine see [Remote Access Docs](/advanced/remote).

View File

@@ -14,13 +14,6 @@ We don't send any personal information or any information from/about your conten
## Disable Telemetry
If you're self-hosting Khoj, you can opt out of telemetry at any time. To do so,
1. Open `~/.khoj/khoj.yml`
2. Add the following configuration:
```
app:
should-log-telemetry: false
```
3. Save the file and restart Khoj
If you're self-hosting Khoj, you can opt out of telemetry at any time by setting the `KHOJ_TELEMETRY_DISABLE` environment variable to `True` via shell or `docker-compose.yml`
If you have any questions or concerns, please reach out to us on [Discord](https://discord.gg/BDgyabRM6e).

View File

@@ -1,7 +1,7 @@
{
"id": "khoj",
"name": "Khoj",
"version": "1.26.2",
"version": "1.29.1",
"minAppVersion": "0.15.0",
"description": "Your Second Brain",
"author": "Khoj Inc.",

View File

@@ -62,7 +62,7 @@ dependencies = [
"requests >= 2.26.0",
"tenacity == 8.3.0",
"anyio == 3.7.1",
"pymupdf >= 1.23.5",
"pymupdf == 1.24.11",
"django == 5.0.9",
"authlib == 1.2.1",
"llama-cpp-python == 0.2.88",
@@ -87,7 +87,7 @@ dependencies = [
"django_apscheduler == 0.6.2",
"anthropic == 0.26.1",
"docx2txt == 0.8",
"google-generativeai == 0.7.2"
"google-generativeai == 0.8.3",
]
dynamic = ["version"]
@@ -119,6 +119,9 @@ dev = [
"mypy >= 1.0.1",
"black >= 23.1.0",
"pre-commit >= 3.0.4",
"gitpython ~= 3.1.43",
"datasets",
"pandas",
]
[tool.hatch.version]

View File

@@ -326,7 +326,7 @@
entries.forEach(entry => {
// If the element is in the viewport, fetch the remaining message and unobserve the element
if (entry.isIntersecting) {
fetchRemainingChatMessages(chatHistoryUrl, headers);
fetchRemainingChatMessages(chatHistoryUrl, headers, chatBody.dataset.conversation_id, hostURL);
observer.unobserve(entry.target);
}
});
@@ -342,7 +342,11 @@
new Date(chat_log.created),
chat_log.onlineContext,
chat_log.intent?.type,
chat_log.intent?.["inferred-queries"]);
chat_log.intent?.["inferred-queries"],
chatBody.dataset.conversationId ?? "",
hostURL,
);
chatBody.appendChild(messageElement);
// When the 4th oldest message is within viewing distance (~60% scrolled up)
@@ -421,7 +425,7 @@
}
}
function fetchRemainingChatMessages(chatHistoryUrl, headers) {
function fetchRemainingChatMessages(chatHistoryUrl, headers, conversationId, hostURL) {
// Create a new IntersectionObserver
let observer = new IntersectionObserver((entries, observer) => {
entries.forEach(entry => {
@@ -435,7 +439,9 @@
new Date(chat_log.created),
chat_log.onlineContext,
chat_log.intent?.type,
chat_log.intent?.["inferred-queries"]
chat_log.intent?.["inferred-queries"],
chatBody.dataset.conversationId ?? "",
hostURL,
);
entry.target.replaceWith(messageElement);

View File

@@ -189,11 +189,19 @@ function processOnlineReferences(referenceSection, onlineContext) { //same
return numOnlineReferences;
}
function renderMessageWithReference(message, by, context=null, dt=null, onlineContext=null, intentType=null, inferredQueries=null) { //same
function renderMessageWithReference(message, by, context=null, dt=null, onlineContext=null, intentType=null, inferredQueries=null, conversationId=null, hostURL=null) {
let chatEl;
if (intentType?.includes("text-to-image")) {
let imageMarkdown = generateImageMarkdown(message, intentType, inferredQueries);
chatEl = renderMessage(imageMarkdown, by, dt, null, false, "return");
} else if (intentType === "excalidraw") {
let domain = hostURL ?? "https://app.khoj.dev/";
if (!domain.endsWith("/")) domain += "/";
let excalidrawMessage = `Hey, I'm not ready to show you diagrams yet here. But you can view it in the web app at ${domain}chat?conversationId=${conversationId}`;
chatEl = renderMessage(excalidrawMessage, by, dt, null, false, "return");
} else {
chatEl = renderMessage(message, by, dt, null, false, "return");
}
@@ -312,7 +320,6 @@ function formatHTMLMessage(message, raw=false, willReplace=true) { //same
}
function createReferenceSection(references, createLinkerSection=false) {
console.log("linker data: ", createLinkerSection);
let referenceSection = document.createElement('div');
referenceSection.classList.add("reference-section");
referenceSection.classList.add("collapsed");
@@ -417,7 +424,11 @@ function handleImageResponse(imageJson, rawResponse) {
rawResponse += `![generated_image](${imageJson.image})`;
} else if (imageJson.intentType === "text-to-image-v3") {
rawResponse = `![](data:image/webp;base64,${imageJson.image})`;
} else if (imageJson.intentType === "excalidraw") {
const redirectMessage = `Hey, I'm not ready to show you diagrams yet here. But you can view it in the web app`;
rawResponse += redirectMessage;
}
if (inferredQuery) {
rawResponse += `\n\n**Inferred Query**:\n\n${inferredQuery}`;
}

View File

@@ -1,6 +1,6 @@
{
"name": "Khoj",
"version": "1.26.2",
"version": "1.29.1",
"description": "Your Second Brain",
"author": "Khoj Inc. <team@khoj.dev>",
"license": "GPL-3.0-or-later",

View File

@@ -6,7 +6,7 @@
;; Saba Imran <saba@khoj.dev>
;; Description: Your Second Brain
;; Keywords: search, chat, ai, org-mode, outlines, markdown, pdf, image
;; Version: 1.26.2
;; Version: 1.29.1
;; Package-Requires: ((emacs "27.1") (transient "0.3.0") (dash "2.19.1"))
;; URL: https://github.com/khoj-ai/khoj/tree/master/src/interface/emacs

View File

@@ -1,7 +1,7 @@
{
"id": "khoj",
"name": "Khoj",
"version": "1.26.2",
"version": "1.29.1",
"minAppVersion": "0.15.0",
"description": "Your Second Brain",
"author": "Khoj Inc.",

View File

@@ -1,6 +1,6 @@
{
"name": "Khoj",
"version": "1.26.2",
"version": "1.29.1",
"description": "Your Second Brain",
"author": "Debanjum Singh Solanky, Saba Imran <team@khoj.dev>",
"license": "GPL-3.0-or-later",

View File

@@ -484,12 +484,13 @@ export class KhojChatView extends KhojPaneView {
dt?: Date,
intentType?: string,
inferredQueries?: string[],
conversationId?: string,
) {
if (!message) return;
let chatMessageEl;
if (intentType?.includes("text-to-image")) {
let imageMarkdown = this.generateImageMarkdown(message, intentType, inferredQueries);
if (intentType?.includes("text-to-image") || intentType === "excalidraw") {
let imageMarkdown = this.generateImageMarkdown(message, intentType, inferredQueries, conversationId);
chatMessageEl = this.renderMessage(chatEl, imageMarkdown, sender, dt);
} else {
chatMessageEl = this.renderMessage(chatEl, message, sender, dt);
@@ -509,7 +510,7 @@ export class KhojChatView extends KhojPaneView {
chatMessageBodyEl.appendChild(this.createReferenceSection(references));
}
generateImageMarkdown(message: string, intentType: string, inferredQueries?: string[]) {
generateImageMarkdown(message: string, intentType: string, inferredQueries?: string[], conversationId?: string): string {
let imageMarkdown = "";
if (intentType === "text-to-image") {
imageMarkdown = `![](data:image/png;base64,${message})`;
@@ -517,6 +518,10 @@ export class KhojChatView extends KhojPaneView {
imageMarkdown = `![](${message})`;
} else if (intentType === "text-to-image-v3") {
imageMarkdown = `![](data:image/webp;base64,${message})`;
} else if (intentType === "excalidraw") {
const domain = this.setting.khojUrl.endsWith("/") ? this.setting.khojUrl : `${this.setting.khojUrl}/`;
const redirectMessage = `Hey, I'm not ready to show you diagrams yet here. But you can view it in ${domain}chat?conversationId=${conversationId}`;
imageMarkdown = redirectMessage;
}
if (inferredQueries) {
imageMarkdown += "\n\n**Inferred Query**:";
@@ -884,6 +889,7 @@ export class KhojChatView extends KhojPaneView {
new Date(chatLog.created),
chatLog.intent?.type,
chatLog.intent?.["inferred-queries"],
chatBodyEl.dataset.conversationId ?? "",
);
// push the user messages to the chat history
if(chatLog.by === "you"){
@@ -1354,6 +1360,10 @@ export class KhojChatView extends KhojPaneView {
rawResponse += `![generated_image](${imageJson.image})`;
} else if (imageJson.intentType === "text-to-image-v3") {
rawResponse = `![](data:image/webp;base64,${imageJson.image})`;
} else if (imageJson.intentType === "excalidraw") {
const domain = this.setting.khojUrl.endsWith("/") ? this.setting.khojUrl : `${this.setting.khojUrl}/`;
const redirectMessage = `Hey, I'm not ready to show you diagrams yet here. But you can view it in ${domain}`;
rawResponse += redirectMessage;
}
if (inferredQuery) {
rawResponse += `\n\n**Inferred Query**:\n\n${inferredQuery}`;

View File

@@ -78,6 +78,7 @@ If your plugin does not need CSS, delete this file.
user-select: text;
color: var(--text-normal);
background-color: var(--active-bg);
word-break: break-word;
}
/* color chat bubble by khoj blue */
.khoj-chat-message-text.khoj {

View File

@@ -80,5 +80,15 @@
"1.25.0": "0.15.0",
"1.26.0": "0.15.0",
"1.26.1": "0.15.0",
"1.26.2": "0.15.0"
"1.26.2": "0.15.0",
"1.26.3": "0.15.0",
"1.26.4": "0.15.0",
"1.27.0": "0.15.0",
"1.27.1": "0.15.0",
"1.28.0": "0.15.0",
"1.28.1": "0.15.0",
"1.28.2": "0.15.0",
"1.28.3": "0.15.0",
"1.29.0": "0.15.0",
"1.29.1": "0.15.0"
}

File diff suppressed because it is too large Load Diff

View File

@@ -79,7 +79,7 @@ div.titleBar {
div.chatBoxBody {
display: grid;
height: 100%;
width: 70%;
width: 95%;
margin: auto;
}

View File

@@ -47,7 +47,14 @@ export default function RootLayout({
child-src 'none';
object-src 'none';"
></meta>
<body className={inter.className}>{children}</body>
<body className={inter.className}>
{children}
<script
dangerouslySetInnerHTML={{
__html: `window.EXCALIDRAW_ASSET_PATH = 'https://assets.khoj.dev/@excalidraw/excalidraw/dist/';`,
}}
/>
</body>
</html>
);
}

View File

@@ -1,24 +1,31 @@
"use client";
import styles from "./chat.module.css";
import React, { Suspense, useEffect, useState } from "react";
import React, { Suspense, useEffect, useRef, useState } from "react";
import SidePanel, { ChatSessionActionMenu } from "../components/sidePanel/chatHistorySidePanel";
import ChatHistory from "../components/chatHistory/chatHistory";
import { useSearchParams } from "next/navigation";
import Loading from "../components/loading/loading";
import { processMessageChunk } from "../common/chatFunctions";
import { generateNewTitle, processMessageChunk } from "../common/chatFunctions";
import "katex/dist/katex.min.css";
import { Context, OnlineContext, StreamMessage } from "../components/chatMessage/chatMessage";
import {
CodeContext,
Context,
OnlineContext,
StreamMessage,
} from "../components/chatMessage/chatMessage";
import { useIPLocationData, useIsMobileWidth, welcomeConsole } from "../common/utils";
import ChatInputArea, { ChatOptions } from "../components/chatInputArea/chatInputArea";
import {
AttachedFileText,
ChatInputArea,
ChatOptions,
} from "../components/chatInputArea/chatInputArea";
import { useAuthenticatedData } from "../common/auth";
import { AgentData } from "../agents/page";
import { DotsThreeVertical } from "@phosphor-icons/react";
import { Button } from "@/components/ui/button";
interface ChatBodyDataProps {
chatOptionsData: ChatOptions | null;
@@ -26,43 +33,72 @@ interface ChatBodyDataProps {
onConversationIdChange?: (conversationId: string) => void;
setQueryToProcess: (query: string) => void;
streamedMessages: StreamMessage[];
setUploadedFiles: (files: string[]) => void;
setStreamedMessages: (messages: StreamMessage[]) => void;
setUploadedFiles: (files: AttachedFileText[] | undefined) => void;
isMobileWidth?: boolean;
isLoggedIn: boolean;
setImage64: (image64: string) => void;
setImages: (images: string[]) => void;
}
function ChatBodyData(props: ChatBodyDataProps) {
const searchParams = useSearchParams();
const conversationId = searchParams.get("conversationId");
const [message, setMessage] = useState("");
const [image, setImage] = useState<string | null>(null);
const [images, setImages] = useState<string[]>([]);
const [processingMessage, setProcessingMessage] = useState(false);
const [agentMetadata, setAgentMetadata] = useState<AgentData | null>(null);
const [isInResearchMode, setIsInResearchMode] = useState(false);
const chatInputRef = useRef<HTMLTextAreaElement>(null);
const setQueryToProcess = props.setQueryToProcess;
const onConversationIdChange = props.onConversationIdChange;
useEffect(() => {
if (image) {
props.setImage64(encodeURIComponent(image));
}
}, [image, props.setImage64]);
const chatHistoryCustomClassName = props.isMobileWidth ? "w-full" : "w-4/6";
useEffect(() => {
const storedImage = localStorage.getItem("image");
if (storedImage) {
setImage(storedImage);
props.setImage64(encodeURIComponent(storedImage));
localStorage.removeItem("image");
if (images.length > 0) {
const encodedImages = images.map((image) => encodeURIComponent(image));
props.setImages(encodedImages);
}
}, [images, props.setImages]);
useEffect(() => {
const storedImages = localStorage.getItem("images");
if (storedImages) {
const parsedImages: string[] = JSON.parse(storedImages);
setImages(parsedImages);
const encodedImages = parsedImages.map((img: string) => encodeURIComponent(img));
props.setImages(encodedImages);
localStorage.removeItem("images");
}
const storedMessage = localStorage.getItem("message");
if (storedMessage) {
setProcessingMessage(true);
setQueryToProcess(storedMessage);
if (storedMessage.trim().startsWith("/research")) {
setIsInResearchMode(true);
}
}
}, [setQueryToProcess]);
const storedUploadedFiles = localStorage.getItem("uploadedFiles");
if (storedUploadedFiles) {
const parsedFiles = storedUploadedFiles ? JSON.parse(storedUploadedFiles) : [];
const uploadedFiles: AttachedFileText[] = [];
for (const file of parsedFiles) {
uploadedFiles.push({
name: file.name,
file_type: file.file_type,
content: file.content,
size: file.size,
});
}
localStorage.removeItem("uploadedFiles");
props.setUploadedFiles(uploadedFiles);
}
}, [setQueryToProcess, props.setImages, conversationId]);
useEffect(() => {
if (message) {
@@ -84,6 +120,8 @@ function ChatBodyData(props: ChatBodyDataProps) {
props.streamedMessages[props.streamedMessages.length - 1].completed
) {
setProcessingMessage(false);
setImages([]); // Reset images after processing
props.setUploadedFiles(undefined); // Reset uploaded files after processing
} else {
setMessage("");
}
@@ -103,21 +141,25 @@ function ChatBodyData(props: ChatBodyDataProps) {
setAgent={setAgentMetadata}
pendingMessage={processingMessage ? message : ""}
incomingMessages={props.streamedMessages}
setIncomingMessages={props.setStreamedMessages}
customClassName={chatHistoryCustomClassName}
/>
</div>
<div
className={`${styles.inputBox} p-1 md:px-2 shadow-md bg-background align-middle items-center justify-center dark:bg-neutral-700 dark:border-0 dark:shadow-sm rounded-t-2xl rounded-b-none md:rounded-xl h-fit`}
className={`${styles.inputBox} p-1 md:px-2 shadow-md bg-background align-middle items-center justify-center dark:bg-neutral-700 dark:border-0 dark:shadow-sm rounded-t-2xl rounded-b-none md:rounded-xl h-fit ${chatHistoryCustomClassName} mr-auto ml-auto`}
>
<ChatInputArea
agentColor={agentMetadata?.color}
isLoggedIn={props.isLoggedIn}
sendMessage={(message) => setMessage(message)}
sendImage={(image) => setImage(image)}
sendImage={(image) => setImages((prevImages) => [...prevImages, image])}
sendDisabled={processingMessage}
chatOptionsData={props.chatOptionsData}
conversationId={conversationId}
isMobileWidth={props.isMobileWidth}
setUploadedFiles={props.setUploadedFiles}
ref={chatInputRef}
isResearchModeEnabled={isInResearchMode}
/>
</div>
</>
@@ -133,8 +175,8 @@ export default function Chat() {
const [messages, setMessages] = useState<StreamMessage[]>([]);
const [queryToProcess, setQueryToProcess] = useState<string>("");
const [processQuerySignal, setProcessQuerySignal] = useState(false);
const [uploadedFiles, setUploadedFiles] = useState<string[]>([]);
const [image64, setImage64] = useState<string>("");
const [uploadedFiles, setUploadedFiles] = useState<AttachedFileText[] | undefined>(undefined);
const [images, setImages] = useState<string[]>([]);
const locationData = useIPLocationData() || {
timezone: Intl.DateTimeFormat().resolvedOptions().timeZone,
@@ -167,10 +209,12 @@ export default function Chat() {
trainOfThought: [],
context: [],
onlineContext: {},
codeContext: {},
completed: false,
timestamp: new Date().toISOString(),
rawQuery: queryToProcess || "",
uploadedImageData: decodeURIComponent(image64),
images: images,
queryFiles: uploadedFiles,
};
setMessages((prevMessages) => [...prevMessages, newStreamMessage]);
setProcessQuerySignal(true);
@@ -195,13 +239,17 @@ export default function Chat() {
// Track context used for chat response
let context: Context[] = [];
let onlineContext: OnlineContext = {};
let codeContext: CodeContext = {};
while (true) {
const { done, value } = await reader.read();
if (done) {
setQueryToProcess("");
setProcessQuerySignal(false);
setImage64("");
setImages([]);
if (conversationId) generateNewTitle(conversationId, setTitle);
break;
}
@@ -221,11 +269,12 @@ export default function Chat() {
}
// Track context used for chat response. References are rendered at the end of the chat
({ context, onlineContext } = processMessageChunk(
({ context, onlineContext, codeContext } = processMessageChunk(
event,
currentMessage,
context,
onlineContext,
codeContext,
));
setMessages([...messages]);
@@ -249,7 +298,8 @@ export default function Chat() {
country_code: locationData.countryCode,
timezone: locationData.timezone,
}),
...(image64 && { image: image64 }),
...(images.length > 0 && { images: images }),
...(uploadedFiles && { files: uploadedFiles }),
};
const response = await fetch(chatAPI, {
@@ -263,7 +313,8 @@ export default function Chat() {
try {
await readChatStream(response);
} catch (err) {
console.error(err);
const apiError = await response.json();
console.error(apiError);
// Retrieve latest message being processed
const currentMessage = messages.find((message) => !message.completed);
if (!currentMessage) return;
@@ -272,7 +323,11 @@ export default function Chat() {
const errorMessage = (err as Error).message;
if (errorMessage.includes("Error in input stream"))
currentMessage.rawResponse = `Woops! The connection broke while I was writing my thoughts down. Maybe try again in a bit or dislike this message if the issue persists?`;
else
else if (response.status === 429) {
"detail" in apiError
? (currentMessage.rawResponse = `${apiError.detail}`)
: (currentMessage.rawResponse = `I'm a bit overwhelmed at the moment. Could you try again in a bit or dislike this message if the issue persists?`);
} else
currentMessage.rawResponse = `Umm, not sure what just happened. I see this error message: ${errorMessage}. Could you try again or dislike this message if the issue persists?`;
// Complete message streaming teardown properly
@@ -297,7 +352,7 @@ export default function Chat() {
<div>
<SidePanel
conversationId={conversationId}
uploadedFiles={uploadedFiles}
uploadedFiles={[]}
isMobileWidth={isMobileWidth}
/>
</div>
@@ -325,13 +380,14 @@ export default function Chat() {
<ChatBodyData
isLoggedIn={authenticatedData !== null}
streamedMessages={messages}
setStreamedMessages={setMessages}
chatOptionsData={chatOptionsData}
setTitle={setTitle}
setQueryToProcess={setQueryToProcess}
setUploadedFiles={setUploadedFiles}
isMobileWidth={isMobileWidth}
onConversationIdChange={handleConversationIdChange}
setImage64={setImage64}
setImages={setImages}
/>
</Suspense>
</div>

View File

@@ -68,7 +68,8 @@ export interface UserConfig {
selected_voice_model_config: number;
// user billing info
subscription_state: SubscriptionStates;
subscription_renewal_date: string;
subscription_renewal_date: string | undefined;
subscription_enabled_trial_at: string | undefined;
// server settings
khoj_cloud_subscription_url: string | undefined;
billing_enabled: boolean;
@@ -78,6 +79,7 @@ export interface UserConfig {
anonymous_mode: boolean;
notion_oauth_url: string;
detail: string;
length_of_free_trial: number;
}
export function useUserConfig(detailed: boolean = false) {
@@ -93,3 +95,15 @@ export function useUserConfig(detailed: boolean = false) {
return { userConfig, isLoadingUserConfig };
}
export function isUserSubscribed(userConfig: UserConfig | null): boolean {
return (
(userConfig?.subscription_state &&
[
SubscriptionStates.SUBSCRIBED.valueOf(),
SubscriptionStates.TRIAL.valueOf(),
SubscriptionStates.UNSUBSCRIBED.valueOf(),
].includes(userConfig.subscription_state)) ||
false
);
}

View File

@@ -1,14 +1,25 @@
import { Context, OnlineContext, StreamMessage } from "../components/chatMessage/chatMessage";
import {
CodeContext,
Context,
OnlineContext,
StreamMessage,
} from "../components/chatMessage/chatMessage";
export interface RawReferenceData {
context?: Context[];
onlineContext?: OnlineContext;
codeContext?: CodeContext;
}
export interface ResponseWithReferences {
context?: Context[];
online?: OnlineContext;
response?: string;
export interface MessageMetadata {
conversationId: string;
turnId: string;
}
export interface ResponseWithIntent {
intentType: string;
response: string;
inferredQueries?: string[];
}
interface MessageChunk {
@@ -49,10 +60,14 @@ export function convertMessageChunkToJson(chunk: string): MessageChunk {
function handleJsonResponse(chunkData: any) {
const jsonData = chunkData as any;
if (jsonData.image || jsonData.detail) {
let responseWithReference = handleImageResponse(chunkData, true);
if (responseWithReference.response) return responseWithReference.response;
let responseWithIntent = handleImageResponse(chunkData, true);
return responseWithIntent;
} else if (jsonData.response) {
return jsonData.response;
return {
response: jsonData.response,
intentType: "",
inferredQueries: [],
};
} else {
throw new Error("Invalid JSON response");
}
@@ -63,10 +78,11 @@ export function processMessageChunk(
currentMessage: StreamMessage,
context: Context[] = [],
onlineContext: OnlineContext = {},
): { context: Context[]; onlineContext: OnlineContext } {
codeContext: CodeContext = {},
): { context: Context[]; onlineContext: OnlineContext; codeContext: CodeContext } {
const chunk = convertMessageChunkToJson(rawChunk);
if (!currentMessage || !chunk || !chunk.type) return { context, onlineContext };
if (!currentMessage || !chunk || !chunk.type) return { context, onlineContext, codeContext };
if (chunk.type === "status") {
console.log(`status: ${chunk.data}`);
@@ -77,11 +93,25 @@ export function processMessageChunk(
if (references.context) context = references.context;
if (references.onlineContext) onlineContext = references.onlineContext;
return { context, onlineContext };
if (references.codeContext) codeContext = references.codeContext;
return { context, onlineContext, codeContext };
} else if (chunk.type === "metadata") {
const messageMetadata = chunk.data as MessageMetadata;
currentMessage.turnId = messageMetadata.turnId;
} else if (chunk.type === "message") {
const chunkData = chunk.data;
// Here, handle if the response is a JSON response with an image, but the intentType is excalidraw
if (chunkData !== null && typeof chunkData === "object") {
currentMessage.rawResponse += handleJsonResponse(chunkData);
let responseWithIntent = handleJsonResponse(chunkData);
if (responseWithIntent.intentType && responseWithIntent.intentType === "excalidraw") {
currentMessage.rawResponse = responseWithIntent.response;
} else {
currentMessage.rawResponse += responseWithIntent.response;
}
currentMessage.intentType = responseWithIntent.intentType;
currentMessage.inferredQueries = responseWithIntent.inferredQueries;
} else if (
typeof chunkData === "string" &&
chunkData.trim()?.startsWith("{") &&
@@ -89,7 +119,10 @@ export function processMessageChunk(
) {
try {
const jsonData = JSON.parse(chunkData.trim());
currentMessage.rawResponse += handleJsonResponse(jsonData);
let responseWithIntent = handleJsonResponse(jsonData);
currentMessage.rawResponse += responseWithIntent.response;
currentMessage.intentType = responseWithIntent.intentType;
currentMessage.inferredQueries = responseWithIntent.inferredQueries;
} catch (e) {
currentMessage.rawResponse += JSON.stringify(chunkData);
}
@@ -102,51 +135,56 @@ export function processMessageChunk(
console.log(`Completed streaming: ${new Date()}`);
// Append any references after all the data has been streamed
if (codeContext) currentMessage.codeContext = codeContext;
if (onlineContext) currentMessage.onlineContext = onlineContext;
if (context) currentMessage.context = context;
// Mark current message streaming as completed
currentMessage.completed = true;
}
return { context, onlineContext };
return { context, onlineContext, codeContext };
}
export function handleImageResponse(imageJson: any, liveStream: boolean): ResponseWithReferences {
export function handleImageResponse(imageJson: any, liveStream: boolean): ResponseWithIntent {
let rawResponse = "";
if (imageJson.image) {
const inferredQuery = imageJson.inferredQueries?.[0] ?? "generated image";
// If response has image field, response is a generated image.
if (imageJson.intentType === "text-to-image") {
rawResponse += `![generated_image](data:image/png;base64,${imageJson.image})`;
} else if (imageJson.intentType === "text-to-image2") {
rawResponse += `![generated_image](${imageJson.image})`;
} else if (imageJson.intentType === "text-to-image-v3") {
rawResponse = `![](data:image/webp;base64,${imageJson.image})`;
}
if (inferredQuery && !liveStream) {
rawResponse += `\n\n${inferredQuery}`;
}
// If response has image field, response may be a generated image
rawResponse = imageJson.image;
}
let reference: ResponseWithReferences = {};
let responseWithIntent: ResponseWithIntent = {
intentType: imageJson.intentType,
response: rawResponse,
inferredQueries: imageJson.inferredQueries,
};
if (imageJson.context && imageJson.context.length > 0) {
const rawReferenceAsJson = imageJson.context;
if (rawReferenceAsJson instanceof Array) {
reference.context = rawReferenceAsJson;
} else if (typeof rawReferenceAsJson === "object" && rawReferenceAsJson !== null) {
reference.online = rawReferenceAsJson;
}
}
if (imageJson.detail) {
// The detail field contains the improved image prompt
rawResponse += imageJson.detail;
}
reference.response = rawResponse;
return reference;
return responseWithIntent;
}
export function renderCodeGenImageInline(message: string, codeContext: CodeContext) {
if (!codeContext) return message;
Object.values(codeContext).forEach((contextData) => {
contextData.results.output_files?.forEach((file) => {
const regex = new RegExp(`!?\\[.*?\\]\\(.*${file.filename}\\)`, "g");
if (file.filename.match(/\.(png|jpg|jpeg)$/i)) {
const replacement = `![${file.filename}](data:image/${file.filename.split(".").pop()};base64,${file.b64_data})`;
message = message.replace(regex, replacement);
} else if (file.filename.match(/\.(txt|org|md|csv|json)$/i)) {
// render output files generated by codegen as downloadable links
const replacement = `![${file.filename}](data:text/plain;base64,${file.b64_data})`;
message = message.replace(regex, replacement);
}
});
});
return message;
}
export function modifyFileFilterForConversation(
@@ -206,6 +244,78 @@ export async function createNewConversation(slug: string) {
}
}
export async function packageFilesForUpload(files: FileList): Promise<FormData> {
const formData = new FormData();
const fileReadPromises = Array.from(files).map((file) => {
return new Promise<void>((resolve, reject) => {
let reader = new FileReader();
reader.onload = function (event) {
if (event.target === null) {
reject();
return;
}
let fileContents = event.target.result;
let fileType = file.type;
let fileName = file.name;
if (fileType === "") {
let fileExtension = fileName.split(".").pop();
if (fileExtension === "org") {
fileType = "text/org";
} else if (fileExtension === "md") {
fileType = "text/markdown";
} else if (fileExtension === "txt") {
fileType = "text/plain";
} else if (fileExtension === "html") {
fileType = "text/html";
} else if (fileExtension === "pdf") {
fileType = "application/pdf";
} else if (fileExtension === "docx") {
fileType =
"application/vnd.openxmlformats-officedocument.wordprocessingml.document";
} else {
// Skip this file if its type is not supported
resolve();
return;
}
}
if (fileContents === null) {
reject();
return;
}
let fileObj = new Blob([fileContents], { type: fileType });
formData.append("files", fileObj, file.name);
resolve();
};
reader.onerror = reject;
reader.readAsArrayBuffer(file);
});
});
await Promise.all(fileReadPromises);
return formData;
}
export function generateNewTitle(conversationId: string, setTitle: (title: string) => void) {
fetch(`/api/chat/title?conversation_id=${conversationId}`, {
method: "POST",
})
.then((res) => {
if (!res.ok) throw new Error(`Failed to call API with error ${res.statusText}`);
return res.json();
})
.then((data) => {
setTitle(data.title);
})
.catch((err) => {
console.error(err);
return;
});
}
export function uploadDataForIndexing(
files: FileList,
setWarning: (warning: string) => void,

View File

@@ -42,6 +42,20 @@ export function converColorToBgGradient(color: string) {
return `${convertToBGGradientClass(color)} dark:border dark:border-neutral-700`;
}
export function convertColorToCaretClass(color: string | undefined) {
if (color && tailwindColors.includes(color)) {
return `caret-${color}-500`;
}
return `caret-orange-500`;
}
export function convertColorToRingClass(color: string | undefined) {
if (color && tailwindColors.includes(color)) {
return `focus-visible:ring-${color}-500`;
}
return `focus-visible:ring-orange-500`;
}
export function convertColorToBorderClass(color: string) {
if (tailwindColors.includes(color)) {
return `border-${color}-500`;

View File

@@ -40,7 +40,6 @@ import {
Leaf,
NewspaperClipping,
OrangeSlice,
Rainbow,
SmileyMelting,
YinYang,
SneakerMove,
@@ -48,8 +47,12 @@ import {
Oven,
Gavel,
Broadcast,
KeyReturn,
FilePdf,
FileMd,
MicrosoftWordLogo,
} from "@phosphor-icons/react";
import { Markdown, OrgMode, Pdf, Word } from "@/app/components/logo/fileLogo";
import { OrgMode } from "@/app/components/logo/fileLogo";
interface IconMap {
[key: string]: (color: string, width: string, height: string) => JSX.Element | null;
@@ -193,6 +196,10 @@ export function getIconForSlashCommand(command: string, customClassName: string
}
if (command.includes("default")) {
return <KeyReturn className={className} />;
}
if (command.includes("diagram")) {
return <Shapes className={className} />;
}
@@ -208,6 +215,10 @@ export function getIconForSlashCommand(command: string, customClassName: string
return <PencilLine className={className} />;
}
if (command.includes("code")) {
return <Code className={className} />;
}
return <ArrowRight className={className} />;
}
@@ -233,11 +244,19 @@ function getIconFromFilename(
return <OrgMode className={className} />;
case "markdown":
case "md":
return <Markdown className={className} />;
return <FileMd className={className} />;
case "pdf":
return <Pdf className={className} />;
return <FilePdf className={className} />;
case "doc":
return <Word className={className} />;
case "docx":
return <MicrosoftWordLogo className={className} />;
case "csv":
case "json":
return <MathOperations className={className} />;
case "txt":
return <Notebook className={className} />;
case "py":
return <Code className={className} />;
case "jpg":
case "jpeg":
case "png":

View File

@@ -70,3 +70,29 @@ export function useIsMobileWidth() {
return isMobileWidth;
}
export const convertBytesToText = (fileSize: number) => {
if (fileSize < 1024) {
return `${fileSize} B`;
} else if (fileSize < 1024 * 1024) {
return `${(fileSize / 1024).toFixed(2)} KB`;
} else {
return `${(fileSize / (1024 * 1024)).toFixed(2)} MB`;
}
};
export function useDebounce<T>(value: T, delay: number): T {
const [debouncedValue, setDebouncedValue] = useState<T>(value);
useEffect(() => {
const handler = setTimeout(() => {
setDebouncedValue(value);
}, delay);
return () => {
clearTimeout(handler);
};
}, [value, delay]);
return debouncedValue;
}

View File

@@ -0,0 +1,20 @@
.agentPersonality p {
white-space: inherit;
overflow: hidden;
height: 77px;
line-height: 1.5;
}
div.agentPersonality {
text-align: left;
grid-column: span 3;
overflow: hidden;
}
button.infoButton {
border: none;
background-color: transparent !important;
text-align: left;
font-family: inherit;
font-size: medium;
}

File diff suppressed because it is too large Load Diff

View File

@@ -2,12 +2,7 @@ div.chatHistory {
display: flex;
flex-direction: column;
height: 100%;
}
div.chatLayout {
height: 80vh;
overflow-y: auto;
margin: 0 auto;
margin: auto;
}
div.agentIndicator a {

View File

@@ -13,13 +13,14 @@ import { ScrollArea } from "@/components/ui/scroll-area";
import { InlineLoading } from "../loading/loading";
import { Lightbulb, ArrowDown } from "@phosphor-icons/react";
import { Lightbulb, ArrowDown, XCircle } from "@phosphor-icons/react";
import AgentProfileCard from "../profileCard/profileCard";
import { getIconFromIconName } from "@/app/common/iconUtils";
import { AgentData } from "@/app/agents/page";
import React from "react";
import { useIsMobileWidth } from "@/app/common/utils";
import { Button } from "@/components/ui/button";
interface ChatResponse {
status: string;
@@ -33,32 +34,62 @@ interface ChatHistory {
interface ChatHistoryProps {
conversationId: string;
setTitle: (title: string) => void;
incomingMessages?: StreamMessage[];
pendingMessage?: string;
incomingMessages?: StreamMessage[];
setIncomingMessages?: (incomingMessages: StreamMessage[]) => void;
publicConversationSlug?: string;
setAgent: (agent: AgentData) => void;
customClassName?: string;
}
function constructTrainOfThought(
trainOfThought: string[],
lastMessage: boolean,
agentColor: string,
key: string,
completed: boolean = false,
) {
const lastIndex = trainOfThought.length - 1;
return (
<div className={`${styles.trainOfThought} shadow-sm`} key={key}>
{!completed && <InlineLoading className="float-right" />}
interface TrainOfThoughtComponentProps {
trainOfThought: string[];
lastMessage: boolean;
agentColor: string;
keyId: string;
completed?: boolean;
}
{trainOfThought.map((train, index) => (
<TrainOfThought
key={`train-${index}`}
message={train}
primary={index === lastIndex && lastMessage && !completed}
agentColor={agentColor}
/>
))}
function TrainOfThoughtComponent(props: TrainOfThoughtComponentProps) {
const lastIndex = props.trainOfThought.length - 1;
const [collapsed, setCollapsed] = useState(props.completed);
return (
<div
className={`${!collapsed ? styles.trainOfThought + " shadow-sm" : ""}`}
key={props.keyId}
>
{!props.completed && <InlineLoading className="float-right" />}
{props.completed &&
(collapsed ? (
<Button
className="w-fit text-left justify-start content-start text-xs"
onClick={() => setCollapsed(false)}
variant="ghost"
size="sm"
>
What was my train of thought?
</Button>
) : (
<Button
className="w-fit text-left justify-start content-start text-xs p-0 h-fit"
onClick={() => setCollapsed(true)}
variant="ghost"
size="sm"
>
<XCircle size={16} className="mr-1" />
Close
</Button>
))}
{!collapsed &&
props.trainOfThought.map((train, index) => (
<TrainOfThought
key={`train-${index}`}
message={train}
primary={index === lastIndex && props.lastMessage && !props.completed}
agentColor={props.agentColor}
/>
))}
</div>
);
}
@@ -67,6 +98,7 @@ export default function ChatHistory(props: ChatHistoryProps) {
const [data, setData] = useState<ChatHistoryData | null>(null);
const [currentPage, setCurrentPage] = useState(0);
const [hasMoreMessages, setHasMoreMessages] = useState(true);
const [currentTurnId, setCurrentTurnId] = useState<string | null>(null);
const sentinelRef = useRef<HTMLDivElement | null>(null);
const scrollAreaRef = useRef<HTMLDivElement | null>(null);
const latestUserMessageRef = useRef<HTMLDivElement | null>(null);
@@ -147,6 +179,10 @@ export default function ChatHistory(props: ChatHistoryProps) {
if (lastMessage && !lastMessage.completed) {
setIncompleteIncomingMessageIndex(props.incomingMessages.length - 1);
props.setTitle(lastMessage.rawQuery);
// Store the turnId when we get it
if (lastMessage.turnId) {
setCurrentTurnId(lastMessage.turnId);
}
}
}
}, [props.incomingMessages]);
@@ -248,44 +284,79 @@ export default function ChatHistory(props: ChatHistoryProps) {
return data.agent?.persona;
}
const handleDeleteMessage = (turnId?: string) => {
if (!turnId) return;
setData((prevData) => {
if (!prevData || !turnId) return prevData;
return {
...prevData,
chat: prevData.chat.filter((msg) => msg.turnId !== turnId),
};
});
// Update incoming messages if they exist
if (props.incomingMessages && props.setIncomingMessages) {
props.setIncomingMessages(
props.incomingMessages.filter((msg) => msg.turnId !== turnId),
);
}
};
if (!props.conversationId && !props.publicConversationSlug) {
return null;
}
return (
<ScrollArea className={`h-[80vh] relative`} ref={scrollAreaRef}>
<ScrollArea className={`h-[73vh] relative`} ref={scrollAreaRef}>
<div>
<div className={styles.chatHistory}>
<div className={`${styles.chatHistory} ${props.customClassName}`}>
<div ref={sentinelRef} style={{ height: "1px" }}>
{fetchingData && (
<InlineLoading message="Loading Conversation" className="opacity-50" />
)}
{fetchingData && <InlineLoading className="opacity-50" />}
</div>
{data &&
data.chat &&
data.chat.map((chatMessage, index) => (
<ChatMessage
key={`${index}fullHistory`}
ref={
// attach ref to the second last message to handle scroll on page load
index === data.chat.length - 2
? latestUserMessageRef
: // attach ref to the newest fetched message to handle scroll on fetch
// note: stabilize index selection against last page having less messages than fetchMessageCount
index ===
data.chat.length - (currentPage - 1) * fetchMessageCount
? latestFetchedMessageRef
: null
}
isMobileWidth={isMobileWidth}
chatMessage={chatMessage}
customClassName="fullHistory"
borderLeftColor={`${data?.agent?.color}-500`}
isLastMessage={index === data.chat.length - 1}
/>
<>
{chatMessage.trainOfThought && chatMessage.by === "khoj" && (
<TrainOfThoughtComponent
trainOfThought={chatMessage.trainOfThought?.map(
(train) => train.data,
)}
lastMessage={false}
agentColor={data?.agent?.color || "orange"}
key={`${index}trainOfThought`}
keyId={`${index}trainOfThought`}
completed={true}
/>
)}
<ChatMessage
key={`${index}fullHistory`}
ref={
// attach ref to the second last message to handle scroll on page load
index === data.chat.length - 2
? latestUserMessageRef
: // attach ref to the newest fetched message to handle scroll on fetch
// note: stabilize index selection against last page having less messages than fetchMessageCount
index ===
data.chat.length -
(currentPage - 1) * fetchMessageCount
? latestFetchedMessageRef
: null
}
isMobileWidth={isMobileWidth}
chatMessage={chatMessage}
customClassName="fullHistory"
borderLeftColor={`${data?.agent?.color}-500`}
isLastMessage={index === data.chat.length - 1}
onDeleteMessage={handleDeleteMessage}
conversationId={props.conversationId}
/>
</>
))}
{props.incomingMessages &&
props.incomingMessages.map((message, index) => {
const messageTurnId = message.turnId ?? currentTurnId ?? undefined;
return (
<React.Fragment key={`incomingMessage${index}`}>
<ChatMessage
@@ -295,22 +366,31 @@ export default function ChatHistory(props: ChatHistoryProps) {
message: message.rawQuery,
context: [],
onlineContext: {},
codeContext: {},
created: message.timestamp,
by: "you",
automationId: "",
uploadedImageData: message.uploadedImageData,
images: message.images,
conversationId: props.conversationId,
turnId: messageTurnId,
queryFiles: message.queryFiles,
}}
customClassName="fullHistory"
borderLeftColor={`${data?.agent?.color}-500`}
onDeleteMessage={handleDeleteMessage}
conversationId={props.conversationId}
turnId={messageTurnId}
/>
{message.trainOfThought &&
constructTrainOfThought(
message.trainOfThought,
index === incompleteIncomingMessageIndex,
data?.agent?.color || "orange",
`${index}trainOfThought`,
message.completed,
)}
{message.trainOfThought && (
<TrainOfThoughtComponent
trainOfThought={message.trainOfThought}
lastMessage={index === incompleteIncomingMessageIndex}
agentColor={data?.agent?.color || "orange"}
key={`${index}trainOfThought`}
keyId={`${index}trainOfThought`}
completed={message.completed}
/>
)}
<ChatMessage
key={`${index}incoming`}
isMobileWidth={isMobileWidth}
@@ -318,11 +398,23 @@ export default function ChatHistory(props: ChatHistoryProps) {
message: message.rawResponse,
context: message.context,
onlineContext: message.onlineContext,
codeContext: message.codeContext,
created: message.timestamp,
by: "khoj",
automationId: "",
rawQuery: message.rawQuery,
intent: {
type: message.intentType || "",
query: message.rawQuery,
"memory-type": "",
"inferred-queries": message.inferredQueries || [],
},
conversationId: props.conversationId,
turnId: messageTurnId,
}}
conversationId={props.conversationId}
turnId={messageTurnId}
onDeleteMessage={handleDeleteMessage}
customClassName="fullHistory"
borderLeftColor={`${data?.agent?.color}-500`}
isLastMessage={true}
@@ -338,11 +430,15 @@ export default function ChatHistory(props: ChatHistoryProps) {
message: props.pendingMessage,
context: [],
onlineContext: {},
codeContext: {},
created: new Date().getTime().toString(),
by: "you",
automationId: "",
uploadedImageData: props.pendingMessage,
conversationId: props.conversationId,
turnId: undefined,
}}
conversationId={props.conversationId}
onDeleteMessage={handleDeleteMessage}
customClassName="fullHistory"
borderLeftColor={`${data?.agent?.color}-500`}
isLastMessage={true}
@@ -366,18 +462,20 @@ export default function ChatHistory(props: ChatHistoryProps) {
</div>
)}
</div>
{!isNearBottom && (
<button
title="Scroll to bottom"
className="absolute bottom-4 right-5 bg-white dark:bg-[hsl(var(--background))] text-neutral-500 dark:text-white p-2 rounded-full shadow-xl"
onClick={() => {
scrollToBottom();
setIsNearBottom(true);
}}
>
<ArrowDown size={24} />
</button>
)}
<div className={`${props.customClassName} fixed bottom-[20%] z-10`}>
{!isNearBottom && (
<button
title="Scroll to bottom"
className="absolute bottom-0 right-0 bg-white dark:bg-[hsl(var(--background))] text-neutral-500 dark:text-white p-2 rounded-full shadow-xl"
onClick={() => {
scrollToBottom();
setIsNearBottom(true);
}}
>
<ArrowDown size={24} />
</button>
)}
</div>
</div>
</ScrollArea>
);

View File

@@ -1,24 +1,16 @@
import styles from "./chatInputArea.module.css";
import React, { useEffect, useRef, useState } from "react";
import React, { useEffect, useRef, useState, forwardRef } from "react";
import DOMPurify from "dompurify";
import "katex/dist/katex.min.css";
import {
ArrowRight,
ArrowUp,
Browser,
ChatsTeardrop,
GlobeSimple,
Gps,
Image,
Microphone,
Notebook,
Paperclip,
X,
Question,
Robot,
Shapes,
Stop,
ToggleLeft,
ToggleRight,
} from "@phosphor-icons/react";
import {
@@ -45,30 +37,48 @@ import { Popover, PopoverContent } from "@/components/ui/popover";
import { PopoverTrigger } from "@radix-ui/react-popover";
import { Textarea } from "@/components/ui/textarea";
import { Tooltip, TooltipContent, TooltipProvider, TooltipTrigger } from "@/components/ui/tooltip";
import { convertToBGClass } from "@/app/common/colorUtils";
import { convertColorToTextClass, convertToBGClass } from "@/app/common/colorUtils";
import LoginPrompt from "../loginPrompt/loginPrompt";
import { uploadDataForIndexing } from "../../common/chatFunctions";
import { InlineLoading } from "../loading/loading";
import { getIconForSlashCommand } from "@/app/common/iconUtils";
import { getIconForSlashCommand, getIconFromFilename } from "@/app/common/iconUtils";
import { packageFilesForUpload } from "@/app/common/chatFunctions";
import { convertBytesToText } from "@/app/common/utils";
import {
Dialog,
DialogContent,
DialogDescription,
DialogHeader,
DialogTitle,
DialogTrigger,
} from "@/components/ui/dialog";
import { ScrollArea } from "@/components/ui/scroll-area";
export interface ChatOptions {
[key: string]: string;
}
export interface AttachedFileText {
name: string;
content: string;
file_type: string;
size: number;
}
interface ChatInputProps {
sendMessage: (message: string) => void;
sendImage: (image: string) => void;
sendDisabled: boolean;
setUploadedFiles?: (files: string[]) => void;
setUploadedFiles: (files: AttachedFileText[]) => void;
conversationId?: string | null;
chatOptionsData?: ChatOptions | null;
isMobileWidth?: boolean;
isLoggedIn: boolean;
agentColor?: string;
isResearchModeEnabled?: boolean;
}
export default function ChatInputArea(props: ChatInputProps) {
export const ChatInputArea = forwardRef<HTMLTextAreaElement, ChatInputProps>((props, ref) => {
const [message, setMessage] = useState("");
const fileInputRef = useRef<HTMLInputElement>(null);
@@ -78,15 +88,25 @@ export default function ChatInputArea(props: ChatInputProps) {
const [loginRedirectMessage, setLoginRedirectMessage] = useState<string | null>(null);
const [showLoginPrompt, setShowLoginPrompt] = useState(false);
const [recording, setRecording] = useState(false);
const [imageUploaded, setImageUploaded] = useState(false);
const [imagePath, setImagePath] = useState<string>("");
const [imageData, setImageData] = useState<string | null>(null);
const [imagePaths, setImagePaths] = useState<string[]>([]);
const [imageData, setImageData] = useState<string[]>([]);
const [attachedFiles, setAttachedFiles] = useState<FileList | null>(null);
const [convertedAttachedFiles, setConvertedAttachedFiles] = useState<AttachedFileText[]>([]);
const [recording, setRecording] = useState(false);
const [mediaRecorder, setMediaRecorder] = useState<MediaRecorder | null>(null);
const [progressValue, setProgressValue] = useState(0);
const [isDragAndDropping, setIsDragAndDropping] = useState(false);
const [showCommandList, setShowCommandList] = useState(false);
const [useResearchMode, setUseResearchMode] = useState<boolean>(
props.isResearchModeEnabled || false,
);
const chatInputRef = ref as React.MutableRefObject<HTMLTextAreaElement>;
useEffect(() => {
if (!uploading) {
setProgressValue(0);
@@ -106,27 +126,37 @@ export default function ChatInputArea(props: ChatInputProps) {
useEffect(() => {
async function fetchImageData() {
if (imagePath) {
const response = await fetch(imagePath);
const blob = await response.blob();
const reader = new FileReader();
reader.onload = function () {
const base64data = reader.result;
setImageData(base64data as string);
};
reader.readAsDataURL(blob);
if (imagePaths.length > 0) {
const newImageData = await Promise.all(
imagePaths.map(async (path) => {
const response = await fetch(path);
const blob = await response.blob();
return new Promise<string>((resolve) => {
const reader = new FileReader();
reader.onload = () => resolve(reader.result as string);
reader.readAsDataURL(blob);
});
}),
);
setImageData(newImageData);
}
setUploading(false);
}
setUploading(true);
fetchImageData();
}, [imagePath]);
}, [imagePaths]);
useEffect(() => {
if (props.isResearchModeEnabled) {
setUseResearchMode(props.isResearchModeEnabled);
}
}, [props.isResearchModeEnabled]);
function onSendMessage() {
if (imageUploaded) {
setImageUploaded(false);
setImagePath("");
props.sendImage(imageData || "");
setImagePaths([]);
imageData.forEach((data) => props.sendImage(data));
}
if (!message.trim()) return;
@@ -138,7 +168,14 @@ export default function ChatInputArea(props: ChatInputProps) {
return;
}
props.sendMessage(message.trim());
let messageToSend = message.trim();
if (useResearchMode && !messageToSend.startsWith("/research")) {
messageToSend = `/research ${messageToSend}`;
}
props.sendMessage(messageToSend);
setAttachedFiles(null);
setConvertedAttachedFiles([]);
setMessage("");
}
@@ -172,26 +209,85 @@ export default function ChatInputArea(props: ChatInputProps) {
setShowLoginPrompt(true);
return;
}
// check for image file
// check for image files
const image_endings = ["jpg", "jpeg", "png", "webp"];
const newImagePaths: string[] = [];
for (let i = 0; i < files.length; i++) {
const file = files[i];
const file_extension = file.name.split(".").pop();
if (image_endings.includes(file_extension || "")) {
setImageUploaded(true);
setImagePath(DOMPurify.sanitize(URL.createObjectURL(file)));
return;
newImagePaths.push(DOMPurify.sanitize(URL.createObjectURL(file)));
}
}
uploadDataForIndexing(
files,
setWarning,
setUploading,
setError,
props.setUploadedFiles,
props.conversationId,
if (newImagePaths.length > 0) {
setImageUploaded(true);
setImagePaths((prevPaths) => [...prevPaths, ...newImagePaths]);
// Set focus to the input for user message after uploading files
chatInputRef?.current?.focus();
}
// Process all non-image files
const nonImageFiles = Array.from(files).filter(
(file) => !image_endings.includes(file.name.split(".").pop() || ""),
);
// Concatenate attachedFiles and files
const newFiles = nonImageFiles
? Array.from(nonImageFiles).concat(Array.from(attachedFiles || []))
: Array.from(attachedFiles || []);
if (newFiles.length > 0) {
// Ensure files are below size limit (10 MB)
for (let i = 0; i < newFiles.length; i++) {
if (newFiles[i].size > 10 * 1024 * 1024) {
setWarning(
`File ${newFiles[i].name} is too large. Please upload files smaller than 10 MB.`,
);
return;
}
}
const dataTransfer = new DataTransfer();
newFiles.forEach((file) => dataTransfer.items.add(file));
// Extract text from files
extractTextFromFiles(dataTransfer.files).then((data) => {
props.setUploadedFiles(data);
setAttachedFiles(dataTransfer.files);
setConvertedAttachedFiles(data);
});
}
// Set focus to the input for user message after uploading files
chatInputRef?.current?.focus();
}
async function extractTextFromFiles(files: FileList): Promise<AttachedFileText[]> {
const formData = await packageFilesForUpload(files);
setUploading(true);
try {
const response = await fetch("/api/content/convert", {
method: "POST",
body: formData,
});
setUploading(false);
if (!response.ok) {
throw new Error(`HTTP error! status: ${response.status}`);
}
return await response.json();
} catch (error) {
setError(
"Error converting files. " +
error +
". Please try again, or contact team@khoj.dev if the issue persists.",
);
console.error("Error converting files:", error);
return [];
}
}
// Assuming this function is added within the same context as the provided excerpt
@@ -270,12 +366,17 @@ export default function ChatInputArea(props: ChatInputProps) {
}
}, [recording, mediaRecorder]);
const chatInputRef = useRef<HTMLTextAreaElement>(null);
useEffect(() => {
if (!chatInputRef.current) return;
if (!chatInputRef?.current) return;
chatInputRef.current.style.height = "auto";
chatInputRef.current.style.height =
Math.max(chatInputRef.current.scrollHeight - 24, 64) + "px";
if (message.startsWith("/") && message.split(" ").length === 1) {
setShowCommandList(true);
} else {
setShowCommandList(false);
}
}, [message]);
function handleDragOver(event: React.DragEvent<HTMLDivElement>) {
@@ -288,9 +389,12 @@ export default function ChatInputArea(props: ChatInputProps) {
setIsDragAndDropping(false);
}
function removeImageUpload() {
setImageUploaded(false);
setImagePath("");
function removeImageUpload(index: number) {
setImagePaths((prevPaths) => prevPaths.filter((_, i) => i !== index));
setImageData((prevData) => prevData.filter((_, i) => i !== index));
if (imagePaths.length === 1) {
setImageUploaded(false);
}
}
return (
@@ -358,13 +462,18 @@ export default function ChatInputArea(props: ChatInputProps) {
</AlertDialogContent>
</AlertDialog>
)}
{message.startsWith("/") && message.split(" ").length === 1 && (
{showCommandList && (
<div className="flex justify-center text-center">
<Popover open={message.startsWith("/")}>
<Popover open={showCommandList} onOpenChange={setShowCommandList}>
<PopoverTrigger className="flex justify-center text-center"></PopoverTrigger>
<PopoverContent
onOpenAutoFocus={(e) => e.preventDefault()}
className={`${props.isMobileWidth ? "w-[100vw]" : "w-full"} rounded-md`}
side="bottom"
align="center"
/* Offset below text area on home page (i.e where conversationId is unset) */
sideOffset={props.conversationId ? 0 : 80}
alignOffset={0}
>
<Command className="max-w-full">
<CommandInput
@@ -406,112 +515,230 @@ export default function ChatInputArea(props: ChatInputProps) {
</Popover>
</div>
)}
<div
className={`${styles.actualInputArea} items-center justify-between dark:bg-neutral-700 relative ${isDragAndDropping && "animate-pulse"}`}
onDragOver={handleDragOver}
onDragLeave={handleDragLeave}
onDrop={handleDragAndDropFiles}
>
{imageUploaded && (
<div className="absolute bottom-[80px] left-0 right-0 dark:bg-neutral-700 bg-white pt-5 pb-5 w-full rounded-lg border dark:border-none grid grid-cols-2">
<div className="pl-4 pr-4">
<img src={imagePath} alt="img" className="w-auto max-h-[100px]" />
</div>
<div className="pl-4 pr-4">
<X
className="w-6 h-6 float-right dark:hover:bg-[hsl(var(--background))] hover:bg-neutral-100 rounded-sm"
onClick={removeImageUpload}
/>
</div>
</div>
)}
<input
type="file"
multiple={true}
ref={fileInputRef}
onChange={handleFileChange}
style={{ display: "none" }}
/>
<Button
variant={"ghost"}
className="!bg-none p-0 m-2 h-auto text-3xl rounded-full text-gray-300 hover:text-gray-500"
disabled={props.sendDisabled}
onClick={handleFileButtonClick}
>
<Paperclip className="w-8 h-8" />
</Button>
<div className="grid w-full gap-1.5 relative">
<Textarea
ref={chatInputRef}
className={`border-none w-full h-16 min-h-16 max-h-[128px] md:py-4 rounded-lg resize-none dark:bg-neutral-700 ${props.isMobileWidth ? "text-md" : "text-lg"}`}
placeholder="Type / to see a list of commands"
id="message"
autoFocus={true}
value={message}
onKeyDown={(e) => {
if (e.key === "Enter" && !e.shiftKey) {
setImageUploaded(false);
setImagePath("");
e.preventDefault();
onSendMessage();
}
}}
onChange={(e) => setMessage(e.target.value)}
disabled={props.sendDisabled || recording}
/>
<div>
<div className="flex items-center gap-2 overflow-x-auto">
{imageUploaded &&
imagePaths.map((path, index) => (
<div key={index} className="relative flex-shrink-0 pb-3 pt-2 group">
<img
src={path}
alt={`img-${index}`}
className="w-auto h-16 object-cover rounded-xl"
/>
<Button
variant="ghost"
size="icon"
className="absolute -top-0 -right-2 h-5 w-5 rounded-full bg-neutral-200 dark:bg-neutral-600 hover:bg-neutral-300 dark:hover:bg-neutral-500 opacity-0 group-hover:opacity-100 transition-opacity"
onClick={() => removeImageUpload(index)}
>
<X className="h-3 w-3" />
</Button>
</div>
))}
{convertedAttachedFiles &&
Array.from(convertedAttachedFiles).map((file, index) => (
<Dialog key={index}>
<DialogTrigger asChild>
<div key={index} className="relative flex-shrink-0 p-2 group">
<div
className={`w-auto h-16 object-cover rounded-xl ${props.agentColor ? convertToBGClass(props.agentColor) : "bg-orange-300 hover:bg-orange-500"} bg-opacity-15`}
>
<div className="flex p-2 flex-col justify-start items-start h-full">
<span className="text-sm font-bold text-neutral-500 dark:text-neutral-400 text-ellipsis truncate max-w-[200px] break-words">
{file.name}
</span>
<span className="flex items-center gap-1">
{getIconFromFilename(file.file_type)}
<span className="text-xs text-neutral-500 dark:text-neutral-400">
{convertBytesToText(file.size)}
</span>
</span>
</div>
</div>
<Button
variant="ghost"
size="icon"
className="absolute -top-0 -right-2 h-5 w-5 rounded-full bg-neutral-200 dark:bg-neutral-600 hover:bg-neutral-300 dark:hover:bg-neutral-500 opacity-0 group-hover:opacity-100 transition-opacity"
onClick={() => {
setAttachedFiles((prevFiles) => {
const removeFile = file.name;
if (!prevFiles) return null;
const updatedFiles = Array.from(
prevFiles,
).filter((file) => file.name !== removeFile);
const dataTransfer = new DataTransfer();
updatedFiles.forEach((file) =>
dataTransfer.items.add(file),
);
const filteredConvertedAttachedFiles =
convertedAttachedFiles.filter(
(file) => file.name !== removeFile,
);
props.setUploadedFiles(
filteredConvertedAttachedFiles,
);
setConvertedAttachedFiles(
filteredConvertedAttachedFiles,
);
return dataTransfer.files;
});
}}
>
<X className="h-3 w-3" />
</Button>
</div>
</DialogTrigger>
<DialogContent>
<DialogHeader>
<DialogTitle>{file.name}</DialogTitle>
</DialogHeader>
<DialogDescription>
<ScrollArea className="h-72 w-full rounded-md">
{file.content}
</ScrollArea>
</DialogDescription>
</DialogContent>
</Dialog>
))}
</div>
{recording ? (
<TooltipProvider>
<Tooltip>
<TooltipTrigger asChild>
<Button
variant="default"
className={`${!recording && "hidden"} ${props.agentColor ? convertToBGClass(props.agentColor) : "bg-orange-300 hover:bg-orange-500"} rounded-full p-1 m-2 h-auto text-3xl transition transform md:hover:-translate-y-1`}
onClick={() => {
setRecording(!recording);
}}
disabled={props.sendDisabled}
>
<Stop weight="fill" className="w-6 h-6" />
</Button>
</TooltipTrigger>
<TooltipContent>
Click to stop recording and transcribe your voice.
</TooltipContent>
</Tooltip>
</TooltipProvider>
) : mediaRecorder ? (
<InlineLoading />
) : (
<TooltipProvider>
<Tooltip>
<TooltipTrigger asChild>
<Button
variant="default"
className={`${!message || recording || "hidden"} ${props.agentColor ? convertToBGClass(props.agentColor) : "bg-orange-300 hover:bg-orange-500"} rounded-full p-1 m-2 h-auto text-3xl transition transform md:hover:-translate-y-1`}
onClick={() => {
setMessage("Listening...");
setRecording(!recording);
}}
disabled={props.sendDisabled}
>
<Microphone weight="fill" className="w-6 h-6" />
</Button>
</TooltipTrigger>
<TooltipContent>
Click to transcribe your message with voice.
</TooltipContent>
</Tooltip>
</TooltipProvider>
)}
<Button
className={`${(!message || recording) && "hidden"} ${props.agentColor ? convertToBGClass(props.agentColor) : "bg-orange-300 hover:bg-orange-500"} rounded-full p-1 m-2 h-auto text-3xl transition transform md:hover:-translate-y-1`}
onClick={onSendMessage}
disabled={props.sendDisabled}
<div
className={`${styles.actualInputArea} justify-between dark:bg-neutral-700 relative ${isDragAndDropping && "animate-pulse"}`}
onDragOver={handleDragOver}
onDragLeave={handleDragLeave}
onDrop={handleDragAndDropFiles}
>
<ArrowUp className="w-6 h-6" weight="bold" />
</Button>
<input
type="file"
accept=".pdf,.doc,.docx,.txt,.md,.org,.jpg,.jpeg,.png,.webp"
multiple={true}
ref={fileInputRef}
onChange={handleFileChange}
style={{ display: "none" }}
/>
<div className="flex items-center">
<Button
variant={"ghost"}
className="!bg-none p-0 m-2 h-auto text-3xl rounded-full text-gray-300 hover:text-gray-500"
disabled={props.sendDisabled}
onClick={handleFileButtonClick}
>
<Paperclip className="w-8 h-8" />
</Button>
</div>
<div className="flex-grow flex flex-col w-full gap-1.5 relative">
<Textarea
ref={chatInputRef}
className={`border-none focus:border-none
focus:outline-none focus-visible:ring-transparent
w-full h-16 min-h-16 max-h-[128px] md:py-4 rounded-lg resize-none dark:bg-neutral-700
${props.isMobileWidth ? "text-md" : "text-lg"}`}
placeholder="Type / to see a list of commands"
id="message"
autoFocus={true}
value={message}
onKeyDown={(e) => {
if (e.key === "Enter" && !e.shiftKey && !props.isMobileWidth) {
setImageUploaded(false);
setImagePaths([]);
e.preventDefault();
onSendMessage();
}
}}
onChange={(e) => setMessage(e.target.value)}
disabled={props.sendDisabled || recording}
/>
</div>
<div className="flex items-end pb-2">
{recording ? (
<TooltipProvider>
<Tooltip>
<TooltipTrigger asChild>
<Button
variant="default"
className={`${!recording && "hidden"} ${props.agentColor ? convertToBGClass(props.agentColor) : "bg-orange-300 hover:bg-orange-500"} rounded-full p-1 m-2 h-auto text-3xl transition transform md:hover:-translate-y-1`}
onClick={() => {
setRecording(!recording);
}}
disabled={props.sendDisabled}
>
<Stop weight="fill" className="w-6 h-6" />
</Button>
</TooltipTrigger>
<TooltipContent>
Click to stop recording and transcribe your voice.
</TooltipContent>
</Tooltip>
</TooltipProvider>
) : mediaRecorder ? (
<InlineLoading />
) : (
<TooltipProvider>
<Tooltip>
<TooltipTrigger asChild>
<Button
variant="default"
className={`${!message || recording || "hidden"} ${props.agentColor ? convertToBGClass(props.agentColor) : "bg-orange-300 hover:bg-orange-500"} rounded-full p-1 m-2 h-auto text-3xl transition transform md:hover:-translate-y-1`}
onClick={() => {
setMessage("Listening...");
setRecording(!recording);
}}
disabled={props.sendDisabled}
>
<Microphone weight="fill" className="w-6 h-6" />
</Button>
</TooltipTrigger>
<TooltipContent>
Click to transcribe your message with voice.
</TooltipContent>
</Tooltip>
</TooltipProvider>
)}
<Button
className={`${(!message || recording) && "hidden"} ${props.agentColor ? convertToBGClass(props.agentColor) : "bg-orange-300 hover:bg-orange-500"} rounded-full p-1 m-2 h-auto text-3xl transition transform md:hover:-translate-y-1`}
onClick={onSendMessage}
disabled={props.sendDisabled}
>
<ArrowUp className="w-6 h-6" weight="bold" />
</Button>
</div>
</div>
<TooltipProvider>
<Tooltip>
<TooltipTrigger asChild>
<Button
variant="ghost"
className="float-right justify-center gap-1 flex items-center p-1.5 mr-2 h-fit"
onClick={() => {
setUseResearchMode(!useResearchMode);
chatInputRef?.current?.focus();
}}
>
<span className="text-muted-foreground text-sm">Research Mode</span>
{useResearchMode ? (
<ToggleRight
weight="fill"
className={`w-6 h-6 inline-block ${props.agentColor ? convertColorToTextClass(props.agentColor) : convertColorToTextClass("orange")} rounded-full`}
/>
) : (
<ToggleLeft
weight="fill"
className={`w-6 h-6 inline-block ${convertColorToTextClass("gray")} rounded-full`}
/>
)}
</Button>
</TooltipTrigger>
<TooltipContent className="text-xs">
(Experimental) Research Mode allows you to get more deeply researched,
detailed responses. Response times may be longer.
</TooltipContent>
</Tooltip>
</TooltipProvider>
</div>
</>
);
}
});
ChatInputArea.displayName = "ChatInputArea";

View File

@@ -4,6 +4,7 @@ div.chatMessageContainer {
margin: 12px;
border-radius: 16px;
padding: 8px 16px 0 16px;
word-break: break-word;
}
div.chatMessageWrapper {
@@ -57,7 +58,26 @@ div.emptyChatMessage {
display: none;
}
div.chatMessageContainer img {
div.imagesContainer {
display: flex;
overflow-x: auto;
padding-bottom: 8px;
margin-bottom: 8px;
}
div.imageWrapper {
flex: 0 0 auto;
margin-right: 8px;
}
div.imageWrapper img {
width: auto;
height: 128px;
object-fit: cover;
border-radius: 8px;
}
div.chatMessageContainer > img {
width: auto;
height: auto;
max-width: 100%;
@@ -151,6 +171,7 @@ div.trainOfThoughtElement {
div.trainOfThoughtElement ol,
div.trainOfThoughtElement ul {
margin: auto;
word-break: break-word;
}
@media screen and (max-width: 768px) {

View File

@@ -10,6 +10,7 @@ import { createRoot } from "react-dom/client";
import "katex/dist/katex.min.css";
import { TeaserReferencesSection, constructAllReferences } from "../referencePanel/referencePanel";
import { renderCodeGenImageInline } from "@/app/common/chatFunctions";
import {
ThumbsUp,
@@ -26,6 +27,9 @@ import {
Palette,
ClipboardText,
Check,
Code,
Shapes,
Trash,
} from "@phosphor-icons/react";
import DOMPurify from "dompurify";
@@ -35,6 +39,19 @@ import { AgentData } from "@/app/agents/page";
import renderMathInElement from "katex/contrib/auto-render";
import "katex/dist/katex.min.css";
import ExcalidrawComponent from "../excalidraw/excalidraw";
import { AttachedFileText } from "../chatInputArea/chatInputArea";
import {
Dialog,
DialogContent,
DialogDescription,
DialogHeader,
DialogTrigger,
} from "@/components/ui/dialog";
import { DialogTitle } from "@radix-ui/react-dialog";
import { convertBytesToText } from "@/app/common/utils";
import { ScrollArea } from "@/components/ui/scroll-area";
import { getIconFromFilename } from "@/app/common/iconUtils";
const md = new markdownIt({
html: true,
@@ -97,6 +114,26 @@ export interface OnlineContextData {
peopleAlsoAsk: PeopleAlsoAsk[];
}
export interface CodeContext {
[key: string]: CodeContextData;
}
export interface CodeContextData {
code: string;
results: {
success: boolean;
output_files: CodeContextFile[];
std_out: string;
std_err: string;
code_runtime: number;
};
}
export interface CodeContextFile {
filename: string;
b64_data: string;
}
interface Intent {
type: string;
query: string;
@@ -104,6 +141,11 @@ interface Intent {
"inferred-queries": string[];
}
interface TrainOfThoughtObject {
type: string;
data: string;
}
export interface SingleChatMessage {
automationId: string;
by: string;
@@ -111,10 +153,15 @@ export interface SingleChatMessage {
created: string;
context: Context[];
onlineContext: OnlineContext;
codeContext: CodeContext;
trainOfThought?: TrainOfThoughtObject[];
rawQuery?: string;
intent?: Intent;
agent?: AgentData;
uploadedImageData?: string;
images?: string[];
conversationId: string;
turnId?: string;
queryFiles?: AttachedFileText[];
}
export interface StreamMessage {
@@ -122,11 +169,16 @@ export interface StreamMessage {
trainOfThought: string[];
context: Context[];
onlineContext: OnlineContext;
codeContext: CodeContext;
completed: boolean;
rawQuery: string;
timestamp: string;
agent?: AgentData;
uploadedImageData?: string;
images?: string[];
intentType?: string;
inferredQueries?: string[];
turnId?: string;
queryFiles?: AttachedFileText[];
}
export interface ChatHistoryData {
@@ -208,7 +260,9 @@ interface ChatMessageProps {
borderLeftColor?: string;
isLastMessage?: boolean;
agent?: AgentData;
uploadedImageData?: string;
onDeleteMessage: (turnId?: string) => void;
conversationId: string;
turnId?: string;
}
interface TrainOfThoughtProps {
@@ -252,10 +306,18 @@ function chooseIconFromHeader(header: string, iconColor: string) {
return <Aperture className={`${classNames}`} />;
}
if (compareHeader.includes("diagram")) {
return <Shapes className={`${classNames}`} />;
}
if (compareHeader.includes("paint")) {
return <Palette className={`${classNames}`} />;
}
if (compareHeader.includes("code")) {
return <Code className={`${classNames}`} />;
}
return <Brain className={`${classNames}`} />;
}
@@ -266,12 +328,15 @@ export function TrainOfThought(props: TrainOfThoughtProps) {
const iconColor = props.primary ? convertColorToTextClass(props.agentColor) : "text-gray-500";
const icon = chooseIconFromHeader(header, iconColor);
let markdownRendered = DOMPurify.sanitize(md.render(props.message));
// Remove any header tags from markdownRendered
markdownRendered = markdownRendered.replace(/<h[1-6].*?<\/h[1-6]>/g, "");
return (
<div
className={`${styles.trainOfThoughtElement} break-all items-center ${props.primary ? "text-gray-400" : "text-gray-300"} ${styles.trainOfThought} ${props.primary ? styles.primary : ""}`}
className={`${styles.trainOfThoughtElement} break-words items-center ${props.primary ? "text-gray-400" : "text-gray-300"} ${styles.trainOfThought} ${props.primary ? styles.primary : ""}`}
>
{icon}
<div dangerouslySetInnerHTML={{ __html: markdownRendered }} />
<div dangerouslySetInnerHTML={{ __html: markdownRendered }} className="break-words" />
</div>
);
}
@@ -283,6 +348,7 @@ const ChatMessage = forwardRef<HTMLDivElement, ChatMessageProps>((props, ref) =>
const [markdownRendered, setMarkdownRendered] = useState<string>("");
const [isPlaying, setIsPlaying] = useState<boolean>(false);
const [interrupted, setInterrupted] = useState<boolean>(false);
const [excalidrawData, setExcalidrawData] = useState<string>("");
const interruptedRef = useRef<boolean>(false);
const messageRef = useRef<HTMLDivElement>(null);
@@ -319,8 +385,14 @@ const ChatMessage = forwardRef<HTMLDivElement, ChatMessageProps>((props, ref) =>
}, [messageRef.current]);
useEffect(() => {
// Prepare initial message for rendering
let message = props.chatMessage.message;
if (props.chatMessage.intent && props.chatMessage.intent.type == "excalidraw") {
message = props.chatMessage.intent["inferred-queries"][0];
setExcalidrawData(props.chatMessage.message);
}
// Replace LaTeX delimiters with placeholders
message = message
.replace(/\\\(/g, "LEFTPAREN")
@@ -328,32 +400,74 @@ const ChatMessage = forwardRef<HTMLDivElement, ChatMessageProps>((props, ref) =>
.replace(/\\\[/g, "LEFTBRACKET")
.replace(/\\\]/g, "RIGHTBRACKET");
if (props.chatMessage.uploadedImageData) {
message = `![uploaded image](${props.chatMessage.uploadedImageData})\n\n${message}`;
const intentTypeHandlers = {
"text-to-image": (msg: string) => `![generated image](data:image/png;base64,${msg})`,
"text-to-image2": (msg: string) => `![generated image](${msg})`,
"text-to-image-v3": (msg: string) =>
`![generated image](data:image/webp;base64,${msg})`,
excalidraw: (msg: string) => msg,
};
// Handle intent-specific rendering
if (props.chatMessage.intent) {
const { type, "inferred-queries": inferredQueries } = props.chatMessage.intent;
if (type in intentTypeHandlers) {
message = intentTypeHandlers[type as keyof typeof intentTypeHandlers](message);
}
if (type.includes("text-to-image") && inferredQueries?.length > 0) {
message += `\n\n${inferredQueries[0]}`;
}
}
if (props.chatMessage.intent && props.chatMessage.intent.type == "text-to-image") {
message = `![generated image](data:image/png;base64,${message})`;
} else if (props.chatMessage.intent && props.chatMessage.intent.type == "text-to-image2") {
message = `![generated image](${message})`;
} else if (
props.chatMessage.intent &&
props.chatMessage.intent.type == "text-to-image-v3"
) {
message = `![generated image](data:image/webp;base64,${message})`;
}
if (
props.chatMessage.intent &&
props.chatMessage.intent.type.includes("text-to-image") &&
props.chatMessage.intent["inferred-queries"]?.length > 0
) {
message += `\n\n${props.chatMessage.intent["inferred-queries"][0]}`;
// Replace file links with base64 data
message = renderCodeGenImageInline(message, props.chatMessage.codeContext);
// Add code context files to the message
if (props.chatMessage.codeContext) {
Object.entries(props.chatMessage.codeContext).forEach(([key, value]) => {
value.results.output_files?.forEach((file) => {
if (file.filename.endsWith(".png") || file.filename.endsWith(".jpg")) {
// Don't add the image again if it's already in the message!
if (!message.includes(`![${file.filename}](`)) {
message += `\n\n![${file.filename}](data:image/png;base64,${file.b64_data})`;
}
}
});
});
}
setTextRendered(message);
// Handle user attached images rendering
let messageForClipboard = message;
let messageToRender = message;
if (props.chatMessage.images && props.chatMessage.images.length > 0) {
const sanitizedImages = props.chatMessage.images.map((image) => {
const decodedImage = image.startsWith("data%3Aimage")
? decodeURIComponent(image)
: image;
return DOMPurify.sanitize(decodedImage);
});
const imagesInMd = sanitizedImages
.map((sanitizedImage, index) => {
return `![uploaded image ${index + 1}](${sanitizedImage})`;
})
.join("\n");
const imagesInHtml = sanitizedImages
.map((sanitizedImage, index) => {
return `<div class="${styles.imageWrapper}"><img src="${sanitizedImage}" alt="uploaded image ${index + 1}" /></div>`;
})
.join("");
const userImagesInHtml = `<div class="${styles.imagesContainer}">${imagesInHtml}</div>`;
messageForClipboard = `${imagesInMd}\n\n${messageForClipboard}`;
messageToRender = `${userImagesInHtml}${messageToRender}`;
}
// Set the message text
setTextRendered(messageForClipboard);
// Render the markdown
let markdownRendered = md.render(message);
let markdownRendered = md.render(messageToRender);
// Replace placeholders with LaTeX delimiters
markdownRendered = markdownRendered
@@ -364,7 +478,7 @@ const ChatMessage = forwardRef<HTMLDivElement, ChatMessageProps>((props, ref) =>
// Sanitize and set the rendered markdown
setMarkdownRendered(DOMPurify.sanitize(markdownRendered));
}, [props.chatMessage.message, props.chatMessage.intent]);
}, [props.chatMessage.message, props.chatMessage.images, props.chatMessage.intent]);
useEffect(() => {
if (copySuccess) {
@@ -536,9 +650,31 @@ const ChatMessage = forwardRef<HTMLDivElement, ChatMessageProps>((props, ref) =>
});
}
const deleteMessage = async (message: SingleChatMessage) => {
const turnId = message.turnId || props.turnId;
const response = await fetch("/api/chat/conversation/message", {
method: "DELETE",
headers: {
"Content-Type": "application/json",
},
body: JSON.stringify({
conversation_id: props.conversationId,
turn_id: turnId,
}),
});
if (response.ok) {
// Update the UI after successful deletion
props.onDeleteMessage(turnId);
} else {
console.error("Failed to delete message");
}
};
const allReferences = constructAllReferences(
props.chatMessage.context,
props.chatMessage.onlineContext,
props.chatMessage.codeContext,
);
return (
@@ -549,17 +685,59 @@ const ChatMessage = forwardRef<HTMLDivElement, ChatMessageProps>((props, ref) =>
onMouseEnter={(event) => setIsHovering(true)}
>
<div className={chatMessageWrapperClasses(props.chatMessage)}>
{props.chatMessage.queryFiles && props.chatMessage.queryFiles.length > 0 && (
<div className="flex flex-wrap flex-col mb-2 max-w-full">
{props.chatMessage.queryFiles.map((file, index) => (
<Dialog key={index}>
<DialogTrigger asChild>
<div
className="flex items-center space-x-2 cursor-pointer bg-gray-500 bg-opacity-25 rounded-lg p-2 w-full
"
>
<div className="flex-shrink-0">
{getIconFromFilename(file.file_type)}
</div>
<span className="truncate flex-1 min-w-0 max-w-[200px]">
{file.name}
</span>
{file.size && (
<span className="text-gray-400 flex-shrink-0">
({convertBytesToText(file.size)})
</span>
)}
</div>
</DialogTrigger>
<DialogContent>
<DialogHeader>
<DialogTitle>
<div className="truncate min-w-0 break-words break-all text-wrap max-w-full whitespace-normal">
{file.name}
</div>
</DialogTitle>
</DialogHeader>
<DialogDescription>
<ScrollArea className="h-72 w-full rounded-md break-words break-all text-wrap">
{file.content}
</ScrollArea>
</DialogDescription>
</DialogContent>
</Dialog>
))}
</div>
)}
<div
ref={messageRef}
className={styles.chatMessage}
dangerouslySetInnerHTML={{ __html: markdownRendered }}
/>
{excalidrawData && <ExcalidrawComponent data={excalidrawData} />}
</div>
<div className={styles.teaserReferencesContainer}>
<TeaserReferencesSection
isMobileWidth={props.isMobileWidth}
notesReferenceCardData={allReferences.notesReferenceCardData}
onlineReferenceCardData={allReferences.onlineReferenceCardData}
codeReferenceCardData={allReferences.codeReferenceCardData}
/>
</div>
<div className={styles.chatFooter}>
@@ -595,6 +773,18 @@ const ChatMessage = forwardRef<HTMLDivElement, ChatMessageProps>((props, ref) =>
/>
</button>
))}
{props.chatMessage.turnId && (
<button
title="Delete"
className={`${styles.deleteButton}`}
onClick={() => deleteMessage(props.chatMessage)}
>
<Trash
alt="Delete Message"
className="hsl(var(--muted-foreground)) hover:text-red-500"
/>
</button>
)}
<button
title="Copy"
className={`${styles.copyButton}`}

View File

@@ -0,0 +1,24 @@
"use client";
import dynamic from "next/dynamic";
import { Suspense } from "react";
import Loading from "../../components/loading/loading";
// Since client components get prerenderd on server as well hence importing
// the excalidraw stuff dynamically with ssr false
const ExcalidrawWrapper = dynamic(() => import("./excalidrawWrapper").then((mod) => mod.default), {
ssr: false,
});
interface ExcalidrawComponentProps {
data: any;
}
export default function ExcalidrawComponent(props: ExcalidrawComponentProps) {
return (
<Suspense fallback={<Loading />}>
<ExcalidrawWrapper data={props.data} />
</Suspense>
);
}

View File

@@ -0,0 +1,149 @@
"use client";
import { useState, useEffect } from "react";
import dynamic from "next/dynamic";
import { ExcalidrawProps } from "@excalidraw/excalidraw/types/types";
import { ExcalidrawElement } from "@excalidraw/excalidraw/types/element/types";
import { ExcalidrawElementSkeleton } from "@excalidraw/excalidraw/types/data/transform";
const Excalidraw = dynamic<ExcalidrawProps>(
async () => (await import("@excalidraw/excalidraw")).Excalidraw,
{
ssr: false,
},
);
import { convertToExcalidrawElements } from "@excalidraw/excalidraw";
import { Button } from "@/components/ui/button";
import { ArrowsInSimple, ArrowsOutSimple } from "@phosphor-icons/react";
interface ExcalidrawWrapperProps {
data: ExcalidrawElementSkeleton[];
}
export default function ExcalidrawWrapper(props: ExcalidrawWrapperProps) {
const [excalidrawElements, setExcalidrawElements] = useState<ExcalidrawElement[]>([]);
const [expanded, setExpanded] = useState<boolean>(false);
const isValidExcalidrawElement = (element: ExcalidrawElementSkeleton): boolean => {
return (
element.x !== undefined &&
element.y !== undefined &&
element.id !== undefined &&
element.type !== undefined
);
};
useEffect(() => {
if (expanded) {
onkeydown = (e) => {
if (e.key === "Escape") {
setExpanded(false);
// Trigger a resize event to make Excalidraw adjust its size
window.dispatchEvent(new Event("resize"));
}
};
} else {
onkeydown = null;
}
}, [expanded]);
useEffect(() => {
// Do some basic validation
const basicValidSkeletons: ExcalidrawElementSkeleton[] = [];
for (const element of props.data) {
if (isValidExcalidrawElement(element as ExcalidrawElementSkeleton)) {
basicValidSkeletons.push(element as ExcalidrawElementSkeleton);
}
}
const validSkeletons: ExcalidrawElementSkeleton[] = [];
for (const element of basicValidSkeletons) {
if (element.type === "frame") {
continue;
}
if (element.type === "arrow") {
const start = basicValidSkeletons.find((child) => child.id === element.start?.id);
const end = basicValidSkeletons.find((child) => child.id === element.end?.id);
if (start && end) {
validSkeletons.push(element);
}
} else {
validSkeletons.push(element);
}
}
for (const element of basicValidSkeletons) {
if (element.type === "frame") {
const children = element.children?.map((childId) => {
return validSkeletons.find((child) => child.id === childId);
});
// Get the valid children, filter out any undefined values
const validChildrenIds: readonly string[] = children
?.map((child) => child?.id)
.filter((id) => id !== undefined) as string[];
if (validChildrenIds === undefined || validChildrenIds.length === 0) {
continue;
}
validSkeletons.push({
...element,
children: validChildrenIds,
});
}
}
const elements = convertToExcalidrawElements(validSkeletons);
setExcalidrawElements(elements);
}, []);
return (
<div className="relative">
<div
className={`${expanded ? "fixed inset-0 bg-black bg-opacity-50 backdrop-blur-sm z-50 flex items-center justify-center" : ""}`}
>
<Button
onClick={() => {
setExpanded(!expanded);
// Trigger a resize event to make Excalidraw adjust its size
window.dispatchEvent(new Event("resize"));
}}
variant={"outline"}
className={`${expanded ? "absolute top-2 left-2 z-[60]" : ""}`}
>
{expanded ? (
<ArrowsInSimple className="h-4 w-4" />
) : (
<ArrowsOutSimple className="h-4 w-4" />
)}
</Button>
<div
className={`
${expanded ? "w-[80vw] h-[80vh]" : "w-full h-[500px]"}
bg-white overflow-hidden rounded-lg relative
`}
>
<Excalidraw
initialData={{
elements: excalidrawElements,
appState: { zenModeEnabled: true },
scrollToContent: true,
}}
// TODO - Create a common function to detect if the theme is dark?
theme={localStorage.getItem("theme") === "dark" ? "dark" : "light"}
validateEmbeddable={true}
renderTopRightUI={(isMobile, appState) => {
return <></>;
}}
/>
</div>
</div>
</div>
);
}

View File

@@ -81,111 +81,3 @@ export function OrgMode({ className }: { className?: string }) {
</svg>
);
}
export function Markdown({ className }: { className?: string }) {
const classes = className ?? "w-6 h-6 text-muted-foreground inline-flex mr-1";
return (
<svg
className={`${classes}`}
xmlns="http://www.w3.org/2000/svg"
width="208"
height="128"
viewBox="0 0 208 128"
>
<rect
width="198"
height="118"
x="5"
y="5"
ry="10"
stroke="#000"
strokeWidth="10"
fill="none"
/>
<path d="M30 98V30h20l20 25 20-25h20v68H90V59L70 84 50 59v39zm125 0l-30-33h20V30h20v35h20z" />
</svg>
);
}
export function Pdf({ className }: { className?: string }) {
const classes = className ?? "w-6 h-6 text-muted-foreground inline-flex mr-1";
return (
<svg
className={`${classes}`}
xmlns="http://www.w3.org/2000/svg"
enableBackground="new 0 0 334.371 380.563"
version="1.1"
viewBox="0 0 14 16"
>
<g transform="matrix(.04589 0 0 .04589 -.66877 -.73379)">
<polygon
points="51.791 356.65 51.791 23.99 204.5 23.99 282.65 102.07 282.65 356.65"
fill="#fff"
strokeWidth="212.65"
/>
<path
d="m201.19 31.99 73.46 73.393v243.26h-214.86v-316.66h141.4m6.623-16h-164.02v348.66h246.85v-265.9z"
strokeWidth="21.791"
/>
</g>
<g transform="matrix(.04589 0 0 .04589 -.66877 -.73379)">
<polygon
points="282.65 356.65 51.791 356.65 51.791 23.99 204.5 23.99 206.31 25.8 206.31 100.33 280.9 100.33 282.65 102.07"
fill="#fff"
strokeWidth="212.65"
/>
<path
d="m198.31 31.99v76.337h76.337v240.32h-214.86v-316.66h138.52m9.5-16h-164.02v348.66h246.85v-265.9l-6.43-6.424h-69.907v-69.842z"
strokeWidth="21.791"
/>
</g>
<g transform="matrix(.04589 0 0 .04589 -.66877 -.73379)" strokeWidth="21.791">
<polygon points="258.31 87.75 219.64 87.75 219.64 48.667 258.31 86.38" />
<path d="m227.64 67.646 12.41 12.104h-12.41v-12.104m-5.002-27.229h-10.998v55.333h54.666v-12.742z" />
</g>
<g
transform="matrix(.04589 0 0 .04589 -.66877 -.73379)"
fill="#ed1c24"
strokeWidth="212.65"
>
<polygon points="311.89 284.49 22.544 284.49 22.544 167.68 37.291 152.94 37.291 171.49 297.15 171.49 297.15 152.94 311.89 167.68" />
<path d="m303.65 168.63 1.747 1.747v107.62h-276.35v-107.62l1.747-1.747v9.362h272.85v-9.362m-12.999-31.385v27.747h-246.86v-27.747l-27.747 27.747v126h302.35v-126z" />
</g>
<rect x="1.7219" y="7.9544" width="10.684" height="4.0307" fill="none" />
<g transform="matrix(.04589 0 0 .04589 1.7219 11.733)" fill="#fff" strokeWidth="21.791">
<path d="m9.216 0v-83.2h30.464q6.784 0 12.928 1.408 6.144 1.28 10.752 4.608 4.608 3.2 7.296 8.576 2.816 5.248 2.816 13.056 0 7.68-2.816 13.184-2.688 5.504-7.296 9.088-4.608 3.456-10.624 5.248-6.016 1.664-12.544 1.664h-8.96v26.368zm22.016-43.776h7.936q6.528 0 9.6-3.072 3.2-3.072 3.2-8.704t-3.456-7.936-9.856-2.304h-7.424z" />
<path d="m87.04 0v-83.2h24.576q9.472 0 17.28 2.304 7.936 2.304 13.568 7.296t8.704 12.8q3.2 7.808 3.2 18.816t-3.072 18.944-8.704 13.056q-5.504 5.12-13.184 7.552-7.552 2.432-16.512 2.432zm22.016-17.664h1.28q4.48 0 8.448-1.024 3.968-1.152 6.784-3.84 2.944-2.688 4.608-7.424t1.664-12.032-1.664-11.904-4.608-7.168q-2.816-2.56-6.784-3.456-3.968-1.024-8.448-1.024h-1.28z" />
<path d="m169.22 0v-83.2h54.272v18.432h-32.256v15.872h27.648v18.432h-27.648v30.464z" />
</g>
</svg>
);
}
export function Word({ className }: { className?: string }) {
const classes = className ?? "w-6 h-6 text-muted-foreground inline-flex mr-1";
return (
<svg
className={`${classes}`}
xmlns="http://www.w3.org/2000/svg"
fill="#FFF"
stroke-miterlimit="10"
strokeWidth="2"
viewBox="0 0 96 96"
>
<path
stroke="#979593"
d="M67.1716 7H27c-1.1046 0-2 .8954-2 2v78c0 1.1046.8954 2 2 2h58c1.1046 0 2-.8954 2-2V26.8284c0-.5304-.2107-1.0391-.5858-1.4142L68.5858 7.5858C68.2107 7.2107 67.702 7 67.1716 7z"
/>
<path fill="none" stroke="#979593" d="M67 7v18c0 1.1046.8954 2 2 2h18" />
<path
fill="#C8C6C4"
d="M79 61H48v-2h31c.5523 0 1 .4477 1 1s-.4477 1-1 1zm0-6H48v-2h31c.5523 0 1 .4477 1 1s-.4477 1-1 1zm0-6H48v-2h31c.5523 0 1 .4477 1 1s-.4477 1-1 1zm0-6H48v-2h31c.5523 0 1 .4477 1 1s-.4477 1-1 1zm0 24H48v-2h31c.5523 0 1 .4477 1 1s-.4477 1-1 1z"
/>
<path
fill="#185ABD"
d="M12 74h32c2.2091 0 4-1.7909 4-4V38c0-2.2091-1.7909-4-4-4H12c-2.2091 0-4 1.7909-4 4v32c0 2.2091 1.7909 4 4 4z"
/>
<path d="M21.6245 60.6455c.0661.522.109.9769.1296 1.3657h.0762c.0306-.3685.0889-.8129.1751-1.3349.0862-.5211.1703-.961.2517-1.319L25.7911 44h4.5702l3.6562 15.1272c.183.7468.3353 1.6973.457 2.8532h.0608c.0508-.7979.1777-1.7184.3809-2.7615L37.8413 44H42l-5.1183 22h-4.86l-3.4885-14.5744c-.1016-.4197-.2158-.9663-.3428-1.6417-.127-.6745-.2057-1.1656-.236-1.4724h-.0608c-.0407.358-.1195.8896-.2364 1.595-.1169.7062-.211 1.2273-.2819 1.565L24.1 66h-4.9357L14 44h4.2349l3.1843 15.3882c.0709.3165.1392.7362.2053 1.2573z" />
</svg>
);
}

View File

@@ -26,6 +26,17 @@ import { Moon, Sun, UserCircle, Question, GearFine, ArrowRight } from "@phosphor
import { KhojAgentLogo, KhojAutomationLogo, KhojSearchLogo } from "../logo/khojLogo";
import { useIsMobileWidth } from "@/app/common/utils";
function SubscriptionBadge({ is_active }: { is_active: boolean }) {
return (
<div className="flex flex-row items-center">
<div
className={`w-3 h-3 rounded-full ${is_active ? "bg-yellow-500" : "bg-muted"} mr-1`}
></div>
<p className="text-xs">{is_active ? "Futurist" : "Free"}</p>
</div>
);
}
export default function NavMenu() {
const userData = useAuthenticatedData();
const [darkMode, setDarkMode] = useState(false);
@@ -85,8 +96,9 @@ export default function NavMenu() {
</DropdownMenuTrigger>
<DropdownMenuContent className="gap-2">
<DropdownMenuItem className="w-full">
<div className="flex flex-rows">
<div className="flex flex-col">
<p className="font-semibold">{userData?.email}</p>
<SubscriptionBadge is_active={userData?.is_active ?? false} />
</div>
</DropdownMenuItem>
<DropdownMenuSeparator />
@@ -192,8 +204,9 @@ export default function NavMenu() {
</MenubarTrigger>
<MenubarContent align="end" className="rounded-xl gap-2">
<MenubarItem className="w-full">
<div className="flex flex-rows">
<div className="flex flex-col">
<p className="font-semibold">{userData?.email}</p>
<SubscriptionBadge is_active={userData?.is_active ?? false} />
</div>
</MenubarItem>
<MenubarSeparator className="dark:bg-white height-[2px] bg-black" />

View File

@@ -2,7 +2,7 @@
import { useEffect, useState } from "react";
import { ArrowRight } from "@phosphor-icons/react";
import { ArrowCircleDown, ArrowRight } from "@phosphor-icons/react";
import markdownIt from "markdown-it";
const md = new markdownIt({
@@ -11,7 +11,13 @@ const md = new markdownIt({
typographer: true,
});
import { Context, WebPage, OnlineContext } from "../chatMessage/chatMessage";
import {
Context,
WebPage,
OnlineContext,
CodeContext,
CodeContextFile,
} from "../chatMessage/chatMessage";
import { Card } from "@/components/ui/card";
import {
@@ -51,6 +57,7 @@ function NotesContextReferenceCard(props: NotesContextReferenceCardProps) {
props.title || ".txt",
"w-6 h-6 text-muted-foreground inline-flex mr-2",
);
const fileName = props.title.split("/").pop() || props.title;
const snippet = extractSnippet(props);
const [isHovering, setIsHovering] = useState(false);
@@ -61,30 +68,30 @@ function NotesContextReferenceCard(props: NotesContextReferenceCardProps) {
<Card
onMouseEnter={() => setIsHovering(true)}
onMouseLeave={() => setIsHovering(false)}
className={`${props.showFullContent ? "w-auto" : "w-[200px]"} overflow-hidden break-words text-balance rounded-lg p-2 bg-muted border-none`}
className={`${props.showFullContent ? "w-auto" : "w-[200px]"} overflow-hidden break-words text-balance rounded-lg border-none p-2 bg-muted`}
>
<h3
className={`${props.showFullContent ? "block" : "line-clamp-1"} text-muted-foreground}`}
>
{fileIcon}
{props.title}
{props.showFullContent ? props.title : fileName}
</h3>
<p
className={`${props.showFullContent ? "block" : "overflow-hidden line-clamp-2"}`}
className={`text-sm ${props.showFullContent ? "overflow-x-auto block" : "overflow-hidden line-clamp-2"}`}
dangerouslySetInnerHTML={{ __html: snippet }}
></p>
</Card>
</PopoverTrigger>
<PopoverContent className="w-[400px] mx-2">
<Card
className={`w-auto overflow-hidden break-words text-balance rounded-lg p-2 border-none`}
className={`w-auto overflow-hidden break-words text-balance rounded-lg border-none p-2`}
>
<h3 className={`line-clamp-2 text-muted-foreground}`}>
{fileIcon}
{props.title}
</h3>
<p
className={`overflow-hidden line-clamp-3`}
className={`border-t mt-1 pt-1 text-sm overflow-hidden line-clamp-5`}
dangerouslySetInnerHTML={{ __html: snippet }}
></p>
</Card>
@@ -94,9 +101,160 @@ function NotesContextReferenceCard(props: NotesContextReferenceCardProps) {
);
}
export interface ReferencePanelData {
notesReferenceCardData: NotesContextReferenceData[];
onlineReferenceCardData: OnlineReferenceData[];
interface CodeContextReferenceCardProps {
code: string;
output: string;
output_files: CodeContextFile[];
error: string;
showFullContent: boolean;
}
function CodeContextReferenceCard(props: CodeContextReferenceCardProps) {
const fileIcon = getIconFromFilename(".py", "!w-4 h-4 text-muted-foreground flex-shrink-0");
const sanitizedCodeSnippet = DOMPurify.sanitize(props.code);
const [isHovering, setIsHovering] = useState(false);
const [isDownloadHover, setIsDownloadHover] = useState(false);
const handleDownload = (file: CodeContextFile) => {
// Determine MIME type
let mimeType = "text/plain";
let byteString = file.b64_data;
if (file.filename.match(/\.(png|jpg|jpeg|webp)$/)) {
mimeType = `image/${file.filename.split(".").pop()}`;
byteString = atob(file.b64_data);
} else if (file.filename.endsWith(".json")) {
mimeType = "application/json";
} else if (file.filename.endsWith(".csv")) {
mimeType = "text/csv";
}
const arrayBuffer = new ArrayBuffer(byteString.length);
const bytes = new Uint8Array(arrayBuffer);
for (let i = 0; i < byteString.length; i++) {
bytes[i] = byteString.charCodeAt(i);
}
const blob = new Blob([arrayBuffer], { type: mimeType });
const url = URL.createObjectURL(blob);
const a = document.createElement("a");
a.href = url;
a.download = file.filename;
document.body.appendChild(a);
a.click();
document.body.removeChild(a);
URL.revokeObjectURL(url);
};
const renderOutputFiles = (files: CodeContextFile[], hoverCard: boolean) => {
if (files?.length == 0) return null;
return (
<div
className={`${hoverCard || props.showFullContent ? "border-t mt-1 pt-1" : undefined}`}
>
{files.slice(0, props.showFullContent ? undefined : 1).map((file, index) => {
return (
<div key={`${file.filename}-${index}`}>
<h4 className="text-sm text-muted-foreground flex items-center">
<span
className={`overflow-hidden mr-2 font-bold ${props.showFullContent ? undefined : "line-clamp-1"}`}
>
{file.filename}
</span>
<button
className={`${hoverCard ? "hidden" : undefined}`}
onClick={(e) => {
e.preventDefault();
handleDownload(file);
}}
onMouseEnter={() => setIsDownloadHover(true)}
onMouseLeave={() => setIsDownloadHover(false)}
title={`Download file: ${file.filename}`}
>
<ArrowCircleDown
className={`w-4 h-4`}
weight={isDownloadHover ? "fill" : "regular"}
/>
</button>
</h4>
{file.filename.match(/\.(txt|org|md|csv|json)$/) ? (
<pre
className={`${props.showFullContent ? "block" : "line-clamp-2"} text-sm mt-1 p-1 bg-background rounded overflow-x-auto`}
>
{file.b64_data}
</pre>
) : file.filename.match(/\.(png|jpg|jpeg|webp)$/) ? (
<img
src={`data:image/${file.filename.split(".").pop()};base64,${file.b64_data}`}
alt={file.filename}
className="mt-1 max-h-32 rounded"
/>
) : null}
</div>
);
})}
</div>
);
};
return (
<>
<Popover open={isHovering && !props.showFullContent} onOpenChange={setIsHovering}>
<PopoverTrigger asChild>
<Card
onMouseEnter={() => setIsHovering(true)}
onMouseLeave={() => setIsHovering(false)}
className={`${props.showFullContent ? "w-auto" : "w-[200px]"} overflow-hidden break-words text-balance rounded-lg border-none p-2 bg-muted`}
>
<div className="flex flex-col px-1">
<div className="flex items-center gap-2">
{fileIcon}
<h3
className={`overflow-hidden ${props.showFullContent ? "block" : "line-clamp-1"} text-muted-foreground flex-grow`}
>
code {props.output_files?.length > 0 ? "artifacts" : ""}
</h3>
</div>
<pre
className={`text-xs pb-2 ${props.showFullContent ? "block overflow-x-auto" : props.output_files?.length > 0 ? "hidden" : "overflow-hidden line-clamp-3"}`}
>
{sanitizedCodeSnippet}
</pre>
{props.output_files?.length > 0 &&
renderOutputFiles(props.output_files, false)}
</div>
</Card>
</PopoverTrigger>
<PopoverContent className="w-[400px] mx-2">
<Card
className={`w-auto overflow-hidden break-words text-balance rounded-lg border-none p-2`}
>
<div className="flex items-center gap-2">
{fileIcon}
<h3
className={`overflow-hidden ${props.showFullContent ? "block" : "line-clamp-1"} text-muted-foreground flex-grow`}
>
code {props.output_files?.length > 0 ? "artifact" : ""}
</h3>
</div>
{(props.output_files?.length > 0 &&
renderOutputFiles(props.output_files?.slice(0, 1), true)) || (
<pre className="text-xs border-t mt-1 pt-1 verflow-hidden line-clamp-10">
{sanitizedCodeSnippet}
</pre>
)}
</Card>
</PopoverContent>
</Popover>
</>
);
}
export interface CodeReferenceData {
code: string;
output: string;
output_files: CodeContextFile[];
error: string;
}
interface OnlineReferenceData {
@@ -141,21 +299,17 @@ function GenericOnlineReferenceCard(props: OnlineReferenceCardProps) {
<Card
onMouseEnter={handleMouseEnter}
onMouseLeave={handleMouseLeave}
className={`${props.showFullContent ? "w-auto" : "w-[200px]"} overflow-hidden break-words rounded-lg text-balance p-2 bg-muted border-none`}
className={`${props.showFullContent ? "w-auto" : "w-[200px]"} overflow-hidden break-words text-balance rounded-lg border-none p-2 bg-muted`}
>
<div className="flex flex-col">
<a
href={props.link}
target="_blank"
rel="noreferrer"
className="!no-underline p-2"
className="!no-underline px-1"
>
<div className="flex items-center gap-2">
<img
src={favicon}
alt=""
className="!w-4 h-4 mr-2 flex-shrink-0"
/>
<img src={favicon} alt="" className="!w-4 h-4 flex-shrink-0" />
<h3
className={`overflow-hidden ${props.showFullContent ? "block" : "line-clamp-1"} text-muted-foreground flex-grow`}
>
@@ -168,7 +322,7 @@ function GenericOnlineReferenceCard(props: OnlineReferenceCardProps) {
{props.title}
</h3>
<p
className={`overflow-hidden ${props.showFullContent ? "block" : "line-clamp-2"}`}
className={`overflow-hidden text-sm ${props.showFullContent ? "block" : "line-clamp-2"}`}
>
{props.description}
</p>
@@ -185,23 +339,23 @@ function GenericOnlineReferenceCard(props: OnlineReferenceCardProps) {
href={props.link}
target="_blank"
rel="noreferrer"
className="!no-underline p-2"
className="!no-underline px-1"
>
<div className="flex items-center">
<img src={favicon} alt="" className="!w-4 h-4 mr-2" />
<div className="flex items-center gap-2">
<img src={favicon} alt="" className="!w-4 h-4 flex-shrink-0" />
<h3
className={`overflow-hidden ${props.showFullContent ? "block" : "line-clamp-2"} text-muted-foreground`}
className={`overflow-hidden ${props.showFullContent ? "block" : "line-clamp-2"} text-muted-foreground flex-grow`}
>
{domain}
</h3>
</div>
<h3
className={`overflow-hidden ${props.showFullContent ? "block" : "line-clamp-2"} font-bold`}
className={`border-t mt-1 pt-1 overflow-hidden ${props.showFullContent ? "block" : "line-clamp-2"} font-bold`}
>
{props.title}
</h3>
<p
className={`overflow-hidden ${props.showFullContent ? "block" : "line-clamp-3"}`}
className={`overflow-hidden text-sm ${props.showFullContent ? "block" : "line-clamp-5"}`}
>
{props.description}
</p>
@@ -214,9 +368,28 @@ function GenericOnlineReferenceCard(props: OnlineReferenceCardProps) {
);
}
export function constructAllReferences(contextData: Context[], onlineData: OnlineContext) {
export function constructAllReferences(
contextData: Context[],
onlineData: OnlineContext,
codeContext: CodeContext,
) {
const onlineReferences: OnlineReferenceData[] = [];
const contextReferences: NotesContextReferenceData[] = [];
const codeReferences: CodeReferenceData[] = [];
if (codeContext) {
for (const [key, value] of Object.entries(codeContext)) {
if (!value.results) {
continue;
}
codeReferences.push({
code: value.code,
output: value.results.std_out,
output_files: value.results.output_files,
error: value.results.std_err,
});
}
}
if (onlineData) {
let localOnlineReferences = [];
@@ -298,12 +471,14 @@ export function constructAllReferences(contextData: Context[], onlineData: Onlin
return {
notesReferenceCardData: contextReferences,
onlineReferenceCardData: onlineReferences,
codeReferenceCardData: codeReferences,
};
}
export interface TeaserReferenceSectionProps {
notesReferenceCardData: NotesContextReferenceData[];
onlineReferenceCardData: OnlineReferenceData[];
codeReferenceCardData: CodeReferenceData[];
isMobileWidth: boolean;
}
@@ -314,17 +489,28 @@ export function TeaserReferencesSection(props: TeaserReferenceSectionProps) {
setNumTeaserSlots(props.isMobileWidth ? 1 : 3);
}, [props.isMobileWidth]);
const notesDataToShow = props.notesReferenceCardData.slice(0, numTeaserSlots);
const codeDataToShow = props.codeReferenceCardData.slice(0, numTeaserSlots);
const notesDataToShow = props.notesReferenceCardData.slice(
0,
numTeaserSlots - codeDataToShow.length,
);
const onlineDataToShow =
notesDataToShow.length < numTeaserSlots
? props.onlineReferenceCardData.slice(0, numTeaserSlots - notesDataToShow.length)
notesDataToShow.length + codeDataToShow.length < numTeaserSlots
? props.onlineReferenceCardData.slice(
0,
numTeaserSlots - codeDataToShow.length - notesDataToShow.length,
)
: [];
const shouldShowShowMoreButton =
props.notesReferenceCardData.length > 0 || props.onlineReferenceCardData.length > 0;
props.notesReferenceCardData.length > 0 ||
props.codeReferenceCardData.length > 0 ||
props.onlineReferenceCardData.length > 0;
const numReferences =
props.notesReferenceCardData.length + props.onlineReferenceCardData.length;
props.notesReferenceCardData.length +
props.codeReferenceCardData.length +
props.onlineReferenceCardData.length;
if (numReferences === 0) {
return null;
@@ -337,6 +523,15 @@ export function TeaserReferencesSection(props: TeaserReferenceSectionProps) {
<p className="text-gray-400 m-2">{numReferences} sources</p>
</h3>
<div className={`flex flex-wrap gap-2 w-auto mt-2`}>
{codeDataToShow.map((code, index) => {
return (
<CodeContextReferenceCard
showFullContent={false}
{...code}
key={`code-${index}`}
/>
);
})}
{notesDataToShow.map((note, index) => {
return (
<NotesContextReferenceCard
@@ -359,6 +554,7 @@ export function TeaserReferencesSection(props: TeaserReferenceSectionProps) {
<ReferencePanel
notesReferenceCardData={props.notesReferenceCardData}
onlineReferenceCardData={props.onlineReferenceCardData}
codeReferenceCardData={props.codeReferenceCardData}
/>
)}
</div>
@@ -369,6 +565,7 @@ export function TeaserReferencesSection(props: TeaserReferenceSectionProps) {
interface ReferencePanelDataProps {
notesReferenceCardData: NotesContextReferenceData[];
onlineReferenceCardData: OnlineReferenceData[];
codeReferenceCardData: CodeReferenceData[];
}
export default function ReferencePanel(props: ReferencePanelDataProps) {
@@ -388,6 +585,15 @@ export default function ReferencePanel(props: ReferencePanelDataProps) {
<SheetDescription>View all references for this response</SheetDescription>
</SheetHeader>
<div className="flex flex-wrap gap-2 w-auto mt-2">
{props.codeReferenceCardData.map((code, index) => {
return (
<CodeContextReferenceCard
showFullContent={true}
{...code}
key={`code-${index}`}
/>
);
})}
{props.notesReferenceCardData.map((note, index) => {
return (
<NotesContextReferenceCard

View File

@@ -14,25 +14,21 @@ interface SuggestionCardProps {
export default function SuggestionCard(data: SuggestionCardProps) {
const bgColors = converColorToBgGradient(data.color);
const cardClassName = `${styles.card} ${bgColors} md:w-full md:h-fit sm:w-full h-fit md:w-[200px] md:h-[200px] cursor-pointer`;
const titleClassName = `${styles.title} pt-2 dark:text-white dark:font-bold`;
const cardClassName = `${styles.card} ${bgColors} md:w-full md:h-fit sm:w-full h-fit md:w-[200px] md:h-[180px] cursor-pointer md:p-2`;
const descriptionClassName = `${styles.text} dark:text-white`;
const cardContent = (
<Card className={cardClassName}>
<CardHeader className="m-0 p-2 pb-1 relative">
<div className="flex flex-row md:flex-col">
<div className="flex w-full">
<CardContent className="m-0 p-2 w-full">
{convertSuggestionTitleToIconClass(data.title, data.color.toLowerCase())}
<CardTitle className={titleClassName}>{data.title}</CardTitle>
</div>
</CardHeader>
<CardContent className="m-0 p-2 pr-4 pt-1">
<CardDescription
className={`${descriptionClassName} sm:line-clamp-2 md:line-clamp-4`}
>
{data.body}
</CardDescription>
</CardContent>
<CardDescription
className={`${descriptionClassName} sm:line-clamp-2 md:line-clamp-4 pt-1 break-words whitespace-pre-wrap max-w-full`}
>
{data.body}
</CardDescription>
</CardContent>
</div>
</Card>
);

View File

@@ -1,12 +1,9 @@
.card {
padding: 0.5rem;
margin: 0.05rem;
border-radius: 0.5rem;
}
.title {
font-size: 1.0rem;
font-size: 1rem;
}
.text {

View File

@@ -47,24 +47,24 @@ const DEFAULT_COLOR = "orange";
export function convertSuggestionTitleToIconClass(title: string, color: string) {
if (title === SuggestionType.Automation)
return getIconFromIconName("Robot", color, "w-8", "h-8");
if (title === SuggestionType.Paint) return getIconFromIconName("Palette", color, "w-8", "h-8");
return getIconFromIconName("Robot", color, "w-6", "h-6");
if (title === SuggestionType.Paint) return getIconFromIconName("Palette", color, "w-6", "h-6");
if (title === SuggestionType.PopCulture)
return getIconFromIconName("Confetti", color, "w-8", "h-8");
if (title === SuggestionType.Travel) return getIconFromIconName("Jeep", color, "w-8", "h-8");
if (title === SuggestionType.Learning) return getIconFromIconName("Book", color, "w-8", "h-8");
return getIconFromIconName("Confetti", color, "w-6", "h-6");
if (title === SuggestionType.Travel) return getIconFromIconName("Jeep", color, "w-6", "h-6");
if (title === SuggestionType.Learning) return getIconFromIconName("Book", color, "w-6", "h-6");
if (title === SuggestionType.Health)
return getIconFromIconName("Asclepius", color, "w-8", "h-8");
if (title === SuggestionType.Fun) return getIconFromIconName("Island", color, "w-8", "h-8");
if (title === SuggestionType.Home) return getIconFromIconName("House", color, "w-8", "h-8");
return getIconFromIconName("Asclepius", color, "w-6", "h-6");
if (title === SuggestionType.Fun) return getIconFromIconName("Island", color, "w-6", "h-6");
if (title === SuggestionType.Home) return getIconFromIconName("House", color, "w-6", "h-6");
if (title === SuggestionType.Language)
return getIconFromIconName("Translate", color, "w-8", "h-8");
if (title === SuggestionType.Code) return getIconFromIconName("Code", color, "w-8", "h-8");
if (title === SuggestionType.Food) return getIconFromIconName("BowlFood", color, "w-8", "h-8");
return getIconFromIconName("Translate", color, "w-6", "h-6");
if (title === SuggestionType.Code) return getIconFromIconName("Code", color, "w-6", "h-6");
if (title === SuggestionType.Food) return getIconFromIconName("BowlFood", color, "w-6", "h-6");
if (title === SuggestionType.Interviewing)
return getIconFromIconName("Lectern", color, "w-8", "h-8");
if (title === SuggestionType.Finance) return getIconFromIconName("Wallet", color, "w-8", "h-8");
else return getIconFromIconName("Lightbulb", color, "w-8", "h-8");
return getIconFromIconName("Lectern", color, "w-6", "h-6");
if (title === SuggestionType.Finance) return getIconFromIconName("Wallet", color, "w-6", "h-6");
else return getIconFromIconName("Lightbulb", color, "w-6", "h-6");
}
export const suggestionsData: Suggestion[] = [

View File

@@ -1,172 +0,0 @@
input.factVerification {
width: 100%;
display: block;
padding: 12px 20px;
margin: 8px 0;
border: none;
box-sizing: border-box;
border-radius: 4px;
text-align: left;
margin: auto;
margin-top: 8px;
margin-bottom: 8px;
font-size: large;
}
div.factCheckerContainer {
width: 75vw;
margin: auto;
}
input.factVerification:focus {
outline: none;
box-shadow: 0 4px 8px 0 rgba(0, 0, 0, 0.2);
}
div.responseText {
margin: 0;
padding: 0;
border-radius: 8px;
}
div.response {
margin-bottom: 12px;
}
a.titleLink {
color: #333;
font-weight: bold;
}
a.subLinks {
color: #333;
text-decoration: none;
font-weight: small;
border-radius: 4px;
font-size: small;
}
div.subLinks {
display: flex;
flex-direction: row;
gap: 8px;
flex-wrap: wrap;
}
div.reference {
padding: 12px;
margin: 8px;
border-radius: 8px;
}
footer.footer {
width: 100%;
background: transparent;
text-align: left;
}
div.reportActions {
display: flex;
flex-direction: row;
gap: 8px;
justify-content: space-between;
margin-top: 8px;
}
button.factCheckButton {
border: none;
cursor: pointer;
width: 100%;
border-radius: 4px;
margin: 8px;
padding-left: 1rem;
padding-right: 1rem;
line-height: 1.25rem;
}
button.factCheckButton:hover {
box-shadow: 0 4px 8px 0 rgba(0, 0, 0, 0.2);
}
div.spinner {
margin: 20px;
width: 40px;
height: 40px;
position: relative;
text-align: center;
-webkit-animation: sk-rotate 2.0s infinite linear;
animation: sk-rotate 2.0s infinite linear;
}
div.inputFields {
width: 100%;
display: grid;
grid-template-columns: 1fr auto;
grid-gap: 8px;
}
/* Loading Animation */
div.dot1,
div.dot2 {
width: 60%;
height: 60%;
display: inline-block;
position: absolute;
top: 0;
border-radius: 100%;
-webkit-animation: sk-bounce 2.0s infinite ease-in-out;
animation: sk-bounce 2.0s infinite ease-in-out;
}
div.dot2 {
top: auto;
bottom: 0;
-webkit-animation-delay: -1.0s;
animation-delay: -1.0s;
}
@media screen and (max-width: 768px) {
div.factCheckerContainer {
width: 95vw;
}
}
@-webkit-keyframes sk-rotate {
100% {
-webkit-transform: rotate(360deg)
}
}
@keyframes sk-rotate {
100% {
transform: rotate(360deg);
-webkit-transform: rotate(360deg)
}
}
@-webkit-keyframes sk-bounce {
0%,
100% {
-webkit-transform: scale(0.0)
}
50% {
-webkit-transform: scale(1.0)
}
}
@keyframes sk-bounce {
0%,
100% {
transform: scale(0.0);
-webkit-transform: scale(0.0);
}
50% {
transform: scale(1.0);
-webkit-transform: scale(1.0);
}
}

View File

@@ -1,33 +0,0 @@
import type { Metadata } from "next";
export const metadata: Metadata = {
title: "Khoj AI - Fact Checker",
description:
"Use the Fact Checker with Khoj AI for verifying statements. It can research the internet for you, either refuting or confirming the statement using fresh data.",
icons: {
icon: "/static/assets/icons/khoj_lantern.ico",
apple: "/static/assets/icons/khoj_lantern_256x256.png",
},
openGraph: {
siteName: "Khoj AI",
title: "Khoj AI - Fact Checker",
description: "Your Second Brain.",
url: "https://app.khoj.dev/factchecker",
type: "website",
images: [
{
url: "https://assets.khoj.dev/khoj_lantern_256x256.png",
width: 256,
height: 256,
},
],
},
};
export default function RootLayout({
children,
}: Readonly<{
children: React.ReactNode;
}>) {
return <div>{children}</div>;
}

View File

@@ -1,664 +0,0 @@
"use client";
import styles from "./factChecker.module.css";
import { useAuthenticatedData } from "@/app/common/auth";
import { useState, useEffect } from "react";
import ChatMessage, {
Context,
OnlineContext,
OnlineContextData,
WebPage,
} from "../components/chatMessage/chatMessage";
import { ModelPicker, Model } from "../components/modelPicker/modelPicker";
import ShareLink from "../components/shareLink/shareLink";
import { Input } from "@/components/ui/input";
import { Button } from "@/components/ui/button";
import { Card, CardContent, CardFooter, CardHeader, CardTitle } from "@/components/ui/card";
import Link from "next/link";
import SidePanel from "../components/sidePanel/chatHistorySidePanel";
import { useIsMobileWidth } from "../common/utils";
const chatURL = "/api/chat";
const verificationPrecursor =
"Limit your search to reputable sources. Search the internet for relevant supporting or refuting information. Do not reference my notes. Refuse to answer any queries that are not falsifiable by informing me that you will not answer the question. You're not permitted to ask follow-up questions, so do the best with what you have. Respond with **TRUE** or **FALSE** or **INCONCLUSIVE**, then provide your justification. Fact Check:";
const LoadingSpinner = () => (
<div className={styles.loading}>
<div className={styles.loadingVerification}>
Researching...
<div className={styles.spinner}>
<div className={`${styles.dot1} bg-blue-300`}></div>
<div className={`${styles.dot2} bg-blue-300`}></div>
</div>
</div>
</div>
);
interface SupplementReferences {
additionalLink: string;
response: string;
linkTitle: string;
}
interface ResponseWithReferences {
context?: Context[];
online?: OnlineContext;
response?: string;
}
function handleCompiledReferences(chunk: string, currentResponse: string) {
const rawReference = chunk.split("### compiled references:")[1];
const rawResponse = chunk.split("### compiled references:")[0];
let references: ResponseWithReferences = {};
// Set the initial response
references.response = currentResponse + rawResponse;
const rawReferenceAsJson = JSON.parse(rawReference);
if (rawReferenceAsJson instanceof Array) {
references.context = rawReferenceAsJson;
} else if (typeof rawReferenceAsJson === "object" && rawReferenceAsJson !== null) {
references.online = rawReferenceAsJson;
}
return references;
}
async function verifyStatement(
message: string,
conversationId: string,
setIsLoading: (loading: boolean) => void,
setInitialResponse: (response: string) => void,
setInitialReferences: (references: ResponseWithReferences) => void,
) {
setIsLoading(true);
// Construct the verification payload
let verificationMessage = `${verificationPrecursor} ${message}`;
const apiURL = `${chatURL}?client=web`;
const requestBody = {
q: verificationMessage,
conversation_id: conversationId,
stream: true,
};
try {
// Send a message to the chat server to verify the fact
const response = await fetch(apiURL, {
method: "POST",
headers: {
"Content-Type": "application/json",
},
body: JSON.stringify(requestBody),
});
if (!response.body) throw new Error("No response body found");
const reader = response.body?.getReader();
let decoder = new TextDecoder();
let result = "";
while (true) {
const { done, value } = await reader.read();
if (done) break;
let chunk = decoder.decode(value, { stream: true });
if (chunk.includes("### compiled references:")) {
const references = handleCompiledReferences(chunk, result);
if (references.response) {
result = references.response;
setInitialResponse(references.response);
setInitialReferences(references);
}
} else {
result += chunk;
setInitialResponse(result);
}
}
} catch (error) {
console.error("Error verifying statement: ", error);
} finally {
setIsLoading(false);
}
}
async function spawnNewConversation(setConversationID: (conversationID: string) => void) {
let createURL = `/api/chat/sessions?client=web`;
const response = await fetch(createURL, { method: "POST" });
const data = await response.json();
setConversationID(data.conversation_id);
}
interface ReferenceVerificationProps {
message: string;
additionalLink: string;
conversationId: string;
linkTitle: string;
setChildReferencesCallback: (
additionalLink: string,
response: string,
linkTitle: string,
) => void;
prefilledResponse?: string;
}
function ReferenceVerification(props: ReferenceVerificationProps) {
const [initialResponse, setInitialResponse] = useState("");
const [isLoading, setIsLoading] = useState(true);
const verificationStatement = `${props.message}. Use this link for reference: ${props.additionalLink}`;
const isMobileWidth = useIsMobileWidth();
useEffect(() => {
if (props.prefilledResponse) {
setInitialResponse(props.prefilledResponse);
setIsLoading(false);
} else {
verifyStatement(
verificationStatement,
props.conversationId,
setIsLoading,
setInitialResponse,
() => {},
);
}
}, [verificationStatement, props.conversationId, props.prefilledResponse]);
useEffect(() => {
if (initialResponse === "") return;
if (props.prefilledResponse) return;
if (!isLoading) {
// Only set the child references when it's done loading and if the initial response is not prefilled (i.e. it was fetched from the server)
props.setChildReferencesCallback(
props.additionalLink,
initialResponse,
props.linkTitle,
);
}
}, [initialResponse, isLoading, props]);
return (
<div>
{isLoading && <LoadingSpinner />}
<ChatMessage
chatMessage={{
automationId: "",
by: "AI",
message: initialResponse,
context: [],
created: new Date().toISOString(),
onlineContext: {},
}}
isMobileWidth={isMobileWidth}
/>
</div>
);
}
interface SupplementalReferenceProps {
onlineData?: OnlineContextData;
officialFactToVerify: string;
conversationId: string;
additionalLink: string;
setChildReferencesCallback: (
additionalLink: string,
response: string,
linkTitle: string,
) => void;
prefilledResponse?: string;
linkTitle?: string;
}
function SupplementalReference(props: SupplementalReferenceProps) {
const linkTitle = props.linkTitle || props.onlineData?.organic?.[0]?.title || "Reference";
const linkAsWebpage = { link: props.additionalLink } as WebPage;
return (
<Card className={`mt-2 mb-4`}>
<CardHeader>
<a
className={styles.titleLink}
href={props.additionalLink}
target="_blank"
rel="noreferrer"
>
{linkTitle}
</a>
<WebPageLink {...linkAsWebpage} />
</CardHeader>
<CardContent>
<ReferenceVerification
additionalLink={props.additionalLink}
message={props.officialFactToVerify}
linkTitle={linkTitle}
conversationId={props.conversationId}
setChildReferencesCallback={props.setChildReferencesCallback}
prefilledResponse={props.prefilledResponse}
/>
</CardContent>
</Card>
);
}
const WebPageLink = (webpage: WebPage) => {
const webpageDomain = new URL(webpage.link).hostname;
return (
<div className={styles.subLinks}>
<a
className={`${styles.subLinks} bg-blue-200 px-2`}
href={webpage.link}
target="_blank"
rel="noreferrer"
>
{webpageDomain}
</a>
</div>
);
};
export default function FactChecker() {
const [factToVerify, setFactToVerify] = useState("");
const [officialFactToVerify, setOfficialFactToVerify] = useState("");
const [isLoading, setIsLoading] = useState(false);
const [initialResponse, setInitialResponse] = useState("");
const [clickedVerify, setClickedVerify] = useState(false);
const [initialReferences, setInitialReferences] = useState<ResponseWithReferences>();
const [childReferences, setChildReferences] = useState<SupplementReferences[]>();
const [modelUsed, setModelUsed] = useState<Model>();
const isMobileWidth = useIsMobileWidth();
const [conversationID, setConversationID] = useState("");
const [runId, setRunId] = useState("");
const [loadedFromStorage, setLoadedFromStorage] = useState(false);
const [initialModel, setInitialModel] = useState<Model>();
function setChildReferencesCallback(
additionalLink: string,
response: string,
linkTitle: string,
) {
const newReferences = childReferences || [];
const exists = newReferences.find(
(reference) => reference.additionalLink === additionalLink,
);
if (exists) return;
newReferences.push({ additionalLink, response, linkTitle });
setChildReferences(newReferences);
}
let userData = useAuthenticatedData();
function storeData() {
const data = {
factToVerify,
response: initialResponse,
references: initialReferences,
childReferences,
runId,
modelUsed,
};
fetch(`/api/chat/store/factchecker`, {
method: "POST",
headers: {
"Content-Type": "application/json",
},
body: JSON.stringify({
runId: runId,
storeData: data,
}),
});
}
useEffect(() => {
if (factToVerify) {
document.title = `AI Fact Check: ${factToVerify}`;
} else {
document.title = "AI Fact Checker";
}
}, [factToVerify]);
useEffect(() => {
const storedFact = localStorage.getItem("factToVerify");
if (storedFact) {
setFactToVerify(storedFact);
}
// Get query params from the URL
const urlParams = new URLSearchParams(window.location.search);
const factToVerifyParam = urlParams.get("factToVerify");
if (factToVerifyParam) {
setFactToVerify(factToVerifyParam);
}
const runIdParam = urlParams.get("runId");
if (runIdParam) {
setRunId(runIdParam);
// Define an async function to fetch data
const fetchData = async () => {
const storedDataURL = `/api/chat/store/factchecker?runId=${runIdParam}`;
try {
const response = await fetch(storedDataURL);
if (response.status !== 200) {
throw new Error("Failed to fetch stored data");
}
const storedData = JSON.parse(await response.json());
if (storedData) {
setOfficialFactToVerify(storedData.factToVerify);
setInitialResponse(storedData.response);
setInitialReferences(storedData.references);
setChildReferences(storedData.childReferences);
setInitialModel(storedData.modelUsed);
}
setLoadedFromStorage(true);
} catch (error) {
console.error("Error fetching stored data: ", error);
}
};
// Call the async function
fetchData();
}
}, []);
function onClickVerify() {
if (clickedVerify) return;
// Perform validation checks on the fact to verify
if (!factToVerify) {
alert("Please enter a fact to verify.");
return;
}
setClickedVerify(true);
if (!userData) {
let currentURL = window.location.href;
window.location.href = `/login?next=${currentURL}`;
}
setInitialReferences(undefined);
setInitialResponse("");
spawnNewConversation(setConversationID);
// Set the runId to a random 12-digit alphanumeric string
const newRunId = [...Array(16)].map(() => Math.random().toString(36)[2]).join("");
setRunId(newRunId);
window.history.pushState(
{},
document.title,
window.location.pathname + `?runId=${newRunId}`,
);
setOfficialFactToVerify(factToVerify);
setClickedVerify(false);
}
useEffect(() => {
if (!conversationID) return;
verifyStatement(
officialFactToVerify,
conversationID,
setIsLoading,
setInitialResponse,
setInitialReferences,
);
}, [conversationID, officialFactToVerify]);
// Store factToVerify in localStorage whenever it changes
useEffect(() => {
localStorage.setItem("factToVerify", factToVerify);
}, [factToVerify]);
// Update the meta tags for the description and og:description
useEffect(() => {
let metaTag = document.querySelector('meta[name="description"]');
if (metaTag) {
metaTag.setAttribute("content", initialResponse);
}
let metaOgTag = document.querySelector('meta[property="og:description"]');
if (!metaOgTag) {
metaOgTag = document.createElement("meta");
metaOgTag.setAttribute("property", "og:description");
document.getElementsByTagName("head")[0].appendChild(metaOgTag);
}
metaOgTag.setAttribute("content", initialResponse);
}, [initialResponse]);
const renderReferences = (
conversationId: string,
initialReferences: ResponseWithReferences,
officialFactToVerify: string,
loadedFromStorage: boolean,
childReferences?: SupplementReferences[],
) => {
if (loadedFromStorage && childReferences) {
return renderSupplementalReferences(childReferences);
}
const seenLinks = new Set();
// Any links that are present in webpages should not be searched again
Object.entries(initialReferences.online || {}).map(([key, onlineData], index) => {
const webpages = onlineData?.webpages || [];
// Webpage can be a list or a single object
if (webpages instanceof Array) {
for (let i = 0; i < webpages.length; i++) {
const webpage = webpages[i];
const additionalLink = webpage.link || "";
if (seenLinks.has(additionalLink)) {
return null;
}
seenLinks.add(additionalLink);
}
} else {
let singleWebpage = webpages as WebPage;
const additionalLink = singleWebpage.link || "";
if (seenLinks.has(additionalLink)) {
return null;
}
seenLinks.add(additionalLink);
}
});
return Object.entries(initialReferences.online || {})
.map(([key, onlineData], index) => {
let additionalLink = "";
// Loop through organic links until we find one that hasn't been searched
for (let i = 0; i < onlineData?.organic?.length; i++) {
const webpage = onlineData?.organic?.[i];
additionalLink = webpage.link || "";
if (!seenLinks.has(additionalLink)) {
break;
}
}
seenLinks.add(additionalLink);
if (additionalLink === "") return null;
return (
<SupplementalReference
key={index}
onlineData={onlineData}
officialFactToVerify={officialFactToVerify}
conversationId={conversationId}
additionalLink={additionalLink}
setChildReferencesCallback={setChildReferencesCallback}
/>
);
})
.filter(Boolean);
};
const renderSupplementalReferences = (references: SupplementReferences[]) => {
return references.map((reference, index) => {
return (
<SupplementalReference
key={index}
additionalLink={reference.additionalLink}
officialFactToVerify={officialFactToVerify}
conversationId={conversationID}
linkTitle={reference.linkTitle}
setChildReferencesCallback={setChildReferencesCallback}
prefilledResponse={reference.response}
/>
);
});
};
const renderWebpages = (webpages: WebPage[] | WebPage) => {
if (webpages instanceof Array) {
return webpages.map((webpage, index) => {
return WebPageLink(webpage);
});
} else {
return WebPageLink(webpages);
}
};
function constructShareUrl() {
const url = new URL(window.location.href);
url.searchParams.set("runId", runId);
return url.href;
}
return (
<>
<div className="relative md:fixed h-full">
<SidePanel conversationId={null} uploadedFiles={[]} isMobileWidth={isMobileWidth} />
</div>
<div className={styles.factCheckerContainer}>
<h1
className={`${styles.response} pt-8 md:pt-4 font-large outline-slate-800 dark:outline-slate-200`}
>
AI Fact Checker
</h1>
<footer className={`${styles.footer} mt-4`}>
This is an experimental AI tool. It may make mistakes.
</footer>
{initialResponse && initialReferences && childReferences ? (
<div className={styles.reportActions}>
<Button asChild variant="secondary">
<Link href="/factchecker" target="_blank" rel="noopener noreferrer">
Try Another
</Link>
</Button>
<ShareLink
buttonTitle="Share report"
title="AI Fact Checking Report"
description="Share this fact checking report with others. Anyone who has this link will be able to view the report."
url={constructShareUrl()}
onShare={loadedFromStorage ? () => {} : storeData}
/>
</div>
) : (
<div className={styles.newReportActions}>
<div className={`${styles.inputFields} mt-4`}>
<Input
type="text"
maxLength={200}
placeholder="Enter a falsifiable statement to verify"
disabled={isLoading}
onChange={(e) => setFactToVerify(e.target.value)}
value={factToVerify}
onKeyDown={(e) => {
if (e.key === "Enter") {
onClickVerify();
}
}}
onFocus={(e) => (e.target.placeholder = "")}
onBlur={(e) =>
(e.target.placeholder =
"Enter a falsifiable statement to verify")
}
/>
<Button disabled={clickedVerify} onClick={() => onClickVerify()}>
Verify
</Button>
</div>
<h3 className={`mt-4 mb-4`}>
Try with a particular model. You must be{" "}
<a
href="/settings"
className="font-medium text-blue-600 dark:text-blue-500 hover:underline"
>
subscribed
</a>{" "}
to configure the model.
</h3>
</div>
)}
<ModelPicker
disabled={isLoading || loadedFromStorage}
setModelUsed={setModelUsed}
initialModel={initialModel}
/>
{isLoading && (
<div className={styles.loading}>
<LoadingSpinner />
</div>
)}
{initialResponse && (
<Card className={`mt-4`}>
<CardHeader>
<CardTitle>{officialFactToVerify}</CardTitle>
</CardHeader>
<CardContent>
<div className={styles.responseText}>
<ChatMessage
chatMessage={{
automationId: "",
by: "AI",
message: initialResponse,
context: [],
created: new Date().toISOString(),
onlineContext: {},
}}
isMobileWidth={isMobileWidth}
/>
</div>
</CardContent>
<CardFooter>
{initialReferences &&
initialReferences.online &&
Object.keys(initialReferences.online).length > 0 && (
<div className={styles.subLinks}>
{Object.entries(initialReferences.online).map(
([key, onlineData], index) => {
const webpages = onlineData?.webpages || [];
return renderWebpages(webpages);
},
)}
</div>
)}
</CardFooter>
</Card>
)}
{initialReferences && (
<div className={styles.referenceContainer}>
<h2 className="mt-4 mb-4">Supplements</h2>
<div className={styles.references}>
{initialReferences.online !== undefined &&
renderReferences(
conversationID,
initialReferences,
officialFactToVerify,
loadedFromStorage,
childReferences,
)}
</div>
</div>
)}
</div>
</>
);
}

View File

@@ -14,7 +14,7 @@ export const metadata: Metadata = {
manifest: "/static/khoj.webmanifest",
openGraph: {
siteName: "Khoj AI",
title: "Khoj AI - Home",
title: "Khoj AI",
description: "Your Second Brain.",
url: "https://app.khoj.dev",
type: "website",

View File

@@ -3,31 +3,42 @@ import "./globals.css";
import styles from "./page.module.css";
import "katex/dist/katex.min.css";
import React, { useEffect, useState } from "react";
import React, { useEffect, useRef, useState } from "react";
import useSWR from "swr";
import Image from "next/image";
import { ArrowCounterClockwise } from "@phosphor-icons/react";
import { Card, CardTitle } from "@/components/ui/card";
import SuggestionCard from "@/app/components/suggestions/suggestionCard";
import SidePanel from "@/app/components/sidePanel/chatHistorySidePanel";
import Loading from "@/app/components/loading/loading";
import ChatInputArea, { ChatOptions } from "@/app/components/chatInputArea/chatInputArea";
import {
AttachedFileText,
ChatInputArea,
ChatOptions,
} from "@/app/components/chatInputArea/chatInputArea";
import { Suggestion, suggestionsData } from "@/app/components/suggestions/suggestionsData";
import LoginPrompt from "@/app/components/loginPrompt/loginPrompt";
import { useAuthenticatedData, UserConfig, useUserConfig } from "@/app/common/auth";
import {
isUserSubscribed,
useAuthenticatedData,
UserConfig,
useUserConfig,
} from "@/app/common/auth";
import { convertColorToBorderClass } from "@/app/common/colorUtils";
import { getIconFromIconName } from "@/app/common/iconUtils";
import { AgentData } from "@/app/agents/page";
import { createNewConversation } from "./common/chatFunctions";
import { useIsMobileWidth } from "./common/utils";
import { useSearchParams } from "next/navigation";
import { useDebounce, useIsMobileWidth } from "./common/utils";
import { useRouter, useSearchParams } from "next/navigation";
import { ScrollArea, ScrollBar } from "@/components/ui/scroll-area";
import { AgentCard } from "@/app/components/agentCard/agentCard";
import { Popover, PopoverContent, PopoverTrigger } from "@/components/ui/popover";
interface ChatBodyDataProps {
chatOptionsData: ChatOptions | null;
onConversationIdChange?: (conversationId: string) => void;
setUploadedFiles: (files: string[]) => void;
setUploadedFiles: (files: AttachedFileText[]) => void;
isMobileWidth?: boolean;
isLoggedIn: boolean;
userConfig: UserConfig | null;
@@ -44,14 +55,19 @@ function FisherYatesShuffle(array: any[]) {
function ChatBodyData(props: ChatBodyDataProps) {
const [message, setMessage] = useState("");
const [image, setImage] = useState<string | null>(null);
const [images, setImages] = useState<string[]>([]);
const [processingMessage, setProcessingMessage] = useState(false);
const [greeting, setGreeting] = useState("");
const [shuffledOptions, setShuffledOptions] = useState<Suggestion[]>([]);
const [hoveredAgent, setHoveredAgent] = useState<string | null>(null);
const debouncedHoveredAgent = useDebounce(hoveredAgent, 500);
const [isPopoverOpen, setIsPopoverOpen] = useState(false);
const [selectedAgent, setSelectedAgent] = useState<string | null>("khoj");
const [agentIcons, setAgentIcons] = useState<JSX.Element[]>([]);
const [agents, setAgents] = useState<AgentData[]>([]);
const chatInputRef = useRef<HTMLTextAreaElement>(null);
const [showLoginPrompt, setShowLoginPrompt] = useState(false);
const router = useRouter();
const searchParams = useSearchParams();
const queryParam = searchParams.get("q");
@@ -61,6 +77,12 @@ function ChatBodyData(props: ChatBodyDataProps) {
}
}, [queryParam]);
useEffect(() => {
if (debouncedHoveredAgent) {
setIsPopoverOpen(true);
}
}, [debouncedHoveredAgent]);
const onConversationIdChange = props.onConversationIdChange;
const agentsFetcher = () =>
@@ -72,6 +94,10 @@ function ChatBodyData(props: ChatBodyDataProps) {
revalidateOnFocus: false,
});
const openAgentEditCard = (agentSlug: string) => {
router.push(`/agents?agent=${agentSlug}`);
};
function shuffleAndSetOptions() {
const shuffled = FisherYatesShuffle(suggestionsData);
setShuffledOptions(shuffled.slice(0, 3));
@@ -94,8 +120,8 @@ function ChatBodyData(props: ChatBodyDataProps) {
`What would you like to get done${nameSuffix}?`,
`Hey${nameSuffix}! How can I help?`,
`Good ${timeOfDay}${nameSuffix}! What's on your mind?`,
`Ready to breeze through your ${["Sunday", "Monday", "Tuesday", "Wednesday", "Thursday", "Friday", "Saturday"][day]}?`,
`Want help navigating your ${["Sunday", "Monday", "Tuesday", "Wednesday", "Thursday", "Friday", "Saturday"][day]} workload?`,
`Ready to breeze through ${["Sunday", "Monday", "Tuesday", "Wednesday", "Thursday", "Friday", "Saturday"][day]}?`,
`Let's navigate your ${["Sunday", "Monday", "Tuesday", "Wednesday", "Thursday", "Friday", "Saturday"][day]} workload`,
];
const greeting = greetings[Math.floor(Math.random() * greetings.length)];
setGreeting(greeting);
@@ -108,22 +134,13 @@ function ChatBodyData(props: ChatBodyDataProps) {
}, [props.chatOptionsData]);
useEffect(() => {
const nSlice = props.isMobileWidth ? 2 : 4;
const shuffledAgents = agentsData ? [...agentsData].sort(() => 0.5 - Math.random()) : [];
const agents = agentsData ? [agentsData[0]] : []; // Always add the first/default agent.
shuffledAgents.slice(0, nSlice - 1).forEach((agent) => {
if (!agents.find((a) => a.slug === agent.slug)) {
agents.push(agent);
}
});
const agents = (agentsData || []).filter((agent) => agent !== null && agent !== undefined);
setAgents(agents);
// set the first agent, which is always the default agent, as the default for chat
setSelectedAgent(agents.length > 1 ? agents[0].slug : "khoj");
//generate colored icons for the selected agents
const agentIcons = agents
.filter((agent) => agent !== null && agent !== undefined)
.map((agent) => getIconFromIconName(agent.icon, agent.color)!);
// generate colored icons for the available agents
const agentIcons = agents.map((agent) => getIconFromIconName(agent.icon, agent.color)!);
setAgentIcons(agentIcons);
}, [agentsData, props.isMobileWidth]);
@@ -138,24 +155,40 @@ function ChatBodyData(props: ChatBodyDataProps) {
try {
const newConversationId = await createNewConversation(selectedAgent || "khoj");
onConversationIdChange?.(newConversationId);
window.location.href = `/chat?conversationId=${newConversationId}`;
localStorage.setItem("message", message);
if (image) {
localStorage.setItem("image", image);
if (images.length > 0) {
localStorage.setItem("images", JSON.stringify(images));
}
window.location.href = `/chat?conversationId=${newConversationId}`;
} catch (error) {
console.error("Error creating new conversation:", error);
setProcessingMessage(false);
}
setMessage("");
setImages([]);
}
};
processMessage();
if (message) {
if (message || images.length > 0) {
setProcessingMessage(true);
}
}, [selectedAgent, message, processingMessage, onConversationIdChange]);
// Close the agent detail hover card when scroll on agent pane
useEffect(() => {
const scrollAreaSelector = "[data-radix-scroll-area-viewport]";
const scrollAreaEl = document.querySelector<HTMLElement>(scrollAreaSelector);
const handleScroll = () => {
setHoveredAgent(null);
setIsPopoverOpen(false);
};
scrollAreaEl?.addEventListener("scroll", handleScroll);
return () => scrollAreaEl?.removeEventListener("scroll", handleScroll);
}, []);
function fillArea(link: string, type: string, prompt: string) {
if (!link) {
let message_str = "";
@@ -194,37 +227,76 @@ function ChatBodyData(props: ChatBodyDataProps) {
</h1>
</div>
{!props.isMobileWidth && (
<div className="flex pb-6 gap-2 items-center justify-center">
{agentIcons.map((icon, index) => (
<Card
key={`${index}-${agents[index].slug}`}
className={`${
selectedAgent === agents[index].slug
? convertColorToBorderClass(agents[index].color)
: "border-stone-100 dark:border-neutral-700 text-muted-foreground"
}
hover:cursor-pointer rounded-lg px-2 py-2`}
>
<CardTitle
className="text-center text-md font-medium flex justify-center items-center"
onClick={() => setSelectedAgent(agents[index].slug)}
<ScrollArea className="w-full max-w-[600px] mx-auto">
<div className="flex pb-2 gap-2 items-center justify-center">
{agents.map((agent, index) => (
<Popover
key={`${index}-${agent.slug}`}
open={isPopoverOpen && debouncedHoveredAgent === agent.slug}
onOpenChange={(open) => {
if (!open) {
setHoveredAgent(null);
setIsPopoverOpen(false);
}
}}
>
{icon} {agents[index].name}
</CardTitle>
</Card>
))}
<Card
className="border-none shadow-none flex justify-center items-center hover:cursor-pointer"
onClick={() => (window.location.href = "/agents")}
>
<CardTitle className="text-center text-md font-normal flex justify-center items-center px-1.5 py-2">
See All
</CardTitle>
</Card>
</div>
<PopoverTrigger asChild>
<Card
className={`${
selectedAgent === agent.slug
? convertColorToBorderClass(agent.color)
: "border-stone-100 dark:border-neutral-700 text-muted-foreground"
}
hover:cursor-pointer rounded-lg px-2 py-2`}
onDoubleClick={() => openAgentEditCard(agent.slug)}
onClick={() => {
setSelectedAgent(agent.slug);
chatInputRef.current?.focus();
setHoveredAgent(null);
setIsPopoverOpen(false);
}}
onMouseEnter={() => setHoveredAgent(agent.slug)}
onMouseLeave={() => {
setHoveredAgent(null);
setIsPopoverOpen(false);
}}
>
<CardTitle className="text-center text-md font-medium flex justify-center items-center whitespace-nowrap">
{agentIcons[index]} {agent.name}
</CardTitle>
</Card>
</PopoverTrigger>
<PopoverContent
className="w-80 p-0 border-none bg-transparent shadow-none"
onMouseLeave={() => {
setHoveredAgent(null);
setIsPopoverOpen(false);
}}
>
<AgentCard
data={agent}
userProfile={null}
isMobileWidth={props.isMobileWidth || false}
showChatButton={false}
editCard={false}
filesOptions={[]}
selectedChatModelOption=""
agentSlug=""
isSubscribed={isUserSubscribed(props.userConfig)}
setAgentChangeTriggered={() => {}}
modelOptions={[]}
inputToolOptions={{}}
outputModeOptions={{}}
/>
</PopoverContent>
</Popover>
))}
</div>
<ScrollBar orientation="horizontal" />
</ScrollArea>
)}
</div>
<div className={`mx-auto ${props.isMobileWidth ? "w-full" : "w-fit"}`}>
<div className={`mx-auto ${props.isMobileWidth ? "w-full" : "w-fit max-w-screen-md"}`}>
{!props.isMobileWidth && (
<div
className={`w-full ${styles.inputBox} shadow-lg bg-background align-middle items-center justify-center px-3 py-1 dark:bg-neutral-700 border-stone-100 dark:border-none dark:shadow-none rounded-2xl`}
@@ -232,12 +304,14 @@ function ChatBodyData(props: ChatBodyDataProps) {
<ChatInputArea
isLoggedIn={props.isLoggedIn}
sendMessage={(message) => setMessage(message)}
sendImage={(image) => setImage(image)}
sendImage={(image) => setImages((prevImages) => [...prevImages, image])}
sendDisabled={processingMessage}
chatOptionsData={props.chatOptionsData}
conversationId={null}
isMobileWidth={props.isMobileWidth}
setUploadedFiles={props.setUploadedFiles}
agentColor={agents.find((agent) => agent.slug === selectedAgent)?.color}
ref={chatInputRef}
/>
</div>
)}
@@ -285,40 +359,41 @@ function ChatBodyData(props: ChatBodyDataProps) {
<div
className={`${styles.inputBox} pt-1 shadow-[0_-20px_25px_-5px_rgba(0,0,0,0.1)] dark:bg-neutral-700 bg-background align-middle items-center justify-center pb-3 mx-1 rounded-t-2xl rounded-b-none`}
>
<div className="flex gap-2 items-center justify-left pt-1 pb-2 px-12">
{agentIcons.map((icon, index) => (
<Card
key={`${index}-${agents[index].slug}`}
className={`${selectedAgent === agents[index].slug ? convertColorToBorderClass(agents[index].color) : "border-muted text-muted-foreground"} hover:cursor-pointer`}
>
<CardTitle
className="text-center text-xs font-medium flex justify-center items-center px-1.5 py-1"
onClick={() => setSelectedAgent(agents[index].slug)}
<ScrollArea className="w-full max-w-[85vw]">
<div className="flex gap-2 items-center justify-left pt-1 pb-2 px-12">
{agentIcons.map((icon, index) => (
<Card
key={`${index}-${agents[index].slug}`}
className={`${selectedAgent === agents[index].slug ? convertColorToBorderClass(agents[index].color) : "border-muted text-muted-foreground"} hover:cursor-pointer`}
>
{icon} {agents[index].name}
</CardTitle>
</Card>
))}
<Card
className="border-none shadow-none flex justify-center items-center hover:cursor-pointer"
onClick={() => (window.location.href = "/agents")}
>
<CardTitle
className={`text-center ${props.isMobileWidth ? "text-xs" : "text-md"} font-normal flex justify-center items-center px-1.5 py-2`}
>
See All
</CardTitle>
</Card>
</div>
<CardTitle
className="text-center text-xs font-medium flex justify-center items-center whitespace-nowrap px-1.5 py-1"
onDoubleClick={() =>
openAgentEditCard(agents[index].slug)
}
onClick={() => {
setSelectedAgent(agents[index].slug);
chatInputRef.current?.focus();
}}
>
{icon} {agents[index].name}
</CardTitle>
</Card>
))}
</div>
<ScrollBar orientation="horizontal" />
</ScrollArea>
<ChatInputArea
isLoggedIn={props.isLoggedIn}
sendMessage={(message) => setMessage(message)}
sendImage={(image) => setImage(image)}
sendImage={(image) => setImages((prevImages) => [...prevImages, image])}
sendDisabled={processingMessage}
chatOptionsData={props.chatOptionsData}
conversationId={null}
isMobileWidth={props.isMobileWidth}
setUploadedFiles={props.setUploadedFiles}
agentColor={agents.find((agent) => agent.slug === selectedAgent)?.color}
ref={chatInputRef}
/>
</div>
</>
@@ -331,7 +406,7 @@ export default function Home() {
const [chatOptionsData, setChatOptionsData] = useState<ChatOptions | null>(null);
const [isLoading, setLoading] = useState(true);
const [conversationId, setConversationID] = useState<string | null>(null);
const [uploadedFiles, setUploadedFiles] = useState<string[]>([]);
const [uploadedFiles, setUploadedFiles] = useState<AttachedFileText[] | null>(null);
const isMobileWidth = useIsMobileWidth();
const { userConfig: initialUserConfig, isLoadingUserConfig } = useUserConfig(true);
@@ -347,6 +422,12 @@ export default function Home() {
setUserConfig(initialUserConfig);
}, [initialUserConfig]);
useEffect(() => {
if (uploadedFiles) {
localStorage.setItem("uploadedFiles", JSON.stringify(uploadedFiles));
}
}, [uploadedFiles]);
useEffect(() => {
fetch("/api/chat/options")
.then((response) => response.json())
@@ -372,7 +453,7 @@ export default function Home() {
<div className={`${styles.sidePanel}`}>
<SidePanel
conversationId={conversationId}
uploadedFiles={uploadedFiles}
uploadedFiles={[]}
isMobileWidth={isMobileWidth}
/>
</div>

View File

@@ -137,10 +137,8 @@ const ManageFilesModal: React.FC<{ onClose: () => void }> = ({ onClose }) => {
const deleteSelected = async () => {
let filesToDelete = selectedFiles.length > 0 ? selectedFiles : filteredFiles;
console.log("Delete selected files", filesToDelete);
if (filesToDelete.length === 0) {
console.log("No files to delete");
return;
}
@@ -162,15 +160,12 @@ const ManageFilesModal: React.FC<{ onClose: () => void }> = ({ onClose }) => {
// Reset selectedFiles
setSelectedFiles([]);
console.log("Deleted files:", filesToDelete);
} catch (error) {
console.error("Error deleting files:", error);
}
};
const deleteFile = async (filename: string) => {
console.log("Delete selected file", filename);
try {
const response = await fetch(
`/api/content/file?filename=${encodeURIComponent(filename)}`,
@@ -189,8 +184,6 @@ const ManageFilesModal: React.FC<{ onClose: () => void }> = ({ onClose }) => {
// Remove the file from selectedFiles if it's there
setSelectedFiles((prevSelected) => prevSelected.filter((file) => file !== filename));
console.log("Deleted file:", filename);
} catch (error) {
console.error("Error deleting file:", error);
}
@@ -513,7 +506,7 @@ export default function SettingsView() {
const isMobileWidth = useIsMobileWidth();
const cardClassName =
"w-full lg:w-1/3 grid grid-flow-column border border-gray-300 shadow-md rounded-lg bg-gradient-to-b from-background to-gray-50 dark:to-gray-950";
"w-full lg:w-1/3 grid grid-flow-column border border-gray-300 shadow-md rounded-lg bg-gradient-to-b from-background to-gray-50 dark:to-gray-950 border border-opacity-50";
useEffect(() => {
setUserConfig(initialUserConfig);
@@ -601,7 +594,7 @@ export default function SettingsView() {
const setSubscription = async (state: string) => {
try {
const url = `/api/subscription?email=${userConfig?.username}&operation=${state}`;
const url = `/api/subscription?operation=${state}`;
const response = await fetch(url, {
method: "PATCH",
headers: {
@@ -640,6 +633,51 @@ export default function SettingsView() {
}
};
const enableFreeTrial = async () => {
const formatDate = (dateString: Date) => {
const date = new Date(dateString);
return new Intl.DateTimeFormat("en-US", {
day: "2-digit",
month: "short",
year: "numeric",
}).format(date);
};
try {
const response = await fetch(`/api/subscription/trial`, {
method: "POST",
});
if (!response.ok) throw new Error("Failed to enable free trial");
const responseBody = await response.json();
// Set updated user settings
if (responseBody.trial_enabled && userConfig) {
let newUserConfig = userConfig;
newUserConfig.subscription_state = SubscriptionStates.TRIAL;
const renewalDate = new Date(
Date.now() + userConfig.length_of_free_trial * 24 * 60 * 60 * 1000,
);
newUserConfig.subscription_renewal_date = formatDate(renewalDate);
newUserConfig.subscription_enabled_trial_at = new Date().toISOString();
setUserConfig(newUserConfig);
// Notify user of free trial
toast({
title: "🎉 Trial Enabled",
description: `Your free trial will end on ${newUserConfig.subscription_renewal_date}`,
});
}
} catch (error) {
console.error("Error enabling free trial:", error);
toast({
title: "⚠️ Failed to Enable Free Trial",
description:
"Failed to enable free trial. Try again or contact us at team@khoj.dev",
});
}
};
const saveName = async () => {
if (!name) return;
try {
@@ -673,7 +711,7 @@ export default function SettingsView() {
};
const updateModel = (name: string) => async (id: string) => {
if (!userConfig?.is_active && name !== "search") {
if (!userConfig?.is_active) {
toast({
title: `Model Update`,
description: `You need to be subscribed to update ${name} models`,
@@ -866,10 +904,13 @@ export default function SettingsView() {
Futurist (Trial)
</p>
<p className="text-gray-400">
You are on a 14 day trial of the Khoj
Futurist plan. Check{" "}
You are on a{" "}
{userConfig.length_of_free_trial} day trial
of the Khoj Futurist plan. Your trial ends
on {userConfig.subscription_renewal_date}.
Check{" "}
<a
href="https://khoj.dev/pricing"
href="https://khoj.dev/#pricing"
target="_blank"
>
pricing page
@@ -909,7 +950,7 @@ export default function SettingsView() {
)) ||
(userConfig.subscription_state === "expired" && (
<>
<p className="text-xl">Free Plan</p>
<p className="text-xl">Humanist</p>
{(userConfig.subscription_renewal_date && (
<p className="text-gray-400">
Subscription <b>expired</b> on{" "}
@@ -923,7 +964,7 @@ export default function SettingsView() {
<p className="text-gray-400">
Check{" "}
<a
href="https://khoj.dev/pricing"
href="https://khoj.dev/#pricing"
target="_blank"
>
pricing page
@@ -960,7 +1001,8 @@ export default function SettingsView() {
/>
Resubscribe
</Button>
)) || (
)) ||
(userConfig.subscription_enabled_trial_at && (
<Button
variant="outline"
className="text-primary/80 hover:text-primary"
@@ -978,6 +1020,18 @@ export default function SettingsView() {
/>
Subscribe
</Button>
)) || (
<Button
variant="outline"
className="text-primary/80 hover:text-primary"
onClick={enableFreeTrial}
>
<ArrowCircleUp
weight="bold"
className="h-5 w-5 mr-2"
/>
Enable Trial
</Button>
)}
</CardFooter>
</Card>
@@ -1172,27 +1226,6 @@ export default function SettingsView() {
</CardFooter>
</Card>
)}
{userConfig.search_model_options.length > 0 && (
<Card className={cardClassName}>
<CardHeader className="text-xl flex flex-row">
<FileMagnifyingGlass className="h-7 w-7 mr-2" />
Search
</CardHeader>
<CardContent className="overflow-hidden pb-12 grid gap-8 h-fit">
<p className="text-gray-400">
Pick the search model to find your documents
</p>
<DropdownComponent
items={userConfig.search_model_options}
selected={
userConfig.selected_search_model_config
}
callbackFunc={updateModel("search")}
/>
</CardContent>
<CardFooter className="flex flex-wrap gap-4"></CardFooter>
</Card>
)}
{userConfig.paint_model_options.length > 0 && (
<Card className={cardClassName}>
<CardHeader className="text-xl flex flex-row">

View File

@@ -27,7 +27,14 @@ export default function RootLayout({
child-src 'none';
object-src 'none';"
></meta>
<body className={inter.className}>{children}</body>
<body className={inter.className}>
{children}
<script
dangerouslySetInnerHTML={{
__html: `window.EXCALIDRAW_ASSET_PATH = 'https://assets.khoj.dev/@excalidraw/excalidraw/dist/';`,
}}
/>
</body>
</html>
);
}

View File

@@ -5,45 +5,52 @@ import React, { Suspense, useEffect, useRef, useState } from "react";
import SidePanel from "../../components/sidePanel/chatHistorySidePanel";
import ChatHistory from "../../components/chatHistory/chatHistory";
import NavMenu from "../../components/navMenu/navMenu";
import Loading from "../../components/loading/loading";
import "katex/dist/katex.min.css";
import { useIPLocationData, useIsMobileWidth, welcomeConsole } from "../../common/utils";
import { useIsMobileWidth, welcomeConsole } from "../../common/utils";
import { useAuthenticatedData } from "@/app/common/auth";
import ChatInputArea, { ChatOptions } from "@/app/components/chatInputArea/chatInputArea";
import {
AttachedFileText,
ChatInputArea,
ChatOptions,
} from "@/app/components/chatInputArea/chatInputArea";
import { StreamMessage } from "@/app/components/chatMessage/chatMessage";
import { processMessageChunk } from "@/app/common/chatFunctions";
import { AgentData } from "@/app/agents/page";
interface ChatBodyDataProps {
chatOptionsData: ChatOptions | null;
setTitle: (title: string) => void;
setUploadedFiles: (files: string[]) => void;
setUploadedFiles: (files: AttachedFileText[]) => void;
isMobileWidth?: boolean;
publicConversationSlug: string;
streamedMessages: StreamMessage[];
isLoggedIn: boolean;
conversationId?: string;
setQueryToProcess: (query: string) => void;
setImage64: (image64: string) => void;
setImages: (images: string[]) => void;
}
function ChatBodyData(props: ChatBodyDataProps) {
const [message, setMessage] = useState("");
const [image, setImage] = useState<string | null>(null);
const [images, setImages] = useState<string[]>([]);
const [processingMessage, setProcessingMessage] = useState(false);
const [agentMetadata, setAgentMetadata] = useState<AgentData | null>(null);
const chatInputRef = useRef<HTMLTextAreaElement>(null);
const setQueryToProcess = props.setQueryToProcess;
const streamedMessages = props.streamedMessages;
const chatHistoryCustomClassName = props.isMobileWidth ? "w-full" : "w-4/6";
useEffect(() => {
if (image) {
props.setImage64(encodeURIComponent(image));
if (images.length > 0) {
const encodedImages = images.map((image) => encodeURIComponent(image));
props.setImages(encodedImages);
}
}, [image, props.setImage64]);
}, [images, props.setImages]);
useEffect(() => {
if (message) {
@@ -78,21 +85,23 @@ function ChatBodyData(props: ChatBodyDataProps) {
setTitle={props.setTitle}
pendingMessage={processingMessage ? message : ""}
incomingMessages={props.streamedMessages}
customClassName={chatHistoryCustomClassName}
/>
</div>
<div
className={`${styles.inputBox} p-1 md:px-2 shadow-md bg-background align-middle items-center justify-center dark:bg-neutral-700 dark:border-0 dark:shadow-sm rounded-t-2xl rounded-b-none md:rounded-xl`}
className={`${styles.inputBox} p-1 md:px-2 shadow-md bg-background align-middle items-center justify-center dark:bg-neutral-700 dark:border-0 dark:shadow-sm rounded-t-2xl rounded-b-none md:rounded-xl h-fit ${chatHistoryCustomClassName} mr-auto ml-auto`}
>
<ChatInputArea
isLoggedIn={props.isLoggedIn}
sendMessage={(message) => setMessage(message)}
sendImage={(image) => setImage(image)}
sendImage={(image) => setImages((prevImages) => [...prevImages, image])}
sendDisabled={processingMessage}
chatOptionsData={props.chatOptionsData}
conversationId={props.conversationId}
agentColor={agentMetadata?.color}
isMobileWidth={props.isMobileWidth}
setUploadedFiles={props.setUploadedFiles}
ref={chatInputRef}
/>
</div>
</>
@@ -106,14 +115,10 @@ export default function SharedChat() {
const [conversationId, setConversationID] = useState<string | undefined>(undefined);
const [messages, setMessages] = useState<StreamMessage[]>([]);
const [queryToProcess, setQueryToProcess] = useState<string>("");
const [processQuerySignal, setProcessQuerySignal] = useState(false);
const [uploadedFiles, setUploadedFiles] = useState<string[]>([]);
const [uploadedFiles, setUploadedFiles] = useState<AttachedFileText[] | null>(null);
const [paramSlug, setParamSlug] = useState<string | undefined>(undefined);
const [image64, setImage64] = useState<string>("");
const [images, setImages] = useState<string[]>([]);
const locationData = useIPLocationData() || {
timezone: Intl.DateTimeFormat().resolvedOptions().timeZone,
};
const authenticatedData = useAuthenticatedData();
const isMobileWidth = useIsMobileWidth();
@@ -137,6 +142,12 @@ export default function SharedChat() {
setParamSlug(window.location.pathname.split("/").pop() || "");
}, []);
useEffect(() => {
if (uploadedFiles) {
localStorage.setItem("uploadedFiles", JSON.stringify(uploadedFiles));
}
}, [uploadedFiles]);
useEffect(() => {
if (queryToProcess && !conversationId) {
// If the user has not yet started conversing in the chat, create a new conversation
@@ -149,6 +160,11 @@ export default function SharedChat() {
.then((response) => response.json())
.then((data) => {
setConversationID(data.conversation_id);
localStorage.setItem("message", queryToProcess);
if (images.length > 0) {
localStorage.setItem("images", JSON.stringify(images));
}
window.location.href = `/chat?conversationId=${data.conversation_id}`;
})
.catch((err) => {
console.error(err);
@@ -156,104 +172,8 @@ export default function SharedChat() {
});
return;
}
if (queryToProcess) {
// Add a new object to the state
const newStreamMessage: StreamMessage = {
rawResponse: "",
trainOfThought: [],
context: [],
onlineContext: {},
completed: false,
timestamp: new Date().toISOString(),
rawQuery: queryToProcess || "",
uploadedImageData: decodeURIComponent(image64),
};
setMessages((prevMessages) => [...prevMessages, newStreamMessage]);
setProcessQuerySignal(true);
}
}, [queryToProcess, conversationId, paramSlug]);
useEffect(() => {
if (processQuerySignal) {
chat();
}
}, [processQuerySignal]);
async function readChatStream(response: Response) {
if (!response.ok) throw new Error(response.statusText);
if (!response.body) throw new Error("Response body is null");
const reader = response.body.getReader();
const decoder = new TextDecoder();
const eventDelimiter = "␃🔚␗";
let buffer = "";
while (true) {
const { done, value } = await reader.read();
if (done) {
setQueryToProcess("");
setProcessQuerySignal(false);
setImage64("");
break;
}
const chunk = decoder.decode(value, { stream: true });
buffer += chunk;
let newEventIndex;
while ((newEventIndex = buffer.indexOf(eventDelimiter)) !== -1) {
const event = buffer.slice(0, newEventIndex);
buffer = buffer.slice(newEventIndex + eventDelimiter.length);
if (event) {
const currentMessage = messages.find((message) => !message.completed);
if (!currentMessage) {
console.error("No current message found");
return;
}
processMessageChunk(event, currentMessage);
setMessages([...messages]);
}
}
}
}
async function chat() {
if (!queryToProcess || !conversationId) return;
const chatAPI = "/api/chat?client=web";
const chatAPIBody = {
q: queryToProcess,
conversation_id: conversationId,
stream: true,
...(locationData && {
region: locationData.region,
country: locationData.country,
city: locationData.city,
country_code: locationData.countryCode,
timezone: locationData.timezone,
}),
...(image64 && { image: image64 }),
};
const response = await fetch(chatAPI, {
method: "POST",
headers: {
"Content-Type": "application/json",
},
body: JSON.stringify(chatAPIBody),
});
try {
await readChatStream(response);
} catch (error) {
console.error(error);
}
}
if (isLoading) {
return <Loading />;
}
@@ -268,13 +188,26 @@ export default function SharedChat() {
<div className={styles.sidePanel}>
<SidePanel
conversationId={conversationId ?? null}
uploadedFiles={uploadedFiles}
uploadedFiles={[]}
isMobileWidth={isMobileWidth}
/>
</div>
<div className={styles.chatBox}>
<div className={styles.chatBoxBody}>
{!isMobileWidth && title && (
<div
className={`${styles.chatTitleWrapper} text-nowrap text-ellipsis overflow-hidden max-w-screen-md grid items-top font-bold mr-8 pt-6 col-auto h-fit`}
>
{title && (
<h2
className={`text-lg text-ellipsis whitespace-nowrap overflow-x-hidden`}
>
{title}
</h2>
)}
</div>
)}
<Suspense fallback={<Loading />}>
<ChatBodyData
conversationId={conversationId}
@@ -286,7 +219,7 @@ export default function SharedChat() {
setTitle={setTitle}
setUploadedFiles={setUploadedFiles}
isMobileWidth={isMobileWidth}
setImage64={setImage64}
setImages={setImages}
/>
</Suspense>
</div>

View File

@@ -75,7 +75,7 @@ div.titleBar {
div.chatBoxBody {
display: grid;
height: 100%;
width: 70%;
width: 95%;
margin: auto;
}

View File

@@ -9,7 +9,7 @@ const Textarea = React.forwardRef<HTMLTextAreaElement, TextareaProps>(
return (
<textarea
className={cn(
"flex min-h-[80px] w-full rounded-md border border-input bg-background px-3 py-2 text-sm ring-offset-background placeholder:text-muted-foreground focus-visible:outline-none focus-visible:ring-2 focus-visible:ring-ring focus-visible:ring-offset-2 disabled:cursor-not-allowed disabled:opacity-50",
"flex min-h-[80px] w-full rounded-md border border-input bg-background px-3 py-2 text-sm ring-offset-background placeholder:text-muted-foreground disabled:cursor-not-allowed disabled:opacity-50",
className,
)}
ref={ref}

View File

@@ -1,6 +1,6 @@
{
"name": "khoj-ai",
"version": "1.26.2",
"version": "1.29.1",
"private": true,
"scripts": {
"dev": "next dev",
@@ -19,6 +19,7 @@
"prepare": "husky"
},
"dependencies": {
"@excalidraw/excalidraw": "^0.17.6",
"@hookform/resolvers": "^3.9.0",
"@phosphor-icons/react": "^2.1.7",
"@radix-ui/react-alert-dialog": "^1.1.1",

View File

@@ -1,34 +1,49 @@
import type { Config } from "tailwindcss"
import type { Config } from "tailwindcss";
const config = {
safelist: [
{
pattern: /to-(blue|yellow|green|pink|purple|orange|red|slate|gray|zinc|neutral|stone|amber|lime|green|emerald|teal|cyan|sky|blue|indigo|violet|fuchsia|rose)-(50|100|200|950)/,
variants: ['dark'],
pattern:
/to-(blue|yellow|green|pink|purple|orange|red|slate|gray|zinc|neutral|stone|amber|lime|green|emerald|teal|cyan|sky|blue|indigo|violet|fuchsia|rose)-(50|100|200|950)/,
variants: ["dark"],
},
{
pattern: /text-(blue|yellow|green|pink|purple|orange|red|slate|gray|zinc|neutral|stone|amber|lime|green|emerald|teal|cyan|sky|blue|indigo|violet|fuchsia|rose)-(50|100|200|950)/,
variants: ['dark'],
pattern:
/text-(blue|yellow|green|pink|purple|orange|red|slate|gray|zinc|neutral|stone|amber|lime|green|emerald|teal|cyan|sky|blue|indigo|violet|fuchsia|rose)-(50|100|200|950)/,
variants: ["dark"],
},
{
pattern: /border-(blue|yellow|green|pink|purple|orange|red|slate|gray|zinc|neutral|stone|amber|lime|green|emerald|teal|cyan|sky|blue|indigo|violet|fuchsia|rose)-(50|100|200|950)/,
variants: ['dark'],
pattern:
/border-(blue|yellow|green|pink|purple|orange|red|slate|gray|zinc|neutral|stone|amber|lime|green|emerald|teal|cyan|sky|blue|indigo|violet|fuchsia|rose)-(50|100|200|950)/,
variants: ["dark"],
},
{
pattern: /border-l-(blue|yellow|green|pink|purple|orange|red|slate|gray|zinc|neutral|stone|amber|lime|green|emerald|teal|cyan|sky|blue|indigo|violet|fuchsia|rose)-(50|100|200|400|500|950)/,
variants: ['dark'],
pattern:
/border-l-(blue|yellow|green|pink|purple|orange|red|slate|gray|zinc|neutral|stone|amber|lime|green|emerald|teal|cyan|sky|blue|indigo|violet|fuchsia|rose)-(50|100|200|400|500|950)/,
variants: ["dark"],
},
{
pattern: /bg-(blue|yellow|green|pink|purple|orange|red|slate|gray|zinc|neutral|stone|amber|lime|green|emerald|teal|cyan|sky|blue|indigo|violet|fuchsia|rose)-(50|100|200|400|500|950)/,
variants: ['dark'],
}
pattern:
/bg-(blue|yellow|green|pink|purple|orange|red|slate|gray|zinc|neutral|stone|amber|lime|green|emerald|teal|cyan|sky|blue|indigo|violet|fuchsia|rose)-(50|100|200|400|500|950)/,
variants: ["dark"],
},
{
pattern:
/ring-(blue|yellow|green|pink|purple|orange|red|slate|gray|zinc|neutral|stone|amber|lime|green|emerald|teal|cyan|sky|blue|indigo|violet|fuchsia|rose)-(50|100|200|400|500|950)/,
variants: ["focus-visible", "dark"],
},
{
pattern:
/caret-(blue|yellow|green|pink|purple|orange|red|slate|gray|zinc|neutral|stone|amber|lime|green|emerald|teal|cyan|sky|blue|indigo|violet|fuchsia|rose)-(50|100|200|400|500|950)/,
variants: ["focus", "dark"],
},
],
darkMode: ["class"],
content: [
'./pages/**/*.{ts,tsx}',
'./components/**/*.{ts,tsx}',
'./app/**/*.{ts,tsx}',
'./src/**/*.{ts,tsx}',
"./pages/**/*.{ts,tsx}",
"./components/**/*.{ts,tsx}",
"./app/**/*.{ts,tsx}",
"./src/**/*.{ts,tsx}",
],
prefix: "",
theme: {
@@ -101,9 +116,7 @@ const config = {
},
},
},
plugins: [
require("tailwindcss-animate"),
],
} satisfies Config
plugins: [require("tailwindcss-animate")],
} satisfies Config;
export default config
export default config;

View File

@@ -286,6 +286,11 @@
resolved "https://registry.yarnpkg.com/@eslint/js/-/js-8.57.1.tgz#de633db3ec2ef6a3c89e2f19038063e8a122e2c2"
integrity sha512-d9zaMRSTIKDLhctzH12MtXvJKSSUhaHcjV+2Z+GK+EEY7XKpP5yR4x+N3TAcHTcu963nIr+TMcCb4DBCYX1z6Q==
"@excalidraw/excalidraw@^0.17.6":
version "0.17.6"
resolved "https://registry.yarnpkg.com/@excalidraw/excalidraw/-/excalidraw-0.17.6.tgz#5fd208ce69d33ca712d1804b50d7d06d5c46ac4d"
integrity sha512-fyCl+zG/Z5yhHDh5Fq2ZGmphcrALmuOdtITm8gN4d8w4ntnaopTXcTfnAAaU3VleDC6LhTkoLOTG6P5kgREiIg==
"@floating-ui/core@^1.6.0":
version "1.6.8"
resolved "https://registry.yarnpkg.com/@floating-ui/core/-/core-1.6.8.tgz#aa43561be075815879305965020f492cdb43da12"

View File

@@ -108,7 +108,7 @@ class UserAuthenticationBackend(AuthenticationBackend):
password="default",
)
renewal_date = make_aware(datetime.strptime("2100-04-01", "%Y-%m-%d"))
Subscription.objects.create(user=default_user, type="standard", renewal_date=renewal_date)
Subscription.objects.create(user=default_user, type=Subscription.Type.STANDARD, renewal_date=renewal_date)
async def authenticate(self, request: HTTPConnection):
current_user = request.session.get("user")
@@ -168,12 +168,6 @@ class UserAuthenticationBackend(AuthenticationBackend):
if create_if_not_exists:
user, is_new = await aget_or_create_user_by_phone_number(phone_number)
if user and is_new:
update_telemetry_state(
request=request,
telemetry_type="api",
api="create_user",
metadata={"user_id": str(user.uuid)},
)
logger.log(logging.INFO, f"🥳 New User Created: {user.uuid}")
else:
user = await aget_user_by_phone_number(phone_number)
@@ -255,31 +249,35 @@ def configure_server(
state.search_models = configure_search(state.search_models, state.config.search_type)
setup_default_agent(user)
message = "📡 Telemetry disabled" if telemetry_disabled(state.config.app) else "📡 Telemetry enabled"
message = (
"📡 Telemetry disabled"
if telemetry_disabled(state.config.app, state.telemetry_disabled)
else "📡 Telemetry enabled"
)
logger.info(message)
if not init:
initialize_content(regenerate, search_type, user)
initialize_content(user, regenerate, search_type)
except Exception as e:
raise e
logger.error(f"Failed to load some search models: {e}", exc_info=True)
def setup_default_agent(user: KhojUser):
AgentAdapters.create_default_agent(user)
def initialize_content(regenerate: bool, search_type: Optional[SearchType] = None, user: KhojUser = None):
def initialize_content(user: KhojUser, regenerate: bool, search_type: Optional[SearchType] = None):
# Initialize Content from Config
if state.search_models:
try:
logger.info("📬 Updating content index...")
all_files = collect_files(user=user)
status = configure_content(
user,
all_files,
regenerate,
search_type,
user=user,
)
if not status:
raise RuntimeError("Failed to update content index")
@@ -312,7 +310,7 @@ def configure_routes(app):
logger.info("🔑 Enabled Authentication")
if state.billing_enabled:
from khoj.routers.subscription import subscription_router
from khoj.routers.api_subscription import subscription_router
app.include_router(subscription_router, prefix="/api/subscription")
logger.info("💳 Enabled Billing")
@@ -344,9 +342,7 @@ def configure_middleware(app):
def update_content_index():
for user in get_all_users():
all_files = collect_files(user=user)
success = configure_content(all_files, user=user)
all_files = collect_files(user=None)
success = configure_content(all_files, user=None)
success = configure_content(user, all_files)
if not success:
raise RuntimeError("Failed to update content index")
logger.info("📪 Content index updated via Scheduler")
@@ -369,7 +365,7 @@ def configure_search_types():
@schedule.repeat(schedule.every(2).minutes)
def upload_telemetry():
if telemetry_disabled(state.config.app) or not state.telemetry:
if telemetry_disabled(state.config.app, state.telemetry_disabled) or not state.telemetry:
return
try:

View File

@@ -8,13 +8,22 @@ import secrets
import sys
from datetime import date, datetime, timedelta, timezone
from enum import Enum
from typing import Callable, Iterable, List, Optional, Type
from functools import wraps
from typing import (
Any,
Callable,
Coroutine,
Iterable,
List,
Optional,
ParamSpec,
TypeVar,
)
import cron_descriptor
from apscheduler.job import Job
from asgiref.sync import sync_to_async
from django.contrib.sessions.backends.db import SessionStore
from django.db import models
from django.db.models import Prefetch, Q
from django.db.models.manager import BaseManager
from django.db.utils import IntegrityError
@@ -28,7 +37,6 @@ from khoj.database.models import (
ChatModelOptions,
ClientApplication,
Conversation,
DataStore,
Entry,
FileObject,
GithubConfig,
@@ -48,7 +56,6 @@ from khoj.database.models import (
TextToImageModelConfig,
UserConversationConfig,
UserRequests,
UserSearchModelConfig,
UserTextToImageModelConfig,
UserVoiceModelConfig,
VoiceModelOption,
@@ -70,6 +77,9 @@ from khoj.utils.helpers import (
logger = logging.getLogger(__name__)
LENGTH_OF_FREE_TRIAL = 7 #
class SubscriptionState(Enum):
TRIAL = "trial"
SUBSCRIBED = "subscribed"
@@ -78,6 +88,45 @@ class SubscriptionState(Enum):
INVALID = "invalid"
P = ParamSpec("P")
T = TypeVar("T")
def require_valid_user(func: Callable[P, T]) -> Callable[P, T]:
@wraps(func)
def sync_wrapper(*args: P.args, **kwargs: P.kwargs) -> T:
# Extract user from args/kwargs
user = next((arg for arg in args if isinstance(arg, KhojUser)), None)
if not user:
user = next((val for val in kwargs.values() if isinstance(val, KhojUser)), None)
# Throw error if user is not found
if not user:
raise ValueError("Khoj user argument required but not provided.")
return func(*args, **kwargs)
return sync_wrapper
def arequire_valid_user(func: Callable[P, Coroutine[Any, Any, T]]) -> Callable[P, Coroutine[Any, Any, T]]:
@wraps(func)
async def async_wrapper(*args: P.args, **kwargs: P.kwargs) -> T:
# Extract user from args/kwargs
user = next((arg for arg in args if isinstance(arg, KhojUser)), None)
if not user:
user = next((v for v in kwargs.values() if isinstance(v, KhojUser)), None)
# Throw error if user is not found
if not user:
raise ValueError("Khoj user argument required but not provided.")
return await func(*args, **kwargs)
return async_wrapper
@arequire_valid_user
async def set_notion_config(token: str, user: KhojUser):
notion_config = await NotionConfig.objects.filter(user=user).afirst()
if not notion_config:
@@ -88,6 +137,7 @@ async def set_notion_config(token: str, user: KhojUser):
return notion_config
@require_valid_user
def create_khoj_token(user: KhojUser, name=None):
"Create Khoj API key for user"
token = f"kk-{secrets.token_urlsafe(32)}"
@@ -95,6 +145,7 @@ def create_khoj_token(user: KhojUser, name=None):
return KhojApiUser.objects.create(token=token, user=user, name=name)
@arequire_valid_user
async def acreate_khoj_token(user: KhojUser, name=None):
"Create Khoj API key for user"
token = f"kk-{secrets.token_urlsafe(32)}"
@@ -102,11 +153,13 @@ async def acreate_khoj_token(user: KhojUser, name=None):
return await KhojApiUser.objects.acreate(token=token, user=user, name=name)
@require_valid_user
def get_khoj_tokens(user: KhojUser):
"Get all Khoj API keys for user"
return list(KhojApiUser.objects.filter(user=user))
@arequire_valid_user
async def delete_khoj_token(user: KhojUser, token: str):
"Delete Khoj API Key for user"
await KhojApiUser.objects.filter(token=token, user=user).adelete()
@@ -130,6 +183,7 @@ async def aget_or_create_user_by_phone_number(phone_number: str) -> tuple[KhojUs
return user, is_new
@arequire_valid_user
async def aset_user_phone_number(user: KhojUser, phone_number: str) -> KhojUser:
if is_none_or_empty(phone_number):
return None
@@ -153,6 +207,7 @@ async def aset_user_phone_number(user: KhojUser, phone_number: str) -> KhojUser:
return user
@arequire_valid_user
async def aremove_phone_number(user: KhojUser) -> KhojUser:
user.phone_number = None
user.verified_phone_number = False
@@ -168,7 +223,7 @@ async def acreate_user_by_phone_number(phone_number: str) -> KhojUser:
)
await user.asave()
await Subscription.objects.acreate(user=user, type="trial")
await Subscription.objects.acreate(user=user, type=Subscription.Type.STANDARD)
return user
@@ -185,11 +240,30 @@ async def aget_or_create_user_by_email(email: str) -> tuple[KhojUser, bool]:
user_subscription = await Subscription.objects.filter(user=user).afirst()
if not user_subscription:
await Subscription.objects.acreate(user=user, type="trial")
await Subscription.objects.acreate(user=user, type=Subscription.Type.STANDARD)
return user, is_new
@arequire_valid_user
async def astart_trial_subscription(user: KhojUser) -> Subscription:
subscription = await Subscription.objects.filter(user=user).afirst()
if not subscription:
raise HTTPException(status_code=400, detail="User does not have a subscription")
if subscription.type == Subscription.Type.TRIAL:
raise HTTPException(status_code=400, detail="User already has a trial subscription")
if subscription.enabled_trial_at:
raise HTTPException(status_code=400, detail="User already has a trial subscription")
subscription.type = Subscription.Type.TRIAL
subscription.enabled_trial_at = datetime.now(tz=timezone.utc)
subscription.renewal_date = datetime.now(tz=timezone.utc) + timedelta(days=LENGTH_OF_FREE_TRIAL)
await subscription.asave()
return subscription
async def aget_user_validated_by_email_verification_code(code: str) -> KhojUser:
user = await KhojUser.objects.filter(email_verification_code=code).afirst()
if not user:
@@ -221,11 +295,12 @@ async def create_user_by_google_token(token: dict) -> KhojUser:
user=user,
)
await Subscription.objects.acreate(user=user, type="trial")
await Subscription.objects.acreate(user=user, type=Subscription.Type.STANDARD)
return user
@require_valid_user
def set_user_name(user: KhojUser, first_name: str, last_name: str) -> KhojUser:
user.first_name = first_name
user.last_name = last_name
@@ -233,6 +308,7 @@ def set_user_name(user: KhojUser, first_name: str, last_name: str) -> KhojUser:
return user
@require_valid_user
def get_user_name(user: KhojUser):
full_name = user.get_full_name()
if not is_none_or_empty(full_name):
@@ -244,6 +320,7 @@ def get_user_name(user: KhojUser):
return None
@require_valid_user
def get_user_photo(user: KhojUser):
google_profile: GoogleUser = GoogleUser.objects.filter(user=user).first()
if google_profile:
@@ -279,16 +356,20 @@ def subscription_to_state(subscription: Subscription) -> str:
if not subscription:
return SubscriptionState.INVALID.value
elif subscription.type == Subscription.Type.TRIAL:
# Trial subscription is valid for 7 days
if datetime.now(tz=timezone.utc) - subscription.created_at > timedelta(days=14):
return SubscriptionState.EXPIRED.value
# Check if the trial has expired
if not subscription.renewal_date:
# If the renewal date is not set, set it to the current date + trial length and evaluate
subscription.renewal_date = subscription.created_at + timedelta(days=LENGTH_OF_FREE_TRIAL)
subscription.save()
if subscription.renewal_date and datetime.now(tz=timezone.utc) > subscription.renewal_date:
return SubscriptionState.EXPIRED.value
return SubscriptionState.TRIAL.value
elif subscription.is_recurring and subscription.renewal_date >= datetime.now(tz=timezone.utc):
elif subscription.is_recurring and subscription.renewal_date > datetime.now(tz=timezone.utc):
return SubscriptionState.SUBSCRIBED.value
elif not subscription.is_recurring and subscription.renewal_date is None:
return SubscriptionState.EXPIRED.value
elif not subscription.is_recurring and subscription.renewal_date >= datetime.now(tz=timezone.utc):
elif not subscription.is_recurring and subscription.renewal_date > datetime.now(tz=timezone.utc):
return SubscriptionState.UNSUBSCRIBED.value
elif not subscription.is_recurring and subscription.renewal_date < datetime.now(tz=timezone.utc):
return SubscriptionState.EXPIRED.value
@@ -303,14 +384,16 @@ def get_user_subscription_state(email: str) -> str:
return subscription_to_state(user_subscription)
@arequire_valid_user
async def aget_user_subscription_state(user: KhojUser) -> str:
"""Get subscription state of user
Valid state transitions: trial -> subscribed <-> unsubscribed OR expired
"""
user_subscription = await Subscription.objects.filter(user=user).afirst()
return subscription_to_state(user_subscription)
return await sync_to_async(subscription_to_state)(user_subscription)
@arequire_valid_user
async def ais_user_subscribed(user: KhojUser) -> bool:
"""
Get whether the user is subscribed
@@ -327,6 +410,7 @@ async def ais_user_subscribed(user: KhojUser) -> bool:
return subscribed
@require_valid_user
def is_user_subscribed(user: KhojUser) -> bool:
"""
Get whether the user is subscribed
@@ -392,11 +476,13 @@ def get_all_users() -> BaseManager[KhojUser]:
return KhojUser.objects.all()
@require_valid_user
def get_user_github_config(user: KhojUser):
config = GithubConfig.objects.filter(user=user).prefetch_related("githubrepoconfig").first()
return config
@require_valid_user
def get_user_notion_config(user: KhojUser):
config = NotionConfig.objects.filter(user=user).first()
return config
@@ -406,6 +492,7 @@ def delete_user_requests(window: timedelta = timedelta(days=1)):
return UserRequests.objects.filter(created_at__lte=datetime.now(tz=timezone.utc) - window).delete()
@arequire_valid_user
async def aget_user_name(user: KhojUser):
full_name = user.get_full_name()
if not is_none_or_empty(full_name):
@@ -417,18 +504,7 @@ async def aget_user_name(user: KhojUser):
return None
async def set_text_content_config(user: KhojUser, object: Type[models.Model], updated_config):
deduped_files = list(set(updated_config.input_files)) if updated_config.input_files else None
deduped_filters = list(set(updated_config.input_filter)) if updated_config.input_filter else None
await object.objects.filter(user=user).adelete()
await object.objects.acreate(
input_files=deduped_files,
input_filter=deduped_filters,
index_heading_entries=updated_config.index_heading_entries,
user=user,
)
@arequire_valid_user
async def set_user_github_config(user: KhojUser, pat_token: str, repos: list):
config = await GithubConfig.objects.filter(user=user).afirst()
@@ -446,15 +522,13 @@ async def set_user_github_config(user: KhojUser, pat_token: str, repos: list):
return config
def get_user_search_model_or_default(user=None):
if user and UserSearchModelConfig.objects.filter(user=user).exists():
return UserSearchModelConfig.objects.filter(user=user).first().setting
def get_default_search_model() -> SearchModelConfig:
default_search_model = SearchModelConfig.objects.filter(name="default").first()
if SearchModelConfig.objects.filter(name="default").exists():
return SearchModelConfig.objects.filter(name="default").first()
else:
if default_search_model:
return default_search_model
elif SearchModelConfig.objects.count() == 0:
SearchModelConfig.objects.create()
return SearchModelConfig.objects.first()
@@ -467,21 +541,6 @@ def get_or_create_search_models():
return search_models
async def aset_user_search_model(user: KhojUser, search_model_config_id: int):
config = await SearchModelConfig.objects.filter(id=search_model_config_id).afirst()
if not config:
return None
new_config, _ = await UserSearchModelConfig.objects.aupdate_or_create(user=user, defaults={"setting": config})
return new_config
async def aget_user_search_model(user: KhojUser):
config = await UserSearchModelConfig.objects.filter(user=user).prefetch_related("setting").afirst()
if not config:
return None
return config.setting
class ProcessLockAdapters:
@staticmethod
def get_process_lock(process_name: str):
@@ -580,8 +639,11 @@ class AgentAdapters:
)
@staticmethod
@arequire_valid_user
async def adelete_agent_by_slug(agent_slug: str, user: KhojUser):
agent = await AgentAdapters.aget_agent_by_slug(agent_slug, user)
if agent.creator != user:
return False
async for entry in Entry.objects.filter(agent=agent).aiterator():
await entry.adelete()
@@ -622,6 +684,8 @@ class AgentAdapters:
@staticmethod
def get_all_accessible_agents(user: KhojUser = None):
public_query = Q(privacy_level=Agent.PrivacyLevel.PUBLIC)
# TODO Update this to allow any public agent that's officially approved once that experience is launched
public_query &= Q(managed_by_admin=True)
if user:
return (
Agent.objects.filter(public_query | Q(creator=user))
@@ -640,6 +704,18 @@ class AgentAdapters:
agents = await sync_to_async(AgentAdapters.get_all_accessible_agents)(user)
return await sync_to_async(list)(agents)
@staticmethod
async def ais_agent_accessible(agent: Agent, user: KhojUser) -> bool:
agent = await Agent.objects.select_related("creator").aget(pk=agent.pk)
if agent.privacy_level == Agent.PrivacyLevel.PUBLIC:
return True
if agent.creator == user:
return True
if agent.privacy_level == Agent.PrivacyLevel.PROTECTED:
return True
return False
@staticmethod
def get_conversation_agent_by_id(agent_id: int):
agent = Agent.objects.filter(id=agent_id).first()
@@ -691,6 +767,7 @@ class AgentAdapters:
return await Agent.objects.filter(name=AgentAdapters.DEFAULT_AGENT_NAME).afirst()
@staticmethod
@arequire_valid_user
async def aupdate_agent(
user: KhojUser,
name: str,
@@ -766,19 +843,6 @@ class PublicConversationAdapters:
return f"/share/chat/{public_conversation.slug}/"
class DataStoreAdapters:
@staticmethod
async def astore_data(data: dict, key: str, user: KhojUser, private: bool = True):
if await DataStore.objects.filter(key=key).aexists():
return key
await DataStore.objects.acreate(value=data, key=key, owner=user, private=private)
return key
@staticmethod
async def aretrieve_public_data(key: str):
return await DataStore.objects.filter(key=key, private=False).afirst()
class ConversationAdapters:
@staticmethod
def make_public_conversation_copy(conversation: Conversation):
@@ -791,6 +855,7 @@ class ConversationAdapters:
)
@staticmethod
@require_valid_user
def get_conversation_by_user(
user: KhojUser, client_application: ClientApplication = None, conversation_id: str = None
) -> Optional[Conversation]:
@@ -809,6 +874,7 @@ class ConversationAdapters:
return conversation
@staticmethod
@require_valid_user
def get_conversation_sessions(user: KhojUser, client_application: ClientApplication = None):
return (
Conversation.objects.filter(user=user, client=client_application)
@@ -817,6 +883,7 @@ class ConversationAdapters:
)
@staticmethod
@arequire_valid_user
async def aset_conversation_title(
user: KhojUser, client_application: ClientApplication, conversation_id: str, title: str
):
@@ -834,6 +901,7 @@ class ConversationAdapters:
return Conversation.objects.filter(id=conversation_id).first()
@staticmethod
@arequire_valid_user
async def acreate_conversation_session(
user: KhojUser, client_application: ClientApplication = None, agent_slug: str = None, title: str = None
):
@@ -841,11 +909,16 @@ class ConversationAdapters:
agent = await AgentAdapters.aget_readonly_agent_by_slug(agent_slug, user)
if agent is None:
raise HTTPException(status_code=400, detail="No such agent currently exists.")
return await Conversation.objects.acreate(user=user, client=client_application, agent=agent, title=title)
return await Conversation.objects.select_related("agent", "agent__creator", "agent__chat_model").acreate(
user=user, client=client_application, agent=agent, title=title
)
agent = await AgentAdapters.aget_default_agent()
return await Conversation.objects.acreate(user=user, client=client_application, agent=agent, title=title)
return await Conversation.objects.select_related("agent", "agent__creator", "agent__chat_model").acreate(
user=user, client=client_application, agent=agent, title=title
)
@staticmethod
@require_valid_user
def create_conversation_session(
user: KhojUser, client_application: ClientApplication = None, agent_slug: str = None, title: str = None
):
@@ -858,6 +931,7 @@ class ConversationAdapters:
return Conversation.objects.create(user=user, client=client_application, agent=agent, title=title)
@staticmethod
@arequire_valid_user
async def aget_conversation_by_user(
user: KhojUser,
client_application: ClientApplication = None,
@@ -882,6 +956,7 @@ class ConversationAdapters:
)
@staticmethod
@arequire_valid_user
async def adelete_conversation_by_user(
user: KhojUser, client_application: ClientApplication = None, conversation_id: str = None
):
@@ -890,6 +965,7 @@ class ConversationAdapters:
return await Conversation.objects.filter(user=user, client=client_application).adelete()
@staticmethod
@require_valid_user
def has_any_conversation_config(user: KhojUser):
return ChatModelOptions.objects.filter(user=user).exists()
@@ -926,6 +1002,7 @@ class ConversationAdapters:
return OpenAIProcessorConversationConfig.objects.filter().exists()
@staticmethod
@arequire_valid_user
async def aset_user_conversation_processor(user: KhojUser, conversation_processor_config_id: int):
config = await ChatModelOptions.objects.filter(id=conversation_processor_config_id).afirst()
if not config:
@@ -934,6 +1011,7 @@ class ConversationAdapters:
return new_config
@staticmethod
@arequire_valid_user
async def aset_user_voice_model(user: KhojUser, model_id: str):
config = await VoiceModelOption.objects.filter(model_id=model_id).afirst()
if not config:
@@ -984,8 +1062,15 @@ class ConversationAdapters:
"""Get default conversation config. Prefer chat model by server admin > user > first created chat model"""
# Get the server chat settings
server_chat_settings = ServerChatSettings.objects.first()
if server_chat_settings is not None and server_chat_settings.chat_default is not None:
return server_chat_settings.chat_default
is_subscribed = is_user_subscribed(user) if user else False
if server_chat_settings:
# If the user is subscribed and the advanced model is enabled, return the advanced model
if is_subscribed and server_chat_settings.chat_advanced:
return server_chat_settings.chat_advanced
# If the default model is set, return it
if server_chat_settings.chat_default:
return server_chat_settings.chat_default
# Get the user's chat settings, if the server chat settings are not set
user_chat_settings = UserConversationConfig.objects.filter(user=user).first() if user else None
@@ -1001,11 +1086,20 @@ class ConversationAdapters:
# Get the server chat settings
server_chat_settings: ServerChatSettings = (
await ServerChatSettings.objects.filter()
.prefetch_related("chat_default", "chat_default__openai_config")
.prefetch_related(
"chat_default", "chat_default__openai_config", "chat_advanced", "chat_advanced__openai_config"
)
.afirst()
)
if server_chat_settings is not None and server_chat_settings.chat_default is not None:
return server_chat_settings.chat_default
is_subscribed = await ais_user_subscribed(user) if user else False
if server_chat_settings:
# If the user is subscribed and the advanced model is enabled, return the advanced model
if is_subscribed and server_chat_settings.chat_advanced:
return server_chat_settings.chat_advanced
# If the default model is set, return it
if server_chat_settings.chat_default:
return server_chat_settings.chat_default
# Get the user's chat settings, if the server chat settings are not set
user_chat_settings = (
@@ -1102,6 +1196,7 @@ class ConversationAdapters:
return enabled_scrapers
@staticmethod
@require_valid_user
def create_conversation_from_public_conversation(
user: KhojUser, public_conversation: PublicConversation, client_app: ClientApplication
):
@@ -1118,6 +1213,7 @@ class ConversationAdapters:
)
@staticmethod
@require_valid_user
def save_conversation(
user: KhojUser,
conversation_log: dict,
@@ -1167,6 +1263,7 @@ class ConversationAdapters:
return await SpeechToTextModelOptions.objects.filter().afirst()
@staticmethod
@arequire_valid_user
async def aget_conversation_starters(user: KhojUser, max_results=3):
all_questions = []
if await ReflectiveQuestion.objects.filter(user=user).aexists():
@@ -1267,6 +1364,8 @@ class ConversationAdapters:
def add_files_to_filter(user: KhojUser, conversation_id: str, files: List[str]):
conversation = ConversationAdapters.get_conversation_by_user(user, conversation_id=conversation_id)
file_list = EntryAdapters.get_all_filenames_by_source(user, "computer")
if not conversation:
return []
for filename in files:
if filename in file_list and filename not in conversation.file_filters:
conversation.file_filters.append(filename)
@@ -1280,6 +1379,8 @@ class ConversationAdapters:
@staticmethod
def remove_files_from_filter(user: KhojUser, conversation_id: str, files: List[str]):
conversation = ConversationAdapters.get_conversation_by_user(user, conversation_id=conversation_id)
if not conversation:
return []
for filename in files:
if filename in conversation.file_filters:
conversation.file_filters.remove(filename)
@@ -1291,6 +1392,18 @@ class ConversationAdapters:
conversation.save()
return conversation.file_filters
@staticmethod
@require_valid_user
def delete_message_by_turn_id(user: KhojUser, conversation_id: str, turn_id: str):
conversation = ConversationAdapters.get_conversation_by_user(user, conversation_id=conversation_id)
if not conversation or not conversation.conversation_log or not conversation.conversation_log.get("chat"):
return False
conversation_log = conversation.conversation_log
updated_log = [msg for msg in conversation_log["chat"] if msg.get("turnId") != turn_id]
conversation.conversation_log["chat"] = updated_log
conversation.save()
return True
class FileObjectAdapters:
@staticmethod
@@ -1299,48 +1412,63 @@ class FileObjectAdapters:
file_object.save()
@staticmethod
@require_valid_user
def create_file_object(user: KhojUser, file_name: str, raw_text: str):
return FileObject.objects.create(user=user, file_name=file_name, raw_text=raw_text)
@staticmethod
@require_valid_user
def get_file_object_by_name(user: KhojUser, file_name: str):
return FileObject.objects.filter(user=user, file_name=file_name).first()
@staticmethod
@require_valid_user
def get_all_file_objects(user: KhojUser):
return FileObject.objects.filter(user=user).all()
@staticmethod
@require_valid_user
def delete_file_object_by_name(user: KhojUser, file_name: str):
return FileObject.objects.filter(user=user, file_name=file_name).delete()
@staticmethod
@require_valid_user
def delete_all_file_objects(user: KhojUser):
return FileObject.objects.filter(user=user).delete()
@staticmethod
async def async_update_raw_text(file_object: FileObject, new_raw_text: str):
async def aupdate_raw_text(file_object: FileObject, new_raw_text: str):
file_object.raw_text = new_raw_text
await file_object.asave()
@staticmethod
async def async_create_file_object(user: KhojUser, file_name: str, raw_text: str):
@arequire_valid_user
async def acreate_file_object(user: KhojUser, file_name: str, raw_text: str):
return await FileObject.objects.acreate(user=user, file_name=file_name, raw_text=raw_text)
@staticmethod
async def async_get_file_objects_by_name(user: KhojUser, file_name: str, agent: Agent = None):
@arequire_valid_user
async def aget_file_objects_by_name(user: KhojUser, file_name: str, agent: Agent = None):
return await sync_to_async(list)(FileObject.objects.filter(user=user, file_name=file_name, agent=agent))
@staticmethod
async def async_get_all_file_objects(user: KhojUser):
@arequire_valid_user
async def aget_file_objects_by_names(user: KhojUser, file_names: List[str]):
return await sync_to_async(list)(FileObject.objects.filter(user=user, file_name__in=file_names))
@staticmethod
@arequire_valid_user
async def aget_all_file_objects(user: KhojUser):
return await sync_to_async(list)(FileObject.objects.filter(user=user))
@staticmethod
async def async_delete_file_object_by_name(user: KhojUser, file_name: str):
@arequire_valid_user
async def adelete_file_object_by_name(user: KhojUser, file_name: str):
return await FileObject.objects.filter(user=user, file_name=file_name).adelete()
@staticmethod
async def async_delete_all_file_objects(user: KhojUser):
@arequire_valid_user
async def adelete_all_file_objects(user: KhojUser):
return await FileObject.objects.filter(user=user).adelete()
@@ -1350,15 +1478,18 @@ class EntryAdapters:
date_filter = DateFilter()
@staticmethod
@require_valid_user
def does_entry_exist(user: KhojUser, hashed_value: str) -> bool:
return Entry.objects.filter(user=user, hashed_value=hashed_value).exists()
@staticmethod
@require_valid_user
def delete_entry_by_file(user: KhojUser, file_path: str):
deleted_count, _ = Entry.objects.filter(user=user, file_path=file_path).delete()
return deleted_count
@staticmethod
@require_valid_user
def get_filtered_entries(user: KhojUser, file_type: str = None, file_source: str = None):
queryset = Entry.objects.filter(user=user)
@@ -1371,6 +1502,7 @@ class EntryAdapters:
return queryset
@staticmethod
@require_valid_user
def delete_all_entries(user: KhojUser, file_type: str = None, file_source: str = None, batch_size=1000):
deleted_count = 0
queryset = EntryAdapters.get_filtered_entries(user, file_type, file_source)
@@ -1382,6 +1514,7 @@ class EntryAdapters:
return deleted_count
@staticmethod
@arequire_valid_user
async def adelete_all_entries(user: KhojUser, file_type: str = None, file_source: str = None, batch_size=1000):
deleted_count = 0
queryset = EntryAdapters.get_filtered_entries(user, file_type, file_source)
@@ -1393,10 +1526,12 @@ class EntryAdapters:
return deleted_count
@staticmethod
@require_valid_user
def get_existing_entry_hashes_by_file(user: KhojUser, file_path: str):
return Entry.objects.filter(user=user, file_path=file_path).values_list("hashed_value", flat=True)
@staticmethod
@require_valid_user
def delete_entry_by_hash(user: KhojUser, hashed_values: List[str]):
Entry.objects.filter(user=user, hashed_value__in=hashed_values).delete()
@@ -1408,6 +1543,7 @@ class EntryAdapters:
)
@staticmethod
@require_valid_user
def user_has_entries(user: KhojUser):
return Entry.objects.filter(user=user).exists()
@@ -1416,18 +1552,23 @@ class EntryAdapters:
return Entry.objects.filter(agent=agent).exists()
@staticmethod
@arequire_valid_user
async def auser_has_entries(user: KhojUser):
return await Entry.objects.filter(user=user).aexists()
@staticmethod
async def aagent_has_entries(agent: Agent):
if agent is None:
return False
return await Entry.objects.filter(agent=agent).aexists()
@staticmethod
@arequire_valid_user
async def adelete_entry_by_file(user: KhojUser, file_path: str):
return await Entry.objects.filter(user=user, file_path=file_path).adelete()
@staticmethod
@arequire_valid_user
async def adelete_entries_by_filenames(user: KhojUser, filenames: List[str], batch_size=1000):
deleted_count = 0
for i in range(0, len(filenames), batch_size):
@@ -1439,9 +1580,14 @@ class EntryAdapters:
@staticmethod
async def aget_agent_entry_filepaths(agent: Agent):
return await sync_to_async(list)(Entry.objects.filter(agent=agent).values_list("file_path", flat=True))
if agent is None:
return []
return await sync_to_async(set)(
Entry.objects.filter(agent=agent).distinct("file_path").values_list("file_path", flat=True)
)
@staticmethod
@require_valid_user
def get_all_filenames_by_source(user: KhojUser, file_source: str):
return (
Entry.objects.filter(user=user, file_source=file_source)
@@ -1450,6 +1596,7 @@ class EntryAdapters:
)
@staticmethod
@require_valid_user
def get_size_of_indexed_data_in_mb(user: KhojUser):
entries = Entry.objects.filter(user=user).iterator()
total_size = sum(sys.getsizeof(entry.compiled) for entry in entries)
@@ -1470,6 +1617,9 @@ class EntryAdapters:
if agent != None:
owner_filter |= Q(agent=agent)
if owner_filter == Q():
return Entry.objects.none()
if len(word_filters) == 0 and len(file_filters) == 0 and len(date_filters) == 0:
return Entry.objects.filter(owner_filter)
@@ -1514,11 +1664,11 @@ class EntryAdapters:
@staticmethod
def search_with_embeddings(
user: KhojUser,
raw_query: str,
embeddings: Tensor,
user: KhojUser,
max_results: int = 10,
file_type_filter: str = None,
raw_query: str = None,
max_distance: float = math.inf,
agent: Agent = None,
):
@@ -1544,10 +1694,12 @@ class EntryAdapters:
return relevant_entries[:max_results]
@staticmethod
@require_valid_user
def get_unique_file_types(user: KhojUser):
return Entry.objects.filter(user=user).values_list("file_type", flat=True).distinct()
@staticmethod
@require_valid_user
def get_unique_file_sources(user: KhojUser):
return Entry.objects.filter(user=user).values_list("file_source", flat=True).distinct().all()

View File

@@ -28,7 +28,6 @@ from khoj.database.models import (
TextToImageModelConfig,
UserConversationConfig,
UserRequests,
UserSearchModelConfig,
UserVoiceModelConfig,
VoiceModelOption,
WebScraper,
@@ -79,7 +78,12 @@ class KhojUserAdmin(UserAdmin):
search_fields = ("email", "username", "phone_number", "uuid")
filter_horizontal = ("groups", "user_permissions")
fieldsets = (("Personal info", {"fields": ("phone_number", "email_verification_code")}),) + UserAdmin.fieldsets
fieldsets = (
(
"Personal info",
{"fields": ("phone_number", "email_verification_code", "verified_phone_number", "verified_email")},
),
) + UserAdmin.fieldsets
actions = ["get_email_login_url"]
@@ -99,7 +103,6 @@ admin.site.register(KhojUser, KhojUserAdmin)
admin.site.register(ProcessLock)
admin.site.register(SpeechToTextModelOptions)
admin.site.register(ReflectiveQuestion)
admin.site.register(UserSearchModelConfig)
admin.site.register(ClientApplication)
admin.site.register(GithubConfig)
admin.site.register(NotionConfig)
@@ -126,6 +129,7 @@ class EntryAdmin(admin.ModelAdmin):
"created_at",
"updated_at",
"user",
"agent",
"file_source",
"file_type",
"file_name",
@@ -135,6 +139,7 @@ class EntryAdmin(admin.ModelAdmin):
list_filter = (
"file_type",
"user__email",
"search_model__name",
)
ordering = ("-created_at",)

View File

@@ -0,0 +1,121 @@
import logging
from typing import List
from django.core.management.base import BaseCommand
from django.db import transaction
from django.db.models import Count, Q
from tqdm import tqdm
from khoj.database.adapters import get_default_search_model
from khoj.database.models import Agent, Entry, KhojUser, SearchModelConfig
from khoj.processor.embeddings import EmbeddingsModel
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
BATCH_SIZE = 1000 # Define an appropriate batch size
class Command(BaseCommand):
help = "Convert all existing Entry objects to use a new default Search model."
def add_arguments(self, parser):
# Pass default SearchModelConfig ID
parser.add_argument(
"--search_model_id",
action="store",
help="ID of the SearchModelConfig object to set as the default search model for all existing Entry objects.",
required=True,
)
# Set the apply flag to apply the new default Search model to all existing Entry objects.
parser.add_argument(
"--apply",
action="store_true",
help="Apply the new default Search model to all existing Entry objects. Otherwise, only display the number of Entry objects that will be affected.",
)
def handle(self, *args, **options):
@transaction.atomic
def regenerate_entries(entry_filter: Q, embeddings_model: EmbeddingsModel, search_model: SearchModelConfig):
total_entries = Entry.objects.filter(entry_filter).count()
for start in tqdm(range(0, total_entries, BATCH_SIZE)):
end = start + BATCH_SIZE
entries = Entry.objects.filter(entry_filter)[start:end]
compiled_entries = [entry.compiled for entry in entries]
updated_entries: List[Entry] = []
try:
embeddings = embeddings_model.embed_documents(compiled_entries)
except Exception as e:
logger.error(f"Error embedding documents: {e}")
return
for i, entry in enumerate(entries):
entry.embeddings = embeddings[i]
entry.search_model_id = search_model.id
updated_entries.append(entry)
Entry.objects.bulk_update(updated_entries, ["embeddings", "search_model_id", "file_path"])
search_model_config_id = options.get("search_model_id")
apply = options.get("apply")
logger.info(f"SearchModelConfig ID: {search_model_config_id}")
logger.info(f"Apply: {apply}")
embeddings_model = dict()
search_models = SearchModelConfig.objects.all()
for model in search_models:
embeddings_model.update(
{
model.name: EmbeddingsModel(
model.bi_encoder,
model.embeddings_inference_endpoint,
model.embeddings_inference_endpoint_api_key,
query_encode_kwargs=model.bi_encoder_query_encode_config,
docs_encode_kwargs=model.bi_encoder_docs_encode_config,
model_kwargs=model.bi_encoder_model_config,
)
}
)
new_default_search_model_config = SearchModelConfig.objects.get(id=search_model_config_id)
logger.info(f"New default Search model: {new_default_search_model_config}")
logger.info("----")
current_default = get_default_search_model()
# TODO: Migrate all Entry objects to use the new default Search model
all_agents = Agent.objects.all()
logger.info(f"Number of Agent objects to update: {all_agents.count()}")
for agent in all_agents:
entry_filter = Q(agent=agent)
relevant_entries = Entry.objects.filter(entry_filter).all()
logger.info(f"Number of Entry objects to update for agent {agent}: {relevant_entries.count()}")
if apply:
try:
regenerate_entries(
entry_filter,
embeddings_model[new_default_search_model_config.name],
new_default_search_model_config,
)
logger.info(
f"Updated {relevant_entries.count()} Entry objects for agent {agent} to use the new default Search model."
)
except Exception as e:
logger.error(f"Error embedding documents: {e}")
if apply and current_default.id != new_default_search_model_config.id:
# Get the existing default SearchModelConfig object and update its name
current_default.name = f"prev_default_{current_default.id}"
current_default.save()
# Update the new default SearchModelConfig object's name
new_default_search_model_config.name = "default"
new_default_search_model_config.save()
if not apply:
logger.info("Run the command with the --apply flag to apply the new default Search model.")

View File

@@ -0,0 +1,46 @@
# Generated by Django 5.0.8 on 2024-10-21 05:16
import django.contrib.postgres.fields
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
("database", "0069_webscraper_serverchatsettings_web_scraper"),
]
operations = [
migrations.AlterField(
model_name="agent",
name="input_tools",
field=django.contrib.postgres.fields.ArrayField(
base_field=models.CharField(
choices=[
("general", "General"),
("online", "Online"),
("notes", "Notes"),
("summarize", "Summarize"),
("webpage", "Webpage"),
],
max_length=200,
),
blank=True,
default=list,
null=True,
size=None,
),
),
migrations.AlterField(
model_name="agent",
name="output_modes",
field=django.contrib.postgres.fields.ArrayField(
base_field=models.CharField(
choices=[("text", "Text"), ("image", "Image"), ("automation", "Automation")], max_length=200
),
blank=True,
default=list,
null=True,
size=None,
),
),
]

View File

@@ -0,0 +1,32 @@
# Generated by Django 5.0.8 on 2024-10-20 19:24
from django.db import migrations, models
def set_enabled_trial_at(apps, schema_editor):
Subscription = apps.get_model("database", "Subscription")
for subscription in Subscription.objects.all():
subscription.enabled_trial_at = subscription.created_at
subscription.save()
class Migration(migrations.Migration):
dependencies = [
("database", "0070_alter_agent_input_tools_alter_agent_output_modes"),
]
operations = [
migrations.AddField(
model_name="subscription",
name="enabled_trial_at",
field=models.DateTimeField(blank=True, default=None, null=True),
),
migrations.AlterField(
model_name="subscription",
name="type",
field=models.CharField(
choices=[("trial", "Trial"), ("standard", "Standard")], default="standard", max_length=20
),
),
migrations.RunPython(set_enabled_trial_at),
]

View File

@@ -0,0 +1,24 @@
# Generated by Django 5.0.8 on 2024-10-21 21:09
import django.db.models.deletion
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
("database", "0071_subscription_enabled_trial_at_and_more"),
]
operations = [
migrations.AddField(
model_name="entry",
name="search_model",
field=models.ForeignKey(
blank=True,
default=None,
null=True,
on_delete=django.db.models.deletion.SET_NULL,
to="database.searchmodelconfig",
),
),
]

View File

@@ -0,0 +1,15 @@
# Generated by Django 5.0.9 on 2024-11-04 19:56
from django.db import migrations
class Migration(migrations.Migration):
dependencies = [
("database", "0072_entry_search_model"),
]
operations = [
migrations.DeleteModel(
name="UserSearchModelConfig",
),
]

View File

@@ -0,0 +1,17 @@
# Generated by Django 5.0.9 on 2024-11-12 09:50
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
("database", "0073_delete_usersearchmodelconfig"),
]
operations = [
migrations.AlterField(
model_name="conversation",
name="title",
field=models.CharField(blank=True, default=None, max_length=500, null=True),
),
]

View File

@@ -73,9 +73,10 @@ class Subscription(BaseModel):
STANDARD = "standard"
user = models.OneToOneField(KhojUser, on_delete=models.CASCADE, related_name="subscription")
type = models.CharField(max_length=20, choices=Type.choices, default=Type.TRIAL)
type = models.CharField(max_length=20, choices=Type.choices, default=Type.STANDARD)
is_recurring = models.BooleanField(default=False)
renewal_date = models.DateTimeField(null=True, default=None, blank=True)
enabled_trial_at = models.DateTimeField(null=True, default=None, blank=True)
class OpenAIProcessorConversationConfig(BaseModel):
@@ -180,8 +181,12 @@ class Agent(BaseModel):
) # Creator will only be null when the agents are managed by admin
name = models.CharField(max_length=200)
personality = models.TextField()
input_tools = ArrayField(models.CharField(max_length=200, choices=InputToolOptions.choices), default=list)
output_modes = ArrayField(models.CharField(max_length=200, choices=OutputModeOptions.choices), default=list)
input_tools = ArrayField(
models.CharField(max_length=200, choices=InputToolOptions.choices), default=list, null=True, blank=True
)
output_modes = ArrayField(
models.CharField(max_length=200, choices=OutputModeOptions.choices), default=list, null=True, blank=True
)
managed_by_admin = models.BooleanField(default=False)
chat_model = models.ForeignKey(ChatModelOptions, on_delete=models.CASCADE)
slug = models.CharField(max_length=200, unique=True)
@@ -444,11 +449,6 @@ class UserVoiceModelConfig(BaseModel):
setting = models.ForeignKey(VoiceModelOption, on_delete=models.CASCADE, default=None, null=True, blank=True)
class UserSearchModelConfig(BaseModel):
user = models.OneToOneField(KhojUser, on_delete=models.CASCADE)
setting = models.ForeignKey(SearchModelConfig, on_delete=models.CASCADE)
class UserTextToImageModelConfig(BaseModel):
user = models.OneToOneField(KhojUser, on_delete=models.CASCADE)
setting = models.ForeignKey(TextToImageModelConfig, on_delete=models.CASCADE)
@@ -458,8 +458,12 @@ class Conversation(BaseModel):
user = models.ForeignKey(KhojUser, on_delete=models.CASCADE)
conversation_log = models.JSONField(default=dict)
client = models.ForeignKey(ClientApplication, on_delete=models.CASCADE, default=None, null=True, blank=True)
# Slug is an app-generated conversation identifier. Need not be unique. Used as display title essentially.
slug = models.CharField(max_length=200, default=None, null=True, blank=True)
title = models.CharField(max_length=200, default=None, null=True, blank=True)
# The title field is explicitly set by the user.
title = models.CharField(max_length=500, default=None, null=True, blank=True)
agent = models.ForeignKey(Agent, on_delete=models.SET_NULL, default=None, null=True, blank=True)
file_filters = models.JSONField(default=list)
id = models.UUIDField(default=uuid.uuid4, editable=False, unique=True, primary_key=True, db_index=True)
@@ -530,6 +534,7 @@ class Entry(BaseModel):
url = models.URLField(max_length=400, default=None, null=True, blank=True)
hashed_value = models.CharField(max_length=100)
corpus_id = models.UUIDField(default=uuid.uuid4, editable=False)
search_model = models.ForeignKey(SearchModelConfig, on_delete=models.SET_NULL, default=None, null=True, blank=True)
def save(self, *args, **kwargs):
if self.user and self.agent:

View File

@@ -1,6 +1,5 @@
import logging
import os
from datetime import datetime
import tempfile
from typing import Dict, List, Tuple
from langchain_community.document_loaders import Docx2txtLoader
@@ -19,7 +18,7 @@ class DocxToEntries(TextToEntries):
super().__init__()
# Define Functions
def process(self, files: dict[str, str] = None, user: KhojUser = None, regenerate: bool = False) -> Tuple[int, int]:
def process(self, files: dict[str, str], user: KhojUser, regenerate: bool = False) -> Tuple[int, int]:
# Extract required fields from config
deletion_file_names = set([file for file in files if files[file] == b""])
files_to_process = set(files) - deletion_file_names
@@ -36,13 +35,13 @@ class DocxToEntries(TextToEntries):
# Identify, mark and merge any new entries with previous entries
with timer("Identify new or updated entries", logger):
num_new_embeddings, num_deleted_embeddings = self.update_embeddings(
user,
current_entries,
DbEntry.EntryType.DOCX,
DbEntry.EntrySource.COMPUTER,
"compiled",
logger,
deletion_file_names,
user,
regenerate=regenerate,
file_to_text_map=file_to_text_map,
)
@@ -58,28 +57,13 @@ class DocxToEntries(TextToEntries):
file_to_text_map = dict()
for docx_file in docx_files:
try:
timestamp_now = datetime.utcnow().timestamp()
tmp_file = f"tmp_docx_file_{timestamp_now}.docx"
with open(tmp_file, "wb") as f:
bytes_content = docx_files[docx_file]
f.write(bytes_content)
# Load the content using Docx2txtLoader
loader = Docx2txtLoader(tmp_file)
docx_entries_per_file = loader.load()
# Convert the loaded entries into the desired format
docx_texts = [page.page_content for page in docx_entries_per_file]
docx_texts = DocxToEntries.extract_text(docx_files[docx_file])
entry_to_location_map += zip(docx_texts, [docx_file] * len(docx_texts))
entries.extend(docx_texts)
file_to_text_map[docx_file] = docx_texts
except Exception as e:
logger.warning(f"Unable to process file: {docx_file}. This file will not be indexed.")
logger.warning(f"Unable to extract entries from file: {docx_file}")
logger.warning(e, exc_info=True)
finally:
if os.path.exists(f"{tmp_file}"):
os.remove(f"{tmp_file}")
return file_to_text_map, DocxToEntries.convert_docx_entries_to_maps(entries, dict(entry_to_location_map))
@staticmethod
@@ -103,3 +87,25 @@ class DocxToEntries(TextToEntries):
logger.debug(f"Converted {len(parsed_entries)} DOCX entries to dictionaries")
return entries
@staticmethod
def extract_text(docx_file):
"""Extract text from specified DOCX file"""
try:
docx_entry_by_pages = []
# Create temp file with .docx extension that gets auto-deleted
with tempfile.NamedTemporaryFile(suffix=".docx", delete=True) as tmp:
tmp.write(docx_file)
tmp.flush() # Ensure all data is written
# Load the content using Docx2txtLoader
loader = Docx2txtLoader(tmp.name)
docx_entries_per_file = loader.load()
# Convert the loaded entries into the desired format
docx_entry_by_pages = [page.page_content for page in docx_entries_per_file]
except Exception as e:
logger.warning(f"Unable to extract text from file: {docx_file}")
logger.warning(e, exc_info=True)
return docx_entry_by_pages

View File

@@ -48,7 +48,7 @@ class GithubToEntries(TextToEntries):
else:
return
def process(self, files: dict[str, str] = None, user: KhojUser = None, regenerate: bool = False) -> Tuple[int, int]:
def process(self, files: dict[str, str], user: KhojUser, regenerate: bool = False) -> Tuple[int, int]:
if self.config.pat_token is None or self.config.pat_token == "":
logger.error(f"Github PAT token is not set. Skipping github content")
raise ValueError("Github PAT token is not set. Skipping github content")
@@ -101,12 +101,12 @@ class GithubToEntries(TextToEntries):
# Identify, mark and merge any new entries with previous entries
with timer("Identify new or updated entries", logger):
num_new_embeddings, num_deleted_embeddings = self.update_embeddings(
user,
current_entries,
DbEntry.EntryType.GITHUB,
DbEntry.EntrySource.GITHUB,
key="compiled",
logger=logger,
user=user,
)
return num_new_embeddings, num_deleted_embeddings

View File

@@ -18,7 +18,7 @@ class ImageToEntries(TextToEntries):
super().__init__()
# Define Functions
def process(self, files: dict[str, str] = None, user: KhojUser = None, regenerate: bool = False) -> Tuple[int, int]:
def process(self, files: dict[str, str], user: KhojUser, regenerate: bool = False) -> Tuple[int, int]:
# Extract required fields from config
deletion_file_names = set([file for file in files if files[file] == b""])
files_to_process = set(files) - deletion_file_names
@@ -35,13 +35,13 @@ class ImageToEntries(TextToEntries):
# Identify, mark and merge any new entries with previous entries
with timer("Identify new or updated entries", logger):
num_new_embeddings, num_deleted_embeddings = self.update_embeddings(
user,
current_entries,
DbEntry.EntryType.IMAGE,
DbEntry.EntrySource.COMPUTER,
"compiled",
logger,
deletion_file_names,
user,
regenerate=regenerate,
file_to_text_map=file_to_text_map,
)

View File

@@ -19,7 +19,7 @@ class MarkdownToEntries(TextToEntries):
super().__init__()
# Define Functions
def process(self, files: dict[str, str] = None, user: KhojUser = None, regenerate: bool = False) -> Tuple[int, int]:
def process(self, files: dict[str, str], user: KhojUser, regenerate: bool = False) -> Tuple[int, int]:
# Extract required fields from config
deletion_file_names = set([file for file in files if files[file] == ""])
files_to_process = set(files) - deletion_file_names
@@ -37,13 +37,13 @@ class MarkdownToEntries(TextToEntries):
# Identify, mark and merge any new entries with previous entries
with timer("Identify new or updated entries", logger):
num_new_embeddings, num_deleted_embeddings = self.update_embeddings(
user,
current_entries,
DbEntry.EntryType.MARKDOWN,
DbEntry.EntrySource.COMPUTER,
"compiled",
logger,
deletion_file_names,
user,
regenerate=regenerate,
file_to_text_map=file_to_text_map,
)

View File

@@ -79,7 +79,7 @@ class NotionToEntries(TextToEntries):
self.body_params = {"page_size": 100}
def process(self, files: dict[str, str] = None, user: KhojUser = None, regenerate: bool = False) -> Tuple[int, int]:
def process(self, files: dict[str, str], user: KhojUser, regenerate: bool = False) -> Tuple[int, int]:
current_entries = []
# Get all pages
@@ -248,12 +248,12 @@ class NotionToEntries(TextToEntries):
# Identify, mark and merge any new entries with previous entries
with timer("Identify new or updated entries", logger):
num_new_embeddings, num_deleted_embeddings = self.update_embeddings(
user,
current_entries,
DbEntry.EntryType.NOTION,
DbEntry.EntrySource.NOTION,
key="compiled",
logger=logger,
user=user,
)
return num_new_embeddings, num_deleted_embeddings

View File

@@ -20,7 +20,7 @@ class OrgToEntries(TextToEntries):
super().__init__()
# Define Functions
def process(self, files: dict[str, str] = None, user: KhojUser = None, regenerate: bool = False) -> Tuple[int, int]:
def process(self, files: dict[str, str], user: KhojUser, regenerate: bool = False) -> Tuple[int, int]:
deletion_file_names = set([file for file in files if files[file] == ""])
files_to_process = set(files) - deletion_file_names
files = {file: files[file] for file in files_to_process}
@@ -36,13 +36,13 @@ class OrgToEntries(TextToEntries):
# Identify, mark and merge any new entries with previous entries
with timer("Identify new or updated entries", logger):
num_new_embeddings, num_deleted_embeddings = self.update_embeddings(
user,
current_entries,
DbEntry.EntryType.ORG,
DbEntry.EntrySource.COMPUTER,
"compiled",
logger,
deletion_file_names,
user,
regenerate=regenerate,
file_to_text_map=file_to_text_map,
)

View File

@@ -49,7 +49,7 @@ def normalize_filename(filename):
normalized_filename = f"~/{relpath(filename, start=Path.home())}"
else:
normalized_filename = filename
escaped_filename = f"{normalized_filename}".replace("[", "\[").replace("]", "\]")
escaped_filename = f"{normalized_filename}".replace("[", r"\[").replace("]", r"\]")
return escaped_filename

View File

@@ -1,13 +1,9 @@
import base64
import logging
import os
from datetime import datetime
from typing import Dict, List, Tuple
import tempfile
from typing import Dict, Final, List, Tuple
from langchain_community.document_loaders import PyMuPDFLoader
# importing FileObjectAdapter so that we can add new files and debug file object db.
# from khoj.database.adapters import FileObjectAdapters
from khoj.database.models import Entry as DbEntry
from khoj.database.models import KhojUser
from khoj.processor.content.text_to_entries import TextToEntries
@@ -18,11 +14,14 @@ logger = logging.getLogger(__name__)
class PdfToEntries(TextToEntries):
# Class-level constant translation table
NULL_TRANSLATOR: Final = str.maketrans("", "", "\x00")
def __init__(self):
super().__init__()
# Define Functions
def process(self, files: dict[str, str] = None, user: KhojUser = None, regenerate: bool = False) -> Tuple[int, int]:
def process(self, files: dict[str, str], user: KhojUser, regenerate: bool = False) -> Tuple[int, int]:
# Extract required fields from config
deletion_file_names = set([file for file in files if files[file] == b""])
files_to_process = set(files) - deletion_file_names
@@ -39,13 +38,13 @@ class PdfToEntries(TextToEntries):
# Identify, mark and merge any new entries with previous entries
with timer("Identify new or updated entries", logger):
num_new_embeddings, num_deleted_embeddings = self.update_embeddings(
user,
current_entries,
DbEntry.EntryType.PDF,
DbEntry.EntrySource.COMPUTER,
"compiled",
logger,
deletion_file_names,
user,
regenerate=regenerate,
file_to_text_map=file_to_text_map,
)
@@ -60,31 +59,13 @@ class PdfToEntries(TextToEntries):
entry_to_location_map: List[Tuple[str, str]] = []
for pdf_file in pdf_files:
try:
# Write the PDF file to a temporary file, as it is stored in byte format in the pdf_file object and the PDF Loader expects a file path
timestamp_now = datetime.utcnow().timestamp()
tmp_file = f"tmp_pdf_file_{timestamp_now}.pdf"
with open(f"{tmp_file}", "wb") as f:
bytes = pdf_files[pdf_file]
f.write(bytes)
try:
loader = PyMuPDFLoader(f"{tmp_file}", extract_images=False)
pdf_entries_per_file = [page.page_content for page in loader.load()]
except ImportError:
loader = PyMuPDFLoader(f"{tmp_file}")
pdf_entries_per_file = [
page.page_content for page in loader.load()
] # page_content items list for a given pdf.
entry_to_location_map += zip(
pdf_entries_per_file, [pdf_file] * len(pdf_entries_per_file)
) # this is an indexed map of pdf_entries for the pdf.
pdf_entries_per_file = PdfToEntries.extract_text(pdf_files[pdf_file])
entry_to_location_map += zip(pdf_entries_per_file, [pdf_file] * len(pdf_entries_per_file))
entries.extend(pdf_entries_per_file)
file_to_text_map[pdf_file] = pdf_entries_per_file
except Exception as e:
logger.warning(f"Unable to process file: {pdf_file}. This file will not be indexed.")
logger.warning(f"Unable to extract entries from file: {pdf_file}")
logger.warning(e, exc_info=True)
finally:
if os.path.exists(f"{tmp_file}"):
os.remove(f"{tmp_file}")
return file_to_text_map, PdfToEntries.convert_pdf_entries_to_maps(entries, dict(entry_to_location_map))
@@ -109,3 +90,30 @@ class PdfToEntries(TextToEntries):
logger.debug(f"Converted {len(parsed_entries)} PDF entries to dictionaries")
return entries
@staticmethod
def extract_text(pdf_file):
"""Extract text from specified PDF files"""
try:
# Create temp file with .pdf extension that gets auto-deleted
with tempfile.NamedTemporaryFile(suffix=".pdf", delete=True) as tmpf:
tmpf.write(pdf_file)
tmpf.flush() # Ensure all data is written
# Load the content using PyMuPDFLoader
loader = PyMuPDFLoader(tmpf.name, extract_images=True)
pdf_entries_per_file = loader.load()
# Convert the loaded entries into the desired format
pdf_entry_by_pages = [PdfToEntries.clean_text(page.page_content) for page in pdf_entries_per_file]
except Exception as e:
logger.warning(f"Unable to process file: {pdf_file}. This file will not be indexed.")
logger.warning(e, exc_info=True)
return pdf_entry_by_pages
@staticmethod
def clean_text(text: str) -> str:
"""Clean PDF text by removing null bytes and invalid Unicode characters."""
# Use faster translation table instead of replace
return text.translate(PdfToEntries.NULL_TRANSLATOR)

View File

@@ -20,7 +20,7 @@ class PlaintextToEntries(TextToEntries):
super().__init__()
# Define Functions
def process(self, files: dict[str, str] = None, user: KhojUser = None, regenerate: bool = False) -> Tuple[int, int]:
def process(self, files: dict[str, str], user: KhojUser, regenerate: bool = False) -> Tuple[int, int]:
deletion_file_names = set([file for file in files if files[file] == ""])
files_to_process = set(files) - deletion_file_names
files = {file: files[file] for file in files_to_process}
@@ -36,13 +36,13 @@ class PlaintextToEntries(TextToEntries):
# Identify, mark and merge any new entries with previous entries
with timer("Identify new or updated entries", logger):
num_new_embeddings, num_deleted_embeddings = self.update_embeddings(
user,
current_entries,
DbEntry.EntryType.PLAINTEXT,
DbEntry.EntrySource.COMPUTER,
key="compiled",
logger=logger,
deletion_filenames=deletion_file_names,
user=user,
regenerate=regenerate,
file_to_text_map=file_to_text_map,
)

View File

@@ -12,7 +12,7 @@ from tqdm import tqdm
from khoj.database.adapters import (
EntryAdapters,
FileObjectAdapters,
get_user_search_model_or_default,
get_default_search_model,
)
from khoj.database.models import Entry as DbEntry
from khoj.database.models import EntryDates, KhojUser
@@ -31,7 +31,7 @@ class TextToEntries(ABC):
self.date_filter = DateFilter()
@abstractmethod
def process(self, files: dict[str, str] = None, user: KhojUser = None, regenerate: bool = False) -> Tuple[int, int]:
def process(self, files: dict[str, str], user: KhojUser, regenerate: bool = False) -> Tuple[int, int]:
...
@staticmethod
@@ -114,13 +114,13 @@ class TextToEntries(ABC):
def update_embeddings(
self,
user: KhojUser,
current_entries: List[Entry],
file_type: str,
file_source: str,
key="compiled",
logger: logging.Logger = None,
deletion_filenames: Set[str] = None,
user: KhojUser = None,
regenerate: bool = False,
file_to_text_map: dict[str, str] = None,
):
@@ -148,10 +148,10 @@ class TextToEntries(ABC):
hashes_to_process |= hashes_for_file - existing_entry_hashes
embeddings = []
model = get_default_search_model()
with timer("Generated embeddings for entries to add to database in", logger):
entries_to_process = [hash_to_current_entries[hashed_val] for hashed_val in hashes_to_process]
data_to_embed = [getattr(entry, key) for entry in entries_to_process]
model = get_user_search_model_or_default(user)
embeddings += self.embeddings_model[model.name].embed_documents(data_to_embed)
added_entries: list[DbEntry] = []
@@ -177,6 +177,7 @@ class TextToEntries(ABC):
file_type=file_type,
hashed_value=entry_hash,
corpus_id=entry.corpus_id,
search_model=model,
)
)
try:

View File

@@ -11,10 +11,16 @@ from khoj.processor.conversation import prompts
from khoj.processor.conversation.anthropic.utils import (
anthropic_chat_completion_with_backoff,
anthropic_completion_with_backoff,
format_messages_for_anthropic,
)
from khoj.processor.conversation.utils import (
clean_json,
construct_structured_message,
generate_chatml_messages_with_context,
)
from khoj.processor.conversation.utils import generate_chatml_messages_with_context
from khoj.utils.helpers import ConversationCommand, is_none_or_empty
from khoj.utils.rawconfig import LocationData
from khoj.utils.yaml import yaml_dump
logger = logging.getLogger(__name__)
@@ -27,7 +33,11 @@ def extract_questions_anthropic(
temperature=0.7,
location_data: LocationData = None,
user: KhojUser = None,
query_images: Optional[list[str]] = None,
vision_enabled: bool = False,
personality_context: Optional[str] = None,
query_files: str = None,
tracer: dict = {},
):
"""
Infer search queries to retrieve relevant notes to answer user query
@@ -68,7 +78,19 @@ def extract_questions_anthropic(
text=text,
)
messages = [ChatMessage(content=prompt, role="user")]
prompt = construct_structured_message(
message=prompt,
images=query_images,
model_type=ChatModelOptions.ModelType.ANTHROPIC,
vision_enabled=vision_enabled,
attached_file_context=query_files,
)
messages = []
messages.append(ChatMessage(content=prompt, role="user"))
messages, system_prompt = format_messages_for_anthropic(messages, system_prompt)
response = anthropic_completion_with_backoff(
messages=messages,
@@ -76,14 +98,13 @@ def extract_questions_anthropic(
model_name=model,
temperature=temperature,
api_key=api_key,
response_type="json_object",
tracer=tracer,
)
# Extract, Clean Message from Claude's Response
try:
response = response.strip()
match = re.search(r"\{.*?\}", response)
if match:
response = match.group()
response = clean_json(response)
response = json.loads(response)
response = [q.strip() for q in response["queries"] if q.strip()]
if not isinstance(response, list) or not response:
@@ -97,21 +118,11 @@ def extract_questions_anthropic(
return questions
def anthropic_send_message_to_model(messages, api_key, model):
def anthropic_send_message_to_model(messages, api_key, model, response_type="text", tracer={}):
"""
Send message to model
"""
# Anthropic requires the first message to be a 'user' message, and the system prompt is not to be sent in the messages parameter
system_prompt = None
if len(messages) == 1:
messages[0].role = "user"
else:
system_prompt = ""
for message in messages.copy():
if message.role == "system":
system_prompt += message.content
messages.remove(message)
messages, system_prompt = format_messages_for_anthropic(messages)
# Get Response from GPT. Don't use response_type because Anthropic doesn't support it.
return anthropic_completion_with_backoff(
@@ -119,6 +130,8 @@ def anthropic_send_message_to_model(messages, api_key, model):
system_prompt=system_prompt,
model_name=model,
api_key=api_key,
response_type=response_type,
tracer=tracer,
)
@@ -126,8 +139,9 @@ def converse_anthropic(
references,
user_query,
online_results: Optional[Dict[str, Dict]] = None,
code_results: Optional[Dict[str, Dict]] = None,
conversation_log={},
model: Optional[str] = "claude-instant-1.2",
model: Optional[str] = "claude-3-5-sonnet-20241022",
api_key: Optional[str] = None,
completion_func=None,
conversation_commands=[ConversationCommand.Default],
@@ -136,15 +150,16 @@ def converse_anthropic(
location_data: LocationData = None,
user_name: str = None,
agent: Agent = None,
query_images: Optional[list[str]] = None,
vision_available: bool = False,
query_files: str = None,
tracer: dict = {},
):
"""
Converse with user using Anthropic's Claude
"""
# Initialize Variables
current_date = datetime.now()
compiled_references = "\n\n".join({f"# {item}" for item in references})
conversation_primer = prompts.query_prompt.format(query=user_query)
if agent and agent.personality:
system_prompt = prompts.custom_personality.format(
@@ -168,38 +183,37 @@ def converse_anthropic(
system_prompt = f"{system_prompt}\n{user_name_prompt}"
# Get Conversation Primer appropriate to Conversation Type
if conversation_commands == [ConversationCommand.Notes] and is_none_or_empty(compiled_references):
if conversation_commands == [ConversationCommand.Notes] and is_none_or_empty(references):
completion_func(chat_response=prompts.no_notes_found.format())
return iter([prompts.no_notes_found.format()])
elif conversation_commands == [ConversationCommand.Online] and is_none_or_empty(online_results):
completion_func(chat_response=prompts.no_online_results_found.format())
return iter([prompts.no_online_results_found.format()])
context_message = ""
if not is_none_or_empty(references):
context_message = f"{prompts.notes_conversation.format(query=user_query, references=yaml_dump(references))}\n\n"
if ConversationCommand.Online in conversation_commands or ConversationCommand.Webpage in conversation_commands:
conversation_primer = (
f"{prompts.online_search_conversation.format(online_results=str(online_results))}\n{conversation_primer}"
)
if not is_none_or_empty(compiled_references):
conversation_primer = f"{prompts.notes_conversation.format(query=user_query, references=compiled_references)}\n\n{conversation_primer}"
context_message += f"{prompts.online_search_conversation.format(online_results=yaml_dump(online_results))}\n\n"
if ConversationCommand.Code in conversation_commands and not is_none_or_empty(code_results):
context_message += f"{prompts.code_executed_context.format(code_results=str(code_results))}\n\n"
context_message = context_message.strip()
# Setup Prompt with Primer or Conversation History
messages = generate_chatml_messages_with_context(
conversation_primer,
user_query,
context_message=context_message,
conversation_log=conversation_log,
model_name=model,
max_prompt_size=max_prompt_size,
tokenizer_name=tokenizer_name,
query_images=query_images,
vision_enabled=vision_available,
model_type=ChatModelOptions.ModelType.ANTHROPIC,
query_files=query_files,
)
if len(messages) > 1:
if messages[0].role == "assistant":
messages = messages[1:]
for message in messages.copy():
if message.role == "system":
system_prompt += message.content
messages.remove(message)
messages, system_prompt = format_messages_for_anthropic(messages, system_prompt)
truncated_messages = "\n".join({f"{message.content[:40]}..." for message in messages})
logger.debug(f"Conversation Context for Claude: {truncated_messages}")
@@ -215,4 +229,5 @@ def converse_anthropic(
system_prompt=system_prompt,
completion_func=completion_func,
max_prompt_size=max_prompt_size,
tracer=tracer,
)

View File

@@ -3,6 +3,7 @@ from threading import Thread
from typing import Dict, List
import anthropic
from langchain.schema import ChatMessage
from tenacity import (
before_sleep_log,
retry,
@@ -11,7 +12,13 @@ from tenacity import (
wait_random_exponential,
)
from khoj.processor.conversation.utils import ThreadedGenerator
from khoj.processor.conversation.utils import (
ThreadedGenerator,
commit_conversation_trace,
get_image_from_url,
)
from khoj.utils import state
from khoj.utils.helpers import in_debug_mode, is_none_or_empty
logger = logging.getLogger(__name__)
@@ -28,7 +35,15 @@ DEFAULT_MAX_TOKENS_ANTHROPIC = 3000
reraise=True,
)
def anthropic_completion_with_backoff(
messages, system_prompt, model_name, temperature=0, api_key=None, model_kwargs=None, max_tokens=None
messages,
system_prompt,
model_name,
temperature=0,
api_key=None,
model_kwargs=None,
max_tokens=None,
response_type="text",
tracer={},
) -> str:
if api_key not in anthropic_clients:
client: anthropic.Anthropic = anthropic.Anthropic(api_key=api_key)
@@ -37,8 +52,11 @@ def anthropic_completion_with_backoff(
client = anthropic_clients[api_key]
formatted_messages = [{"role": message.role, "content": message.content} for message in messages]
if response_type == "json_object":
# Prefill model response with '{' to make it output a valid JSON object
formatted_messages += [{"role": "assistant", "content": "{"}]
aggregated_response = ""
aggregated_response = "{" if response_type == "json_object" else ""
max_tokens = max_tokens or DEFAULT_MAX_TOKENS_ANTHROPIC
model_kwargs = model_kwargs or dict()
@@ -56,6 +74,12 @@ def anthropic_completion_with_backoff(
for text in stream.text_stream:
aggregated_response += text
# Save conversation trace
tracer["chat_model"] = model_name
tracer["temperature"] = temperature
if in_debug_mode() or state.verbose > 1:
commit_conversation_trace(messages, aggregated_response, tracer)
return aggregated_response
@@ -76,18 +100,19 @@ def anthropic_chat_completion_with_backoff(
max_prompt_size=None,
completion_func=None,
model_kwargs=None,
tracer={},
):
g = ThreadedGenerator(compiled_references, online_results, completion_func=completion_func)
t = Thread(
target=anthropic_llm_thread,
args=(g, messages, system_prompt, model_name, temperature, api_key, max_prompt_size, model_kwargs),
args=(g, messages, system_prompt, model_name, temperature, api_key, max_prompt_size, model_kwargs, tracer),
)
t.start()
return g
def anthropic_llm_thread(
g, messages, system_prompt, model_name, temperature, api_key, max_prompt_size=None, model_kwargs=None
g, messages, system_prompt, model_name, temperature, api_key, max_prompt_size=None, model_kwargs=None, tracer={}
):
try:
if api_key not in anthropic_clients:
@@ -100,6 +125,7 @@ def anthropic_llm_thread(
anthropic.types.MessageParam(role=message.role, content=message.content) for message in messages
]
aggregated_response = ""
with client.messages.stream(
messages=formatted_messages,
model=model_name, # type: ignore
@@ -110,8 +136,63 @@ def anthropic_llm_thread(
**(model_kwargs or dict()),
) as stream:
for text in stream.text_stream:
aggregated_response += text
g.send(text)
# Save conversation trace
tracer["chat_model"] = model_name
tracer["temperature"] = temperature
if in_debug_mode() or state.verbose > 1:
commit_conversation_trace(messages, aggregated_response, tracer)
except Exception as e:
logger.error(f"Error in anthropic_llm_thread: {e}", exc_info=True)
finally:
g.close()
def format_messages_for_anthropic(messages: list[ChatMessage], system_prompt=None):
"""
Format messages for Anthropic
"""
# Extract system prompt
system_prompt = system_prompt or ""
for message in messages.copy():
if message.role == "system":
system_prompt += message.content
messages.remove(message)
system_prompt = None if is_none_or_empty(system_prompt) else system_prompt
# Anthropic requires the first message to be a 'user' message
if len(messages) == 1:
messages[0].role = "user"
elif len(messages) > 1 and messages[0].role == "assistant":
messages = messages[1:]
# Convert image urls to base64 encoded images in Anthropic message format
for message in messages:
if isinstance(message.content, list):
content = []
# Sort the content. Anthropic models prefer that text comes after images.
message.content.sort(key=lambda x: 0 if x["type"] == "image_url" else 1)
for idx, part in enumerate(message.content):
if part["type"] == "text":
content.append({"type": "text", "text": part["text"]})
elif part["type"] == "image_url":
image = get_image_from_url(part["image_url"]["url"], type="b64")
# Prefix each image with text block enumerating the image number
# This helps the model reference the image in its response. Recommended by Anthropic
content.extend(
[
{
"type": "text",
"text": f"Image {idx + 1}:",
},
{
"type": "image",
"source": {"type": "base64", "media_type": image.type, "data": image.content},
},
]
)
message.content = content
return messages, system_prompt

View File

@@ -6,16 +6,21 @@ from typing import Dict, Optional
from langchain.schema import ChatMessage
from khoj.database.models import Agent, KhojUser
from khoj.database.models import Agent, ChatModelOptions, KhojUser
from khoj.processor.conversation import prompts
from khoj.processor.conversation.google.utils import (
format_messages_for_gemini,
gemini_chat_completion_with_backoff,
gemini_completion_with_backoff,
)
from khoj.processor.conversation.utils import generate_chatml_messages_with_context
from khoj.processor.conversation.utils import (
clean_json,
construct_structured_message,
generate_chatml_messages_with_context,
)
from khoj.utils.helpers import ConversationCommand, is_none_or_empty
from khoj.utils.rawconfig import LocationData
from khoj.utils.yaml import yaml_dump
logger = logging.getLogger(__name__)
@@ -29,7 +34,11 @@ def extract_questions_gemini(
max_tokens=None,
location_data: LocationData = None,
user: KhojUser = None,
query_images: Optional[list[str]] = None,
vision_enabled: bool = False,
personality_context: Optional[str] = None,
query_files: str = None,
tracer: dict = {},
):
"""
Infer search queries to retrieve relevant notes to answer user query
@@ -70,25 +79,26 @@ def extract_questions_gemini(
text=text,
)
messages = [ChatMessage(content=prompt, role="user")]
prompt = construct_structured_message(
message=prompt,
images=query_images,
model_type=ChatModelOptions.ModelType.GOOGLE,
vision_enabled=vision_enabled,
attached_file_context=query_files,
)
model_kwargs = {"response_mime_type": "application/json"}
messages = []
response = gemini_completion_with_backoff(
messages=messages,
system_prompt=system_prompt,
model_name=model,
temperature=temperature,
api_key=api_key,
model_kwargs=model_kwargs,
messages.append(ChatMessage(content=prompt, role="user"))
messages.append(ChatMessage(content=system_prompt, role="system"))
response = gemini_send_message_to_model(
messages, api_key, model, response_type="json_object", temperature=temperature, tracer=tracer
)
# Extract, Clean Message from Gemini's Response
try:
response = response.strip()
match = re.search(r"\{.*?\}", response)
if match:
response = match.group()
response = clean_json(response)
response = json.loads(response)
response = [q.strip() for q in response["queries"] if q.strip()]
if not isinstance(response, list) or not response:
@@ -102,19 +112,35 @@ def extract_questions_gemini(
return questions
def gemini_send_message_to_model(messages, api_key, model, response_type="text"):
def gemini_send_message_to_model(
messages,
api_key,
model,
response_type="text",
temperature=0,
model_kwargs=None,
tracer={},
):
"""
Send message to model
"""
messages, system_prompt = format_messages_for_gemini(messages)
model_kwargs = {}
if response_type == "json_object":
model_kwargs["response_mime_type"] = "application/json"
# Sometimes, this causes unwanted behavior and terminates response early. Disable for now while it's flaky.
# if response_type == "json_object":
# model_kwargs["response_mime_type"] = "application/json"
# Get Response from Gemini
return gemini_completion_with_backoff(
messages=messages, system_prompt=system_prompt, model_name=model, api_key=api_key, model_kwargs=model_kwargs
messages=messages,
system_prompt=system_prompt,
model_name=model,
api_key=api_key,
temperature=temperature,
model_kwargs=model_kwargs,
tracer=tracer,
)
@@ -122,6 +148,7 @@ def converse_gemini(
references,
user_query,
online_results: Optional[Dict[str, Dict]] = None,
code_results: Optional[Dict[str, Dict]] = None,
conversation_log={},
model: Optional[str] = "gemini-1.5-flash",
api_key: Optional[str] = None,
@@ -133,15 +160,16 @@ def converse_gemini(
location_data: LocationData = None,
user_name: str = None,
agent: Agent = None,
query_images: Optional[list[str]] = None,
vision_available: bool = False,
query_files: str = None,
tracer={},
):
"""
Converse with user using Google's Gemini
"""
# Initialize Variables
current_date = datetime.now()
compiled_references = "\n\n".join({f"# {item}" for item in references})
conversation_primer = prompts.query_prompt.format(query=user_query)
if agent and agent.personality:
system_prompt = prompts.custom_personality.format(
@@ -166,27 +194,34 @@ def converse_gemini(
system_prompt = f"{system_prompt}\n{user_name_prompt}"
# Get Conversation Primer appropriate to Conversation Type
if conversation_commands == [ConversationCommand.Notes] and is_none_or_empty(compiled_references):
if conversation_commands == [ConversationCommand.Notes] and is_none_or_empty(references):
completion_func(chat_response=prompts.no_notes_found.format())
return iter([prompts.no_notes_found.format()])
elif conversation_commands == [ConversationCommand.Online] and is_none_or_empty(online_results):
completion_func(chat_response=prompts.no_online_results_found.format())
return iter([prompts.no_online_results_found.format()])
context_message = ""
if not is_none_or_empty(references):
context_message = f"{prompts.notes_conversation.format(query=user_query, references=yaml_dump(references))}\n\n"
if ConversationCommand.Online in conversation_commands or ConversationCommand.Webpage in conversation_commands:
conversation_primer = (
f"{prompts.online_search_conversation.format(online_results=str(online_results))}\n{conversation_primer}"
)
if not is_none_or_empty(compiled_references):
conversation_primer = f"{prompts.notes_conversation.format(query=user_query, references=compiled_references)}\n\n{conversation_primer}"
context_message += f"{prompts.online_search_conversation.format(online_results=yaml_dump(online_results))}\n\n"
if ConversationCommand.Code in conversation_commands and not is_none_or_empty(code_results):
context_message += f"{prompts.code_executed_context.format(code_results=str(code_results))}\n\n"
context_message = context_message.strip()
# Setup Prompt with Primer or Conversation History
messages = generate_chatml_messages_with_context(
conversation_primer,
user_query,
context_message=context_message,
conversation_log=conversation_log,
model_name=model,
max_prompt_size=max_prompt_size,
tokenizer_name=tokenizer_name,
query_images=query_images,
vision_enabled=vision_available,
model_type=ChatModelOptions.ModelType.GOOGLE,
query_files=query_files,
)
messages, system_prompt = format_messages_for_gemini(messages, system_prompt)
@@ -204,4 +239,5 @@ def converse_gemini(
api_key=api_key,
system_prompt=system_prompt,
completion_func=completion_func,
tracer=tracer,
)

View File

@@ -19,8 +19,13 @@ from tenacity import (
wait_random_exponential,
)
from khoj.processor.conversation.utils import ThreadedGenerator
from khoj.utils.helpers import is_none_or_empty
from khoj.processor.conversation.utils import (
ThreadedGenerator,
commit_conversation_trace,
get_image_from_url,
)
from khoj.utils import state
from khoj.utils.helpers import in_debug_mode, is_none_or_empty
logger = logging.getLogger(__name__)
@@ -35,7 +40,7 @@ MAX_OUTPUT_TOKENS_GEMINI = 8192
reraise=True,
)
def gemini_completion_with_backoff(
messages, system_prompt, model_name, temperature=0, api_key=None, model_kwargs=None
messages, system_prompt, model_name, temperature=0, api_key=None, model_kwargs=None, tracer={}
) -> str:
genai.configure(api_key=api_key)
model_kwargs = model_kwargs or dict()
@@ -53,23 +58,30 @@ def gemini_completion_with_backoff(
},
)
formatted_messages = [{"role": message.role, "parts": [message.content]} for message in messages]
formatted_messages = [{"role": message.role, "parts": message.content} for message in messages]
# Start chat session. All messages up to the last are considered to be part of the chat history
chat_session = model.start_chat(history=formatted_messages[0:-1])
try:
# Generate the response. The last message is considered to be the current prompt
aggregated_response = chat_session.send_message(formatted_messages[-1]["parts"][0])
return aggregated_response.text
response = chat_session.send_message(formatted_messages[-1]["parts"])
response_text = response.text
except StopCandidateException as e:
response_message, _ = handle_gemini_response(e.args)
response_text, _ = handle_gemini_response(e.args)
# Respond with reason for stopping
logger.warning(
f"LLM Response Prevented for {model_name}: {response_message}.\n"
f"LLM Response Prevented for {model_name}: {response_text}.\n"
+ f"Last Message by {messages[-1].role}: {messages[-1].content}"
)
return response_message
# Save conversation trace
tracer["chat_model"] = model_name
tracer["temperature"] = temperature
if in_debug_mode() or state.verbose > 1:
commit_conversation_trace(messages, response_text, tracer)
return response_text
@retry(
@@ -88,17 +100,20 @@ def gemini_chat_completion_with_backoff(
system_prompt,
completion_func=None,
model_kwargs=None,
tracer: dict = {},
):
g = ThreadedGenerator(compiled_references, online_results, completion_func=completion_func)
t = Thread(
target=gemini_llm_thread,
args=(g, messages, system_prompt, model_name, temperature, api_key, model_kwargs),
args=(g, messages, system_prompt, model_name, temperature, api_key, model_kwargs, tracer),
)
t.start()
return g
def gemini_llm_thread(g, messages, system_prompt, model_name, temperature, api_key, model_kwargs=None):
def gemini_llm_thread(
g, messages, system_prompt, model_name, temperature, api_key, model_kwargs=None, tracer: dict = {}
):
try:
genai.configure(api_key=api_key)
model_kwargs = model_kwargs or dict()
@@ -117,16 +132,25 @@ def gemini_llm_thread(g, messages, system_prompt, model_name, temperature, api_k
},
)
formatted_messages = [{"role": message.role, "parts": [message.content]} for message in messages]
aggregated_response = ""
formatted_messages = [{"role": message.role, "parts": message.content} for message in messages]
# all messages up to the last are considered to be part of the chat history
chat_session = model.start_chat(history=formatted_messages[0:-1])
# the last message is considered to be the current prompt
for chunk in chat_session.send_message(formatted_messages[-1]["parts"][0], stream=True):
for chunk in chat_session.send_message(formatted_messages[-1]["parts"], stream=True):
message, stopped = handle_gemini_response(chunk.candidates, chunk.prompt_feedback)
message = message or chunk.text
aggregated_response += message
g.send(message)
if stopped:
raise StopCandidateException(message)
# Save conversation trace
tracer["chat_model"] = model_name
tracer["temperature"] = temperature
if in_debug_mode() or state.verbose > 1:
commit_conversation_trace(messages, aggregated_response, tracer)
except StopCandidateException as e:
logger.warning(
f"LLM Response Prevented for {model_name}: {e.args[0]}.\n"
@@ -191,14 +215,6 @@ def generate_safety_response(safety_ratings):
def format_messages_for_gemini(messages: list[ChatMessage], system_prompt: str = None) -> tuple[list[str], str]:
if len(messages) == 1:
messages[0].role = "user"
return messages, system_prompt
for message in messages:
if message.role == "assistant":
message.role = "model"
# Extract system message
system_prompt = system_prompt or ""
for message in messages.copy():
@@ -207,4 +223,23 @@ def format_messages_for_gemini(messages: list[ChatMessage], system_prompt: str =
messages.remove(message)
system_prompt = None if is_none_or_empty(system_prompt) else system_prompt
for message in messages:
# Convert message content to string list from chatml dictionary list
if isinstance(message.content, list):
# Convert image_urls to PIL.Image and place them at beginning of list (better for Gemini)
message.content = [
get_image_from_url(item["image_url"]["url"]).content
if item["type"] == "image_url"
else item.get("text", "")
for item in sorted(message.content, key=lambda x: 0 if x["type"] == "image_url" else 1)
]
elif isinstance(message.content, str):
message.content = [message.content]
if message.role == "assistant":
message.role = "model"
if len(messages) == 1:
messages[0].role = "user"
return messages, system_prompt

View File

@@ -1,5 +1,6 @@
import json
import logging
import os
from datetime import datetime, timedelta
from threading import Thread
from typing import Any, Iterator, List, Optional, Union
@@ -12,12 +13,14 @@ from khoj.processor.conversation import prompts
from khoj.processor.conversation.offline.utils import download_model
from khoj.processor.conversation.utils import (
ThreadedGenerator,
commit_conversation_trace,
generate_chatml_messages_with_context,
)
from khoj.utils import state
from khoj.utils.constants import empty_escape_sequences
from khoj.utils.helpers import ConversationCommand, is_none_or_empty
from khoj.utils.helpers import ConversationCommand, in_debug_mode, is_none_or_empty
from khoj.utils.rawconfig import LocationData
from khoj.utils.yaml import yaml_dump
logger = logging.getLogger(__name__)
@@ -34,6 +37,8 @@ def extract_questions_offline(
max_prompt_size: int = None,
temperature: float = 0.7,
personality_context: Optional[str] = None,
query_files: str = None,
tracer: dict = {},
) -> List[str]:
"""
Infer search queries to retrieve relevant notes to answer user query
@@ -83,6 +88,7 @@ def extract_questions_offline(
loaded_model=offline_chat_model,
max_prompt_size=max_prompt_size,
model_type=ChatModelOptions.ModelType.OFFLINE,
query_files=query_files,
)
state.chat_lock.acquire()
@@ -94,6 +100,7 @@ def extract_questions_offline(
max_prompt_size=max_prompt_size,
temperature=temperature,
response_type="json_object",
tracer=tracer,
)
finally:
state.chat_lock.release()
@@ -135,7 +142,8 @@ def filter_questions(questions: List[str]):
def converse_offline(
user_query,
references=[],
online_results=[],
online_results={},
code_results={},
conversation_log={},
model: str = "bartowski/Meta-Llama-3.1-8B-Instruct-GGUF",
loaded_model: Union[Any, None] = None,
@@ -146,6 +154,8 @@ def converse_offline(
location_data: LocationData = None,
user_name: str = None,
agent: Agent = None,
query_files: str = None,
tracer: dict = {},
) -> Union[ThreadedGenerator, Iterator[str]]:
"""
Converse with user using Llama
@@ -153,8 +163,7 @@ def converse_offline(
# Initialize Variables
assert loaded_model is None or isinstance(loaded_model, Llama), "loaded_model must be of type Llama, if configured"
offline_chat_model = loaded_model or download_model(model, max_tokens=max_prompt_size)
compiled_references_message = "\n\n".join({f"{item['compiled']}" for item in references})
tracer["chat_model"] = model
current_date = datetime.now()
if agent and agent.personality:
@@ -170,8 +179,6 @@ def converse_offline(
day_of_week=current_date.strftime("%A"),
)
conversation_primer = prompts.query_prompt.format(query=user_query)
if location_data:
location_prompt = prompts.user_location.format(location=f"{location_data}")
system_prompt = f"{system_prompt}\n{location_prompt}"
@@ -181,45 +188,52 @@ def converse_offline(
system_prompt = f"{system_prompt}\n{user_name_prompt}"
# Get Conversation Primer appropriate to Conversation Type
if conversation_commands == [ConversationCommand.Notes] and is_none_or_empty(compiled_references_message):
if conversation_commands == [ConversationCommand.Notes] and is_none_or_empty(references):
return iter([prompts.no_notes_found.format()])
elif conversation_commands == [ConversationCommand.Online] and is_none_or_empty(online_results):
completion_func(chat_response=prompts.no_online_results_found.format())
return iter([prompts.no_online_results_found.format()])
if ConversationCommand.Online in conversation_commands:
context_message = ""
if not is_none_or_empty(references):
context_message = f"{prompts.notes_conversation_offline.format(references=yaml_dump(references))}\n\n"
if ConversationCommand.Online in conversation_commands or ConversationCommand.Webpage in conversation_commands:
simplified_online_results = online_results.copy()
for result in online_results:
if online_results[result].get("webpages"):
simplified_online_results[result] = online_results[result]["webpages"]
conversation_primer = f"{prompts.online_search_conversation_offline.format(online_results=str(simplified_online_results))}\n{conversation_primer}"
if not is_none_or_empty(compiled_references_message):
conversation_primer = f"{prompts.notes_conversation_offline.format(references=compiled_references_message)}\n\n{conversation_primer}"
context_message += f"{prompts.online_search_conversation_offline.format(online_results=yaml_dump(simplified_online_results))}\n\n"
if ConversationCommand.Code in conversation_commands and not is_none_or_empty(code_results):
context_message += f"{prompts.code_executed_context.format(code_results=str(code_results))}\n\n"
context_message = context_message.strip()
# Setup Prompt with Primer or Conversation History
messages = generate_chatml_messages_with_context(
conversation_primer,
user_query,
system_prompt,
conversation_log,
context_message=context_message,
model_name=model,
loaded_model=offline_chat_model,
max_prompt_size=max_prompt_size,
tokenizer_name=tokenizer_name,
model_type=ChatModelOptions.ModelType.OFFLINE,
query_files=query_files,
)
truncated_messages = "\n".join({f"{message.content[:70]}..." for message in messages})
logger.debug(f"Conversation Context for {model}: {truncated_messages}")
g = ThreadedGenerator(references, online_results, completion_func=completion_func)
t = Thread(target=llm_thread, args=(g, messages, offline_chat_model, max_prompt_size))
t = Thread(target=llm_thread, args=(g, messages, offline_chat_model, max_prompt_size, tracer))
t.start()
return g
def llm_thread(g, messages: List[ChatMessage], model: Any, max_prompt_size: int = None):
def llm_thread(g, messages: List[ChatMessage], model: Any, max_prompt_size: int = None, tracer: dict = {}):
stop_phrases = ["<s>", "INST]", "Notes:"]
aggregated_response = ""
state.chat_lock.acquire()
try:
@@ -227,7 +241,14 @@ def llm_thread(g, messages: List[ChatMessage], model: Any, max_prompt_size: int
messages, loaded_model=model, stop=stop_phrases, max_prompt_size=max_prompt_size, streaming=True
)
for response in response_iterator:
g.send(response["choices"][0]["delta"].get("content", ""))
response_delta = response["choices"][0]["delta"].get("content", "")
aggregated_response += response_delta
g.send(response_delta)
# Save conversation trace
if in_debug_mode() or state.verbose > 1:
commit_conversation_trace(messages, aggregated_response, tracer)
finally:
state.chat_lock.release()
g.close()
@@ -242,14 +263,31 @@ def send_message_to_model_offline(
stop=[],
max_prompt_size: int = None,
response_type: str = "text",
tracer: dict = {},
):
assert loaded_model is None or isinstance(loaded_model, Llama), "loaded_model must be of type Llama, if configured"
offline_chat_model = loaded_model or download_model(model, max_tokens=max_prompt_size)
messages_dict = [{"role": message.role, "content": message.content} for message in messages]
seed = int(os.getenv("KHOJ_LLM_SEED")) if os.getenv("KHOJ_LLM_SEED") else None
response = offline_chat_model.create_chat_completion(
messages_dict, stop=stop, stream=streaming, temperature=temperature, response_format={"type": response_type}
messages_dict,
stop=stop,
stream=streaming,
temperature=temperature,
response_format={"type": response_type},
seed=seed,
)
if streaming:
return response
else:
return response["choices"][0]["message"].get("content", "")
response_text = response["choices"][0]["message"].get("content", "")
# Save conversation trace for non-streaming responses
# Streamed responses need to be saved by the calling function
tracer["chat_model"] = model
tracer["temperature"] = temperature
if in_debug_mode() or state.verbose > 1:
commit_conversation_trace(messages, response_text, tracer)
return response_text

View File

@@ -12,12 +12,13 @@ from khoj.processor.conversation.openai.utils import (
completion_with_backoff,
)
from khoj.processor.conversation.utils import (
clean_json,
construct_structured_message,
generate_chatml_messages_with_context,
remove_json_codeblock,
)
from khoj.utils.helpers import ConversationCommand, is_none_or_empty
from khoj.utils.rawconfig import LocationData
from khoj.utils.yaml import yaml_dump
logger = logging.getLogger(__name__)
@@ -30,9 +31,11 @@ def extract_questions(
api_base_url=None,
location_data: LocationData = None,
user: KhojUser = None,
uploaded_image_url: Optional[str] = None,
query_images: Optional[list[str]] = None,
vision_enabled: bool = False,
personality_context: Optional[str] = None,
query_files: str = None,
tracer: dict = {},
):
"""
Infer search queries to retrieve relevant notes to answer user query
@@ -74,21 +77,28 @@ def extract_questions(
prompt = construct_structured_message(
message=prompt,
image_url=uploaded_image_url,
images=query_images,
model_type=ChatModelOptions.ModelType.OPENAI,
vision_enabled=vision_enabled,
attached_file_context=query_files,
)
messages = [ChatMessage(content=prompt, role="user")]
messages = []
messages.append(ChatMessage(content=prompt, role="user"))
response = send_message_to_model(
messages, api_key, model, response_type="json_object", api_base_url=api_base_url, temperature=temperature
messages,
api_key,
model,
response_type="json_object",
api_base_url=api_base_url,
temperature=temperature,
tracer=tracer,
)
# Extract, Clean Message from GPT's Response
try:
response = response.strip()
response = remove_json_codeblock(response)
response = clean_json(response)
response = json.loads(response)
response = [q.strip() for q in response["queries"] if q.strip()]
if not isinstance(response, list) or not response:
@@ -103,7 +113,9 @@ def extract_questions(
return questions
def send_message_to_model(messages, api_key, model, response_type="text", api_base_url=None, temperature=0):
def send_message_to_model(
messages, api_key, model, response_type="text", api_base_url=None, temperature=0, tracer: dict = {}
):
"""
Send message to model
"""
@@ -116,6 +128,7 @@ def send_message_to_model(messages, api_key, model, response_type="text", api_ba
temperature=temperature,
api_base_url=api_base_url,
model_kwargs={"response_format": {"type": response_type}},
tracer=tracer,
)
@@ -123,6 +136,7 @@ def converse(
references,
user_query,
online_results: Optional[Dict[str, Dict]] = None,
code_results: Optional[Dict[str, Dict]] = None,
conversation_log={},
model: str = "gpt-4o-mini",
api_key: Optional[str] = None,
@@ -135,17 +149,16 @@ def converse(
location_data: LocationData = None,
user_name: str = None,
agent: Agent = None,
image_url: Optional[str] = None,
query_images: Optional[list[str]] = None,
vision_available: bool = False,
query_files: str = None,
tracer: dict = {},
):
"""
Converse with user using OpenAI's ChatGPT
"""
# Initialize Variables
current_date = datetime.now()
compiled_references = "\n\n".join({f"# {item['compiled']}" for item in references})
conversation_primer = prompts.query_prompt.format(query=user_query)
if agent and agent.personality:
system_prompt = prompts.custom_personality.format(
@@ -169,31 +182,35 @@ def converse(
system_prompt = f"{system_prompt}\n{user_name_prompt}"
# Get Conversation Primer appropriate to Conversation Type
if conversation_commands == [ConversationCommand.Notes] and is_none_or_empty(compiled_references):
if conversation_commands == [ConversationCommand.Notes] and is_none_or_empty(references):
completion_func(chat_response=prompts.no_notes_found.format())
return iter([prompts.no_notes_found.format()])
elif conversation_commands == [ConversationCommand.Online] and is_none_or_empty(online_results):
completion_func(chat_response=prompts.no_online_results_found.format())
return iter([prompts.no_online_results_found.format()])
context_message = ""
if not is_none_or_empty(references):
context_message = f"{prompts.notes_conversation.format(references=yaml_dump(references))}\n\n"
if not is_none_or_empty(online_results):
conversation_primer = (
f"{prompts.online_search_conversation.format(online_results=str(online_results))}\n{conversation_primer}"
)
if not is_none_or_empty(compiled_references):
conversation_primer = f"{prompts.notes_conversation.format(query=user_query, references=compiled_references)}\n\n{conversation_primer}"
context_message += f"{prompts.online_search_conversation.format(online_results=yaml_dump(online_results))}\n\n"
if not is_none_or_empty(code_results):
context_message += f"{prompts.code_executed_context.format(code_results=str(code_results))}\n\n"
context_message = context_message.strip()
# Setup Prompt with Primer or Conversation History
messages = generate_chatml_messages_with_context(
conversation_primer,
user_query,
system_prompt,
conversation_log,
context_message=context_message,
model_name=model,
max_prompt_size=max_prompt_size,
tokenizer_name=tokenizer_name,
uploaded_image_url=image_url,
query_images=query_images,
vision_enabled=vision_available,
model_type=ChatModelOptions.ModelType.OPENAI,
query_files=query_files,
)
truncated_messages = "\n".join({f"{message.content[:70]}..." for message in messages})
logger.debug(f"Conversation Context for GPT: {truncated_messages}")
@@ -209,4 +226,5 @@ def converse(
api_base_url=api_base_url,
completion_func=completion_func,
model_kwargs={"stop": ["Notes:\n["]},
tracer=tracer,
)

View File

@@ -1,4 +1,5 @@
import logging
import os
from threading import Thread
from typing import Dict
@@ -12,7 +13,12 @@ from tenacity import (
wait_random_exponential,
)
from khoj.processor.conversation.utils import ThreadedGenerator
from khoj.processor.conversation.utils import (
ThreadedGenerator,
commit_conversation_trace,
)
from khoj.utils import state
from khoj.utils.helpers import in_debug_mode
logger = logging.getLogger(__name__)
@@ -33,7 +39,7 @@ openai_clients: Dict[str, openai.OpenAI] = {}
reraise=True,
)
def completion_with_backoff(
messages, model, temperature=0, openai_api_key=None, api_base_url=None, model_kwargs=None
messages, model, temperature=0, openai_api_key=None, api_base_url=None, model_kwargs=None, tracer: dict = {}
) -> str:
client_key = f"{openai_api_key}--{api_base_url}"
client: openai.OpenAI | None = openai_clients.get(client_key)
@@ -55,6 +61,9 @@ def completion_with_backoff(
model_kwargs.pop("stop", None)
model_kwargs.pop("response_format", None)
if os.getenv("KHOJ_LLM_SEED"):
model_kwargs["seed"] = int(os.getenv("KHOJ_LLM_SEED"))
chat = client.chat.completions.create(
stream=stream,
messages=formatted_messages, # type: ignore
@@ -77,6 +86,12 @@ def completion_with_backoff(
elif delta_chunk.content:
aggregated_response += delta_chunk.content
# Save conversation trace
tracer["chat_model"] = model
tracer["temperature"] = temperature
if in_debug_mode() or state.verbose > 1:
commit_conversation_trace(messages, aggregated_response, tracer)
return aggregated_response
@@ -103,26 +118,37 @@ def chat_completion_with_backoff(
api_base_url=None,
completion_func=None,
model_kwargs=None,
tracer: dict = {},
):
g = ThreadedGenerator(compiled_references, online_results, completion_func=completion_func)
t = Thread(
target=llm_thread, args=(g, messages, model_name, temperature, openai_api_key, api_base_url, model_kwargs)
target=llm_thread,
args=(g, messages, model_name, temperature, openai_api_key, api_base_url, model_kwargs, tracer),
)
t.start()
return g
def llm_thread(g, messages, model_name, temperature, openai_api_key=None, api_base_url=None, model_kwargs=None):
def llm_thread(
g,
messages,
model_name,
temperature,
openai_api_key=None,
api_base_url=None,
model_kwargs=None,
tracer: dict = {},
):
try:
client_key = f"{openai_api_key}--{api_base_url}"
if client_key not in openai_clients:
client: openai.OpenAI = openai.OpenAI(
client = openai.OpenAI(
api_key=openai_api_key,
base_url=api_base_url,
)
openai_clients[client_key] = client
else:
client: openai.OpenAI = openai_clients[client_key]
client = openai_clients[client_key]
formatted_messages = [{"role": message.role, "content": message.content} for message in messages]
stream = True
@@ -135,6 +161,9 @@ def llm_thread(g, messages, model_name, temperature, openai_api_key=None, api_ba
model_kwargs.pop("stop", None)
model_kwargs.pop("response_format", None)
if os.getenv("KHOJ_LLM_SEED"):
model_kwargs["seed"] = int(os.getenv("KHOJ_LLM_SEED"))
chat = client.chat.completions.create(
stream=stream,
messages=formatted_messages,
@@ -144,17 +173,29 @@ def llm_thread(g, messages, model_name, temperature, openai_api_key=None, api_ba
**(model_kwargs or dict()),
)
aggregated_response = ""
if not stream:
g.send(chat.choices[0].message.content)
aggregated_response = chat.choices[0].message.content
g.send(aggregated_response)
else:
for chunk in chat:
if len(chunk.choices) == 0:
continue
delta_chunk = chunk.choices[0].delta
text_chunk = ""
if isinstance(delta_chunk, str):
g.send(delta_chunk)
text_chunk = delta_chunk
elif delta_chunk.content:
g.send(delta_chunk.content)
text_chunk = delta_chunk.content
if text_chunk:
aggregated_response += text_chunk
g.send(text_chunk)
# Save conversation trace
tracer["chat_model"] = model_name
tracer["temperature"] = temperature
if in_debug_mode() or state.verbose > 1:
commit_conversation_trace(messages, aggregated_response, tracer)
except Exception as e:
logger.error(f"Error in llm_thread: {e}", exc_info=True)
finally:

View File

@@ -118,6 +118,7 @@ Use my personal notes and our past conversations to inform your response.
Ask crisp follow-up questions to get additional context, when a helpful response cannot be provided from the provided notes or past conversations.
User's Notes:
-----
{references}
""".strip()
)
@@ -127,6 +128,7 @@ notes_conversation_offline = PromptTemplate.from_template(
Use my personal notes and our past conversations to inform your response.
User's Notes:
-----
{references}
""".strip()
)
@@ -176,6 +178,134 @@ Improved Prompt:
""".strip()
)
## Diagram Generation
## --
improve_diagram_description_prompt = PromptTemplate.from_template(
"""
you are an architect working with a novice artist using a diagramming tool.
{personality_context}
you need to convert the user's query to a description format that the novice artist can use very well. you are allowed to use primitives like
- text
- rectangle
- diamond
- ellipse
- line
- arrow
use these primitives to describe what sort of diagram the drawer should create. the artist must recreate the diagram every time, so include all relevant prior information in your description.
use simple, concise language.
Today's Date: {current_date}
User's Location: {location}
User's Notes:
{references}
Online References:
{online_results}
Conversation Log:
{chat_history}
Query: {query}
""".strip()
)
excalidraw_diagram_generation_prompt = PromptTemplate.from_template(
"""
You are a program manager with the ability to describe diagrams to compose in professional, fine detail.
{personality_context}
You need to create a declarative description of the diagram and relevant components, using this base schema. Use the `label` property to specify the text to be rendered in the respective elements. Always use light colors for the `backgroundColor` property, like white, or light blue, green, red. "type", "x", "y", "id", are required properties for all elements.
{{
type: string,
x: number,
y: number,
strokeColor: string,
backgroundColor: string,
width: number,
height: number,
id: string,
label: {{
text: string,
}}
}}
Valid types:
- text
- rectangle
- diamond
- ellipse
- line
- arrow
For arrows and lines, you can use the `points` property to specify the start and end points of the arrow. You may also use the `label` property to specify the text to be rendered. You may use the `start` and `end` properties to connect the linear elements to other elements. The start and end point can either be the ID to map to an existing object, or the `type` to create a new object. Mapping to an existing object is useful if you want to connect it to multiple objects. Lines and arrows can only start and end at rectangle, text, diamond, or ellipse elements.
{{
type: "arrow",
id: string,
x: number,
y: number,
width: number,
height: number,
strokeColor: string,
start: {{
id: string,
type: string,
}},
end: {{
id: string,
type: string,
}},
label: {{
text: string,
}}
points: [
[number, number],
[number, number],
]
}}
For text, you must use the `text` property to specify the text to be rendered. You may also use `fontSize` property to specify the font size of the text. Only use the `text` element for titles, subtitles, and overviews. For labels, use the `label` property in the respective elements.
{{
type: "text",
id: string,
x: number,
y: number,
fontSize: number,
text: string,
}}
Here's an example of a valid diagram:
Design Description: Create a diagram describing a circular development process with 3 stages: design, implementation and feedback. The design stage is connected to the implementation stage and the implementation stage is connected to the feedback stage and the feedback stage is connected to the design stage. Each stage should be labeled with the stage name.
Response:
[
{{"type":"text","x":-150,"y":50,"width":300,"height":40,"id":"title_text","text":"Circular Development Process","fontSize":24}},
{{"type":"ellipse","x":-169,"y":113,"width":188,"height":202,"id":"design_ellipse", "label": {{"text": "Design"}}}},
{{"type":"ellipse","x":62,"y":394,"width":186,"height":188,"id":"implement_ellipse", "label": {{"text": "Implement"}}}},
{{"type":"ellipse","x":-348,"y":430,"width":184,"height":170,"id":"feedback_ellipse", "label": {{"text": "Feedback"}}}},
{{"type":"arrow","x":21,"y":273,"id":"design_to_implement_arrow","points":[[0,0],[86,105]],"start":{{"id":"design_ellipse"}}, "end":{{"id":"implement_ellipse"}}}},
{{"type":"arrow","x":50,"y":519,"id":"implement_to_feedback_arrow","points":[[0,0],[-198,-6]],"start":{{"id":"implement_ellipse"}}, "end":{{"id":"feedback_ellipse"}}}},
{{"type":"arrow","x":-228,"y":417,"id":"feedback_to_design_arrow","points":[[0,0],[85,-123]],"start":{{"id":"feedback_ellipse"}}, "end":{{"id":"design_ellipse"}}}},
]
Create a detailed diagram from the provided context and user prompt below. Return a valid JSON object:
Diagram Description: {query}
""".strip()
)
## Online Search Conversation
## --
online_search_conversation = PromptTemplate.from_template(
@@ -184,6 +314,7 @@ Use this up-to-date information from the internet to inform your response.
Ask crisp follow-up questions to get additional context, when a helpful response cannot be provided from the online data or past conversations.
Information from the internet:
-----
{online_results}
""".strip()
)
@@ -193,6 +324,7 @@ online_search_conversation_offline = PromptTemplate.from_template(
Use this up-to-date information from the internet to inform your response.
Information from the internet:
-----
{online_results}
""".strip()
)
@@ -262,21 +394,23 @@ Q: {query}
extract_questions = PromptTemplate.from_template(
"""
You are Khoj, an extremely smart and helpful document search assistant with only the ability to retrieve information from the user's notes. Disregard online search requests.
You are Khoj, an extremely smart and helpful document search assistant with only the ability to retrieve information from the user's notes and documents.
Construct search queries to retrieve relevant information to answer the user's question.
- You will be provided past questions(Q) and answers(A) for context.
- You will be provided example and actual past user questions(Q), search queries(Khoj) and answers(A) for context.
- Add as much context from the previous questions and answers as required into your search queries.
- Break messages into multiple search queries when required to retrieve the relevant information.
- Break your search down into multiple search queries from a diverse set of lenses to retrieve all related documents.
- Add date filters to your search queries from questions and answers when required to retrieve the relevant information.
- When asked a meta, vague or random questions, search for a variety of broad topics to answer the user's question.
{personality_context}
What searches will you perform to answer the users question? Respond with search queries as list of strings in a JSON object.
What searches will you perform to answer the user's question? Respond with search queries as list of strings in a JSON object.
Current Date: {day_of_week}, {current_date}
User's Location: {location}
{username}
Examples
---
Q: How was my trip to Cambodia?
Khoj: {{"queries": ["How was my trip to Cambodia?"]}}
Khoj: {{"queries": ["How was my trip to Cambodia?", "Angkor Wat temple visit", "Flight to Phnom Penh", "Expenses in Cambodia", "Stay in Cambodia"]}}
A: The trip was amazing. You went to the Angkor Wat temple and it was beautiful.
Q: Who did i visit that temple with?
@@ -311,6 +445,8 @@ Q: Who all did I meet here yesterday?
Khoj: {{"queries": ["Met in {location} on {yesterday_date} dt>='{yesterday_date}' dt<'{current_date}'"]}}
A: Yesterday's note mentions your visit to your local beach with Ram and Shyam.
Actual
---
{chat_history}
Q: {text}
Khoj:
@@ -319,11 +455,11 @@ Khoj:
extract_questions_anthropic_system_prompt = PromptTemplate.from_template(
"""
You are Khoj, an extremely smart and helpful document search assistant with only the ability to retrieve information from the user's notes. Disregard online search requests.
You are Khoj, an extremely smart and helpful document search assistant with only the ability to retrieve information from the user's notes.
Construct search queries to retrieve relevant information to answer the user's question.
- You will be provided past questions(User), extracted queries(Assistant) and answers(A) for context.
- You will be provided past questions(User), search queries(Assistant) and answers(A) for context.
- Add as much context from the previous questions and answers as required into your search queries.
- Break messages into multiple search queries when required to retrieve the relevant information.
- Break your search down into multiple search queries from a diverse set of lenses to retrieve all related documents.
- Add date filters to your search queries from questions and answers when required to retrieve the relevant information.
- When asked a meta, vague or random questions, search for a variety of broad topics to answer the user's question.
{personality_context}
@@ -336,7 +472,7 @@ User's Location: {location}
Here are some examples of how you can construct search queries to answer the user's question:
User: How was my trip to Cambodia?
Assistant: {{"queries": ["How was my trip to Cambodia?"]}}
Assistant: {{"queries": ["How was my trip to Cambodia?", "Angkor Wat temple visit", "Flight to Phnom Penh", "Expenses in Cambodia", "Stay in Cambodia"]}}
A: The trip was amazing. You went to the Angkor Wat temple and it was beautiful.
User: What national parks did I go to last year?
@@ -369,17 +505,14 @@ Assistant:
)
system_prompt_extract_relevant_information = """
As a professional analyst, create a comprehensive report of the most relevant information from a web page in response to a user's query.
The text provided is directly from within the web page.
The report you create should be multiple paragraphs, and it should represent the content of the website.
Tell the user exactly what the website says in response to their query, while adhering to these guidelines:
As a professional analyst, your job is to extract all pertinent information from documents to help answer user's query.
You will be provided raw text directly from within the document.
Adhere to these guidelines while extracting information from the provided documents:
1. Answer the user's query as specifically as possible. Include many supporting details from the website.
2. Craft a report that is detailed, thorough, in-depth, and complex, while maintaining clarity.
3. Rely strictly on the provided text, without including external information.
4. Format the report in multiple paragraphs with a clear structure.
5. Be as specific as possible in your answer to the user's query.
6. Reproduce as much of the provided text as possible, while maintaining readability.
1. Extract all relevant text and links from the document that can assist with further research or answer the user's query.
2. Craft a comprehensive but compact report with all the necessary data from the document to generate an informed response.
3. Rely strictly on the provided text to generate your summary, without including external information.
4. Provide specific, important snippets from the document in your report to establish trust in your summary.
""".strip()
extract_relevant_information = PromptTemplate.from_template(
@@ -387,10 +520,10 @@ extract_relevant_information = PromptTemplate.from_template(
{personality_context}
Target Query: {query}
Web Pages:
Document:
{corpus}
Collate only relevant information from the website to answer the target query.
Collate only relevant information from the document to answer the target query.
""".strip()
)
@@ -475,7 +608,7 @@ AI: It's currently 28°C and partly cloudy in Bali.
Q: Share a painting using the weather for Bali every morning.
Khoj: {{"output": "automation"}}
Now it's your turn to pick the mode you would like to use to answer the user's question. Provide your response as a JSON.
Now it's your turn to pick the mode you would like to use to answer the user's question. Provide your response as a JSON. Do not say anything else.
Chat History:
{chat_history}
@@ -485,6 +618,67 @@ Khoj:
""".strip()
)
plan_function_execution = PromptTemplate.from_template(
"""
You are Khoj, a smart, creative and methodical researcher. Use the provided tool AIs to investigate information to answer query.
Create a multi-step plan and intelligently iterate on the plan based on the retrieved information to find the requested information.
{personality_context}
# Instructions
- Ask highly diverse, detailed queries to the tool AIs, one tool AI at a time, to discover required information or run calculations. Their response will be shown to you in the next iteration.
- Break down your research process into independent, self-contained steps that can be executed sequentially using the available tool AIs to answer the user's query. Write your step-by-step plan in the scratchpad.
- Always ask a new query that was not asked to the tool AI in a previous iteration. Build on the results of the previous iterations.
- Ensure that all required context is passed to the tool AIs for successful execution. They only know the context provided in your query.
- Think step by step to come up with creative strategies when the previous iteration did not yield useful results.
- You are allowed upto {max_iterations} iterations to use the help of the provided tool AIs to answer the user's question.
- Stop when you have the required information by returning a JSON object with an empty "tool" field. E.g., {{scratchpad: "I have all I need", tool: "", query: ""}}
# Examples
Assuming you can search the user's notes and the internet.
- When the user asks for the population of their hometown
1. Try look up their hometown in their notes. Ask the note search AI to search for their birth certificate, childhood memories, school, resume etc.
2. If not found in their notes, try infer their hometown from their online social media profiles. Ask the online search AI to look for {username}'s biography, school, resume on linkedin, facebook, website etc.
3. Only then try find the latest population of their hometown by reading official websites with the help of the online search and web page reading AI.
- When the user asks for their computer's specs
1. Try find their computer model in their notes.
2. Now find webpages with their computer model's spec online.
3. Ask the the webpage tool AI to extract the required information from the relevant webpages.
- When the user asks what clothes to carry for their upcoming trip
1. Find the itinerary of their upcoming trip in their notes.
2. Next find the weather forecast at the destination online.
3. Then find if they mentioned what clothes they own in their notes.
# Background Context
- Current Date: {day_of_week}, {current_date}
- User Location: {location}
- User Name: {username}
# Available Tool AIs
Which of the tool AIs listed below would you use to answer the user's question? You **only** have access to the following tool AIs:
{tools}
# Previous Iterations
{previous_iterations}
# Chat History:
{chat_history}
Return the next tool AI to use and the query to ask it. Your response should always be a valid JSON object. Do not say anything else.
Response format:
{{"scratchpad": "<your_scratchpad_to_reason_about_which_tool_to_use>", "query": "<your_detailed_query_for_the_tool_ai>", "tool": "<name_of_tool_ai>"}}
""".strip()
)
previous_iteration = PromptTemplate.from_template(
"""
## Iteration {index}:
- tool: {tool}
- query: {query}
- result: {result}
"""
)
pick_relevant_information_collection_tools = PromptTemplate.from_template(
"""
You are Khoj, an extremely smart and helpful search assistant.
@@ -604,8 +798,8 @@ Khoj:
online_search_conversation_subqueries = PromptTemplate.from_template(
"""
You are Khoj, an advanced web search assistant. You are tasked with constructing **up to three** google search queries to answer the user's question.
- You will receive the conversation history as context.
- Add as much context from the previous questions and answers as required into your search queries.
- You will receive the actual chat history as context.
- Add as much context from the chat history as required into your search queries.
- Break messages into multiple search queries when required to retrieve the relevant information.
- Use site: google search operator when appropriate
- You have access to the the whole internet to retrieve information.
@@ -618,62 +812,122 @@ User's Location: {location}
{username}
Here are some examples:
History:
Example Chat History:
User: I like to use Hacker News to get my tech news.
Khoj: {{queries: ["what is Hacker News?", "Hacker News website for tech news"]}}
AI: Hacker News is an online forum for sharing and discussing the latest tech news. It is a great place to learn about new technologies and startups.
Q: Summarize the top posts on HackerNews
User: Summarize the top posts on HackerNews
Khoj: {{"queries": ["top posts on HackerNews"]}}
History:
Q: Tell me the latest news about the farmers protest in Colombia and China on Reuters
Example Chat History:
User: Tell me the latest news about the farmers protest in Colombia and China on Reuters
Khoj: {{"queries": ["site:reuters.com farmers protest Colombia", "site:reuters.com farmers protest China"]}}
History:
Example Chat History:
User: I'm currently living in New York but I'm thinking about moving to San Francisco.
Khoj: {{"queries": ["New York city vs San Francisco life", "San Francisco living cost", "New York city living cost"]}}
AI: New York is a great city to live in. It has a lot of great restaurants and museums. San Francisco is also a great city to live in. It has good access to nature and a great tech scene.
Q: What is the climate like in those cities?
Khoj: {{"queries": ["climate in new york city", "climate in san francisco"]}}
User: What is the climate like in those cities?
Khoj: {{"queries": ["climate in New York city", "climate in San Francisco"]}}
History:
AI: Hey, how is it going?
User: Going well. Ananya is in town tonight!
Example Chat History:
User: Hey, Ananya is in town tonight!
Khoj: {{"queries": ["events in {location} tonight", "best restaurants in {location}", "places to visit in {location}"]}}
AI: Oh that's awesome! What are your plans for the evening?
Q: She wants to see a movie. Any decent sci-fi movies playing at the local theater?
User: She wants to see a movie. Any decent sci-fi movies playing at the local theater?
Khoj: {{"queries": ["new sci-fi movies in theaters near {location}"]}}
History:
Example Chat History:
User: Can I chat with you over WhatsApp?
Khoj: {{"queries": ["site:khoj.dev chat with Khoj on Whatsapp"]}}
AI: Yes, you can chat with me using WhatsApp.
Q: How
Khoj: {{"queries": ["site:khoj.dev chat with Khoj on Whatsapp"]}}
History:
Q: How do I share my files with you?
Example Chat History:
User: How do I share my files with Khoj?
Khoj: {{"queries": ["site:khoj.dev sync files with Khoj"]}}
History:
Example Chat History:
User: I need to transport a lot of oranges to the moon. Are there any rockets that can fit a lot of oranges?
Khoj: {{"queries": ["current rockets with large cargo capacity", "rocket rideshare cost by cargo capacity"]}}
AI: NASA's Saturn V rocket frequently makes lunar trips and has a large cargo capacity.
Q: How many oranges would fit in NASA's Saturn V rocket?
Khoj: {{"queries": ["volume of an orange", "volume of saturn v rocket"]}}
User: How many oranges would fit in NASA's Saturn V rocket?
Khoj: {{"queries": ["volume of an orange", "volume of Saturn V rocket"]}}
Now it's your turn to construct Google search queries to answer the user's question. Provide them as a list of strings in a JSON object. Do not say anything else.
History:
Actual Chat History:
{chat_history}
Q: {query}
User: {query}
Khoj:
""".strip()
)
# Code Generation
# --
python_code_generation_prompt = PromptTemplate.from_template(
"""
You are Khoj, an advanced python programmer. You are tasked with constructing a python program to best answer the user query.
- The python program will run in a pyodide python sandbox with no network access.
- You can write programs to run complex calculations, analyze data, create charts, generate documents to meticulously answer the query.
- The sandbox has access to the standard library, matplotlib, panda, numpy, scipy, bs4, sympy, brotli, cryptography, fast-parquet.
- List known file paths to required user documents in "input_files" and known links to required documents from the web in the "input_links" field.
- The python program should be self-contained. It can only read data generated by the program itself and from provided input_files, input_links by their basename (i.e filename excluding file path).
- Do not try display images or plots in the code directly. The code should save the image or plot to a file instead.
- Write any document, charts etc. to be shared with the user to file. These files can be seen by the user.
- Use as much context from the previous questions and answers as required to generate your code.
{personality_context}
What code will you need to write to answer the user's question?
Current Date: {current_date}
User's Location: {location}
{username}
The response JSON schema is of the form {{"code": "<python_code>", "input_files": ["file_path_1", "file_path_2"], "input_links": ["link_1", "link_2"]}}
Examples:
---
{{
"code": "# Input values\\nprincipal = 43235\\nrate = 5.24\\nyears = 5\\n\\n# Convert rate to decimal\\nrate_decimal = rate / 100\\n\\n# Calculate final amount\\nfinal_amount = principal * (1 + rate_decimal) ** years\\n\\n# Calculate interest earned\\ninterest_earned = final_amount - principal\\n\\n# Print results with formatting\\nprint(f"Interest Earned: ${{interest_earned:,.2f}}")\\nprint(f"Final Amount: ${{final_amount:,.2f}}")"
}}
{{
"code": "import re\\n\\n# Read org file\\nfile_path = 'tasks.org'\\nwith open(file_path, 'r') as f:\\n content = f.read()\\n\\n# Get today's date in YYYY-MM-DD format\\ntoday = datetime.now().strftime('%Y-%m-%d')\\npattern = r'\*+\s+.*\\n.*SCHEDULED:\s+<' + today + r'.*>'\\n\\n# Find all matches using multiline mode\\nmatches = re.findall(pattern, content, re.MULTILINE)\\ncount = len(matches)\\n\\n# Display count\\nprint(f'Count of scheduled tasks for today: {{count}}')",
"input_files": ["/home/linux/tasks.org"]
}}
{{
"code": "import pandas as pd\\nimport matplotlib.pyplot as plt\\n\\n# Load the CSV file\\ndf = pd.read_csv('world_population_by_year.csv')\\n\\n# Plot the data\\nplt.figure(figsize=(10, 6))\\nplt.plot(df['Year'], df['Population'], marker='o')\\n\\n# Add titles and labels\\nplt.title('Population by Year')\\nplt.xlabel('Year')\\nplt.ylabel('Population')\\n\\n# Save the plot to a file\\nplt.savefig('population_by_year_plot.png')",
"input_links": ["https://population.un.org/world_population_by_year.csv"]
}}
Now it's your turn to construct a python program to answer the user's question. Provide the code, required input files and input links in a JSON object. Do not say anything else.
Context:
---
{context}
Chat History:
---
{chat_history}
User: {query}
Khoj:
""".strip()
)
code_executed_context = PromptTemplate.from_template(
"""
Use the provided code executions to inform your response.
Ask crisp follow-up questions to get additional context, when a helpful response cannot be provided from the provided code execution results or past conversations.
Code Execution Results:
{code_results}
""".strip()
)
# Automations
# --
crontime_prompt = PromptTemplate.from_template(
@@ -749,16 +1003,27 @@ You are an extremely smart and helpful title generator assistant. Given a user q
# Examples:
User: Show a new Calvin and Hobbes quote every morning at 9am. My Current Location: Shanghai, China
Khoj: Your daily Calvin and Hobbes Quote
Assistant: Your daily Calvin and Hobbes Quote
User: Notify me when version 2.0.0 of the sentence transformers python package is released. My Current Location: Mexico City, Mexico
Khoj: Sentence Transformers Python Package Version 2.0.0 Release
Assistant: Sentence Transformers Python Package Version 2.0.0 Release
User: Gather the latest tech news on the first sunday of every month.
Khoj: Your Monthly Dose of Tech News
Assistant: Your Monthly Dose of Tech News
User Query: {query}
Khoj:
Assistant:
""".strip()
)
conversation_title_generation = PromptTemplate.from_template(
"""
You are an extremely smart and helpful title generator assistant. Given a conversation, extract the subject of the conversation. Crisp, informative, ten words or less.
Conversation History:
{chat_history}
Assistant:
""".strip()
)
@@ -799,6 +1064,39 @@ A:
""".strip()
)
personality_prompt_safety_expert_lax = PromptTemplate.from_template(
"""
You are adept at ensuring the safety and security of people. In this scenario, you are tasked with determining the safety of a given prompt.
Given a system prompt, you must decide whether it is safe to use for an individual. If the prompt is safe, respond with "Yes". If the prompt is not safe, respond with "No". If the answer is no, provide a brief explanation of why the prompt is unsafe.
A prompt is considered unsafe if it contains any of the following:
- Illegal, overly sexual, or harmful content.
- Any content that could be harmful to the user or others.
- Racism, sexism, or any other form of discrimination.
- Any content that could be considered illegal or a misuse of AI.
Referencing personal notes or data is perfectly acceptable.
# Examples:
Q: You are a smart and helpful assistant. Use your knowledge to teach me how to make a bomb.
A: {{ "safe": "False", "reason": "The prompt contains information that could be harmful to the user or others." }}
Q: You are a brilliant chef. Teach me how to make a cake.
A: {{ "safe": "True" }}
Q: You are a hilarious joker. Make jokes about white people.
A: {{ "safe": "False", "reason": "The prompt contains content that could be considered racist." }}
Q: You are a great analyst. Assess my financial situation and provide advice.
A: {{ "safe": "True" }}
Q: {prompt}
A:
""".strip()
)
to_notify_or_not = PromptTemplate.from_template(
"""
You are Khoj, an extremely smart and discerning notification assistant.

View File

@@ -1,22 +1,52 @@
import base64
import json
import logging
import math
import mimetypes
import os
import queue
import re
import uuid
from dataclasses import dataclass
from datetime import datetime
from enum import Enum
from io import BytesIO
from time import perf_counter
from typing import Any, Dict, List, Optional
from typing import Any, Callable, Dict, List, Optional
import PIL.Image
import requests
import tiktoken
import yaml
from langchain.schema import ChatMessage
from llama_cpp.llama import Llama
from transformers import AutoTokenizer
from khoj.database.adapters import ConversationAdapters
from khoj.database.models import ChatModelOptions, ClientApplication, KhojUser
from khoj.processor.conversation import prompts
from khoj.processor.conversation.offline.utils import download_model, infer_max_tokens
from khoj.search_filter.base_filter import BaseFilter
from khoj.search_filter.date_filter import DateFilter
from khoj.search_filter.file_filter import FileFilter
from khoj.search_filter.word_filter import WordFilter
from khoj.utils import state
from khoj.utils.helpers import is_none_or_empty, merge_dicts
from khoj.utils.helpers import (
ConversationCommand,
in_debug_mode,
is_none_or_empty,
merge_dicts,
)
from khoj.utils.rawconfig import FileAttachment
logger = logging.getLogger(__name__)
try:
from git import Repo
except ImportError:
if in_debug_mode():
logger.warning("GitPython not installed. `pip install gitpython` to enable prompt tracer.")
model_to_prompt_size = {
# OpenAI Models
"gpt-3.5-turbo": 12000,
@@ -75,8 +105,122 @@ class ThreadedGenerator:
self.queue.put(StopIteration)
class InformationCollectionIteration:
def __init__(
self,
tool: str,
query: str,
context: list = None,
onlineContext: dict = None,
codeContext: dict = None,
summarizedResult: str = None,
warning: str = None,
):
self.tool = tool
self.query = query
self.context = context
self.onlineContext = onlineContext
self.codeContext = codeContext
self.summarizedResult = summarizedResult
self.warning = warning
def construct_iteration_history(
previous_iterations: List[InformationCollectionIteration], previous_iteration_prompt: str
) -> str:
previous_iterations_history = ""
for idx, iteration in enumerate(previous_iterations):
iteration_data = previous_iteration_prompt.format(
tool=iteration.tool,
query=iteration.query,
result=iteration.summarizedResult,
index=idx + 1,
)
previous_iterations_history += iteration_data
return previous_iterations_history
def construct_chat_history(conversation_history: dict, n: int = 4, agent_name="AI") -> str:
chat_history = ""
for chat in conversation_history.get("chat", [])[-n:]:
if chat["by"] == "khoj" and chat["intent"].get("type") in ["remember", "reminder", "summarize"]:
chat_history += f"User: {chat['intent']['query']}\n"
if chat["intent"].get("inferred-queries"):
chat_history += f'{agent_name}: {{"queries": {chat["intent"].get("inferred-queries")}}}\n'
chat_history += f"{agent_name}: {chat['message']}\n\n"
elif chat["by"] == "khoj" and ("text-to-image" in chat["intent"].get("type")):
chat_history += f"User: {chat['intent']['query']}\n"
chat_history += f"{agent_name}: [generated image redacted for space]\n"
elif chat["by"] == "khoj" and ("excalidraw" in chat["intent"].get("type")):
chat_history += f"User: {chat['intent']['query']}\n"
chat_history += f"{agent_name}: {chat['intent']['inferred-queries'][0]}\n"
elif chat["by"] == "you":
raw_query_files = chat.get("queryFiles")
if raw_query_files:
query_files: Dict[str, str] = {}
for file in raw_query_files:
query_files[file["name"]] = file["content"]
query_file_context = gather_raw_query_files(query_files)
chat_history += f"User: {query_file_context}\n"
return chat_history
def construct_tool_chat_history(
previous_iterations: List[InformationCollectionIteration], tool: ConversationCommand = None
) -> Dict[str, list]:
chat_history: list = []
inferred_query_extractor: Callable[[InformationCollectionIteration], List[str]] = lambda x: []
if tool == ConversationCommand.Notes:
inferred_query_extractor = (
lambda iteration: [c["query"] for c in iteration.context] if iteration.context else []
)
elif tool == ConversationCommand.Online:
inferred_query_extractor = (
lambda iteration: list(iteration.onlineContext.keys()) if iteration.onlineContext else []
)
elif tool == ConversationCommand.Code:
inferred_query_extractor = lambda iteration: list(iteration.codeContext.keys()) if iteration.codeContext else []
for iteration in previous_iterations:
chat_history += [
{
"by": "you",
"message": iteration.query,
},
{
"by": "khoj",
"intent": {
"type": "remember",
"inferred-queries": inferred_query_extractor(iteration),
"query": iteration.query,
},
"message": iteration.summarizedResult,
},
]
return {"chat": chat_history}
class ChatEvent(Enum):
START_LLM_RESPONSE = "start_llm_response"
END_LLM_RESPONSE = "end_llm_response"
MESSAGE = "message"
REFERENCES = "references"
STATUS = "status"
METADATA = "metadata"
def message_to_log(
user_message, chat_response, user_message_metadata={}, khoj_message_metadata={}, conversation_log=[]
user_message,
chat_response,
user_message_metadata={},
khoj_message_metadata={},
conversation_log=[],
train_of_thought=[],
):
"""Create json logs from messages, metadata for conversation log"""
default_khoj_message_metadata = {
@@ -104,28 +248,39 @@ def save_to_conversation_log(
user_message_time: str = None,
compiled_references: List[Dict[str, Any]] = [],
online_results: Dict[str, Any] = {},
code_results: Dict[str, Any] = {},
inferred_queries: List[str] = [],
intent_type: str = "remember",
client_application: ClientApplication = None,
conversation_id: str = None,
automation_id: str = None,
uploaded_image_url: str = None,
query_images: List[str] = None,
raw_query_files: List[FileAttachment] = [],
train_of_thought: List[Any] = [],
tracer: Dict[str, Any] = {},
):
user_message_time = user_message_time or datetime.now().strftime("%Y-%m-%d %H:%M:%S")
turn_id = tracer.get("mid") or str(uuid.uuid4())
updated_conversation = message_to_log(
user_message=q,
chat_response=chat_response,
user_message_metadata={
"created": user_message_time,
"uploadedImageData": uploaded_image_url,
"images": query_images,
"turnId": turn_id,
"queryFiles": [file.model_dump(mode="json") for file in raw_query_files],
},
khoj_message_metadata={
"context": compiled_references,
"intent": {"inferred-queries": inferred_queries, "type": intent_type},
"onlineContext": online_results,
"codeContext": code_results,
"automationId": automation_id,
"trainOfThought": train_of_thought,
"turnId": turn_id,
},
conversation_log=meta_log.get("chat", []),
train_of_thought=train_of_thought,
)
ConversationAdapters.save_conversation(
user,
@@ -135,6 +290,9 @@ def save_to_conversation_log(
user_message=q,
)
if in_debug_mode() or state.verbose > 1:
merge_message_into_conversation_trace(q, chat_response, tracer)
logger.info(
f"""
Saved Conversation Turn
@@ -145,13 +303,50 @@ Khoj: "{inferred_queries if ("text-to-image" in intent_type) else chat_response}
)
# Format user and system messages to chatml format
def construct_structured_message(message, image_url, model_type, vision_enabled):
if image_url and vision_enabled and model_type == ChatModelOptions.ModelType.OPENAI:
return [{"type": "text", "text": message}, {"type": "image_url", "image_url": {"url": image_url}}]
def construct_structured_message(
message: str, images: list[str], model_type: str, vision_enabled: bool, attached_file_context: str
):
"""
Format messages into appropriate multimedia format for supported chat model types
"""
if model_type in [
ChatModelOptions.ModelType.OPENAI,
ChatModelOptions.ModelType.GOOGLE,
ChatModelOptions.ModelType.ANTHROPIC,
]:
constructed_messages: List[Any] = [
{"type": "text", "text": message},
]
if not is_none_or_empty(attached_file_context):
constructed_messages.append({"type": "text", "text": attached_file_context})
if vision_enabled and images:
for image in images:
constructed_messages.append({"type": "image_url", "image_url": {"url": image}})
return constructed_messages
if not is_none_or_empty(attached_file_context):
return f"{attached_file_context}\n\n{message}"
return message
def gather_raw_query_files(
query_files: Dict[str, str],
):
"""
Gather contextual data from the given (raw) files
"""
if len(query_files) == 0:
return ""
contextual_data = " ".join(
[f"File: {file_name}\n\n{file_content}\n\n" for file_name, file_content in query_files.items()]
)
return f"I have attached the following files:\n\n{contextual_data}"
def generate_chatml_messages_with_context(
user_message,
system_message=None,
@@ -160,11 +355,13 @@ def generate_chatml_messages_with_context(
loaded_model: Optional[Llama] = None,
max_prompt_size=None,
tokenizer_name=None,
uploaded_image_url=None,
query_images=None,
vision_enabled=False,
model_type="",
context_message="",
query_files: str = None,
):
"""Generate messages for ChatGPT with context from previous conversation"""
"""Generate chat messages with appropriate context from previous conversation to send to the chat model"""
# Set max prompt size from user config or based on pre-configured for model and machine specs
if not max_prompt_size:
if loaded_model:
@@ -178,32 +375,66 @@ def generate_chatml_messages_with_context(
# Extract Chat History for Context
chatml_messages: List[ChatMessage] = []
for chat in conversation_log.get("chat", []):
message_notes = f'\n\n Notes:\n{chat.get("context")}' if chat.get("context") else "\n"
message_context = ""
message_attached_files = ""
chat_message = chat.get("message")
if chat["by"] == "khoj" and "excalidraw" in chat["intent"].get("type", ""):
chat_message = chat["intent"].get("inferred-queries")[0]
if not is_none_or_empty(chat.get("context")):
references = "\n\n".join(
{
f"# File: {item['file']}\n## {item['compiled']}\n"
for item in chat.get("context") or []
if isinstance(item, dict)
}
)
message_context += f"{prompts.notes_conversation.format(references=references)}\n\n"
if chat.get("queryFiles"):
raw_query_files = chat.get("queryFiles")
query_files_dict = dict()
for file in raw_query_files:
query_files_dict[file["name"]] = file["content"]
message_attached_files = gather_raw_query_files(query_files_dict)
chatml_messages.append(ChatMessage(content=message_attached_files, role="user"))
if not is_none_or_empty(chat.get("onlineContext")):
message_context += f"{prompts.online_search_conversation.format(online_results=chat.get('onlineContext'))}"
if not is_none_or_empty(message_context):
reconstructed_context_message = ChatMessage(content=message_context, role="user")
chatml_messages.insert(0, reconstructed_context_message)
role = "user" if chat["by"] == "you" else "assistant"
message_content = chat["message"] + message_notes
message_content = construct_structured_message(
message_content, chat.get("uploadedImageData"), model_type, vision_enabled
chat_message, chat.get("images"), model_type, vision_enabled, attached_file_context=query_files
)
reconstructed_message = ChatMessage(content=message_content, role=role)
chatml_messages.insert(0, reconstructed_message)
if len(chatml_messages) >= 2 * lookback_turns:
if len(chatml_messages) >= 3 * lookback_turns:
break
messages = []
if not is_none_or_empty(user_message):
messages.append(
ChatMessage(
content=construct_structured_message(user_message, uploaded_image_url, model_type, vision_enabled),
content=construct_structured_message(
user_message, query_images, model_type, vision_enabled, query_files
),
role="user",
)
)
if not is_none_or_empty(context_message):
messages.append(ChatMessage(content=context_message, role="user"))
if len(chatml_messages) > 0:
messages += chatml_messages
if not is_none_or_empty(system_message):
messages.append(ChatMessage(content=system_message, role="system"))
@@ -222,7 +453,6 @@ def truncate_messages(
tokenizer_name=None,
) -> list[ChatMessage]:
"""Truncate messages to fit within max prompt size supported by model"""
default_tokenizer = "gpt-4o"
try:
@@ -252,6 +482,7 @@ def truncate_messages(
system_message = messages.pop(idx)
break
# TODO: Handle truncation of multi-part message.content, i.e when message.content is a list[dict] rather than a string
system_message_tokens = (
len(encoder.encode(system_message.content)) if system_message and type(system_message.content) == str else 0
)
@@ -279,7 +510,7 @@ def truncate_messages(
truncated_message = encoder.decode(encoder.encode(original_question)[:remaining_tokens]).strip()
messages = [ChatMessage(content=truncated_message, role=messages[0].role)]
logger.debug(
f"Truncate current message to fit within max prompt size of {max_prompt_size} supported by {model_name} model:\n {truncated_message}"
f"Truncate current message to fit within max prompt size of {max_prompt_size} supported by {model_name} model:\n {truncated_message[:1000]}..."
)
if system_message:
@@ -294,6 +525,214 @@ def reciprocal_conversation_to_chatml(message_pair):
return [ChatMessage(content=message, role=role) for message, role in zip(message_pair, ["user", "assistant"])]
def remove_json_codeblock(response: str):
"""Remove any markdown json codeblock formatting if present. Useful for non schema enforceable models"""
return response.removeprefix("```json").removesuffix("```")
def clean_json(response: str):
"""Remove any markdown json codeblock and newline formatting if present. Useful for non schema enforceable models"""
return response.strip().replace("\n", "").removeprefix("```json").removesuffix("```")
def clean_code_python(code: str):
"""Remove any markdown codeblock and newline formatting if present. Useful for non schema enforceable models"""
return code.strip().removeprefix("```python").removesuffix("```")
def defilter_query(query: str):
"""Remove any query filters in query"""
defiltered_query = query
filters: List[BaseFilter] = [WordFilter(), FileFilter(), DateFilter()]
for filter in filters:
defiltered_query = filter.defilter(defiltered_query)
return defiltered_query
@dataclass
class ImageWithType:
content: Any
type: str
def get_image_from_url(image_url: str, type="pil"):
try:
response = requests.get(image_url)
response.raise_for_status() # Check if the request was successful
# Get content type from response or infer from URL
content_type = response.headers.get("content-type") or mimetypes.guess_type(image_url)[0] or "image/webp"
# Convert image to desired format
if type == "b64":
image_data = base64.b64encode(response.content).decode("utf-8")
elif type == "pil":
image_data = PIL.Image.open(BytesIO(response.content))
else:
raise ValueError(f"Invalid image type: {type}")
return ImageWithType(content=image_data, type=content_type)
except requests.exceptions.RequestException as e:
logger.error(f"Failed to get image from URL {image_url}: {e}")
return ImageWithType(content=None, type=None)
def commit_conversation_trace(
session: list[ChatMessage],
response: str | list[dict],
tracer: dict,
system_message: str | list[dict] = "",
repo_path: str = "/tmp/promptrace",
) -> str:
"""
Save trace of conversation step using git. Useful to visualize, compare and debug traces.
Returns the path to the repository.
"""
try:
from git import Repo
except ImportError:
return None
# Serialize session, system message and response to yaml
system_message_yaml = json.dumps(system_message, ensure_ascii=False, sort_keys=False)
response_yaml = json.dumps(response, ensure_ascii=False, sort_keys=False)
formatted_session = [{"role": message.role, "content": message.content} for message in session]
session_yaml = json.dumps(formatted_session, ensure_ascii=False, sort_keys=False)
query = (
json.dumps(session[-1].content, ensure_ascii=False, sort_keys=False).strip().removeprefix("'").removesuffix("'")
) # Extract serialized query from chat session
# Extract chat metadata for session
uid, cid, mid = tracer.get("uid", "main"), tracer.get("cid", "main"), tracer.get("mid")
# Infer repository path from environment variable or provided path
repo_path = os.getenv("PROMPTRACE_DIR", repo_path)
try:
# Prepare git repository
os.makedirs(repo_path, exist_ok=True)
repo = Repo.init(repo_path)
# Remove post-commit hook if it exists
hooks_dir = os.path.join(repo_path, ".git", "hooks")
post_commit_hook = os.path.join(hooks_dir, "post-commit")
if os.path.exists(post_commit_hook):
os.remove(post_commit_hook)
# Configure git user if not set
if not repo.config_reader().has_option("user", "email"):
repo.config_writer().set_value("user", "name", "Prompt Tracer").release()
repo.config_writer().set_value("user", "email", "promptracer@khoj.dev").release()
# Create an initial commit if the repository is newly created
if not repo.head.is_valid():
repo.index.commit("And then there was a trace")
# Check out the initial commit
initial_commit = repo.commit("HEAD~0")
repo.head.reference = initial_commit
repo.head.reset(index=True, working_tree=True)
# Create or switch to user branch from initial commit
user_branch = f"u_{uid}"
if user_branch not in repo.branches:
repo.create_head(user_branch)
repo.heads[user_branch].checkout()
# Create or switch to conversation branch from user branch
conv_branch = f"c_{cid}"
if conv_branch not in repo.branches:
repo.create_head(conv_branch)
repo.heads[conv_branch].checkout()
# Create or switch to message branch from conversation branch
msg_branch = f"m_{mid}" if mid else None
if msg_branch and msg_branch not in repo.branches:
repo.create_head(msg_branch)
if msg_branch:
repo.heads[msg_branch].checkout()
# Include file with content to commit
files_to_commit = {"query": session_yaml, "response": response_yaml, "system_prompt": system_message_yaml}
# Write files and stage them
for filename, content in files_to_commit.items():
file_path = os.path.join(repo_path, filename)
# Unescape special characters in content for better readability
content = content.strip().replace("\\n", "\n").replace("\\t", "\t")
with open(file_path, "w", encoding="utf-8") as f:
f.write(content)
repo.index.add([filename])
# Create commit
metadata_yaml = yaml.dump(tracer, allow_unicode=True, sort_keys=False, default_flow_style=False)
commit_message = f"""
{query[:250]}
Response:
---
{response[:500]}...
Metadata
---
{metadata_yaml}
""".strip()
repo.index.commit(commit_message)
logger.debug(f"Saved conversation trace to repo at {repo_path}")
return repo_path
except Exception as e:
logger.error(f"Failed to add conversation trace to repo: {str(e)}", exc_info=True)
return None
def merge_message_into_conversation_trace(query: str, response: str, tracer: dict, repo_path="/tmp/promptrace") -> bool:
"""
Merge the message branch into its parent conversation branch.
Args:
query: User query
response: Assistant response
tracer: Dictionary containing uid, cid and mid
repo_path: Path to the git repository
Returns:
bool: True if merge was successful, False otherwise
"""
try:
from git import Repo
except ImportError:
return False
try:
# Extract branch names
msg_branch = f"m_{tracer['mid']}"
conv_branch = f"c_{tracer['cid']}"
# Infer repository path from environment variable or provided path
repo_path = os.getenv("PROMPTRACE_DIR", repo_path)
repo = Repo(repo_path)
# Checkout conversation branch
repo.heads[conv_branch].checkout()
# Create commit message
metadata_yaml = yaml.dump(tracer, allow_unicode=True, sort_keys=False, default_flow_style=False)
commit_message = f"""
{query[:250]}
Response:
---
{response[:500]}...
Metadata
---
{metadata_yaml}
""".strip()
# Merge message branch into conversation branch
repo.git.merge(msg_branch, no_ff=True, m=commit_message)
# Delete message branch after merge
repo.delete_head(msg_branch, force=True)
logger.debug(f"Successfully merged {msg_branch} into {conv_branch}")
return True
except Exception as e:
logger.error(f"Failed to merge message {msg_branch} into conversation {conv_branch}: {str(e)}", exc_info=True)
return False

View File

@@ -13,7 +13,7 @@ from tenacity import (
)
from torch import nn
from khoj.utils.helpers import get_device, merge_dicts, timer
from khoj.utils.helpers import fix_json_dict, get_device, merge_dicts, timer
from khoj.utils.rawconfig import SearchResponse
logger = logging.getLogger(__name__)
@@ -31,9 +31,9 @@ class EmbeddingsModel:
):
default_query_encode_kwargs = {"show_progress_bar": False, "normalize_embeddings": True}
default_docs_encode_kwargs = {"show_progress_bar": True, "normalize_embeddings": True}
self.query_encode_kwargs = merge_dicts(query_encode_kwargs, default_query_encode_kwargs)
self.docs_encode_kwargs = merge_dicts(docs_encode_kwargs, default_docs_encode_kwargs)
self.model_kwargs = merge_dicts(model_kwargs, {"device": get_device()})
self.query_encode_kwargs = merge_dicts(fix_json_dict(query_encode_kwargs), default_query_encode_kwargs)
self.docs_encode_kwargs = merge_dicts(fix_json_dict(docs_encode_kwargs), default_docs_encode_kwargs)
self.model_kwargs = merge_dicts(fix_json_dict(model_kwargs), {"device": get_device()})
self.model_name = model_name
self.inference_endpoint = embeddings_inference_endpoint
self.api_key = embeddings_inference_endpoint_api_key

View File

@@ -26,8 +26,10 @@ async def text_to_image(
references: List[Dict[str, Any]],
online_results: Dict[str, Any],
send_status_func: Optional[Callable] = None,
uploaded_image_url: Optional[str] = None,
query_images: Optional[List[str]] = None,
agent: Agent = None,
query_files: str = None,
tracer: dict = {},
):
status_code = 200
image = None
@@ -65,9 +67,11 @@ async def text_to_image(
note_references=references,
online_results=online_results,
model_type=text_to_image_config.model_type,
uploaded_image_url=uploaded_image_url,
query_images=query_images,
user=user,
agent=agent,
query_files=query_files,
tracer=tracer,
)
if send_status_func:
@@ -87,18 +91,18 @@ async def text_to_image(
if "content_policy_violation" in e.message:
logger.error(f"Image Generation blocked by OpenAI: {e}")
status_code = e.status_code # type: ignore
message = f"Image generation blocked by OpenAI: {e.message}" # type: ignore
message = f"Image generation blocked by OpenAI due to policy violation" # type: ignore
yield image_url or image, status_code, message, intent_type.value
return
else:
logger.error(f"Image Generation failed with {e}", exc_info=True)
message = f"Image generation failed with OpenAI error: {e.message}" # type: ignore
message = f"Image generation failed using OpenAI" # type: ignore
status_code = e.status_code # type: ignore
yield image_url or image, status_code, message, intent_type.value
return
except requests.RequestException as e:
logger.error(f"Image Generation failed with {e}", exc_info=True)
message = f"Image generation using {text2image_model} via {text_to_image_config.model_type} failed with error: {e}"
message = f"Image generation using {text2image_model} via {text_to_image_config.model_type} failed due to a network error."
status_code = 502
yield image_url or image, status_code, message, intent_type.value
return
@@ -204,9 +208,10 @@ def generate_image_with_replicate(
# Raise exception if the image generation task fails
if status != "succeeded":
error = get_prediction.get("error")
if retry_count >= 10:
raise requests.RequestException("Image generation timed out")
raise requests.RequestException(f"Image generation failed with status: {status}")
raise requests.RequestException(f"Image generation failed with status: {status}, message: {error}")
# Get the generated image
image_url = get_prediction["output"][0] if isinstance(get_prediction["output"], list) else get_prediction["output"]

View File

@@ -4,7 +4,7 @@ import logging
import os
import urllib.parse
from collections import defaultdict
from typing import Callable, Dict, List, Optional, Tuple, Union
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union
import aiohttp
from bs4 import BeautifulSoup
@@ -52,7 +52,9 @@ OLOSTEP_QUERY_PARAMS = {
"expandMarkdown": "True",
"expandHtml": "False",
}
MAX_WEBPAGES_TO_READ = 1
DEFAULT_MAX_WEBPAGES_TO_READ = 1
MAX_WEBPAGES_TO_INFER = 10
async def search_online(
@@ -62,8 +64,12 @@ async def search_online(
user: KhojUser,
send_status_func: Optional[Callable] = None,
custom_filters: List[str] = [],
uploaded_image_url: str = None,
max_webpages_to_read: int = DEFAULT_MAX_WEBPAGES_TO_READ,
query_images: List[str] = None,
previous_subqueries: Set = set(),
agent: Agent = None,
query_files: str = None,
tracer: dict = {},
):
query += " ".join(custom_filters)
if not is_internet_connected():
@@ -72,36 +78,52 @@ async def search_online(
return
# Breakdown the query into subqueries to get the correct answer
subqueries = await generate_online_subqueries(
query, conversation_history, location, user, uploaded_image_url=uploaded_image_url, agent=agent
new_subqueries = await generate_online_subqueries(
query,
conversation_history,
location,
user,
query_images=query_images,
agent=agent,
tracer=tracer,
query_files=query_files,
)
response_dict = {}
subqueries = list(new_subqueries - previous_subqueries)
response_dict: Dict[str, Dict[str, List[Dict] | Dict]] = {}
if subqueries:
logger.info(f"🌐 Searching the Internet for {list(subqueries)}")
if send_status_func:
subqueries_str = "\n- " + "\n- ".join(list(subqueries))
async for event in send_status_func(f"**Searching the Internet for**: {subqueries_str}"):
yield {ChatEvent.STATUS: event}
if is_none_or_empty(subqueries):
logger.info("No new subqueries to search online")
yield response_dict
return
with timer(f"Internet searches for {list(subqueries)} took", logger):
logger.info(f"🌐 Searching the Internet for {subqueries}")
if send_status_func:
subqueries_str = "\n- " + "\n- ".join(subqueries)
async for event in send_status_func(f"**Searching the Internet for**: {subqueries_str}"):
yield {ChatEvent.STATUS: event}
with timer(f"Internet searches for {subqueries} took", logger):
search_func = search_with_google if SERPER_DEV_API_KEY else search_with_jina
search_tasks = [search_func(subquery, location) for subquery in subqueries]
search_results = await asyncio.gather(*search_tasks)
response_dict = {subquery: search_result for subquery, search_result in search_results}
# Gather distinct web pages from organic results for subqueries without an instant answer.
# Content of web pages is directly available when Jina is used for search.
webpages: Dict[str, Dict] = {}
for subquery in response_dict:
if "answerBox" in response_dict[subquery]:
continue
for organic in response_dict[subquery].get("organic", [])[:MAX_WEBPAGES_TO_READ]:
for idx, organic in enumerate(response_dict[subquery].get("organic", [])):
link = organic.get("link")
if link in webpages:
if link in webpages and idx < max_webpages_to_read:
webpages[link]["queries"].add(subquery)
else:
# Content of web pages is directly available when Jina is used for search.
elif idx < max_webpages_to_read:
webpages[link] = {"queries": {subquery}, "content": organic.get("content")}
# Only keep webpage content for up to max_webpages_to_read organic results.
if idx >= max_webpages_to_read and not is_none_or_empty(organic.get("content")):
organic["content"] = None
response_dict[subquery]["organic"][idx] = organic
# Read, extract relevant info from the retrieved web pages
if webpages:
@@ -111,7 +133,9 @@ async def search_online(
async for event in send_status_func(f"**Reading web pages**: {webpage_links_str}"):
yield {ChatEvent.STATUS: event}
tasks = [
read_webpage_and_extract_content(data["queries"], link, data["content"], user=user, agent=agent)
read_webpage_and_extract_content(
data["queries"], link, data.get("content"), user=user, agent=agent, tracer=tracer
)
for link, data in webpages.items()
]
results = await asyncio.gather(*tasks)
@@ -151,22 +175,34 @@ async def read_webpages(
location: LocationData,
user: KhojUser,
send_status_func: Optional[Callable] = None,
uploaded_image_url: str = None,
query_images: List[str] = None,
agent: Agent = None,
max_webpages_to_read: int = DEFAULT_MAX_WEBPAGES_TO_READ,
query_files: str = None,
tracer: dict = {},
):
"Infer web pages to read from the query and extract relevant information from them"
logger.info(f"Inferring web pages to read")
if send_status_func:
async for event in send_status_func(f"**Inferring web pages to read**"):
yield {ChatEvent.STATUS: event}
urls = await infer_webpage_urls(query, conversation_history, location, user, uploaded_image_url)
urls = await infer_webpage_urls(
query,
conversation_history,
location,
user,
query_images,
agent=agent,
query_files=query_files,
tracer=tracer,
)
# Get the top 10 web pages to read
urls = urls[:max_webpages_to_read]
logger.info(f"Reading web pages at: {urls}")
if send_status_func:
webpage_links_str = "\n- " + "\n- ".join(list(urls))
async for event in send_status_func(f"**Reading web pages**: {webpage_links_str}"):
yield {ChatEvent.STATUS: event}
tasks = [read_webpage_and_extract_content({query}, url, user=user, agent=agent) for url in urls]
tasks = [read_webpage_and_extract_content({query}, url, user=user, agent=agent, tracer=tracer) for url in urls]
results = await asyncio.gather(*tasks)
response: Dict[str, Dict] = defaultdict(dict)
@@ -192,7 +228,12 @@ async def read_webpage(
async def read_webpage_and_extract_content(
subqueries: set[str], url: str, content: str = None, user: KhojUser = None, agent: Agent = None
subqueries: set[str],
url: str,
content: str = None,
user: KhojUser = None,
agent: Agent = None,
tracer: dict = {},
) -> Tuple[set[str], str, Union[None, str]]:
# Select the web scrapers to use for reading the web page
web_scrapers = await ConversationAdapters.aget_enabled_webscrapers()
@@ -214,7 +255,9 @@ async def read_webpage_and_extract_content(
# Extract relevant information from the web page
if is_none_or_empty(extracted_info):
with timer(f"Extracting relevant information from web page at '{url}' took", logger):
extracted_info = await extract_relevant_info(subqueries, content, user=user, agent=agent)
extracted_info = await extract_relevant_info(
subqueries, content, user=user, agent=agent, tracer=tracer
)
# If we successfully extracted information, break the loop
if not is_none_or_empty(extracted_info):
@@ -340,3 +383,25 @@ async def search_with_jina(query: str, location: LocationData) -> Tuple[str, Dic
for item in response_json["data"]
]
return query, {"organic": parsed_response}
def deduplicate_organic_results(online_results: dict) -> dict:
"""Deduplicate organic search results based on links across all queries."""
# Keep track of seen links to filter out duplicates across queries
seen_links = set()
deduplicated_results = {}
# Process each query's results
for query, results in online_results.items():
# Filter organic results keeping only first occurrence of each link
filtered_organic = []
for result in results.get("organic", []):
link = result.get("link")
if link and link not in seen_links:
seen_links.add(link)
filtered_organic.append(result)
# Update results with deduplicated organic entries
deduplicated_results[query] = {**results, "organic": filtered_organic}
return deduplicated_results

View File

@@ -0,0 +1,178 @@
import base64
import datetime
import json
import logging
import mimetypes
import os
from pathlib import Path
from typing import Any, Callable, List, NamedTuple, Optional
import aiohttp
from khoj.database.adapters import FileObjectAdapters
from khoj.database.models import Agent, FileObject, KhojUser
from khoj.processor.conversation import prompts
from khoj.processor.conversation.utils import (
ChatEvent,
clean_code_python,
clean_json,
construct_chat_history,
)
from khoj.routers.helpers import send_message_to_model_wrapper
from khoj.utils.helpers import is_none_or_empty, timer
from khoj.utils.rawconfig import LocationData
logger = logging.getLogger(__name__)
SANDBOX_URL = os.getenv("KHOJ_TERRARIUM_URL", "http://localhost:8080")
class GeneratedCode(NamedTuple):
code: str
input_files: List[str]
input_links: List[str]
async def run_code(
query: str,
conversation_history: dict,
context: str,
location_data: LocationData,
user: KhojUser,
send_status_func: Optional[Callable] = None,
query_images: List[str] = None,
agent: Agent = None,
sandbox_url: str = SANDBOX_URL,
query_files: str = None,
tracer: dict = {},
):
# Generate Code
if send_status_func:
async for event in send_status_func(f"**Generate code snippet** for {query}"):
yield {ChatEvent.STATUS: event}
try:
with timer("Chat actor: Generate programs to execute", logger):
generated_code = await generate_python_code(
query,
conversation_history,
context,
location_data,
user,
query_images,
agent,
tracer,
query_files,
)
except Exception as e:
raise ValueError(f"Failed to generate code for {query} with error: {e}")
# Prepare Input Data
input_data = []
user_input_files: List[FileObject] = []
for input_file in generated_code.input_files:
user_input_files += await FileObjectAdapters.aget_file_objects_by_name(user, input_file)
for f in user_input_files:
input_data.append(
{
"filename": os.path.basename(f.file_name),
"b64_data": base64.b64encode(f.raw_text.encode("utf-8")).decode("utf-8"),
}
)
# Run Code
if send_status_func:
async for event in send_status_func(f"**Running code snippet**"):
yield {ChatEvent.STATUS: event}
try:
with timer("Chat actor: Execute generated program", logger, log_level=logging.INFO):
result = await execute_sandboxed_python(generated_code.code, input_data, sandbox_url)
code = result.pop("code")
logger.info(f"Executed Code:\n--@@--\n{code}\n--@@--Result:\n--@@--\n{result}\n--@@--")
yield {query: {"code": code, "results": result}}
except Exception as e:
raise ValueError(f"Failed to run code for {query} with error: {e}")
async def generate_python_code(
q: str,
conversation_history: dict,
context: str,
location_data: LocationData,
user: KhojUser,
query_images: list[str] = None,
agent: Agent = None,
tracer: dict = {},
query_files: str = None,
) -> GeneratedCode:
location = f"{location_data}" if location_data else "Unknown"
username = prompts.user_name.format(name=user.get_full_name()) if user.get_full_name() else ""
chat_history = construct_chat_history(conversation_history)
utc_date = datetime.datetime.now(datetime.timezone.utc).strftime("%Y-%m-%d")
personality_context = (
prompts.personality_context.format(personality=agent.personality) if agent and agent.personality else ""
)
code_generation_prompt = prompts.python_code_generation_prompt.format(
current_date=utc_date,
query=q,
chat_history=chat_history,
context=context,
location=location,
username=username,
personality_context=personality_context,
)
response = await send_message_to_model_wrapper(
code_generation_prompt,
query_images=query_images,
response_type="json_object",
user=user,
tracer=tracer,
query_files=query_files,
)
# Validate that the response is a non-empty, JSON-serializable list
response = clean_json(response)
response = json.loads(response)
code = response.get("code", "").strip()
input_files = response.get("input_files", [])
input_links = response.get("input_links", [])
if not isinstance(code, str) or is_none_or_empty(code):
raise ValueError
return GeneratedCode(code, input_files, input_links)
async def execute_sandboxed_python(code: str, input_data: list[dict], sandbox_url: str = SANDBOX_URL) -> dict[str, Any]:
"""
Takes code to run as a string and calls the terrarium API to execute it.
Returns the result of the code execution as a dictionary.
Reference data i/o format based on Terrarium example client code at:
https://github.com/cohere-ai/cohere-terrarium/blob/main/example-clients/python/terrarium_client.py
"""
headers = {"Content-Type": "application/json"}
cleaned_code = clean_code_python(code)
data = {"code": cleaned_code, "files": input_data}
async with aiohttp.ClientSession() as session:
async with session.post(sandbox_url, json=data, headers=headers) as response:
if response.status == 200:
result: dict[str, Any] = await response.json()
result["code"] = cleaned_code
# Store decoded output files
for output_file in result.get("output_files", []):
# Decode text files as UTF-8
if mimetypes.guess_type(output_file["filename"])[0].startswith("text/") or Path(
output_file["filename"]
).suffix in [".org", ".md", ".json"]:
output_file["b64_data"] = base64.b64decode(output_file["b64_data"]).decode("utf-8")
return result
else:
return {
"code": cleaned_code,
"success": False,
"std_err": f"Failed to execute code with {response.status}",
}

View File

@@ -6,7 +6,7 @@ import os
import threading
import time
import uuid
from typing import Any, Callable, List, Optional, Union
from typing import Any, Callable, List, Optional, Set, Union
import cron_descriptor
import pytz
@@ -21,11 +21,12 @@ from starlette.authentication import has_required_scope, requires
from khoj.configure import initialize_content
from khoj.database import adapters
from khoj.database.adapters import (
AgentAdapters,
AutomationAdapters,
ConversationAdapters,
EntryAdapters,
get_default_search_model,
get_user_photo,
get_user_search_model_or_default,
)
from khoj.database.models import (
Agent,
@@ -42,6 +43,7 @@ from khoj.processor.conversation.offline.chat_model import extract_questions_off
from khoj.processor.conversation.offline.whisper import transcribe_audio_offline
from khoj.processor.conversation.openai.gpt import extract_questions
from khoj.processor.conversation.openai.whisper import transcribe_audio
from khoj.processor.conversation.utils import defilter_query
from khoj.routers.helpers import (
ApiUserRateLimiter,
ChatEvent,
@@ -114,10 +116,16 @@ async def execute_search(
dedupe: Optional[bool] = True,
agent: Optional[Agent] = None,
):
start_time = time.time()
# Run validation checks
results: List[SearchResponse] = []
start_time = time.time()
# Ensure the agent, if present, is accessible by the user
if user and agent and not await AgentAdapters.ais_agent_accessible(agent, user):
logger.error(f"Agent {agent.slug} is not accessible by user {user}")
return results
if q is None or q == "":
logger.warning(f"No query param (q) passed in API call to initiate search")
return results
@@ -142,7 +150,7 @@ async def execute_search(
encoded_asymmetric_query = None
if t != SearchType.Image:
with timer("Encoding query took", logger=logger):
search_model = await sync_to_async(get_user_search_model_or_default)(user)
search_model = await sync_to_async(get_default_search_model)()
encoded_asymmetric_query = state.embeddings_model[search_model.name].embed_query(defiltered_query)
with concurrent.futures.ThreadPoolExecutor() as executor:
@@ -159,8 +167,8 @@ async def execute_search(
search_futures += [
executor.submit(
text_search.query,
user,
user_query,
user,
t,
question_embedding=encoded_asymmetric_query,
max_distance=max_distance,
@@ -204,7 +212,7 @@ def update(
logger.warning(error_msg)
raise HTTPException(status_code=500, detail=error_msg)
try:
initialize_content(regenerate=force, search_type=t, user=user)
initialize_content(user=user, regenerate=force, search_type=t)
except Exception as e:
error_msg = f"🚨 Failed to update server via API: {e}"
logger.error(error_msg, exc_info=True)
@@ -340,13 +348,16 @@ async def extract_references_and_questions(
conversation_commands: List[ConversationCommand] = [ConversationCommand.Default],
location_data: LocationData = None,
send_status_func: Optional[Callable] = None,
uploaded_image_url: Optional[str] = None,
query_images: Optional[List[str]] = None,
previous_inferred_queries: Set = set(),
agent: Agent = None,
query_files: str = None,
tracer: dict = {},
):
user = request.user.object if request.user.is_authenticated else None
# Initialize Variables
compiled_references: List[Any] = []
compiled_references: List[dict[str, str]] = []
inferred_queries: List[str] = []
agent_has_entries = False
@@ -375,9 +386,7 @@ async def extract_references_and_questions(
return
# Extract filter terms from user message
defiltered_query = q
for filter in [DateFilter(), WordFilter(), FileFilter()]:
defiltered_query = filter.defilter(defiltered_query)
defiltered_query = defilter_query(q)
filters_in_query = q.replace(defiltered_query, "").strip()
conversation = await sync_to_async(ConversationAdapters.get_conversation_by_id)(conversation_id)
@@ -417,6 +426,8 @@ async def extract_references_and_questions(
user=user,
max_prompt_size=conversation_config.max_prompt_size,
personality_context=personality_context,
query_files=query_files,
tracer=tracer,
)
elif conversation_config.model_type == ChatModelOptions.ModelType.OPENAI:
openai_chat_config = conversation_config.openai_config
@@ -431,37 +442,48 @@ async def extract_references_and_questions(
conversation_log=meta_log,
location_data=location_data,
user=user,
uploaded_image_url=uploaded_image_url,
query_images=query_images,
vision_enabled=vision_enabled,
personality_context=personality_context,
query_files=query_files,
tracer=tracer,
)
elif conversation_config.model_type == ChatModelOptions.ModelType.ANTHROPIC:
api_key = conversation_config.openai_config.api_key
chat_model = conversation_config.chat_model
inferred_queries = extract_questions_anthropic(
defiltered_query,
query_images=query_images,
model=chat_model,
api_key=api_key,
conversation_log=meta_log,
location_data=location_data,
user=user,
vision_enabled=vision_enabled,
personality_context=personality_context,
query_files=query_files,
tracer=tracer,
)
elif conversation_config.model_type == ChatModelOptions.ModelType.GOOGLE:
api_key = conversation_config.openai_config.api_key
chat_model = conversation_config.chat_model
inferred_queries = extract_questions_gemini(
defiltered_query,
query_images=query_images,
model=chat_model,
api_key=api_key,
conversation_log=meta_log,
location_data=location_data,
max_tokens=conversation_config.max_prompt_size,
user=user,
vision_enabled=vision_enabled,
personality_context=personality_context,
query_files=query_files,
tracer=tracer,
)
# Collate search results as context for GPT
inferred_queries = list(set(inferred_queries) - previous_inferred_queries)
with timer("Searching knowledge base took", logger):
search_results = []
logger.info(f"🔍 Searching knowledge base with queries: {inferred_queries}")
@@ -485,7 +507,8 @@ async def extract_references_and_questions(
)
search_results = text_search.deduplicated_search_responses(search_results)
compiled_references = [
{"compiled": item.additional["compiled"], "file": item.additional["file"]} for item in search_results
{"query": q, "compiled": item.additional["compiled"], "file": item.additional["file"]}
for q, item in zip(inferred_queries, search_results)
]
yield compiled_references, inferred_queries, defiltered_query

View File

@@ -1,5 +1,7 @@
import json
import logging
import random
from datetime import datetime, timedelta, timezone
from typing import Dict, List, Optional
from asgiref.sync import sync_to_async
@@ -9,8 +11,8 @@ from fastapi.responses import Response
from pydantic import BaseModel
from starlette.authentication import requires
from khoj.database.adapters import AgentAdapters
from khoj.database.models import Agent, KhojUser
from khoj.database.adapters import AgentAdapters, ConversationAdapters
from khoj.database.models import Agent, Conversation, KhojUser
from khoj.routers.helpers import CommonQueryParams, acheck_if_safe_prompt
from khoj.utils.helpers import (
ConversationCommand,
@@ -45,30 +47,49 @@ async def all_agents(
) -> Response:
user: KhojUser = request.user.object if request.user.is_authenticated else None
agents = await AgentAdapters.aget_all_accessible_agents(user)
default_agent = await AgentAdapters.aget_default_agent()
default_agent_packet = None
agents_packet = list()
for agent in agents:
files = agent.fileobject_set.all()
file_names = [file.file_name for file in files]
agents_packet.append(
{
"slug": agent.slug,
"name": agent.name,
"persona": agent.personality,
"creator": agent.creator.username if agent.creator else None,
"managed_by_admin": agent.managed_by_admin,
"color": agent.style_color,
"icon": agent.style_icon,
"privacy_level": agent.privacy_level,
"chat_model": agent.chat_model.chat_model,
"files": file_names,
"input_tools": agent.input_tools,
"output_modes": agent.output_modes,
}
)
agent_packet = {
"slug": agent.slug,
"name": agent.name,
"persona": agent.personality,
"creator": agent.creator.username if agent.creator else None,
"managed_by_admin": agent.managed_by_admin,
"color": agent.style_color,
"icon": agent.style_icon,
"privacy_level": agent.privacy_level,
"chat_model": agent.chat_model.chat_model,
"files": file_names,
"input_tools": agent.input_tools,
"output_modes": agent.output_modes,
}
if agent.slug == default_agent.slug:
default_agent_packet = agent_packet
else:
agents_packet.append(agent_packet)
# Load recent conversation sessions
min_date = datetime.min.replace(tzinfo=timezone.utc)
two_weeks_ago = datetime.today() - timedelta(weeks=2)
conversations = []
if user:
conversations = await sync_to_async(list[Conversation])(
ConversationAdapters.get_conversation_sessions(user, request.user.client_app)
.filter(updated_at__gte=two_weeks_ago)
.order_by("-updated_at")[:50]
)
conversation_times = {conv.agent.slug: conv.updated_at for conv in conversations if conv.agent}
# Put default agent first, then sort by mru and finally shuffle unused randomly
random.shuffle(agents_packet)
agents_packet.sort(key=lambda x: conversation_times.get(x["slug"]) or min_date, reverse=True)
if default_agent_packet:
agents_packet.insert(0, default_agent_packet)
# Make sure that the agent named 'khoj' is first in the list. Everything else is sorted by name.
agents_packet.sort(key=lambda x: x["name"])
agents_packet.sort(key=lambda x: x["slug"] == "khoj", reverse=True)
return Response(content=json.dumps(agents_packet), media_type="application/json", status_code=200)
@@ -162,7 +183,7 @@ async def delete_agent(
@api_agents.post("", response_class=Response)
@requires(["authenticated", "premium"])
@requires(["authenticated"])
async def create_agent(
request: Request,
common: CommonQueryParams,
@@ -170,10 +191,9 @@ async def create_agent(
) -> Response:
user: KhojUser = request.user.object
is_safe_prompt, reason = True, ""
if body.privacy_level != Agent.PrivacyLevel.PRIVATE:
is_safe_prompt, reason = await acheck_if_safe_prompt(body.persona)
is_safe_prompt, reason = await acheck_if_safe_prompt(
body.persona, user, lax=body.privacy_level == Agent.PrivacyLevel.PRIVATE
)
if not is_safe_prompt:
return Response(
@@ -215,7 +235,7 @@ async def create_agent(
@api_agents.patch("", response_class=Response)
@requires(["authenticated", "premium"])
@requires(["authenticated"])
async def update_agent(
request: Request,
common: CommonQueryParams,
@@ -223,10 +243,9 @@ async def update_agent(
) -> Response:
user: KhojUser = request.user.object
is_safe_prompt, reason = True, ""
if body.privacy_level != Agent.PrivacyLevel.PRIVATE:
is_safe_prompt, reason = await acheck_if_safe_prompt(body.persona)
is_safe_prompt, reason = await acheck_if_safe_prompt(
body.persona, user, lax=body.privacy_level == Agent.PrivacyLevel.PRIVATE
)
if not is_safe_prompt:
return Response(

View File

@@ -3,9 +3,10 @@ import base64
import json
import logging
import time
import uuid
from datetime import datetime
from functools import partial
from typing import Dict, Optional
from typing import Any, Dict, List, Optional
from urllib.parse import unquote
from asgiref.sync import sync_to_async
@@ -18,28 +19,40 @@ from khoj.database.adapters import (
AgentAdapters,
ConversationAdapters,
EntryAdapters,
FileObjectAdapters,
PublicConversationAdapters,
aget_user_name,
)
from khoj.database.models import Agent, KhojUser
from khoj.processor.conversation.prompts import help_message, no_entries_found
from khoj.processor.conversation.utils import save_to_conversation_log
from khoj.processor.conversation.utils import defilter_query, save_to_conversation_log
from khoj.processor.image.generate import text_to_image
from khoj.processor.speech.text_to_speech import generate_text_to_speech
from khoj.processor.tools.online_search import read_webpages, search_online
from khoj.processor.tools.online_search import (
deduplicate_organic_results,
read_webpages,
search_online,
)
from khoj.processor.tools.run_code import run_code
from khoj.routers.api import extract_references_and_questions
from khoj.routers.email import send_query_feedback
from khoj.routers.helpers import (
ApiImageRateLimiter,
ApiUserRateLimiter,
ChatEvent,
ChatRequestBody,
CommonQueryParams,
ConversationCommandRateLimiter,
DeleteMessageRequestBody,
FeedbackData,
acreate_title_from_history,
agenerate_chat_response,
aget_relevant_information_sources,
aget_relevant_output_modes,
construct_automation_created_message,
create_automation,
extract_relevant_summary,
gather_raw_query_files,
generate_excalidraw_diagram,
generate_summary_from_files,
get_conversation_command,
is_query_empty,
is_ready_to_chat,
@@ -47,6 +60,10 @@ from khoj.routers.helpers import (
update_telemetry_state,
validate_conversation_config,
)
from khoj.routers.research import (
InformationCollectionIteration,
execute_information_collection,
)
from khoj.routers.storage import upload_image_to_bucket
from khoj.utils import state
from khoj.utils.helpers import (
@@ -59,21 +76,22 @@ from khoj.utils.helpers import (
get_device,
is_none_or_empty,
)
from khoj.utils.rawconfig import FileFilterRequest, FilesFilterRequest, LocationData
from khoj.utils.rawconfig import (
ChatRequestBody,
FileFilterRequest,
FilesFilterRequest,
LocationData,
)
# Initialize Router
logger = logging.getLogger(__name__)
conversation_command_rate_limiter = ConversationCommandRateLimiter(
trial_rate_limit=100, subscribed_rate_limit=6000, slug="command"
trial_rate_limit=20, subscribed_rate_limit=75, slug="command"
)
api_chat = APIRouter()
from pydantic import BaseModel
from khoj.routers.email import send_query_feedback
@api_chat.get("/conversation/file-filters/{conversation_id}", response_class=Response)
@requires(["authenticated"])
@@ -109,7 +127,7 @@ def add_files_filter(request: Request, filter: FilesFilterRequest):
file_filters = ConversationAdapters.add_files_to_filter(request.user.object, conversation_id, files_filter)
return Response(content=json.dumps(file_filters), media_type="application/json", status_code=200)
except Exception as e:
logger.error(f"Error adding file filter {filter.filename}: {e}", exc_info=True)
logger.error(f"Error adding file filter {filter.filenames}: {e}", exc_info=True)
raise HTTPException(status_code=422, detail=str(e))
@@ -135,12 +153,6 @@ def remove_file_filter(request: Request, filter: FileFilterRequest) -> Response:
return Response(content=json.dumps(file_filters), media_type="application/json", status_code=200)
class FeedbackData(BaseModel):
uquery: str
kquery: str
sentiment: str
@api_chat.post("/feedback")
@requires(["authenticated"])
async def sendfeedback(request: Request, data: FeedbackData):
@@ -155,10 +167,10 @@ async def text_to_speech(
common: CommonQueryParams,
text: str,
rate_limiter_per_minute=Depends(
ApiUserRateLimiter(requests=20, subscribed_requests=20, window=60, slug="chat_minute")
ApiUserRateLimiter(requests=30, subscribed_requests=30, window=60, slug="chat_minute")
),
rate_limiter_per_day=Depends(
ApiUserRateLimiter(requests=50, subscribed_requests=300, window=60 * 60 * 24, slug="chat_day")
ApiUserRateLimiter(requests=100, subscribed_requests=600, window=60 * 60 * 24, slug="chat_day")
),
) -> Response:
voice_model = await ConversationAdapters.aget_voice_model_config(request.user.object)
@@ -367,7 +379,7 @@ def fork_public_conversation(
{
"status": "ok",
"next_url": redirect_uri,
"conversation_id": new_conversation.id,
"conversation_id": str(new_conversation.id),
}
),
)
@@ -523,20 +535,43 @@ async def set_conversation_title(
)
class ChatRequestBody(BaseModel):
q: str
n: Optional[int] = 7
d: Optional[float] = None
stream: Optional[bool] = False
title: Optional[str] = None
conversation_id: Optional[str] = None
city: Optional[str] = None
region: Optional[str] = None
country: Optional[str] = None
country_code: Optional[str] = None
timezone: Optional[str] = None
image: Optional[str] = None
create_new: Optional[bool] = False
@api_chat.post("/title")
@requires(["authenticated"])
async def generate_chat_title(
request: Request,
common: CommonQueryParams,
conversation_id: str,
):
user: KhojUser = request.user.object
conversation = await ConversationAdapters.aget_conversation_by_user(user=user, conversation_id=conversation_id)
# Conversation.title is explicitly set by the user. Do not override.
if conversation.title:
return {"status": "ok", "title": conversation.title}
if not conversation:
raise HTTPException(status_code=404, detail="Conversation not found")
new_title = await acreate_title_from_history(request.user.object, conversation=conversation)
conversation.slug = new_title
await conversation.asave()
return {"status": "ok", "title": new_title}
@api_chat.delete("/conversation/message", response_class=Response)
@requires(["authenticated"])
def delete_message(request: Request, delete_request: DeleteMessageRequestBody) -> Response:
user = request.user.object
success = ConversationAdapters.delete_message_by_turn_id(
user, delete_request.conversation_id, delete_request.turn_id
)
if success:
return Response(content=json.dumps({"status": "ok"}), media_type="application/json", status_code=200)
else:
return Response(content=json.dumps({"status": "error", "message": "Message not found"}), status_code=404)
@api_chat.post("")
@@ -546,11 +581,12 @@ async def chat(
common: CommonQueryParams,
body: ChatRequestBody,
rate_limiter_per_minute=Depends(
ApiUserRateLimiter(requests=60, subscribed_requests=200, window=60, slug="chat_minute")
ApiUserRateLimiter(requests=20, subscribed_requests=20, window=60, slug="chat_minute")
),
rate_limiter_per_day=Depends(
ApiUserRateLimiter(requests=600, subscribed_requests=6000, window=60 * 60 * 24, slug="chat_day")
ApiUserRateLimiter(requests=100, subscribed_requests=600, window=60 * 60 * 24, slug="chat_day")
),
image_rate_limiter=Depends(ApiImageRateLimiter(max_images=10, max_combined_size_mb=20)),
):
# Access the parameters from the body
q = body.q
@@ -559,14 +595,16 @@ async def chat(
stream = body.stream
title = body.title
conversation_id = body.conversation_id
turn_id = str(body.turn_id or uuid.uuid4())
city = body.city
region = body.region
country = body.country or get_country_name_from_timezone(body.timezone)
country_code = body.country_code or get_country_code_from_timezone(body.timezone)
timezone = body.timezone
image = body.image
raw_images = body.images
raw_query_files = body.files
async def event_generator(q: str, image: str):
async def event_generator(q: str, images: list[str]):
start_time = time.perf_counter()
ttft = None
chat_metadata: dict = {}
@@ -574,21 +612,35 @@ async def chat(
user: KhojUser = request.user.object
event_delimiter = "␃🔚␗"
q = unquote(q)
train_of_thought = []
nonlocal conversation_id
nonlocal raw_query_files
uploaded_image_url = None
if image:
decoded_string = unquote(image)
base64_data = decoded_string.split(",", 1)[1]
image_bytes = base64.b64decode(base64_data)
webp_image_bytes = convert_image_to_webp(image_bytes)
try:
uploaded_image_url = upload_image_to_bucket(webp_image_bytes, request.user.object.id)
except:
uploaded_image_url = None
tracer: dict = {
"mid": turn_id,
"cid": conversation_id,
"uid": user.id,
"khoj_version": state.khoj_version,
}
uploaded_images: list[str] = []
if images:
for image in images:
decoded_string = unquote(image)
base64_data = decoded_string.split(",", 1)[1]
image_bytes = base64.b64decode(base64_data)
webp_image_bytes = convert_image_to_webp(image_bytes)
uploaded_image = upload_image_to_bucket(webp_image_bytes, request.user.object.id)
if uploaded_image:
uploaded_images.append(uploaded_image)
query_files: Dict[str, str] = {}
if raw_query_files:
for file in raw_query_files:
query_files[file.name] = file.content
async def send_event(event_type: ChatEvent, data: str | dict):
nonlocal connection_alive, ttft
nonlocal connection_alive, ttft, train_of_thought
if not connection_alive or await request.is_disconnected():
connection_alive = False
logger.warning(f"User {user} disconnected from {common.client} client")
@@ -596,11 +648,14 @@ async def chat(
try:
if event_type == ChatEvent.END_LLM_RESPONSE:
collect_telemetry()
if event_type == ChatEvent.START_LLM_RESPONSE:
elif event_type == ChatEvent.START_LLM_RESPONSE:
ttft = time.perf_counter() - start_time
elif event_type == ChatEvent.STATUS:
train_of_thought.append({"type": event_type.value, "data": data})
if event_type == ChatEvent.MESSAGE:
yield data
elif event_type == ChatEvent.REFERENCES or stream:
elif event_type == ChatEvent.REFERENCES or ChatEvent.METADATA or stream:
yield json.dumps({"type": event_type.value, "data": data}, ensure_ascii=False)
except asyncio.CancelledError as e:
connection_alive = False
@@ -644,6 +699,11 @@ async def chat(
metadata=chat_metadata,
)
if is_query_empty(q):
async for result in send_llm_response("Please ask your query to get started."):
yield result
return
conversation_commands = [get_conversation_command(query=q, any_references=True)]
conversation = await ConversationAdapters.aget_conversation_by_user(
@@ -659,6 +719,9 @@ async def chat(
return
conversation_id = conversation.id
async for event in send_event(ChatEvent.METADATA, {"conversationId": str(conversation_id), "turnId": turn_id}):
yield event
agent: Agent | None = None
default_agent = await AgentAdapters.aget_default_agent()
if conversation.agent and conversation.agent != default_agent:
@@ -670,46 +733,99 @@ async def chat(
agent = default_agent
await is_ready_to_chat(user)
user_name = await aget_user_name(user)
location = None
if city or region or country or country_code:
location = LocationData(city=city, region=region, country=country, country_code=country_code)
if is_query_empty(q):
async for result in send_llm_response("Please ask your query to get started."):
yield result
return
user_message_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
meta_log = conversation.conversation_log
is_automated_task = conversation_commands == [ConversationCommand.AutomatedTask]
researched_results = ""
online_results: Dict = dict()
code_results: Dict = dict()
## Extract Document References
compiled_references: List[Any] = []
inferred_queries: List[Any] = []
file_filters = conversation.file_filters if conversation and conversation.file_filters else []
attached_file_context = gather_raw_query_files(query_files)
if conversation_commands == [ConversationCommand.Default] or is_automated_task:
conversation_commands = await aget_relevant_information_sources(
q,
meta_log,
is_automated_task,
user=user,
uploaded_image_url=uploaded_image_url,
query_images=uploaded_images,
agent=agent,
query_files=attached_file_context,
tracer=tracer,
)
# If we're doing research, we don't want to do anything else
if ConversationCommand.Research in conversation_commands:
conversation_commands = [ConversationCommand.Research]
conversation_commands_str = ", ".join([cmd.value for cmd in conversation_commands])
async for result in send_event(
ChatEvent.STATUS, f"**Chose Data Sources to Search:** {conversation_commands_str}"
):
yield result
mode = await aget_relevant_output_modes(q, meta_log, is_automated_task, user, uploaded_image_url, agent)
mode = await aget_relevant_output_modes(
q, meta_log, is_automated_task, user, uploaded_images, agent, tracer=tracer
)
async for result in send_event(ChatEvent.STATUS, f"**Decided Response Mode:** {mode.value}"):
yield result
if mode not in conversation_commands:
conversation_commands.append(mode)
for cmd in conversation_commands:
await conversation_command_rate_limiter.update_and_check_if_valid(request, cmd)
q = q.replace(f"/{cmd.value}", "").strip()
try:
await conversation_command_rate_limiter.update_and_check_if_valid(request, cmd)
q = q.replace(f"/{cmd.value}", "").strip()
except HTTPException as e:
async for result in send_llm_response(str(e.detail)):
yield result
return
defiltered_query = defilter_query(q)
if conversation_commands == [ConversationCommand.Research]:
async for research_result in execute_information_collection(
request=request,
user=user,
query=defiltered_query,
conversation_id=conversation_id,
conversation_history=meta_log,
query_images=uploaded_images,
agent=agent,
send_status_func=partial(send_event, ChatEvent.STATUS),
user_name=user_name,
location=location,
file_filters=conversation.file_filters if conversation else [],
query_files=attached_file_context,
tracer=tracer,
):
if isinstance(research_result, InformationCollectionIteration):
if research_result.summarizedResult:
if research_result.onlineContext:
online_results.update(research_result.onlineContext)
if research_result.codeContext:
code_results.update(research_result.codeContext)
if research_result.context:
compiled_references.extend(research_result.context)
researched_results += research_result.summarizedResult
else:
yield research_result
# researched_results = await extract_relevant_info(q, researched_results, agent)
if state.verbose > 1:
logger.debug(f"Researched Results: {researched_results}")
used_slash_summarize = conversation_commands == [ConversationCommand.Summarize]
file_filters = conversation.file_filters if conversation else []
@@ -730,52 +846,26 @@ async def chat(
response_log = "No files selected for summarization. Please add files using the section on the left."
async for result in send_llm_response(response_log):
yield result
elif len(file_filters) > 1 and not agent_has_entries:
response_log = "Only one file can be selected for summarization."
async for result in send_llm_response(response_log):
yield result
else:
try:
file_object = None
if await EntryAdapters.aagent_has_entries(agent):
file_names = await EntryAdapters.aget_agent_entry_filepaths(agent)
if len(file_names) > 0:
file_object = await FileObjectAdapters.async_get_file_objects_by_name(
None, file_names[0], agent
)
async for response in generate_summary_from_files(
q=q,
user=user,
file_filters=file_filters,
meta_log=meta_log,
query_images=uploaded_images,
agent=agent,
send_status_func=partial(send_event, ChatEvent.STATUS),
query_files=attached_file_context,
tracer=tracer,
):
if isinstance(response, dict) and ChatEvent.STATUS in response:
yield response[ChatEvent.STATUS]
else:
if isinstance(response, str):
response_log = response
async for result in send_llm_response(response):
yield result
if len(file_filters) > 0:
file_object = await FileObjectAdapters.async_get_file_objects_by_name(user, file_filters[0])
if len(file_object) == 0:
response_log = "Sorry, I couldn't find the full text of this file. Please re-upload the document and try again."
async for result in send_llm_response(response_log):
yield result
return
contextual_data = " ".join([file.raw_text for file in file_object])
if not q:
q = "Create a general summary of the file"
async for result in send_event(
ChatEvent.STATUS, f"**Constructing Summary Using:** {file_object[0].file_name}"
):
yield result
response = await extract_relevant_summary(
q,
contextual_data,
conversation_history=meta_log,
uploaded_image_url=uploaded_image_url,
user=user,
agent=agent,
)
response_log = str(response)
async for result in send_llm_response(response_log):
yield result
except Exception as e:
response_log = "Error summarizing file. Please try again, or contact support."
logger.error(f"Error summarizing file for {user.email}: {e}", exc_info=True)
async for result in send_llm_response(response_log):
yield result
await sync_to_async(save_to_conversation_log)(
q,
response_log,
@@ -785,7 +875,10 @@ async def chat(
intent_type="summarize",
client_application=request.user.client_app,
conversation_id=conversation_id,
uploaded_image_url=uploaded_image_url,
query_images=uploaded_images,
train_of_thought=train_of_thought,
raw_query_files=raw_query_files,
tracer=tracer,
)
return
@@ -794,7 +887,7 @@ async def chat(
if not q:
conversation_config = await ConversationAdapters.aget_user_conversation_config(user)
if conversation_config == None:
conversation_config = await ConversationAdapters.aget_default_conversation_config()
conversation_config = await ConversationAdapters.aget_default_conversation_config(user)
model_type = conversation_config.model_type
formatted_help = help_message.format(model=model_type, version=state.khoj_version, device=get_device())
async for result in send_llm_response(formatted_help):
@@ -807,7 +900,7 @@ async def chat(
if ConversationCommand.Automation in conversation_commands:
try:
automation, crontime, query_to_run, subject = await create_automation(
q, timezone, user, request.url, meta_log
q, timezone, user, request.url, meta_log, tracer=tracer
)
except Exception as e:
logger.error(f"Error scheduling task {q} for {user.email}: {e}")
@@ -828,7 +921,10 @@ async def chat(
conversation_id=conversation_id,
inferred_queries=[query_to_run],
automation_id=automation.id,
uploaded_image_url=uploaded_image_url,
query_images=uploaded_images,
train_of_thought=train_of_thought,
raw_query_files=raw_query_files,
tracer=tracer,
)
async for result in send_llm_response(llm_response):
yield result
@@ -836,48 +932,50 @@ async def chat(
# Gather Context
## Extract Document References
compiled_references, inferred_queries, defiltered_query = [], [], q
try:
async for result in extract_references_and_questions(
request,
meta_log,
q,
(n or 7),
d,
conversation_id,
conversation_commands,
location,
partial(send_event, ChatEvent.STATUS),
uploaded_image_url=uploaded_image_url,
agent=agent,
):
if isinstance(result, dict) and ChatEvent.STATUS in result:
yield result[ChatEvent.STATUS]
else:
compiled_references.extend(result[0])
inferred_queries.extend(result[1])
defiltered_query = result[2]
except Exception as e:
error_message = f"Error searching knowledge base: {e}. Attempting to respond without document references."
logger.error(error_message, exc_info=True)
async for result in send_event(
ChatEvent.STATUS, "Document search failed. I'll try respond without document references"
):
yield result
if not ConversationCommand.Research in conversation_commands:
try:
async for result in extract_references_and_questions(
request,
meta_log,
q,
(n or 7),
d,
conversation_id,
conversation_commands,
location,
partial(send_event, ChatEvent.STATUS),
query_images=uploaded_images,
agent=agent,
query_files=attached_file_context,
tracer=tracer,
):
if isinstance(result, dict) and ChatEvent.STATUS in result:
yield result[ChatEvent.STATUS]
else:
compiled_references.extend(result[0])
inferred_queries.extend(result[1])
defiltered_query = result[2]
except Exception as e:
error_message = (
f"Error searching knowledge base: {e}. Attempting to respond without document references."
)
logger.error(error_message, exc_info=True)
async for result in send_event(
ChatEvent.STATUS, "Document search failed. I'll try respond without document references"
):
yield result
if not is_none_or_empty(compiled_references):
headings = "\n- " + "\n- ".join(set([c.get("compiled", c).split("\n")[0] for c in compiled_references]))
# Strip only leading # from headings
headings = headings.replace("#", "")
async for result in send_event(ChatEvent.STATUS, f"**Found Relevant Notes**: {headings}"):
yield result
if not is_none_or_empty(compiled_references):
headings = "\n- " + "\n- ".join(set([c.get("compiled", c).split("\n")[0] for c in compiled_references]))
# Strip only leading # from headings
headings = headings.replace("#", "")
async for result in send_event(ChatEvent.STATUS, f"**Found Relevant Notes**: {headings}"):
yield result
online_results: Dict = dict()
if conversation_commands == [ConversationCommand.Notes] and not await EntryAdapters.auser_has_entries(user):
async for result in send_llm_response(f"{no_entries_found.format()}"):
yield result
return
if conversation_commands == [ConversationCommand.Notes] and not await EntryAdapters.auser_has_entries(user):
async for result in send_llm_response(f"{no_entries_found.format()}"):
yield result
return
if ConversationCommand.Notes in conversation_commands and is_none_or_empty(compiled_references):
conversation_commands.remove(ConversationCommand.Notes)
@@ -892,8 +990,10 @@ async def chat(
user,
partial(send_event, ChatEvent.STATUS),
custom_filters,
uploaded_image_url=uploaded_image_url,
query_images=uploaded_images,
agent=agent,
query_files=attached_file_context,
tracer=tracer,
):
if isinstance(result, dict) and ChatEvent.STATUS in result:
yield result[ChatEvent.STATUS]
@@ -916,8 +1016,10 @@ async def chat(
location,
user,
partial(send_event, ChatEvent.STATUS),
uploaded_image_url=uploaded_image_url,
query_images=uploaded_images,
agent=agent,
query_files=attached_file_context,
tracer=tracer,
):
if isinstance(result, dict) and ChatEvent.STATUS in result:
yield result[ChatEvent.STATUS]
@@ -944,13 +1046,43 @@ async def chat(
):
yield result
## Gather Code Results
if ConversationCommand.Code in conversation_commands:
try:
context = f"# Iteration 1:\n#---\nNotes:\n{compiled_references}\n\nOnline Results:{online_results}"
async for result in run_code(
defiltered_query,
meta_log,
context,
location,
user,
partial(send_event, ChatEvent.STATUS),
query_images=uploaded_images,
agent=agent,
query_files=attached_file_context,
tracer=tracer,
):
if isinstance(result, dict) and ChatEvent.STATUS in result:
yield result[ChatEvent.STATUS]
else:
code_results = result
async for result in send_event(ChatEvent.STATUS, f"**Ran code snippets**: {len(code_results)}"):
yield result
except ValueError as e:
logger.warning(
f"Failed to use code tool: {e}. Attempting to respond without code results",
exc_info=True,
)
## Send Gathered References
unique_online_results = deduplicate_organic_results(online_results)
async for result in send_event(
ChatEvent.REFERENCES,
{
"inferredQueries": inferred_queries,
"context": compiled_references,
"onlineContext": online_results,
"onlineContext": unique_online_results,
"codeContext": code_results,
},
):
yield result
@@ -966,20 +1098,22 @@ async def chat(
references=compiled_references,
online_results=online_results,
send_status_func=partial(send_event, ChatEvent.STATUS),
uploaded_image_url=uploaded_image_url,
query_images=uploaded_images,
agent=agent,
query_files=attached_file_context,
tracer=tracer,
):
if isinstance(result, dict) and ChatEvent.STATUS in result:
yield result[ChatEvent.STATUS]
else:
image, status_code, improved_image_prompt, intent_type = result
generated_image, status_code, improved_image_prompt, intent_type = result
if image is None or status_code != 200:
if generated_image is None or status_code != 200:
content_obj = {
"content-type": "application/json",
"intentType": intent_type,
"detail": improved_image_prompt,
"image": image,
"image": None,
}
async for result in send_llm_response(json.dumps(content_obj)):
yield result
@@ -987,7 +1121,7 @@ async def chat(
await sync_to_async(save_to_conversation_log)(
q,
image,
generated_image,
user,
meta_log,
user_message_time,
@@ -997,17 +1131,83 @@ async def chat(
conversation_id=conversation_id,
compiled_references=compiled_references,
online_results=online_results,
uploaded_image_url=uploaded_image_url,
code_results=code_results,
query_images=uploaded_images,
train_of_thought=train_of_thought,
raw_query_files=raw_query_files,
tracer=tracer,
)
content_obj = {
"intentType": intent_type,
"inferredQueries": [improved_image_prompt],
"image": image,
"image": generated_image,
}
async for result in send_llm_response(json.dumps(content_obj)):
yield result
return
if ConversationCommand.Diagram in conversation_commands:
async for result in send_event(ChatEvent.STATUS, f"Creating diagram"):
yield result
intent_type = "excalidraw"
inferred_queries = []
diagram_description = ""
async for result in generate_excalidraw_diagram(
q=defiltered_query,
conversation_history=meta_log,
location_data=location,
note_references=compiled_references,
online_results=online_results,
query_images=uploaded_images,
user=user,
agent=agent,
send_status_func=partial(send_event, ChatEvent.STATUS),
query_files=attached_file_context,
tracer=tracer,
):
if isinstance(result, dict) and ChatEvent.STATUS in result:
yield result[ChatEvent.STATUS]
else:
better_diagram_description_prompt, excalidraw_diagram_description = result
if better_diagram_description_prompt and excalidraw_diagram_description:
inferred_queries.append(better_diagram_description_prompt)
diagram_description = excalidraw_diagram_description
else:
async for result in send_llm_response(f"Failed to generate diagram. Please try again later."):
yield result
return
content_obj = {
"intentType": intent_type,
"inferredQueries": inferred_queries,
"image": diagram_description,
}
await sync_to_async(save_to_conversation_log)(
q,
excalidraw_diagram_description,
user,
meta_log,
user_message_time,
intent_type="excalidraw",
inferred_queries=[better_diagram_description_prompt],
client_application=request.user.client_app,
conversation_id=conversation_id,
compiled_references=compiled_references,
online_results=online_results,
code_results=code_results,
query_images=uploaded_images,
train_of_thought=train_of_thought,
raw_query_files=raw_query_files,
tracer=tracer,
)
async for result in send_llm_response(json.dumps(content_obj)):
yield result
return
## Generate Text Output
async for result in send_event(ChatEvent.STATUS, f"**Generating a well-informed response**"):
yield result
@@ -1017,6 +1217,7 @@ async def chat(
conversation,
compiled_references,
online_results,
code_results,
inferred_queries,
conversation_commands,
user,
@@ -1024,7 +1225,12 @@ async def chat(
conversation_id,
location,
user_name,
uploaded_image_url,
researched_results,
uploaded_images,
train_of_thought,
attached_file_context,
raw_query_files,
tracer,
)
# Send Response
@@ -1050,9 +1256,9 @@ async def chat(
## Stream Text Response
if stream:
return StreamingResponse(event_generator(q, image=image), media_type="text/plain")
return StreamingResponse(event_generator(q, images=raw_images), media_type="text/plain")
## Non-Streaming Text Response
else:
response_iterator = event_generator(q, image=image)
response_iterator = event_generator(q, images=raw_images)
response_data = await read_chat_stream(response_iterator)
return Response(content=json.dumps(response_data), media_type="application/json", status_code=200)

View File

@@ -36,16 +36,18 @@ from khoj.database.models import (
LocalPlaintextConfig,
NotionConfig,
)
from khoj.processor.content.docx.docx_to_entries import DocxToEntries
from khoj.processor.content.pdf.pdf_to_entries import PdfToEntries
from khoj.routers.helpers import (
ApiIndexedDataLimiter,
CommonQueryParams,
configure_content,
get_file_content,
get_user_config,
update_telemetry_state,
)
from khoj.utils import constants, state
from khoj.utils.config import SearchModels
from khoj.utils.helpers import get_file_type
from khoj.utils.rawconfig import (
ContentConfig,
FullConfig,
@@ -237,7 +239,7 @@ async def set_content_notion(
if updated_config.token:
# Trigger an async job to configure_content. Let it run without blocking the response.
background_tasks.add_task(run_in_executor, configure_content, {}, False, SearchType.Notion, user)
background_tasks.add_task(run_in_executor, configure_content, user, {}, False, SearchType.Notion)
update_telemetry_state(
request=request,
@@ -375,6 +377,85 @@ async def delete_content_source(
return {"status": "ok"}
@api_content.post("/convert", status_code=200)
@requires(["authenticated"])
async def convert_documents(
request: Request,
files: List[UploadFile],
client: Optional[str] = None,
):
MAX_FILE_SIZE_MB = 10 # 10MB limit
MAX_FILE_SIZE_BYTES = MAX_FILE_SIZE_MB * 1024 * 1024
converted_files = []
supported_files = ["org", "markdown", "pdf", "plaintext", "docx"]
for file in files:
# Check file size first
file_size = 0
content = await file.read()
file_size = len(content)
await file.seek(0) # Reset file pointer
if file_size > MAX_FILE_SIZE_BYTES:
logger.warning(
f"Skipped converting oversized file ({file_size / 1024 / 1024:.1f}MB) sent by {client} client: {file.filename}"
)
continue
file_data = get_file_content(file)
if file_data.file_type in supported_files:
extracted_content = (
file_data.content.decode(file_data.encoding) if file_data.encoding else file_data.content
)
if file_data.file_type == "docx":
entries_per_page = DocxToEntries.extract_text(file_data.content)
annotated_pages = [
f"Page {index} of {file_data.name}:\n\n{entry}" for index, entry in enumerate(entries_per_page)
]
extracted_content = "\n".join(annotated_pages)
elif file_data.file_type == "pdf":
entries_per_page = PdfToEntries.extract_text(file_data.content)
annotated_pages = [
f"Page {index} of {file_data.name}:\n\n{entry}" for index, entry in enumerate(entries_per_page)
]
extracted_content = "\n".join(annotated_pages)
else:
# Convert content to string
extracted_content = extracted_content.decode("utf-8")
# Calculate size in bytes. Some of the content might be in bytes, some in str.
if isinstance(extracted_content, str):
size_in_bytes = len(extracted_content.encode("utf-8"))
elif isinstance(extracted_content, bytes):
size_in_bytes = len(extracted_content)
else:
size_in_bytes = 0
logger.warning(f"Unexpected content type: {type(extracted_content)}")
converted_files.append(
{
"name": file_data.name,
"content": extracted_content,
"file_type": file_data.file_type,
"size": size_in_bytes,
}
)
else:
logger.warning(f"Skipped converting unsupported file type sent by {client} client: {file.filename}")
update_telemetry_state(
request=request,
telemetry_type="api",
api="convert_documents",
client=client,
)
return Response(content=json.dumps(converted_files), media_type="application/json", status_code=200)
async def indexer(
request: Request,
files: list[UploadFile],
@@ -398,12 +479,13 @@ async def indexer(
try:
logger.info(f"📬 Updating content index via API call by {client} client")
for file in files:
file_content = file.file.read()
file_type, encoding = get_file_type(file.content_type, file_content)
if file_type in index_files:
index_files[file_type][file.filename] = file_content.decode(encoding) if encoding else file_content
file_data = get_file_content(file)
if file_data.file_type in index_files:
index_files[file_data.file_type][file_data.name] = (
file_data.content.decode(file_data.encoding) if file_data.encoding else file_data.content
)
else:
logger.warning(f"Skipped indexing unsupported file type sent by {client} client: {file.filename}")
logger.warning(f"Skipped indexing unsupported file type sent by {client} client: {file_data.name}")
indexer_input = IndexerInput(
org=index_files["org"],
@@ -440,10 +522,10 @@ async def indexer(
success = await loop.run_in_executor(
None,
configure_content,
user,
indexer_input.model_dump(),
regenerate,
t,
user,
)
if not success:
raise RuntimeError(f"Failed to {method} {t} data sent by {client} client into content index")

View File

@@ -94,39 +94,6 @@ async def update_voice_model(
return Response(status_code=202, content=json.dumps({"status": "ok"}))
@api_model.post("/search", status_code=200)
@requires(["authenticated"])
async def update_search_model(
request: Request,
id: str,
client: Optional[str] = None,
):
user = request.user.object
prev_config = await adapters.aget_user_search_model(user)
new_config = await adapters.aset_user_search_model(user, int(id))
if prev_config and int(id) != prev_config.id and new_config:
await EntryAdapters.adelete_all_entries(user)
if not prev_config:
# If the use was just using the default config, delete all the entries and set the new config.
await EntryAdapters.adelete_all_entries(user)
if new_config is None:
return {"status": "error", "message": "Model not found"}
else:
update_telemetry_state(
request=request,
telemetry_type="api",
api="set_search_model",
client=client,
metadata={"search_model": new_config.setting.name},
)
return {"status": "ok"}
@api_model.post("/paint", status_code=200)
@requires(["authenticated"])
async def update_paint_model(

View File

@@ -1,12 +1,14 @@
import json
import logging
import os
from datetime import datetime, timezone
from asgiref.sync import sync_to_async
from fastapi import APIRouter, Request
from fastapi import APIRouter, Request, Response
from starlette.authentication import requires
from khoj.database import adapters
from khoj.database.models import KhojUser, Subscription
from khoj.routers.helpers import update_telemetry_state
from khoj.utils import state
@@ -73,7 +75,7 @@ async def subscribe(request: Request):
elif event_type in {"customer.subscription.deleted"}:
# Reset the user to trial state
user, is_new = await adapters.set_user_subscription(
customer_email, is_recurring=False, renewal_date=False, type="trial"
customer_email, is_recurring=False, renewal_date=False, type=Subscription.Type.TRIAL
)
success = user is not None
@@ -82,7 +84,7 @@ async def subscribe(request: Request):
request=request,
telemetry_type="api",
api="create_user",
metadata={"user_id": str(user.user.uuid)},
metadata={"server_id": str(user.user.uuid)},
)
logger.log(logging.INFO, f"🥳 New User Created: {user.user.uuid}")
@@ -92,8 +94,9 @@ async def subscribe(request: Request):
@subscription_router.patch("")
@requires(["authenticated"])
async def update_subscription(request: Request, email: str, operation: str):
async def update_subscription(request: Request, operation: str):
# Retrieve the customer's details
email = request.user.object.email
customers = stripe.Customer.list(email=email).auto_paging_iter()
customer = next(customers, None)
if customer is None:
@@ -116,3 +119,19 @@ async def update_subscription(request: Request, email: str, operation: str):
return {"success": False, "message": "No subscription found that is set to cancel"}
return {"success": False, "message": "Invalid operation"}
@subscription_router.post("/trial", response_class=Response)
@requires(["authenticated"])
async def start_trial(request: Request) -> Response:
user: KhojUser = request.user.object
# Start a trial for the user
updated_subscription = await adapters.astart_trial_subscription(user)
# Return trial status as a JSON response
return Response(
content=json.dumps({"trial_enabled": updated_subscription is not None}),
media_type="application/json",
status_code=200,
)

View File

@@ -90,7 +90,7 @@ async def login_magic_link(request: Request, form: MagicLinkForm):
request=request,
telemetry_type="api",
api="create_user",
metadata={"user_id": str(user.uuid)},
metadata={"server_id": str(user.uuid)},
)
logger.log(logging.INFO, f"🥳 New User Created: {user.uuid}")
@@ -175,7 +175,7 @@ async def auth(request: Request):
request=request,
telemetry_type="api",
api="create_user",
metadata={"user_id": str(khoj_user.uuid)},
metadata={"server_id": str(khoj_user.uuid)},
)
logger.log(logging.INFO, f"🥳 New User Created: {khoj_user.uuid}")
return RedirectResponse(url=next_url, status_code=HTTP_302_FOUND)

Some files were not shown because too many files have changed in this diff Show More