mirror of
https://github.com/khoj-ai/khoj.git
synced 2026-05-13 21:41:41 +00:00
Compare commits
327 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
d607ad7a27 | ||
|
|
8ec1764e42 | ||
|
|
b6714c202f | ||
|
|
f05e64cf8c | ||
|
|
47d3c8c235 | ||
|
|
d7027109a5 | ||
|
|
d68243a3fb | ||
|
|
1cab6c081f | ||
|
|
7bd2f83f97 | ||
|
|
48862a8400 | ||
|
|
5078ac0ce2 | ||
|
|
e1d0015248 | ||
|
|
a52500d289 | ||
|
|
218eed83cd | ||
|
|
b970cfd4b3 | ||
|
|
8e9f4262a9 | ||
|
|
92c1efe6ee | ||
|
|
af0215765c | ||
|
|
7b39f2014a | ||
|
|
dc109559d4 | ||
|
|
cdda9c2e73 | ||
|
|
dd36303bb7 | ||
|
|
ba2471dc02 | ||
|
|
536fe994be | ||
|
|
10bca6fa8f | ||
|
|
ff5c10c221 | ||
|
|
27fa39353e | ||
|
|
b563f46a2e | ||
|
|
2bb2ff27a4 | ||
|
|
47937d5148 | ||
|
|
ae4eb96d48 | ||
|
|
7954f39633 | ||
|
|
4223b355dc | ||
|
|
fd15fc1e59 | ||
|
|
35d6c792e4 | ||
|
|
8805e731fd | ||
|
|
a5e2b9e745 | ||
|
|
55200be4fa | ||
|
|
7468f6a6ed | ||
|
|
137687ee49 | ||
|
|
306f7a2132 | ||
|
|
eb492f3025 | ||
|
|
8ef7892c5e | ||
|
|
d892ab3174 | ||
|
|
80ee35b9b1 | ||
|
|
f967bdf702 | ||
|
|
84a8088c2b | ||
|
|
170d959feb | ||
|
|
2c543bedd7 | ||
|
|
79b15e4594 | ||
|
|
bd55028115 | ||
|
|
92b6b3ef7b | ||
|
|
835fa80a4b | ||
|
|
459318be13 | ||
|
|
dbf0c26247 | ||
|
|
e5ac076fc4 | ||
|
|
bc95a99fb4 | ||
|
|
ceb29eae74 | ||
|
|
3badb27744 | ||
|
|
78630603f4 | ||
|
|
807687a0ac | ||
|
|
7159b0b735 | ||
|
|
4695174149 | ||
|
|
ad46b0e718 | ||
|
|
ee062d1c48 | ||
|
|
623a97a9ee | ||
|
|
33498d876b | ||
|
|
4b8be55958 | ||
|
|
9bbe27fe36 | ||
|
|
3a51996f64 | ||
|
|
a89160e2f7 | ||
|
|
e521853895 | ||
|
|
92c3b9c502 | ||
|
|
140c67f6b5 | ||
|
|
b8ed98530f | ||
|
|
ecc81e06a7 | ||
|
|
394035136d | ||
|
|
3b1e8462cd | ||
|
|
de73cbc610 | ||
|
|
4cad96ded6 | ||
|
|
8679294bed | ||
|
|
05a93fcbed | ||
|
|
a0480d5f6c | ||
|
|
dc26da0a12 | ||
|
|
b51ee644aa | ||
|
|
5724d16a6f | ||
|
|
cf0bcec0e7 | ||
|
|
1f372bf2b1 | ||
|
|
7543360210 | ||
|
|
b6145df3be | ||
|
|
3dc9139cee | ||
|
|
a27b8d3e54 | ||
|
|
362bdebd02 | ||
|
|
e3ca52b7cb | ||
|
|
1e89baca7b | ||
|
|
1ccbf72752 | ||
|
|
99c1d2831a | ||
|
|
075b4ecf15 | ||
|
|
ec44cbe1e7 | ||
|
|
791eb205f6 | ||
|
|
96904e0769 | ||
|
|
31b5fde163 | ||
|
|
5b18dc96e0 | ||
|
|
8d1b1bc78e | ||
|
|
e85dd59295 | ||
|
|
1f79a10541 | ||
|
|
cff8e02b60 | ||
|
|
14e453039d | ||
|
|
ab321dc518 | ||
|
|
1a83bbcc94 | ||
|
|
e6eb87bbb5 | ||
|
|
a213b593e8 | ||
|
|
327fcb8f62 | ||
|
|
b79a9ec36d | ||
|
|
9c7b36dc69 | ||
|
|
ac21b10dd5 | ||
|
|
2b35790165 | ||
|
|
22f3ed3f5d | ||
|
|
baa939f4ce | ||
|
|
8fd2fe162f | ||
|
|
cead1598b9 | ||
|
|
c1c779a7ef | ||
|
|
b3dad1f393 | ||
|
|
23a49b6b95 | ||
|
|
cd75151431 | ||
|
|
0b0cfb35e6 | ||
|
|
ffa7f95559 | ||
|
|
73750ef286 | ||
|
|
1fc280db35 | ||
|
|
1c920273dd | ||
|
|
33d36ee58c | ||
|
|
0145b2a366 | ||
|
|
3ea94ac972 | ||
|
|
149cbe1019 | ||
|
|
21858acccc | ||
|
|
19241805ee | ||
|
|
302bd51d17 | ||
|
|
52163fe299 | ||
|
|
7ebf999688 | ||
|
|
159ea44883 | ||
|
|
89597aefe9 | ||
|
|
5b15176e20 | ||
|
|
559601dd0a | ||
|
|
a13760640c | ||
|
|
8d1ecb9bd8 | ||
|
|
adca6cbe9d | ||
|
|
e17dc9f7b5 | ||
|
|
e8e6ead39f | ||
|
|
cb90abc660 | ||
|
|
ca5a6831b6 | ||
|
|
ba15686682 | ||
|
|
f64f5b3b6e | ||
|
|
b3a63017b5 | ||
|
|
d44e68ba01 | ||
|
|
358a6ce95d | ||
|
|
2ac840e3f2 | ||
|
|
1448b8b3fc | ||
|
|
b8c6989677 | ||
|
|
86ffd7a7a2 | ||
|
|
83ca820abe | ||
|
|
dc8e89b5de | ||
|
|
d865994062 | ||
|
|
06aeca2670 | ||
|
|
01881dc7a2 | ||
|
|
3e695df198 | ||
|
|
a3751d6a04 | ||
|
|
a39e747d07 | ||
|
|
deff512baa | ||
|
|
d3184ae39a | ||
|
|
8bd94bf855 | ||
|
|
b63fbc5345 | ||
|
|
82f3d79064 | ||
|
|
2b2564257e | ||
|
|
9935d4db0b | ||
|
|
d184498038 | ||
|
|
d75ce4a9e3 | ||
|
|
5bea0c705b | ||
|
|
1f1b182461 | ||
|
|
ebaed53069 | ||
|
|
889dbd738a | ||
|
|
50ffd7f199 | ||
|
|
a5d0ca6e1c | ||
|
|
aad7528d1b | ||
|
|
3e17ab438a | ||
|
|
8ddd70f3a9 | ||
|
|
ee0789eb3d | ||
|
|
4e39088f5b | ||
|
|
94074b7007 | ||
|
|
a691ce4aa6 | ||
|
|
2924909692 | ||
|
|
68499e253b | ||
|
|
101ea6efb1 | ||
|
|
0bd78791ca | ||
|
|
a121d67b10 | ||
|
|
9e8ac7f89e | ||
|
|
e4285941d1 | ||
|
|
33e48aa27e | ||
|
|
fd71a4b086 | ||
|
|
3e5b5ec122 | ||
|
|
bf96d81943 | ||
|
|
3e97ebf0c7 | ||
|
|
8af9dc3ee1 | ||
|
|
0f3927e810 | ||
|
|
f04f871a72 | ||
|
|
ddc6ccde2d | ||
|
|
ea0712424b | ||
|
|
a3022b7556 | ||
|
|
eb6424f14d | ||
|
|
6fcd6a5659 | ||
|
|
384f394336 | ||
|
|
10c8fd3b2a | ||
|
|
7e0a692d16 | ||
|
|
b257fa1884 | ||
|
|
0f6f282c30 | ||
|
|
479e156168 | ||
|
|
a11b5293fb | ||
|
|
5acf40c440 | ||
|
|
12b32a3d04 | ||
|
|
adee5a3e20 | ||
|
|
01d740debd | ||
|
|
37317e321d | ||
|
|
2a32836d1a | ||
|
|
30f9225021 | ||
|
|
5120597d4e | ||
|
|
8d588e0765 | ||
|
|
abad5348a0 | ||
|
|
6fd50a5956 | ||
|
|
82eac5a043 | ||
|
|
f3ce47b445 | ||
|
|
bc059eeb0b | ||
|
|
3b978b9b67 | ||
|
|
c5e91c346a | ||
|
|
9f2c02d9f7 | ||
|
|
218946edda | ||
|
|
7d9a06c8ab | ||
|
|
7c29af9745 | ||
|
|
2a50694089 | ||
|
|
a134cd835c | ||
|
|
c81e708833 | ||
|
|
750fbce0c2 | ||
|
|
3be505db48 | ||
|
|
c6f3253ebd | ||
|
|
b3fff43542 | ||
|
|
6c393800cc | ||
|
|
91bbd19333 | ||
|
|
110c67f083 | ||
|
|
aca8bef024 | ||
|
|
0dad4212fa | ||
|
|
1e993d561b | ||
|
|
e8fb79a369 | ||
|
|
39a613d3bc | ||
|
|
0847fb0102 | ||
|
|
0c52a1169a | ||
|
|
7ac241b766 | ||
|
|
892040972f | ||
|
|
db959a504d | ||
|
|
21e69b506d | ||
|
|
9b554feb91 | ||
|
|
220ff1df62 | ||
|
|
54b92eaf73 | ||
|
|
bdbe8f003e | ||
|
|
ad197be70c | ||
|
|
59fec37943 | ||
|
|
a979457442 | ||
|
|
5fca41cc29 | ||
|
|
a6bfdbdbfe | ||
|
|
7646ac6779 | ||
|
|
5d5bea6a5f | ||
|
|
1ad6e1749f | ||
|
|
cb6b3ec1e9 | ||
|
|
545259e308 | ||
|
|
3cc1426edf | ||
|
|
58a331227d | ||
|
|
3e39fac455 | ||
|
|
0d6a54c10f | ||
|
|
e2abc1a257 | ||
|
|
c6c48cfc18 | ||
|
|
feb6d65ef8 | ||
|
|
336c6c3689 | ||
|
|
81fb65fa0a | ||
|
|
3c93f07b3f | ||
|
|
07ab7ebf07 | ||
|
|
d6206aa80c | ||
|
|
263eee4351 | ||
|
|
abcd11cfc0 | ||
|
|
9356e66b94 | ||
|
|
9314f0a398 | ||
|
|
a2200466b7 | ||
|
|
9daaae0fdb | ||
|
|
20d495c43a | ||
|
|
01a58b71a5 | ||
|
|
1b13d069f5 | ||
|
|
f462d34547 | ||
|
|
564491e164 | ||
|
|
6a8fd9bf33 | ||
|
|
0eacc0b2b0 | ||
|
|
284c8c331b | ||
|
|
1e390325d2 | ||
|
|
5a699a52d2 | ||
|
|
61df1d5db8 | ||
|
|
9e7025b330 | ||
|
|
2dc5804571 | ||
|
|
e69a8382f2 | ||
|
|
536422a40c | ||
|
|
8d33c764b7 | ||
|
|
b373073f47 | ||
|
|
a98f97ed5e | ||
|
|
8044733201 | ||
|
|
4d33239af6 | ||
|
|
6ad85e2275 | ||
|
|
a6f6e4f418 | ||
|
|
ec248efd31 | ||
|
|
a6905a9f0c | ||
|
|
028b6e6379 | ||
|
|
717d9da8d8 | ||
|
|
03544efde2 | ||
|
|
ab81b01fcb | ||
|
|
5b8d663cf1 | ||
|
|
7b288a1179 | ||
|
|
f71e4969d3 | ||
|
|
f7e6f99a32 | ||
|
|
6960fb097c | ||
|
|
4978360852 | ||
|
|
46ef205a75 | ||
|
|
4fbaef10e9 | ||
|
|
c91678078d | ||
|
|
f867d5ed72 |
@@ -14,6 +14,10 @@ services:
|
||||
interval: 30s
|
||||
timeout: 10s
|
||||
retries: 5
|
||||
sandbox:
|
||||
image: ghcr.io/khoj-ai/terrarium:latest
|
||||
ports:
|
||||
- "8080:8080"
|
||||
server:
|
||||
depends_on:
|
||||
database:
|
||||
@@ -57,6 +61,10 @@ services:
|
||||
# - KHOJ_NO_HTTPS=True
|
||||
# - KHOJ_DOMAIN=192.168.0.104
|
||||
# - KHOJ_DOMAIN=khoj.example.com
|
||||
# Uncomment the line below to disable telemetry.
|
||||
# Telemetry helps us prioritize feature development and understand how people are using Khoj
|
||||
# Read more at https://docs.khoj.dev/miscellaneous/telemetry
|
||||
# - KHOJ_TELEMETRY_DISABLE=True
|
||||
command: --host="0.0.0.0" --port=42110 -vv --anonymous-mode --non-interactive
|
||||
|
||||
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
# Admin Panel
|
||||
> Describes the Khoj settings configurable via the admin panel
|
||||
|
||||
By default, you admin panel is available at `http://localhost:42110/server/admin/`. You can access the admin panel by logging in with your admin credentials (this would be your `KHOJ_ADMIN_EMAIL` and `KHOJ_ADMIN_PASSWORD`). The admin panel allows you to configure various settings for your Khoj server.
|
||||
|
||||
## App Settings
|
||||
### Agents
|
||||
Add all the agents you want to use for your different use-cases like Writer, Researcher, Therapist etc.
|
||||
|
||||
@@ -31,7 +31,4 @@ Using LiteLLM with Khoj makes it possible to turn any LLM behind an API into you
|
||||
- Openai Config: `<the proxy config you created in step 3>`
|
||||
- Max prompt size: `20000` (replace with the max prompt size of your model)
|
||||
- Tokenizer: *Do not set for OpenAI, Mistral, Llama3 based models*
|
||||
5. Create a new [Server Chat Setting](http://localhost:42110/server/admin/database/serverchatsettings/add/) on your Khoj admin panel
|
||||
- Default model: `<name of chat model option you created in step 4>`
|
||||
- Summarizer model: `<name of chat model option you created in step 4>`
|
||||
6. Go to [your config](http://localhost:42110/settings) and select the model you just created in the chat model dropdown.
|
||||
5. Go to [your config](http://localhost:42110/settings) and select the model you just created in the chat model dropdown.
|
||||
|
||||
@@ -24,7 +24,4 @@ LM Studio can expose an [OpenAI API compatible server](https://lmstudio.ai/docs/
|
||||
- Openai Config: `<the proxy config you created in step 3>`
|
||||
- Max prompt size: `20000` (replace with the max prompt size of your model)
|
||||
- Tokenizer: *Do not set for OpenAI, mistral, llama3 based models*
|
||||
5. Create a new [Server Chat Setting](http://localhost:42110/server/admin/database/serverchatsettings/add/) on your Khoj admin panel
|
||||
- Default model: `<name of chat model option you created in step 4>`
|
||||
- Summarizer model: `<name of chat model option you created in step 4>`
|
||||
6. Go to [your config](http://localhost:42110/settings) and select the model you just created in the chat model dropdown.
|
||||
5. Go to [your config](http://localhost:42110/settings) and select the model you just created in the chat model dropdown.
|
||||
|
||||
@@ -28,9 +28,6 @@ Ollama exposes a local [OpenAI API compatible server](https://github.com/ollama/
|
||||
- Model Type: `Openai`
|
||||
- Openai Config: `<the ollama config you created in step 3>`
|
||||
- Max prompt size: `20000` (replace with the max prompt size of your model)
|
||||
5. Create a new [Server Chat Setting](http://localhost:42110/server/admin/database/serverchatsettings/add/) on your Khoj admin panel
|
||||
- Default model: `<name of chat model option you created in step 4>`
|
||||
- Summarizer model: `<name of chat model option you created in step 4>`
|
||||
6. Go to [your config](http://localhost:42110/settings) and select the model you just created in the chat model dropdown.
|
||||
5. Go to [your config](http://localhost:42110/settings) and select the model you just created in the chat model dropdown.
|
||||
|
||||
That's it! You should now be able to chat with your Ollama model from Khoj. If you want to add additional models running on Ollama, repeat step 6 for each model.
|
||||
|
||||
@@ -31,7 +31,4 @@ For specific integrations, see our [Ollama](/advanced/ollama), [LMStudio](/advan
|
||||
- Openai Config: `<the proxy config you created in step 2>`
|
||||
- Max prompt size: `2000` (replace with the max prompt size of your model)
|
||||
- Tokenizer: *Do not set for OpenAI, mistral, llama3 based models*
|
||||
4. Create a new [Server Chat Setting](http://localhost:42110/server/admin/database/serverchatsettings/add/) on your Khoj admin panel
|
||||
- Default model: `<name of chat model option you created in step 3>`
|
||||
- Summarizer model: `<name of chat model option you created in step 3>`
|
||||
5. Go to [your config](http://localhost:42110/settings) and select the model you just created in the chat model dropdown.
|
||||
4. Go to [your config](http://localhost:42110/settings) and select the model you just created in the chat model dropdown.
|
||||
|
||||
@@ -12,7 +12,7 @@ Without any desktop clients, you can start chatting with Khoj on WhatsApp. Bear
|
||||
|
||||
In order to use Khoj on WhatsApp with your own data, you need to setup a Khoj Cloud account and connect your WhatsApp account to it. This is a one time setup and you can do it from the [Khoj Cloud config page](https://app.khoj.dev/settings).
|
||||
|
||||
If you hit usage limits for the WhatsApp bot, upgrade to [a paid plan](https://khoj.dev/pricing) on Khoj Cloud.
|
||||
If you hit usage limits for the WhatsApp bot, upgrade to [a paid plan](https://khoj.dev/#pricing) on Khoj Cloud.
|
||||
|
||||
<img src="https://khoj-web-bucket.s3.amazonaws.com/khojwhatsapp.png" alt="WhatsApp QR Code" width="300" height="300" />
|
||||
|
||||
|
||||
@@ -102,7 +102,19 @@ sudo -u postgres createdb khoj --password
|
||||
</TabItem>
|
||||
</Tabs>
|
||||
|
||||
#### 3. Run
|
||||
#### 3. Build the front-end assets
|
||||
|
||||
```shell
|
||||
cd src/interface/web/
|
||||
yarn install
|
||||
yarn export
|
||||
```
|
||||
|
||||
You can optionally use `yarn dev` to start a development server for the front-end which will be available at http://localhost:3000. This is especially useful if you're making changes to the front-end code, but not necessary for running Khoj. Note that streaming does not work on the dev server due to how it is handled with SSR in Next.js.
|
||||
|
||||
Always run `yarn export` to test your front-end changes on http://localhost:42110 before creating a PR.
|
||||
|
||||
#### 4. Run
|
||||
1. Start Khoj
|
||||
```bash
|
||||
khoj -vv
|
||||
|
||||
@@ -43,3 +43,6 @@ Slash commands allows you to change what Khoj uses to respond to your query
|
||||
- **/image**: Generate an image in response to your query.
|
||||
- **/help**: Use /help to get all available commands and general information about Khoj
|
||||
- **/summarize**: Can be used to summarize 1 selected file filter for that conversation. Refer to [File Summarization](summarization) for details.
|
||||
- **/diagram**: Generate a diagram in response to your query. This is built on [Excalidraw](https://excalidraw.com/).
|
||||
- **/code**: Generate and run very simple Python code snippets. Refer to [Code Execution](code_execution) for details.
|
||||
- **/research**: Go deeper in a topic for more accurate, in-depth responses.
|
||||
|
||||
30
documentation/docs/features/code_execution.md
Normal file
30
documentation/docs/features/code_execution.md
Normal file
@@ -0,0 +1,30 @@
|
||||
---
|
||||
---
|
||||
|
||||
# Code Execution
|
||||
|
||||
Khoj can generate and run very simple Python code snippets as well. This is useful if you want to generate a plot, run a simple calculation, or do some basic data manipulation. LLMs by default aren't skilled at complex quantitative tasks. Code generation & execution can come in handy for such tasks.
|
||||
|
||||
Just use `/code` in your chat command.
|
||||
|
||||
### Setup (Self-Hosting)
|
||||
Run [Cohere's Terrarium](https://github.com/cohere-ai/cohere-terrarium) on your machine to enable code generation and execution.
|
||||
|
||||
Check the [instructions](https://github.com/cohere-ai/cohere-terrarium?tab=readme-ov-file#development) for running from source.
|
||||
|
||||
For running with Docker, you can use our [docker-compose.yml](https://github.com/khoj-ai/khoj/blob/master/docker-compose.yml), or start it manually like this:
|
||||
|
||||
```bash
|
||||
docker pull ghcr.io/khoj-ai/terrarium:latest
|
||||
docker run -d -p 8080:8080 ghcr.io/khoj-ai/terrarium:latest
|
||||
```
|
||||
|
||||
#### Verify
|
||||
Verify that it's running, by evaluating a simple Python expression:
|
||||
|
||||
```bash
|
||||
curl -X POST -H "Content-Type: application/json" \
|
||||
--url http://localhost:8080 \
|
||||
--data-raw '{"code": "1 + 1"}' \
|
||||
--no-buffer
|
||||
```
|
||||
@@ -27,7 +27,28 @@ If you want to use the offline chat model and you have a GPU, you should use Ins
|
||||
<Tabs groupId="operating-systems" queryString="os">
|
||||
<TabItem value="macos" label="MacOS">
|
||||
<h3>Prerequisites</h3>
|
||||
Install [Docker Desktop](https://docs.docker.com/desktop/install/mac-install/)
|
||||
<h4>Docker</h4>
|
||||
(Option 1) Click here to install [Docker Desktop](https://docs.docker.com/desktop/install/mac-install/). Make sure you also install the [Docker Compose](https://docs.docker.com/desktop/install/mac-install/) tool.
|
||||
|
||||
(Option 2) Use [Homebrew](https://brew.sh/) to install Docker and Docker Compose.
|
||||
```shell
|
||||
brew install --cask docker
|
||||
brew install docker-compose
|
||||
```
|
||||
<h3>Setup</h3>
|
||||
1. Download the Khoj docker-compose.yml file [from Github](https://github.com/khoj-ai/khoj/blob/master/docker-compose.yml)
|
||||
```shell
|
||||
mkdir ~/.khoj && cd ~/.khoj
|
||||
wget https://raw.githubusercontent.com/khoj-ai/khoj/master/docker-compose.yml
|
||||
```
|
||||
2. Configure the environment variables in the docker-compose.yml
|
||||
- Set `KHOJ_ADMIN_PASSWORD`, `KHOJ_DJANGO_SECRET_KEY` (and optionally the `KHOJ_ADMIN_EMAIL`) to something secure. This allows you to customize Khoj later via the admin panel.
|
||||
- Set `OPENAI_API_KEY`, `ANTHROPIC_API_KEY`, or `GEMINI_API_KEY` to your API key if you want to use OpenAI, Anthropic or Gemini chat models respectively.
|
||||
3. Start Khoj by running the following command in the same directory as your docker-compose.yml file.
|
||||
```shell
|
||||
cd ~/.khoj
|
||||
docker-compose up
|
||||
```
|
||||
</TabItem>
|
||||
<TabItem value="windows" label="Windows">
|
||||
<h3>Prerequisites</h3>
|
||||
@@ -37,30 +58,47 @@ If you want to use the offline chat model and you have a GPU, you should use Ins
|
||||
wsl --install
|
||||
```
|
||||
2. Install [Docker Desktop](https://docs.docker.com/desktop/install/windows-install/) with **[WSL2 backend](https://docs.docker.com/desktop/wsl/#turn-on-docker-desktop-wsl-2)** (default)
|
||||
|
||||
<h3>Setup</h3>
|
||||
1. Download the Khoj docker-compose.yml file [from Github](https://github.com/khoj-ai/khoj/blob/master/docker-compose.yml)
|
||||
```shell
|
||||
# Windows users should use their WSL2 terminal to run these commands
|
||||
mkdir ~/.khoj && cd ~/.khoj
|
||||
wget https://raw.githubusercontent.com/khoj-ai/khoj/master/docker-compose.yml
|
||||
```
|
||||
2. Configure the environment variables in the docker-compose.yml
|
||||
- Set `KHOJ_ADMIN_PASSWORD`, `KHOJ_DJANGO_SECRET_KEY` (and optionally the `KHOJ_ADMIN_EMAIL`) to something secure. This allows you to customize Khoj later via the admin panel.
|
||||
- Set `OPENAI_API_KEY`, `ANTHROPIC_API_KEY`, or `GEMINI_API_KEY` to your API key if you want to use OpenAI, Anthropic or Gemini chat models respectively.
|
||||
3. Start Khoj by running the following command in the same directory as your docker-compose.yml file.
|
||||
```shell
|
||||
# Windows users should use their WSL2 terminal to run these commands
|
||||
cd ~/.khoj
|
||||
docker-compose up
|
||||
```
|
||||
</TabItem>
|
||||
<TabItem value="linux" label="Linux">
|
||||
<h3>Prerequisites</h3>
|
||||
Install [Docker Desktop](https://docs.docker.com/desktop/install/windows-install/).
|
||||
Install [Docker Desktop](https://docs.docker.com/desktop/install/linux/).
|
||||
You can also use your package manager to install Docker Engine & Docker Compose.
|
||||
|
||||
<h3>Setup</h3>
|
||||
1. Download the Khoj docker-compose.yml file [from Github](https://github.com/khoj-ai/khoj/blob/master/docker-compose.yml)
|
||||
```shell
|
||||
mkdir ~/.khoj && cd ~/.khoj
|
||||
wget https://raw.githubusercontent.com/khoj-ai/khoj/master/docker-compose.yml
|
||||
```
|
||||
2. Configure the environment variables in the docker-compose.yml
|
||||
- Set `KHOJ_ADMIN_PASSWORD`, `KHOJ_DJANGO_SECRET_KEY` (and optionally the `KHOJ_ADMIN_EMAIL`) to something secure. This allows you to customize Khoj later via the admin panel.
|
||||
- Set `OPENAI_API_KEY`, `ANTHROPIC_API_KEY`, or `GEMINI_API_KEY` to your API key if you want to use OpenAI, Anthropic or Gemini chat models respectively.
|
||||
3. Start Khoj by running the following command in the same directory as your docker-compose.yml file.
|
||||
```shell
|
||||
cd ~/.khoj
|
||||
docker-compose up
|
||||
```
|
||||
</TabItem>
|
||||
</Tabs>
|
||||
|
||||
<h3>Setup</h3>
|
||||
1. Download the Khoj docker-compose.yml file [from Github](https://github.com/khoj-ai/khoj/blob/master/docker-compose.yml)
|
||||
```shell
|
||||
# Windows users should use their WSL2 terminal to run these commands
|
||||
mkdir ~/.khoj && cd ~/.khoj
|
||||
wget https://raw.githubusercontent.com/khoj-ai/khoj/master/docker-compose.yml
|
||||
```
|
||||
2. Configure the environment variables in the docker-compose.yml
|
||||
- Set `KHOJ_ADMIN_PASSWORD`, `KHOJ_DJANGO_SECRET_KEY` (and optionally the `KHOJ_ADMIN_EMAIL`) to something secure. This allows you to customize Khoj later via the admin panel.
|
||||
- Set `OPENAI_API_KEY`, `ANTHROPIC_API_KEY`, or `GEMINI_API_KEY` to your API key if you want to use OpenAI, Anthropic or Gemini chat models respectively.
|
||||
3. Start Khoj by running the following command in the same directory as your docker-compose.yml file.
|
||||
```shell
|
||||
# Windows users should use their WSL2 terminal to run these commands
|
||||
cd ~/.khoj
|
||||
docker-compose up
|
||||
```
|
||||
|
||||
|
||||
:::info[Remote Access]
|
||||
By default Khoj is only accessible on the machine it is running. To access Khoj from a remote machine see [Remote Access Docs](/advanced/remote).
|
||||
|
||||
@@ -14,13 +14,6 @@ We don't send any personal information or any information from/about your conten
|
||||
|
||||
## Disable Telemetry
|
||||
|
||||
If you're self-hosting Khoj, you can opt out of telemetry at any time. To do so,
|
||||
1. Open `~/.khoj/khoj.yml`
|
||||
2. Add the following configuration:
|
||||
```
|
||||
app:
|
||||
should-log-telemetry: false
|
||||
```
|
||||
3. Save the file and restart Khoj
|
||||
If you're self-hosting Khoj, you can opt out of telemetry at any time by setting the `KHOJ_TELEMETRY_DISABLE` environment variable to `True` via shell or `docker-compose.yml`
|
||||
|
||||
If you have any questions or concerns, please reach out to us on [Discord](https://discord.gg/BDgyabRM6e).
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
{
|
||||
"id": "khoj",
|
||||
"name": "Khoj",
|
||||
"version": "1.26.2",
|
||||
"version": "1.29.1",
|
||||
"minAppVersion": "0.15.0",
|
||||
"description": "Your Second Brain",
|
||||
"author": "Khoj Inc.",
|
||||
|
||||
@@ -62,7 +62,7 @@ dependencies = [
|
||||
"requests >= 2.26.0",
|
||||
"tenacity == 8.3.0",
|
||||
"anyio == 3.7.1",
|
||||
"pymupdf >= 1.23.5",
|
||||
"pymupdf == 1.24.11",
|
||||
"django == 5.0.9",
|
||||
"authlib == 1.2.1",
|
||||
"llama-cpp-python == 0.2.88",
|
||||
@@ -87,7 +87,7 @@ dependencies = [
|
||||
"django_apscheduler == 0.6.2",
|
||||
"anthropic == 0.26.1",
|
||||
"docx2txt == 0.8",
|
||||
"google-generativeai == 0.7.2"
|
||||
"google-generativeai == 0.8.3",
|
||||
]
|
||||
dynamic = ["version"]
|
||||
|
||||
@@ -119,6 +119,9 @@ dev = [
|
||||
"mypy >= 1.0.1",
|
||||
"black >= 23.1.0",
|
||||
"pre-commit >= 3.0.4",
|
||||
"gitpython ~= 3.1.43",
|
||||
"datasets",
|
||||
"pandas",
|
||||
]
|
||||
|
||||
[tool.hatch.version]
|
||||
|
||||
@@ -326,7 +326,7 @@
|
||||
entries.forEach(entry => {
|
||||
// If the element is in the viewport, fetch the remaining message and unobserve the element
|
||||
if (entry.isIntersecting) {
|
||||
fetchRemainingChatMessages(chatHistoryUrl, headers);
|
||||
fetchRemainingChatMessages(chatHistoryUrl, headers, chatBody.dataset.conversation_id, hostURL);
|
||||
observer.unobserve(entry.target);
|
||||
}
|
||||
});
|
||||
@@ -342,7 +342,11 @@
|
||||
new Date(chat_log.created),
|
||||
chat_log.onlineContext,
|
||||
chat_log.intent?.type,
|
||||
chat_log.intent?.["inferred-queries"]);
|
||||
chat_log.intent?.["inferred-queries"],
|
||||
chatBody.dataset.conversationId ?? "",
|
||||
hostURL,
|
||||
);
|
||||
|
||||
chatBody.appendChild(messageElement);
|
||||
|
||||
// When the 4th oldest message is within viewing distance (~60% scrolled up)
|
||||
@@ -421,7 +425,7 @@
|
||||
}
|
||||
}
|
||||
|
||||
function fetchRemainingChatMessages(chatHistoryUrl, headers) {
|
||||
function fetchRemainingChatMessages(chatHistoryUrl, headers, conversationId, hostURL) {
|
||||
// Create a new IntersectionObserver
|
||||
let observer = new IntersectionObserver((entries, observer) => {
|
||||
entries.forEach(entry => {
|
||||
@@ -435,7 +439,9 @@
|
||||
new Date(chat_log.created),
|
||||
chat_log.onlineContext,
|
||||
chat_log.intent?.type,
|
||||
chat_log.intent?.["inferred-queries"]
|
||||
chat_log.intent?.["inferred-queries"],
|
||||
chatBody.dataset.conversationId ?? "",
|
||||
hostURL,
|
||||
);
|
||||
entry.target.replaceWith(messageElement);
|
||||
|
||||
|
||||
@@ -189,11 +189,19 @@ function processOnlineReferences(referenceSection, onlineContext) { //same
|
||||
return numOnlineReferences;
|
||||
}
|
||||
|
||||
function renderMessageWithReference(message, by, context=null, dt=null, onlineContext=null, intentType=null, inferredQueries=null) { //same
|
||||
function renderMessageWithReference(message, by, context=null, dt=null, onlineContext=null, intentType=null, inferredQueries=null, conversationId=null, hostURL=null) {
|
||||
let chatEl;
|
||||
if (intentType?.includes("text-to-image")) {
|
||||
let imageMarkdown = generateImageMarkdown(message, intentType, inferredQueries);
|
||||
chatEl = renderMessage(imageMarkdown, by, dt, null, false, "return");
|
||||
} else if (intentType === "excalidraw") {
|
||||
let domain = hostURL ?? "https://app.khoj.dev/";
|
||||
|
||||
if (!domain.endsWith("/")) domain += "/";
|
||||
|
||||
let excalidrawMessage = `Hey, I'm not ready to show you diagrams yet here. But you can view it in the web app at ${domain}chat?conversationId=${conversationId}`;
|
||||
|
||||
chatEl = renderMessage(excalidrawMessage, by, dt, null, false, "return");
|
||||
} else {
|
||||
chatEl = renderMessage(message, by, dt, null, false, "return");
|
||||
}
|
||||
@@ -312,7 +320,6 @@ function formatHTMLMessage(message, raw=false, willReplace=true) { //same
|
||||
}
|
||||
|
||||
function createReferenceSection(references, createLinkerSection=false) {
|
||||
console.log("linker data: ", createLinkerSection);
|
||||
let referenceSection = document.createElement('div');
|
||||
referenceSection.classList.add("reference-section");
|
||||
referenceSection.classList.add("collapsed");
|
||||
@@ -417,7 +424,11 @@ function handleImageResponse(imageJson, rawResponse) {
|
||||
rawResponse += ``;
|
||||
} else if (imageJson.intentType === "text-to-image-v3") {
|
||||
rawResponse = ``;
|
||||
} else if (imageJson.intentType === "excalidraw") {
|
||||
const redirectMessage = `Hey, I'm not ready to show you diagrams yet here. But you can view it in the web app`;
|
||||
rawResponse += redirectMessage;
|
||||
}
|
||||
|
||||
if (inferredQuery) {
|
||||
rawResponse += `\n\n**Inferred Query**:\n\n${inferredQuery}`;
|
||||
}
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
{
|
||||
"name": "Khoj",
|
||||
"version": "1.26.2",
|
||||
"version": "1.29.1",
|
||||
"description": "Your Second Brain",
|
||||
"author": "Khoj Inc. <team@khoj.dev>",
|
||||
"license": "GPL-3.0-or-later",
|
||||
|
||||
@@ -6,7 +6,7 @@
|
||||
;; Saba Imran <saba@khoj.dev>
|
||||
;; Description: Your Second Brain
|
||||
;; Keywords: search, chat, ai, org-mode, outlines, markdown, pdf, image
|
||||
;; Version: 1.26.2
|
||||
;; Version: 1.29.1
|
||||
;; Package-Requires: ((emacs "27.1") (transient "0.3.0") (dash "2.19.1"))
|
||||
;; URL: https://github.com/khoj-ai/khoj/tree/master/src/interface/emacs
|
||||
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
{
|
||||
"id": "khoj",
|
||||
"name": "Khoj",
|
||||
"version": "1.26.2",
|
||||
"version": "1.29.1",
|
||||
"minAppVersion": "0.15.0",
|
||||
"description": "Your Second Brain",
|
||||
"author": "Khoj Inc.",
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
{
|
||||
"name": "Khoj",
|
||||
"version": "1.26.2",
|
||||
"version": "1.29.1",
|
||||
"description": "Your Second Brain",
|
||||
"author": "Debanjum Singh Solanky, Saba Imran <team@khoj.dev>",
|
||||
"license": "GPL-3.0-or-later",
|
||||
|
||||
@@ -484,12 +484,13 @@ export class KhojChatView extends KhojPaneView {
|
||||
dt?: Date,
|
||||
intentType?: string,
|
||||
inferredQueries?: string[],
|
||||
conversationId?: string,
|
||||
) {
|
||||
if (!message) return;
|
||||
|
||||
let chatMessageEl;
|
||||
if (intentType?.includes("text-to-image")) {
|
||||
let imageMarkdown = this.generateImageMarkdown(message, intentType, inferredQueries);
|
||||
if (intentType?.includes("text-to-image") || intentType === "excalidraw") {
|
||||
let imageMarkdown = this.generateImageMarkdown(message, intentType, inferredQueries, conversationId);
|
||||
chatMessageEl = this.renderMessage(chatEl, imageMarkdown, sender, dt);
|
||||
} else {
|
||||
chatMessageEl = this.renderMessage(chatEl, message, sender, dt);
|
||||
@@ -509,7 +510,7 @@ export class KhojChatView extends KhojPaneView {
|
||||
chatMessageBodyEl.appendChild(this.createReferenceSection(references));
|
||||
}
|
||||
|
||||
generateImageMarkdown(message: string, intentType: string, inferredQueries?: string[]) {
|
||||
generateImageMarkdown(message: string, intentType: string, inferredQueries?: string[], conversationId?: string): string {
|
||||
let imageMarkdown = "";
|
||||
if (intentType === "text-to-image") {
|
||||
imageMarkdown = ``;
|
||||
@@ -517,6 +518,10 @@ export class KhojChatView extends KhojPaneView {
|
||||
imageMarkdown = ``;
|
||||
} else if (intentType === "text-to-image-v3") {
|
||||
imageMarkdown = ``;
|
||||
} else if (intentType === "excalidraw") {
|
||||
const domain = this.setting.khojUrl.endsWith("/") ? this.setting.khojUrl : `${this.setting.khojUrl}/`;
|
||||
const redirectMessage = `Hey, I'm not ready to show you diagrams yet here. But you can view it in ${domain}chat?conversationId=${conversationId}`;
|
||||
imageMarkdown = redirectMessage;
|
||||
}
|
||||
if (inferredQueries) {
|
||||
imageMarkdown += "\n\n**Inferred Query**:";
|
||||
@@ -884,6 +889,7 @@ export class KhojChatView extends KhojPaneView {
|
||||
new Date(chatLog.created),
|
||||
chatLog.intent?.type,
|
||||
chatLog.intent?.["inferred-queries"],
|
||||
chatBodyEl.dataset.conversationId ?? "",
|
||||
);
|
||||
// push the user messages to the chat history
|
||||
if(chatLog.by === "you"){
|
||||
@@ -1354,6 +1360,10 @@ export class KhojChatView extends KhojPaneView {
|
||||
rawResponse += ``;
|
||||
} else if (imageJson.intentType === "text-to-image-v3") {
|
||||
rawResponse = ``;
|
||||
} else if (imageJson.intentType === "excalidraw") {
|
||||
const domain = this.setting.khojUrl.endsWith("/") ? this.setting.khojUrl : `${this.setting.khojUrl}/`;
|
||||
const redirectMessage = `Hey, I'm not ready to show you diagrams yet here. But you can view it in ${domain}`;
|
||||
rawResponse += redirectMessage;
|
||||
}
|
||||
if (inferredQuery) {
|
||||
rawResponse += `\n\n**Inferred Query**:\n\n${inferredQuery}`;
|
||||
|
||||
@@ -78,6 +78,7 @@ If your plugin does not need CSS, delete this file.
|
||||
user-select: text;
|
||||
color: var(--text-normal);
|
||||
background-color: var(--active-bg);
|
||||
word-break: break-word;
|
||||
}
|
||||
/* color chat bubble by khoj blue */
|
||||
.khoj-chat-message-text.khoj {
|
||||
|
||||
@@ -80,5 +80,15 @@
|
||||
"1.25.0": "0.15.0",
|
||||
"1.26.0": "0.15.0",
|
||||
"1.26.1": "0.15.0",
|
||||
"1.26.2": "0.15.0"
|
||||
"1.26.2": "0.15.0",
|
||||
"1.26.3": "0.15.0",
|
||||
"1.26.4": "0.15.0",
|
||||
"1.27.0": "0.15.0",
|
||||
"1.27.1": "0.15.0",
|
||||
"1.28.0": "0.15.0",
|
||||
"1.28.1": "0.15.0",
|
||||
"1.28.2": "0.15.0",
|
||||
"1.28.3": "0.15.0",
|
||||
"1.29.0": "0.15.0",
|
||||
"1.29.1": "0.15.0"
|
||||
}
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -79,7 +79,7 @@ div.titleBar {
|
||||
div.chatBoxBody {
|
||||
display: grid;
|
||||
height: 100%;
|
||||
width: 70%;
|
||||
width: 95%;
|
||||
margin: auto;
|
||||
}
|
||||
|
||||
|
||||
@@ -47,7 +47,14 @@ export default function RootLayout({
|
||||
child-src 'none';
|
||||
object-src 'none';"
|
||||
></meta>
|
||||
<body className={inter.className}>{children}</body>
|
||||
<body className={inter.className}>
|
||||
{children}
|
||||
<script
|
||||
dangerouslySetInnerHTML={{
|
||||
__html: `window.EXCALIDRAW_ASSET_PATH = 'https://assets.khoj.dev/@excalidraw/excalidraw/dist/';`,
|
||||
}}
|
||||
/>
|
||||
</body>
|
||||
</html>
|
||||
);
|
||||
}
|
||||
|
||||
@@ -1,24 +1,31 @@
|
||||
"use client";
|
||||
|
||||
import styles from "./chat.module.css";
|
||||
import React, { Suspense, useEffect, useState } from "react";
|
||||
import React, { Suspense, useEffect, useRef, useState } from "react";
|
||||
|
||||
import SidePanel, { ChatSessionActionMenu } from "../components/sidePanel/chatHistorySidePanel";
|
||||
import ChatHistory from "../components/chatHistory/chatHistory";
|
||||
import { useSearchParams } from "next/navigation";
|
||||
import Loading from "../components/loading/loading";
|
||||
|
||||
import { processMessageChunk } from "../common/chatFunctions";
|
||||
import { generateNewTitle, processMessageChunk } from "../common/chatFunctions";
|
||||
|
||||
import "katex/dist/katex.min.css";
|
||||
|
||||
import { Context, OnlineContext, StreamMessage } from "../components/chatMessage/chatMessage";
|
||||
import {
|
||||
CodeContext,
|
||||
Context,
|
||||
OnlineContext,
|
||||
StreamMessage,
|
||||
} from "../components/chatMessage/chatMessage";
|
||||
import { useIPLocationData, useIsMobileWidth, welcomeConsole } from "../common/utils";
|
||||
import ChatInputArea, { ChatOptions } from "../components/chatInputArea/chatInputArea";
|
||||
import {
|
||||
AttachedFileText,
|
||||
ChatInputArea,
|
||||
ChatOptions,
|
||||
} from "../components/chatInputArea/chatInputArea";
|
||||
import { useAuthenticatedData } from "../common/auth";
|
||||
import { AgentData } from "../agents/page";
|
||||
import { DotsThreeVertical } from "@phosphor-icons/react";
|
||||
import { Button } from "@/components/ui/button";
|
||||
|
||||
interface ChatBodyDataProps {
|
||||
chatOptionsData: ChatOptions | null;
|
||||
@@ -26,43 +33,72 @@ interface ChatBodyDataProps {
|
||||
onConversationIdChange?: (conversationId: string) => void;
|
||||
setQueryToProcess: (query: string) => void;
|
||||
streamedMessages: StreamMessage[];
|
||||
setUploadedFiles: (files: string[]) => void;
|
||||
setStreamedMessages: (messages: StreamMessage[]) => void;
|
||||
setUploadedFiles: (files: AttachedFileText[] | undefined) => void;
|
||||
isMobileWidth?: boolean;
|
||||
isLoggedIn: boolean;
|
||||
setImage64: (image64: string) => void;
|
||||
setImages: (images: string[]) => void;
|
||||
}
|
||||
|
||||
function ChatBodyData(props: ChatBodyDataProps) {
|
||||
const searchParams = useSearchParams();
|
||||
const conversationId = searchParams.get("conversationId");
|
||||
const [message, setMessage] = useState("");
|
||||
const [image, setImage] = useState<string | null>(null);
|
||||
const [images, setImages] = useState<string[]>([]);
|
||||
const [processingMessage, setProcessingMessage] = useState(false);
|
||||
const [agentMetadata, setAgentMetadata] = useState<AgentData | null>(null);
|
||||
const [isInResearchMode, setIsInResearchMode] = useState(false);
|
||||
const chatInputRef = useRef<HTMLTextAreaElement>(null);
|
||||
|
||||
const setQueryToProcess = props.setQueryToProcess;
|
||||
const onConversationIdChange = props.onConversationIdChange;
|
||||
|
||||
useEffect(() => {
|
||||
if (image) {
|
||||
props.setImage64(encodeURIComponent(image));
|
||||
}
|
||||
}, [image, props.setImage64]);
|
||||
const chatHistoryCustomClassName = props.isMobileWidth ? "w-full" : "w-4/6";
|
||||
|
||||
useEffect(() => {
|
||||
const storedImage = localStorage.getItem("image");
|
||||
if (storedImage) {
|
||||
setImage(storedImage);
|
||||
props.setImage64(encodeURIComponent(storedImage));
|
||||
localStorage.removeItem("image");
|
||||
if (images.length > 0) {
|
||||
const encodedImages = images.map((image) => encodeURIComponent(image));
|
||||
props.setImages(encodedImages);
|
||||
}
|
||||
}, [images, props.setImages]);
|
||||
|
||||
useEffect(() => {
|
||||
const storedImages = localStorage.getItem("images");
|
||||
if (storedImages) {
|
||||
const parsedImages: string[] = JSON.parse(storedImages);
|
||||
setImages(parsedImages);
|
||||
const encodedImages = parsedImages.map((img: string) => encodeURIComponent(img));
|
||||
props.setImages(encodedImages);
|
||||
localStorage.removeItem("images");
|
||||
}
|
||||
|
||||
const storedMessage = localStorage.getItem("message");
|
||||
if (storedMessage) {
|
||||
setProcessingMessage(true);
|
||||
setQueryToProcess(storedMessage);
|
||||
|
||||
if (storedMessage.trim().startsWith("/research")) {
|
||||
setIsInResearchMode(true);
|
||||
}
|
||||
}
|
||||
}, [setQueryToProcess]);
|
||||
|
||||
const storedUploadedFiles = localStorage.getItem("uploadedFiles");
|
||||
|
||||
if (storedUploadedFiles) {
|
||||
const parsedFiles = storedUploadedFiles ? JSON.parse(storedUploadedFiles) : [];
|
||||
const uploadedFiles: AttachedFileText[] = [];
|
||||
for (const file of parsedFiles) {
|
||||
uploadedFiles.push({
|
||||
name: file.name,
|
||||
file_type: file.file_type,
|
||||
content: file.content,
|
||||
size: file.size,
|
||||
});
|
||||
}
|
||||
localStorage.removeItem("uploadedFiles");
|
||||
props.setUploadedFiles(uploadedFiles);
|
||||
}
|
||||
}, [setQueryToProcess, props.setImages, conversationId]);
|
||||
|
||||
useEffect(() => {
|
||||
if (message) {
|
||||
@@ -84,6 +120,8 @@ function ChatBodyData(props: ChatBodyDataProps) {
|
||||
props.streamedMessages[props.streamedMessages.length - 1].completed
|
||||
) {
|
||||
setProcessingMessage(false);
|
||||
setImages([]); // Reset images after processing
|
||||
props.setUploadedFiles(undefined); // Reset uploaded files after processing
|
||||
} else {
|
||||
setMessage("");
|
||||
}
|
||||
@@ -103,21 +141,25 @@ function ChatBodyData(props: ChatBodyDataProps) {
|
||||
setAgent={setAgentMetadata}
|
||||
pendingMessage={processingMessage ? message : ""}
|
||||
incomingMessages={props.streamedMessages}
|
||||
setIncomingMessages={props.setStreamedMessages}
|
||||
customClassName={chatHistoryCustomClassName}
|
||||
/>
|
||||
</div>
|
||||
<div
|
||||
className={`${styles.inputBox} p-1 md:px-2 shadow-md bg-background align-middle items-center justify-center dark:bg-neutral-700 dark:border-0 dark:shadow-sm rounded-t-2xl rounded-b-none md:rounded-xl h-fit`}
|
||||
className={`${styles.inputBox} p-1 md:px-2 shadow-md bg-background align-middle items-center justify-center dark:bg-neutral-700 dark:border-0 dark:shadow-sm rounded-t-2xl rounded-b-none md:rounded-xl h-fit ${chatHistoryCustomClassName} mr-auto ml-auto`}
|
||||
>
|
||||
<ChatInputArea
|
||||
agentColor={agentMetadata?.color}
|
||||
isLoggedIn={props.isLoggedIn}
|
||||
sendMessage={(message) => setMessage(message)}
|
||||
sendImage={(image) => setImage(image)}
|
||||
sendImage={(image) => setImages((prevImages) => [...prevImages, image])}
|
||||
sendDisabled={processingMessage}
|
||||
chatOptionsData={props.chatOptionsData}
|
||||
conversationId={conversationId}
|
||||
isMobileWidth={props.isMobileWidth}
|
||||
setUploadedFiles={props.setUploadedFiles}
|
||||
ref={chatInputRef}
|
||||
isResearchModeEnabled={isInResearchMode}
|
||||
/>
|
||||
</div>
|
||||
</>
|
||||
@@ -133,8 +175,8 @@ export default function Chat() {
|
||||
const [messages, setMessages] = useState<StreamMessage[]>([]);
|
||||
const [queryToProcess, setQueryToProcess] = useState<string>("");
|
||||
const [processQuerySignal, setProcessQuerySignal] = useState(false);
|
||||
const [uploadedFiles, setUploadedFiles] = useState<string[]>([]);
|
||||
const [image64, setImage64] = useState<string>("");
|
||||
const [uploadedFiles, setUploadedFiles] = useState<AttachedFileText[] | undefined>(undefined);
|
||||
const [images, setImages] = useState<string[]>([]);
|
||||
|
||||
const locationData = useIPLocationData() || {
|
||||
timezone: Intl.DateTimeFormat().resolvedOptions().timeZone,
|
||||
@@ -167,10 +209,12 @@ export default function Chat() {
|
||||
trainOfThought: [],
|
||||
context: [],
|
||||
onlineContext: {},
|
||||
codeContext: {},
|
||||
completed: false,
|
||||
timestamp: new Date().toISOString(),
|
||||
rawQuery: queryToProcess || "",
|
||||
uploadedImageData: decodeURIComponent(image64),
|
||||
images: images,
|
||||
queryFiles: uploadedFiles,
|
||||
};
|
||||
setMessages((prevMessages) => [...prevMessages, newStreamMessage]);
|
||||
setProcessQuerySignal(true);
|
||||
@@ -195,13 +239,17 @@ export default function Chat() {
|
||||
// Track context used for chat response
|
||||
let context: Context[] = [];
|
||||
let onlineContext: OnlineContext = {};
|
||||
let codeContext: CodeContext = {};
|
||||
|
||||
while (true) {
|
||||
const { done, value } = await reader.read();
|
||||
if (done) {
|
||||
setQueryToProcess("");
|
||||
setProcessQuerySignal(false);
|
||||
setImage64("");
|
||||
setImages([]);
|
||||
|
||||
if (conversationId) generateNewTitle(conversationId, setTitle);
|
||||
|
||||
break;
|
||||
}
|
||||
|
||||
@@ -221,11 +269,12 @@ export default function Chat() {
|
||||
}
|
||||
|
||||
// Track context used for chat response. References are rendered at the end of the chat
|
||||
({ context, onlineContext } = processMessageChunk(
|
||||
({ context, onlineContext, codeContext } = processMessageChunk(
|
||||
event,
|
||||
currentMessage,
|
||||
context,
|
||||
onlineContext,
|
||||
codeContext,
|
||||
));
|
||||
|
||||
setMessages([...messages]);
|
||||
@@ -249,7 +298,8 @@ export default function Chat() {
|
||||
country_code: locationData.countryCode,
|
||||
timezone: locationData.timezone,
|
||||
}),
|
||||
...(image64 && { image: image64 }),
|
||||
...(images.length > 0 && { images: images }),
|
||||
...(uploadedFiles && { files: uploadedFiles }),
|
||||
};
|
||||
|
||||
const response = await fetch(chatAPI, {
|
||||
@@ -263,7 +313,8 @@ export default function Chat() {
|
||||
try {
|
||||
await readChatStream(response);
|
||||
} catch (err) {
|
||||
console.error(err);
|
||||
const apiError = await response.json();
|
||||
console.error(apiError);
|
||||
// Retrieve latest message being processed
|
||||
const currentMessage = messages.find((message) => !message.completed);
|
||||
if (!currentMessage) return;
|
||||
@@ -272,7 +323,11 @@ export default function Chat() {
|
||||
const errorMessage = (err as Error).message;
|
||||
if (errorMessage.includes("Error in input stream"))
|
||||
currentMessage.rawResponse = `Woops! The connection broke while I was writing my thoughts down. Maybe try again in a bit or dislike this message if the issue persists?`;
|
||||
else
|
||||
else if (response.status === 429) {
|
||||
"detail" in apiError
|
||||
? (currentMessage.rawResponse = `${apiError.detail}`)
|
||||
: (currentMessage.rawResponse = `I'm a bit overwhelmed at the moment. Could you try again in a bit or dislike this message if the issue persists?`);
|
||||
} else
|
||||
currentMessage.rawResponse = `Umm, not sure what just happened. I see this error message: ${errorMessage}. Could you try again or dislike this message if the issue persists?`;
|
||||
|
||||
// Complete message streaming teardown properly
|
||||
@@ -297,7 +352,7 @@ export default function Chat() {
|
||||
<div>
|
||||
<SidePanel
|
||||
conversationId={conversationId}
|
||||
uploadedFiles={uploadedFiles}
|
||||
uploadedFiles={[]}
|
||||
isMobileWidth={isMobileWidth}
|
||||
/>
|
||||
</div>
|
||||
@@ -325,13 +380,14 @@ export default function Chat() {
|
||||
<ChatBodyData
|
||||
isLoggedIn={authenticatedData !== null}
|
||||
streamedMessages={messages}
|
||||
setStreamedMessages={setMessages}
|
||||
chatOptionsData={chatOptionsData}
|
||||
setTitle={setTitle}
|
||||
setQueryToProcess={setQueryToProcess}
|
||||
setUploadedFiles={setUploadedFiles}
|
||||
isMobileWidth={isMobileWidth}
|
||||
onConversationIdChange={handleConversationIdChange}
|
||||
setImage64={setImage64}
|
||||
setImages={setImages}
|
||||
/>
|
||||
</Suspense>
|
||||
</div>
|
||||
|
||||
@@ -68,7 +68,8 @@ export interface UserConfig {
|
||||
selected_voice_model_config: number;
|
||||
// user billing info
|
||||
subscription_state: SubscriptionStates;
|
||||
subscription_renewal_date: string;
|
||||
subscription_renewal_date: string | undefined;
|
||||
subscription_enabled_trial_at: string | undefined;
|
||||
// server settings
|
||||
khoj_cloud_subscription_url: string | undefined;
|
||||
billing_enabled: boolean;
|
||||
@@ -78,6 +79,7 @@ export interface UserConfig {
|
||||
anonymous_mode: boolean;
|
||||
notion_oauth_url: string;
|
||||
detail: string;
|
||||
length_of_free_trial: number;
|
||||
}
|
||||
|
||||
export function useUserConfig(detailed: boolean = false) {
|
||||
@@ -93,3 +95,15 @@ export function useUserConfig(detailed: boolean = false) {
|
||||
|
||||
return { userConfig, isLoadingUserConfig };
|
||||
}
|
||||
|
||||
export function isUserSubscribed(userConfig: UserConfig | null): boolean {
|
||||
return (
|
||||
(userConfig?.subscription_state &&
|
||||
[
|
||||
SubscriptionStates.SUBSCRIBED.valueOf(),
|
||||
SubscriptionStates.TRIAL.valueOf(),
|
||||
SubscriptionStates.UNSUBSCRIBED.valueOf(),
|
||||
].includes(userConfig.subscription_state)) ||
|
||||
false
|
||||
);
|
||||
}
|
||||
|
||||
@@ -1,14 +1,25 @@
|
||||
import { Context, OnlineContext, StreamMessage } from "../components/chatMessage/chatMessage";
|
||||
import {
|
||||
CodeContext,
|
||||
Context,
|
||||
OnlineContext,
|
||||
StreamMessage,
|
||||
} from "../components/chatMessage/chatMessage";
|
||||
|
||||
export interface RawReferenceData {
|
||||
context?: Context[];
|
||||
onlineContext?: OnlineContext;
|
||||
codeContext?: CodeContext;
|
||||
}
|
||||
|
||||
export interface ResponseWithReferences {
|
||||
context?: Context[];
|
||||
online?: OnlineContext;
|
||||
response?: string;
|
||||
export interface MessageMetadata {
|
||||
conversationId: string;
|
||||
turnId: string;
|
||||
}
|
||||
|
||||
export interface ResponseWithIntent {
|
||||
intentType: string;
|
||||
response: string;
|
||||
inferredQueries?: string[];
|
||||
}
|
||||
|
||||
interface MessageChunk {
|
||||
@@ -49,10 +60,14 @@ export function convertMessageChunkToJson(chunk: string): MessageChunk {
|
||||
function handleJsonResponse(chunkData: any) {
|
||||
const jsonData = chunkData as any;
|
||||
if (jsonData.image || jsonData.detail) {
|
||||
let responseWithReference = handleImageResponse(chunkData, true);
|
||||
if (responseWithReference.response) return responseWithReference.response;
|
||||
let responseWithIntent = handleImageResponse(chunkData, true);
|
||||
return responseWithIntent;
|
||||
} else if (jsonData.response) {
|
||||
return jsonData.response;
|
||||
return {
|
||||
response: jsonData.response,
|
||||
intentType: "",
|
||||
inferredQueries: [],
|
||||
};
|
||||
} else {
|
||||
throw new Error("Invalid JSON response");
|
||||
}
|
||||
@@ -63,10 +78,11 @@ export function processMessageChunk(
|
||||
currentMessage: StreamMessage,
|
||||
context: Context[] = [],
|
||||
onlineContext: OnlineContext = {},
|
||||
): { context: Context[]; onlineContext: OnlineContext } {
|
||||
codeContext: CodeContext = {},
|
||||
): { context: Context[]; onlineContext: OnlineContext; codeContext: CodeContext } {
|
||||
const chunk = convertMessageChunkToJson(rawChunk);
|
||||
|
||||
if (!currentMessage || !chunk || !chunk.type) return { context, onlineContext };
|
||||
if (!currentMessage || !chunk || !chunk.type) return { context, onlineContext, codeContext };
|
||||
|
||||
if (chunk.type === "status") {
|
||||
console.log(`status: ${chunk.data}`);
|
||||
@@ -77,11 +93,25 @@ export function processMessageChunk(
|
||||
|
||||
if (references.context) context = references.context;
|
||||
if (references.onlineContext) onlineContext = references.onlineContext;
|
||||
return { context, onlineContext };
|
||||
if (references.codeContext) codeContext = references.codeContext;
|
||||
return { context, onlineContext, codeContext };
|
||||
} else if (chunk.type === "metadata") {
|
||||
const messageMetadata = chunk.data as MessageMetadata;
|
||||
currentMessage.turnId = messageMetadata.turnId;
|
||||
} else if (chunk.type === "message") {
|
||||
const chunkData = chunk.data;
|
||||
// Here, handle if the response is a JSON response with an image, but the intentType is excalidraw
|
||||
if (chunkData !== null && typeof chunkData === "object") {
|
||||
currentMessage.rawResponse += handleJsonResponse(chunkData);
|
||||
let responseWithIntent = handleJsonResponse(chunkData);
|
||||
|
||||
if (responseWithIntent.intentType && responseWithIntent.intentType === "excalidraw") {
|
||||
currentMessage.rawResponse = responseWithIntent.response;
|
||||
} else {
|
||||
currentMessage.rawResponse += responseWithIntent.response;
|
||||
}
|
||||
|
||||
currentMessage.intentType = responseWithIntent.intentType;
|
||||
currentMessage.inferredQueries = responseWithIntent.inferredQueries;
|
||||
} else if (
|
||||
typeof chunkData === "string" &&
|
||||
chunkData.trim()?.startsWith("{") &&
|
||||
@@ -89,7 +119,10 @@ export function processMessageChunk(
|
||||
) {
|
||||
try {
|
||||
const jsonData = JSON.parse(chunkData.trim());
|
||||
currentMessage.rawResponse += handleJsonResponse(jsonData);
|
||||
let responseWithIntent = handleJsonResponse(jsonData);
|
||||
currentMessage.rawResponse += responseWithIntent.response;
|
||||
currentMessage.intentType = responseWithIntent.intentType;
|
||||
currentMessage.inferredQueries = responseWithIntent.inferredQueries;
|
||||
} catch (e) {
|
||||
currentMessage.rawResponse += JSON.stringify(chunkData);
|
||||
}
|
||||
@@ -102,51 +135,56 @@ export function processMessageChunk(
|
||||
console.log(`Completed streaming: ${new Date()}`);
|
||||
|
||||
// Append any references after all the data has been streamed
|
||||
if (codeContext) currentMessage.codeContext = codeContext;
|
||||
if (onlineContext) currentMessage.onlineContext = onlineContext;
|
||||
if (context) currentMessage.context = context;
|
||||
|
||||
// Mark current message streaming as completed
|
||||
currentMessage.completed = true;
|
||||
}
|
||||
return { context, onlineContext };
|
||||
return { context, onlineContext, codeContext };
|
||||
}
|
||||
|
||||
export function handleImageResponse(imageJson: any, liveStream: boolean): ResponseWithReferences {
|
||||
export function handleImageResponse(imageJson: any, liveStream: boolean): ResponseWithIntent {
|
||||
let rawResponse = "";
|
||||
|
||||
if (imageJson.image) {
|
||||
const inferredQuery = imageJson.inferredQueries?.[0] ?? "generated image";
|
||||
|
||||
// If response has image field, response is a generated image.
|
||||
if (imageJson.intentType === "text-to-image") {
|
||||
rawResponse += ``;
|
||||
} else if (imageJson.intentType === "text-to-image2") {
|
||||
rawResponse += ``;
|
||||
} else if (imageJson.intentType === "text-to-image-v3") {
|
||||
rawResponse = ``;
|
||||
}
|
||||
if (inferredQuery && !liveStream) {
|
||||
rawResponse += `\n\n${inferredQuery}`;
|
||||
}
|
||||
// If response has image field, response may be a generated image
|
||||
rawResponse = imageJson.image;
|
||||
}
|
||||
|
||||
let reference: ResponseWithReferences = {};
|
||||
let responseWithIntent: ResponseWithIntent = {
|
||||
intentType: imageJson.intentType,
|
||||
response: rawResponse,
|
||||
inferredQueries: imageJson.inferredQueries,
|
||||
};
|
||||
|
||||
if (imageJson.context && imageJson.context.length > 0) {
|
||||
const rawReferenceAsJson = imageJson.context;
|
||||
if (rawReferenceAsJson instanceof Array) {
|
||||
reference.context = rawReferenceAsJson;
|
||||
} else if (typeof rawReferenceAsJson === "object" && rawReferenceAsJson !== null) {
|
||||
reference.online = rawReferenceAsJson;
|
||||
}
|
||||
}
|
||||
if (imageJson.detail) {
|
||||
// The detail field contains the improved image prompt
|
||||
rawResponse += imageJson.detail;
|
||||
}
|
||||
|
||||
reference.response = rawResponse;
|
||||
return reference;
|
||||
return responseWithIntent;
|
||||
}
|
||||
|
||||
export function renderCodeGenImageInline(message: string, codeContext: CodeContext) {
|
||||
if (!codeContext) return message;
|
||||
|
||||
Object.values(codeContext).forEach((contextData) => {
|
||||
contextData.results.output_files?.forEach((file) => {
|
||||
const regex = new RegExp(`!?\\[.*?\\]\\(.*${file.filename}\\)`, "g");
|
||||
if (file.filename.match(/\.(png|jpg|jpeg)$/i)) {
|
||||
const replacement = `.pop()};base64,${file.b64_data})`;
|
||||
message = message.replace(regex, replacement);
|
||||
} else if (file.filename.match(/\.(txt|org|md|csv|json)$/i)) {
|
||||
// render output files generated by codegen as downloadable links
|
||||
const replacement = ``;
|
||||
message = message.replace(regex, replacement);
|
||||
}
|
||||
});
|
||||
});
|
||||
|
||||
return message;
|
||||
}
|
||||
|
||||
export function modifyFileFilterForConversation(
|
||||
@@ -206,6 +244,78 @@ export async function createNewConversation(slug: string) {
|
||||
}
|
||||
}
|
||||
|
||||
export async function packageFilesForUpload(files: FileList): Promise<FormData> {
|
||||
const formData = new FormData();
|
||||
|
||||
const fileReadPromises = Array.from(files).map((file) => {
|
||||
return new Promise<void>((resolve, reject) => {
|
||||
let reader = new FileReader();
|
||||
reader.onload = function (event) {
|
||||
if (event.target === null) {
|
||||
reject();
|
||||
return;
|
||||
}
|
||||
|
||||
let fileContents = event.target.result;
|
||||
let fileType = file.type;
|
||||
let fileName = file.name;
|
||||
if (fileType === "") {
|
||||
let fileExtension = fileName.split(".").pop();
|
||||
if (fileExtension === "org") {
|
||||
fileType = "text/org";
|
||||
} else if (fileExtension === "md") {
|
||||
fileType = "text/markdown";
|
||||
} else if (fileExtension === "txt") {
|
||||
fileType = "text/plain";
|
||||
} else if (fileExtension === "html") {
|
||||
fileType = "text/html";
|
||||
} else if (fileExtension === "pdf") {
|
||||
fileType = "application/pdf";
|
||||
} else if (fileExtension === "docx") {
|
||||
fileType =
|
||||
"application/vnd.openxmlformats-officedocument.wordprocessingml.document";
|
||||
} else {
|
||||
// Skip this file if its type is not supported
|
||||
resolve();
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
if (fileContents === null) {
|
||||
reject();
|
||||
return;
|
||||
}
|
||||
|
||||
let fileObj = new Blob([fileContents], { type: fileType });
|
||||
formData.append("files", fileObj, file.name);
|
||||
resolve();
|
||||
};
|
||||
reader.onerror = reject;
|
||||
reader.readAsArrayBuffer(file);
|
||||
});
|
||||
});
|
||||
|
||||
await Promise.all(fileReadPromises);
|
||||
return formData;
|
||||
}
|
||||
|
||||
export function generateNewTitle(conversationId: string, setTitle: (title: string) => void) {
|
||||
fetch(`/api/chat/title?conversation_id=${conversationId}`, {
|
||||
method: "POST",
|
||||
})
|
||||
.then((res) => {
|
||||
if (!res.ok) throw new Error(`Failed to call API with error ${res.statusText}`);
|
||||
return res.json();
|
||||
})
|
||||
.then((data) => {
|
||||
setTitle(data.title);
|
||||
})
|
||||
.catch((err) => {
|
||||
console.error(err);
|
||||
return;
|
||||
});
|
||||
}
|
||||
|
||||
export function uploadDataForIndexing(
|
||||
files: FileList,
|
||||
setWarning: (warning: string) => void,
|
||||
|
||||
@@ -42,6 +42,20 @@ export function converColorToBgGradient(color: string) {
|
||||
return `${convertToBGGradientClass(color)} dark:border dark:border-neutral-700`;
|
||||
}
|
||||
|
||||
export function convertColorToCaretClass(color: string | undefined) {
|
||||
if (color && tailwindColors.includes(color)) {
|
||||
return `caret-${color}-500`;
|
||||
}
|
||||
return `caret-orange-500`;
|
||||
}
|
||||
|
||||
export function convertColorToRingClass(color: string | undefined) {
|
||||
if (color && tailwindColors.includes(color)) {
|
||||
return `focus-visible:ring-${color}-500`;
|
||||
}
|
||||
return `focus-visible:ring-orange-500`;
|
||||
}
|
||||
|
||||
export function convertColorToBorderClass(color: string) {
|
||||
if (tailwindColors.includes(color)) {
|
||||
return `border-${color}-500`;
|
||||
|
||||
@@ -40,7 +40,6 @@ import {
|
||||
Leaf,
|
||||
NewspaperClipping,
|
||||
OrangeSlice,
|
||||
Rainbow,
|
||||
SmileyMelting,
|
||||
YinYang,
|
||||
SneakerMove,
|
||||
@@ -48,8 +47,12 @@ import {
|
||||
Oven,
|
||||
Gavel,
|
||||
Broadcast,
|
||||
KeyReturn,
|
||||
FilePdf,
|
||||
FileMd,
|
||||
MicrosoftWordLogo,
|
||||
} from "@phosphor-icons/react";
|
||||
import { Markdown, OrgMode, Pdf, Word } from "@/app/components/logo/fileLogo";
|
||||
import { OrgMode } from "@/app/components/logo/fileLogo";
|
||||
|
||||
interface IconMap {
|
||||
[key: string]: (color: string, width: string, height: string) => JSX.Element | null;
|
||||
@@ -193,6 +196,10 @@ export function getIconForSlashCommand(command: string, customClassName: string
|
||||
}
|
||||
|
||||
if (command.includes("default")) {
|
||||
return <KeyReturn className={className} />;
|
||||
}
|
||||
|
||||
if (command.includes("diagram")) {
|
||||
return <Shapes className={className} />;
|
||||
}
|
||||
|
||||
@@ -208,6 +215,10 @@ export function getIconForSlashCommand(command: string, customClassName: string
|
||||
return <PencilLine className={className} />;
|
||||
}
|
||||
|
||||
if (command.includes("code")) {
|
||||
return <Code className={className} />;
|
||||
}
|
||||
|
||||
return <ArrowRight className={className} />;
|
||||
}
|
||||
|
||||
@@ -233,11 +244,19 @@ function getIconFromFilename(
|
||||
return <OrgMode className={className} />;
|
||||
case "markdown":
|
||||
case "md":
|
||||
return <Markdown className={className} />;
|
||||
return <FileMd className={className} />;
|
||||
case "pdf":
|
||||
return <Pdf className={className} />;
|
||||
return <FilePdf className={className} />;
|
||||
case "doc":
|
||||
return <Word className={className} />;
|
||||
case "docx":
|
||||
return <MicrosoftWordLogo className={className} />;
|
||||
case "csv":
|
||||
case "json":
|
||||
return <MathOperations className={className} />;
|
||||
case "txt":
|
||||
return <Notebook className={className} />;
|
||||
case "py":
|
||||
return <Code className={className} />;
|
||||
case "jpg":
|
||||
case "jpeg":
|
||||
case "png":
|
||||
|
||||
@@ -70,3 +70,29 @@ export function useIsMobileWidth() {
|
||||
|
||||
return isMobileWidth;
|
||||
}
|
||||
|
||||
export const convertBytesToText = (fileSize: number) => {
|
||||
if (fileSize < 1024) {
|
||||
return `${fileSize} B`;
|
||||
} else if (fileSize < 1024 * 1024) {
|
||||
return `${(fileSize / 1024).toFixed(2)} KB`;
|
||||
} else {
|
||||
return `${(fileSize / (1024 * 1024)).toFixed(2)} MB`;
|
||||
}
|
||||
};
|
||||
|
||||
export function useDebounce<T>(value: T, delay: number): T {
|
||||
const [debouncedValue, setDebouncedValue] = useState<T>(value);
|
||||
|
||||
useEffect(() => {
|
||||
const handler = setTimeout(() => {
|
||||
setDebouncedValue(value);
|
||||
}, delay);
|
||||
|
||||
return () => {
|
||||
clearTimeout(handler);
|
||||
};
|
||||
}, [value, delay]);
|
||||
|
||||
return debouncedValue;
|
||||
}
|
||||
|
||||
@@ -0,0 +1,20 @@
|
||||
.agentPersonality p {
|
||||
white-space: inherit;
|
||||
overflow: hidden;
|
||||
height: 77px;
|
||||
line-height: 1.5;
|
||||
}
|
||||
|
||||
div.agentPersonality {
|
||||
text-align: left;
|
||||
grid-column: span 3;
|
||||
overflow: hidden;
|
||||
}
|
||||
|
||||
button.infoButton {
|
||||
border: none;
|
||||
background-color: transparent !important;
|
||||
text-align: left;
|
||||
font-family: inherit;
|
||||
font-size: medium;
|
||||
}
|
||||
1297
src/interface/web/app/components/agentCard/agentCard.tsx
Normal file
1297
src/interface/web/app/components/agentCard/agentCard.tsx
Normal file
File diff suppressed because it is too large
Load Diff
@@ -2,12 +2,7 @@ div.chatHistory {
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
height: 100%;
|
||||
}
|
||||
|
||||
div.chatLayout {
|
||||
height: 80vh;
|
||||
overflow-y: auto;
|
||||
margin: 0 auto;
|
||||
margin: auto;
|
||||
}
|
||||
|
||||
div.agentIndicator a {
|
||||
|
||||
@@ -13,13 +13,14 @@ import { ScrollArea } from "@/components/ui/scroll-area";
|
||||
|
||||
import { InlineLoading } from "../loading/loading";
|
||||
|
||||
import { Lightbulb, ArrowDown } from "@phosphor-icons/react";
|
||||
import { Lightbulb, ArrowDown, XCircle } from "@phosphor-icons/react";
|
||||
|
||||
import AgentProfileCard from "../profileCard/profileCard";
|
||||
import { getIconFromIconName } from "@/app/common/iconUtils";
|
||||
import { AgentData } from "@/app/agents/page";
|
||||
import React from "react";
|
||||
import { useIsMobileWidth } from "@/app/common/utils";
|
||||
import { Button } from "@/components/ui/button";
|
||||
|
||||
interface ChatResponse {
|
||||
status: string;
|
||||
@@ -33,32 +34,62 @@ interface ChatHistory {
|
||||
interface ChatHistoryProps {
|
||||
conversationId: string;
|
||||
setTitle: (title: string) => void;
|
||||
incomingMessages?: StreamMessage[];
|
||||
pendingMessage?: string;
|
||||
incomingMessages?: StreamMessage[];
|
||||
setIncomingMessages?: (incomingMessages: StreamMessage[]) => void;
|
||||
publicConversationSlug?: string;
|
||||
setAgent: (agent: AgentData) => void;
|
||||
customClassName?: string;
|
||||
}
|
||||
|
||||
function constructTrainOfThought(
|
||||
trainOfThought: string[],
|
||||
lastMessage: boolean,
|
||||
agentColor: string,
|
||||
key: string,
|
||||
completed: boolean = false,
|
||||
) {
|
||||
const lastIndex = trainOfThought.length - 1;
|
||||
return (
|
||||
<div className={`${styles.trainOfThought} shadow-sm`} key={key}>
|
||||
{!completed && <InlineLoading className="float-right" />}
|
||||
interface TrainOfThoughtComponentProps {
|
||||
trainOfThought: string[];
|
||||
lastMessage: boolean;
|
||||
agentColor: string;
|
||||
keyId: string;
|
||||
completed?: boolean;
|
||||
}
|
||||
|
||||
{trainOfThought.map((train, index) => (
|
||||
<TrainOfThought
|
||||
key={`train-${index}`}
|
||||
message={train}
|
||||
primary={index === lastIndex && lastMessage && !completed}
|
||||
agentColor={agentColor}
|
||||
/>
|
||||
))}
|
||||
function TrainOfThoughtComponent(props: TrainOfThoughtComponentProps) {
|
||||
const lastIndex = props.trainOfThought.length - 1;
|
||||
const [collapsed, setCollapsed] = useState(props.completed);
|
||||
|
||||
return (
|
||||
<div
|
||||
className={`${!collapsed ? styles.trainOfThought + " shadow-sm" : ""}`}
|
||||
key={props.keyId}
|
||||
>
|
||||
{!props.completed && <InlineLoading className="float-right" />}
|
||||
{props.completed &&
|
||||
(collapsed ? (
|
||||
<Button
|
||||
className="w-fit text-left justify-start content-start text-xs"
|
||||
onClick={() => setCollapsed(false)}
|
||||
variant="ghost"
|
||||
size="sm"
|
||||
>
|
||||
What was my train of thought?
|
||||
</Button>
|
||||
) : (
|
||||
<Button
|
||||
className="w-fit text-left justify-start content-start text-xs p-0 h-fit"
|
||||
onClick={() => setCollapsed(true)}
|
||||
variant="ghost"
|
||||
size="sm"
|
||||
>
|
||||
<XCircle size={16} className="mr-1" />
|
||||
Close
|
||||
</Button>
|
||||
))}
|
||||
{!collapsed &&
|
||||
props.trainOfThought.map((train, index) => (
|
||||
<TrainOfThought
|
||||
key={`train-${index}`}
|
||||
message={train}
|
||||
primary={index === lastIndex && props.lastMessage && !props.completed}
|
||||
agentColor={props.agentColor}
|
||||
/>
|
||||
))}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
@@ -67,6 +98,7 @@ export default function ChatHistory(props: ChatHistoryProps) {
|
||||
const [data, setData] = useState<ChatHistoryData | null>(null);
|
||||
const [currentPage, setCurrentPage] = useState(0);
|
||||
const [hasMoreMessages, setHasMoreMessages] = useState(true);
|
||||
const [currentTurnId, setCurrentTurnId] = useState<string | null>(null);
|
||||
const sentinelRef = useRef<HTMLDivElement | null>(null);
|
||||
const scrollAreaRef = useRef<HTMLDivElement | null>(null);
|
||||
const latestUserMessageRef = useRef<HTMLDivElement | null>(null);
|
||||
@@ -147,6 +179,10 @@ export default function ChatHistory(props: ChatHistoryProps) {
|
||||
if (lastMessage && !lastMessage.completed) {
|
||||
setIncompleteIncomingMessageIndex(props.incomingMessages.length - 1);
|
||||
props.setTitle(lastMessage.rawQuery);
|
||||
// Store the turnId when we get it
|
||||
if (lastMessage.turnId) {
|
||||
setCurrentTurnId(lastMessage.turnId);
|
||||
}
|
||||
}
|
||||
}
|
||||
}, [props.incomingMessages]);
|
||||
@@ -248,44 +284,79 @@ export default function ChatHistory(props: ChatHistoryProps) {
|
||||
return data.agent?.persona;
|
||||
}
|
||||
|
||||
const handleDeleteMessage = (turnId?: string) => {
|
||||
if (!turnId) return;
|
||||
|
||||
setData((prevData) => {
|
||||
if (!prevData || !turnId) return prevData;
|
||||
return {
|
||||
...prevData,
|
||||
chat: prevData.chat.filter((msg) => msg.turnId !== turnId),
|
||||
};
|
||||
});
|
||||
|
||||
// Update incoming messages if they exist
|
||||
if (props.incomingMessages && props.setIncomingMessages) {
|
||||
props.setIncomingMessages(
|
||||
props.incomingMessages.filter((msg) => msg.turnId !== turnId),
|
||||
);
|
||||
}
|
||||
};
|
||||
|
||||
if (!props.conversationId && !props.publicConversationSlug) {
|
||||
return null;
|
||||
}
|
||||
|
||||
return (
|
||||
<ScrollArea className={`h-[80vh] relative`} ref={scrollAreaRef}>
|
||||
<ScrollArea className={`h-[73vh] relative`} ref={scrollAreaRef}>
|
||||
<div>
|
||||
<div className={styles.chatHistory}>
|
||||
<div className={`${styles.chatHistory} ${props.customClassName}`}>
|
||||
<div ref={sentinelRef} style={{ height: "1px" }}>
|
||||
{fetchingData && (
|
||||
<InlineLoading message="Loading Conversation" className="opacity-50" />
|
||||
)}
|
||||
{fetchingData && <InlineLoading className="opacity-50" />}
|
||||
</div>
|
||||
{data &&
|
||||
data.chat &&
|
||||
data.chat.map((chatMessage, index) => (
|
||||
<ChatMessage
|
||||
key={`${index}fullHistory`}
|
||||
ref={
|
||||
// attach ref to the second last message to handle scroll on page load
|
||||
index === data.chat.length - 2
|
||||
? latestUserMessageRef
|
||||
: // attach ref to the newest fetched message to handle scroll on fetch
|
||||
// note: stabilize index selection against last page having less messages than fetchMessageCount
|
||||
index ===
|
||||
data.chat.length - (currentPage - 1) * fetchMessageCount
|
||||
? latestFetchedMessageRef
|
||||
: null
|
||||
}
|
||||
isMobileWidth={isMobileWidth}
|
||||
chatMessage={chatMessage}
|
||||
customClassName="fullHistory"
|
||||
borderLeftColor={`${data?.agent?.color}-500`}
|
||||
isLastMessage={index === data.chat.length - 1}
|
||||
/>
|
||||
<>
|
||||
{chatMessage.trainOfThought && chatMessage.by === "khoj" && (
|
||||
<TrainOfThoughtComponent
|
||||
trainOfThought={chatMessage.trainOfThought?.map(
|
||||
(train) => train.data,
|
||||
)}
|
||||
lastMessage={false}
|
||||
agentColor={data?.agent?.color || "orange"}
|
||||
key={`${index}trainOfThought`}
|
||||
keyId={`${index}trainOfThought`}
|
||||
completed={true}
|
||||
/>
|
||||
)}
|
||||
<ChatMessage
|
||||
key={`${index}fullHistory`}
|
||||
ref={
|
||||
// attach ref to the second last message to handle scroll on page load
|
||||
index === data.chat.length - 2
|
||||
? latestUserMessageRef
|
||||
: // attach ref to the newest fetched message to handle scroll on fetch
|
||||
// note: stabilize index selection against last page having less messages than fetchMessageCount
|
||||
index ===
|
||||
data.chat.length -
|
||||
(currentPage - 1) * fetchMessageCount
|
||||
? latestFetchedMessageRef
|
||||
: null
|
||||
}
|
||||
isMobileWidth={isMobileWidth}
|
||||
chatMessage={chatMessage}
|
||||
customClassName="fullHistory"
|
||||
borderLeftColor={`${data?.agent?.color}-500`}
|
||||
isLastMessage={index === data.chat.length - 1}
|
||||
onDeleteMessage={handleDeleteMessage}
|
||||
conversationId={props.conversationId}
|
||||
/>
|
||||
</>
|
||||
))}
|
||||
{props.incomingMessages &&
|
||||
props.incomingMessages.map((message, index) => {
|
||||
const messageTurnId = message.turnId ?? currentTurnId ?? undefined;
|
||||
return (
|
||||
<React.Fragment key={`incomingMessage${index}`}>
|
||||
<ChatMessage
|
||||
@@ -295,22 +366,31 @@ export default function ChatHistory(props: ChatHistoryProps) {
|
||||
message: message.rawQuery,
|
||||
context: [],
|
||||
onlineContext: {},
|
||||
codeContext: {},
|
||||
created: message.timestamp,
|
||||
by: "you",
|
||||
automationId: "",
|
||||
uploadedImageData: message.uploadedImageData,
|
||||
images: message.images,
|
||||
conversationId: props.conversationId,
|
||||
turnId: messageTurnId,
|
||||
queryFiles: message.queryFiles,
|
||||
}}
|
||||
customClassName="fullHistory"
|
||||
borderLeftColor={`${data?.agent?.color}-500`}
|
||||
onDeleteMessage={handleDeleteMessage}
|
||||
conversationId={props.conversationId}
|
||||
turnId={messageTurnId}
|
||||
/>
|
||||
{message.trainOfThought &&
|
||||
constructTrainOfThought(
|
||||
message.trainOfThought,
|
||||
index === incompleteIncomingMessageIndex,
|
||||
data?.agent?.color || "orange",
|
||||
`${index}trainOfThought`,
|
||||
message.completed,
|
||||
)}
|
||||
{message.trainOfThought && (
|
||||
<TrainOfThoughtComponent
|
||||
trainOfThought={message.trainOfThought}
|
||||
lastMessage={index === incompleteIncomingMessageIndex}
|
||||
agentColor={data?.agent?.color || "orange"}
|
||||
key={`${index}trainOfThought`}
|
||||
keyId={`${index}trainOfThought`}
|
||||
completed={message.completed}
|
||||
/>
|
||||
)}
|
||||
<ChatMessage
|
||||
key={`${index}incoming`}
|
||||
isMobileWidth={isMobileWidth}
|
||||
@@ -318,11 +398,23 @@ export default function ChatHistory(props: ChatHistoryProps) {
|
||||
message: message.rawResponse,
|
||||
context: message.context,
|
||||
onlineContext: message.onlineContext,
|
||||
codeContext: message.codeContext,
|
||||
created: message.timestamp,
|
||||
by: "khoj",
|
||||
automationId: "",
|
||||
rawQuery: message.rawQuery,
|
||||
intent: {
|
||||
type: message.intentType || "",
|
||||
query: message.rawQuery,
|
||||
"memory-type": "",
|
||||
"inferred-queries": message.inferredQueries || [],
|
||||
},
|
||||
conversationId: props.conversationId,
|
||||
turnId: messageTurnId,
|
||||
}}
|
||||
conversationId={props.conversationId}
|
||||
turnId={messageTurnId}
|
||||
onDeleteMessage={handleDeleteMessage}
|
||||
customClassName="fullHistory"
|
||||
borderLeftColor={`${data?.agent?.color}-500`}
|
||||
isLastMessage={true}
|
||||
@@ -338,11 +430,15 @@ export default function ChatHistory(props: ChatHistoryProps) {
|
||||
message: props.pendingMessage,
|
||||
context: [],
|
||||
onlineContext: {},
|
||||
codeContext: {},
|
||||
created: new Date().getTime().toString(),
|
||||
by: "you",
|
||||
automationId: "",
|
||||
uploadedImageData: props.pendingMessage,
|
||||
conversationId: props.conversationId,
|
||||
turnId: undefined,
|
||||
}}
|
||||
conversationId={props.conversationId}
|
||||
onDeleteMessage={handleDeleteMessage}
|
||||
customClassName="fullHistory"
|
||||
borderLeftColor={`${data?.agent?.color}-500`}
|
||||
isLastMessage={true}
|
||||
@@ -366,18 +462,20 @@ export default function ChatHistory(props: ChatHistoryProps) {
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
{!isNearBottom && (
|
||||
<button
|
||||
title="Scroll to bottom"
|
||||
className="absolute bottom-4 right-5 bg-white dark:bg-[hsl(var(--background))] text-neutral-500 dark:text-white p-2 rounded-full shadow-xl"
|
||||
onClick={() => {
|
||||
scrollToBottom();
|
||||
setIsNearBottom(true);
|
||||
}}
|
||||
>
|
||||
<ArrowDown size={24} />
|
||||
</button>
|
||||
)}
|
||||
<div className={`${props.customClassName} fixed bottom-[20%] z-10`}>
|
||||
{!isNearBottom && (
|
||||
<button
|
||||
title="Scroll to bottom"
|
||||
className="absolute bottom-0 right-0 bg-white dark:bg-[hsl(var(--background))] text-neutral-500 dark:text-white p-2 rounded-full shadow-xl"
|
||||
onClick={() => {
|
||||
scrollToBottom();
|
||||
setIsNearBottom(true);
|
||||
}}
|
||||
>
|
||||
<ArrowDown size={24} />
|
||||
</button>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
</ScrollArea>
|
||||
);
|
||||
|
||||
@@ -1,24 +1,16 @@
|
||||
import styles from "./chatInputArea.module.css";
|
||||
import React, { useEffect, useRef, useState } from "react";
|
||||
import React, { useEffect, useRef, useState, forwardRef } from "react";
|
||||
|
||||
import DOMPurify from "dompurify";
|
||||
import "katex/dist/katex.min.css";
|
||||
import {
|
||||
ArrowRight,
|
||||
ArrowUp,
|
||||
Browser,
|
||||
ChatsTeardrop,
|
||||
GlobeSimple,
|
||||
Gps,
|
||||
Image,
|
||||
Microphone,
|
||||
Notebook,
|
||||
Paperclip,
|
||||
X,
|
||||
Question,
|
||||
Robot,
|
||||
Shapes,
|
||||
Stop,
|
||||
ToggleLeft,
|
||||
ToggleRight,
|
||||
} from "@phosphor-icons/react";
|
||||
|
||||
import {
|
||||
@@ -45,30 +37,48 @@ import { Popover, PopoverContent } from "@/components/ui/popover";
|
||||
import { PopoverTrigger } from "@radix-ui/react-popover";
|
||||
import { Textarea } from "@/components/ui/textarea";
|
||||
import { Tooltip, TooltipContent, TooltipProvider, TooltipTrigger } from "@/components/ui/tooltip";
|
||||
import { convertToBGClass } from "@/app/common/colorUtils";
|
||||
import { convertColorToTextClass, convertToBGClass } from "@/app/common/colorUtils";
|
||||
|
||||
import LoginPrompt from "../loginPrompt/loginPrompt";
|
||||
import { uploadDataForIndexing } from "../../common/chatFunctions";
|
||||
import { InlineLoading } from "../loading/loading";
|
||||
import { getIconForSlashCommand } from "@/app/common/iconUtils";
|
||||
import { getIconForSlashCommand, getIconFromFilename } from "@/app/common/iconUtils";
|
||||
import { packageFilesForUpload } from "@/app/common/chatFunctions";
|
||||
import { convertBytesToText } from "@/app/common/utils";
|
||||
import {
|
||||
Dialog,
|
||||
DialogContent,
|
||||
DialogDescription,
|
||||
DialogHeader,
|
||||
DialogTitle,
|
||||
DialogTrigger,
|
||||
} from "@/components/ui/dialog";
|
||||
import { ScrollArea } from "@/components/ui/scroll-area";
|
||||
|
||||
export interface ChatOptions {
|
||||
[key: string]: string;
|
||||
}
|
||||
|
||||
export interface AttachedFileText {
|
||||
name: string;
|
||||
content: string;
|
||||
file_type: string;
|
||||
size: number;
|
||||
}
|
||||
|
||||
interface ChatInputProps {
|
||||
sendMessage: (message: string) => void;
|
||||
sendImage: (image: string) => void;
|
||||
sendDisabled: boolean;
|
||||
setUploadedFiles?: (files: string[]) => void;
|
||||
setUploadedFiles: (files: AttachedFileText[]) => void;
|
||||
conversationId?: string | null;
|
||||
chatOptionsData?: ChatOptions | null;
|
||||
isMobileWidth?: boolean;
|
||||
isLoggedIn: boolean;
|
||||
agentColor?: string;
|
||||
isResearchModeEnabled?: boolean;
|
||||
}
|
||||
|
||||
export default function ChatInputArea(props: ChatInputProps) {
|
||||
export const ChatInputArea = forwardRef<HTMLTextAreaElement, ChatInputProps>((props, ref) => {
|
||||
const [message, setMessage] = useState("");
|
||||
const fileInputRef = useRef<HTMLInputElement>(null);
|
||||
|
||||
@@ -78,15 +88,25 @@ export default function ChatInputArea(props: ChatInputProps) {
|
||||
const [loginRedirectMessage, setLoginRedirectMessage] = useState<string | null>(null);
|
||||
const [showLoginPrompt, setShowLoginPrompt] = useState(false);
|
||||
|
||||
const [recording, setRecording] = useState(false);
|
||||
const [imageUploaded, setImageUploaded] = useState(false);
|
||||
const [imagePath, setImagePath] = useState<string>("");
|
||||
const [imageData, setImageData] = useState<string | null>(null);
|
||||
const [imagePaths, setImagePaths] = useState<string[]>([]);
|
||||
const [imageData, setImageData] = useState<string[]>([]);
|
||||
|
||||
const [attachedFiles, setAttachedFiles] = useState<FileList | null>(null);
|
||||
const [convertedAttachedFiles, setConvertedAttachedFiles] = useState<AttachedFileText[]>([]);
|
||||
|
||||
const [recording, setRecording] = useState(false);
|
||||
const [mediaRecorder, setMediaRecorder] = useState<MediaRecorder | null>(null);
|
||||
|
||||
const [progressValue, setProgressValue] = useState(0);
|
||||
const [isDragAndDropping, setIsDragAndDropping] = useState(false);
|
||||
|
||||
const [showCommandList, setShowCommandList] = useState(false);
|
||||
const [useResearchMode, setUseResearchMode] = useState<boolean>(
|
||||
props.isResearchModeEnabled || false,
|
||||
);
|
||||
|
||||
const chatInputRef = ref as React.MutableRefObject<HTMLTextAreaElement>;
|
||||
useEffect(() => {
|
||||
if (!uploading) {
|
||||
setProgressValue(0);
|
||||
@@ -106,27 +126,37 @@ export default function ChatInputArea(props: ChatInputProps) {
|
||||
|
||||
useEffect(() => {
|
||||
async function fetchImageData() {
|
||||
if (imagePath) {
|
||||
const response = await fetch(imagePath);
|
||||
const blob = await response.blob();
|
||||
const reader = new FileReader();
|
||||
reader.onload = function () {
|
||||
const base64data = reader.result;
|
||||
setImageData(base64data as string);
|
||||
};
|
||||
reader.readAsDataURL(blob);
|
||||
if (imagePaths.length > 0) {
|
||||
const newImageData = await Promise.all(
|
||||
imagePaths.map(async (path) => {
|
||||
const response = await fetch(path);
|
||||
const blob = await response.blob();
|
||||
return new Promise<string>((resolve) => {
|
||||
const reader = new FileReader();
|
||||
reader.onload = () => resolve(reader.result as string);
|
||||
reader.readAsDataURL(blob);
|
||||
});
|
||||
}),
|
||||
);
|
||||
setImageData(newImageData);
|
||||
}
|
||||
setUploading(false);
|
||||
}
|
||||
setUploading(true);
|
||||
fetchImageData();
|
||||
}, [imagePath]);
|
||||
}, [imagePaths]);
|
||||
|
||||
useEffect(() => {
|
||||
if (props.isResearchModeEnabled) {
|
||||
setUseResearchMode(props.isResearchModeEnabled);
|
||||
}
|
||||
}, [props.isResearchModeEnabled]);
|
||||
|
||||
function onSendMessage() {
|
||||
if (imageUploaded) {
|
||||
setImageUploaded(false);
|
||||
setImagePath("");
|
||||
props.sendImage(imageData || "");
|
||||
setImagePaths([]);
|
||||
imageData.forEach((data) => props.sendImage(data));
|
||||
}
|
||||
if (!message.trim()) return;
|
||||
|
||||
@@ -138,7 +168,14 @@ export default function ChatInputArea(props: ChatInputProps) {
|
||||
return;
|
||||
}
|
||||
|
||||
props.sendMessage(message.trim());
|
||||
let messageToSend = message.trim();
|
||||
if (useResearchMode && !messageToSend.startsWith("/research")) {
|
||||
messageToSend = `/research ${messageToSend}`;
|
||||
}
|
||||
|
||||
props.sendMessage(messageToSend);
|
||||
setAttachedFiles(null);
|
||||
setConvertedAttachedFiles([]);
|
||||
setMessage("");
|
||||
}
|
||||
|
||||
@@ -172,26 +209,85 @@ export default function ChatInputArea(props: ChatInputProps) {
|
||||
setShowLoginPrompt(true);
|
||||
return;
|
||||
}
|
||||
// check for image file
|
||||
// check for image files
|
||||
const image_endings = ["jpg", "jpeg", "png", "webp"];
|
||||
const newImagePaths: string[] = [];
|
||||
for (let i = 0; i < files.length; i++) {
|
||||
const file = files[i];
|
||||
const file_extension = file.name.split(".").pop();
|
||||
if (image_endings.includes(file_extension || "")) {
|
||||
setImageUploaded(true);
|
||||
setImagePath(DOMPurify.sanitize(URL.createObjectURL(file)));
|
||||
return;
|
||||
newImagePaths.push(DOMPurify.sanitize(URL.createObjectURL(file)));
|
||||
}
|
||||
}
|
||||
|
||||
uploadDataForIndexing(
|
||||
files,
|
||||
setWarning,
|
||||
setUploading,
|
||||
setError,
|
||||
props.setUploadedFiles,
|
||||
props.conversationId,
|
||||
if (newImagePaths.length > 0) {
|
||||
setImageUploaded(true);
|
||||
setImagePaths((prevPaths) => [...prevPaths, ...newImagePaths]);
|
||||
// Set focus to the input for user message after uploading files
|
||||
chatInputRef?.current?.focus();
|
||||
}
|
||||
|
||||
// Process all non-image files
|
||||
const nonImageFiles = Array.from(files).filter(
|
||||
(file) => !image_endings.includes(file.name.split(".").pop() || ""),
|
||||
);
|
||||
|
||||
// Concatenate attachedFiles and files
|
||||
const newFiles = nonImageFiles
|
||||
? Array.from(nonImageFiles).concat(Array.from(attachedFiles || []))
|
||||
: Array.from(attachedFiles || []);
|
||||
|
||||
if (newFiles.length > 0) {
|
||||
// Ensure files are below size limit (10 MB)
|
||||
for (let i = 0; i < newFiles.length; i++) {
|
||||
if (newFiles[i].size > 10 * 1024 * 1024) {
|
||||
setWarning(
|
||||
`File ${newFiles[i].name} is too large. Please upload files smaller than 10 MB.`,
|
||||
);
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
const dataTransfer = new DataTransfer();
|
||||
newFiles.forEach((file) => dataTransfer.items.add(file));
|
||||
|
||||
// Extract text from files
|
||||
extractTextFromFiles(dataTransfer.files).then((data) => {
|
||||
props.setUploadedFiles(data);
|
||||
setAttachedFiles(dataTransfer.files);
|
||||
setConvertedAttachedFiles(data);
|
||||
});
|
||||
}
|
||||
|
||||
// Set focus to the input for user message after uploading files
|
||||
chatInputRef?.current?.focus();
|
||||
}
|
||||
|
||||
async function extractTextFromFiles(files: FileList): Promise<AttachedFileText[]> {
|
||||
const formData = await packageFilesForUpload(files);
|
||||
setUploading(true);
|
||||
|
||||
try {
|
||||
const response = await fetch("/api/content/convert", {
|
||||
method: "POST",
|
||||
body: formData,
|
||||
});
|
||||
setUploading(false);
|
||||
|
||||
if (!response.ok) {
|
||||
throw new Error(`HTTP error! status: ${response.status}`);
|
||||
}
|
||||
|
||||
return await response.json();
|
||||
} catch (error) {
|
||||
setError(
|
||||
"Error converting files. " +
|
||||
error +
|
||||
". Please try again, or contact team@khoj.dev if the issue persists.",
|
||||
);
|
||||
console.error("Error converting files:", error);
|
||||
return [];
|
||||
}
|
||||
}
|
||||
|
||||
// Assuming this function is added within the same context as the provided excerpt
|
||||
@@ -270,12 +366,17 @@ export default function ChatInputArea(props: ChatInputProps) {
|
||||
}
|
||||
}, [recording, mediaRecorder]);
|
||||
|
||||
const chatInputRef = useRef<HTMLTextAreaElement>(null);
|
||||
useEffect(() => {
|
||||
if (!chatInputRef.current) return;
|
||||
if (!chatInputRef?.current) return;
|
||||
chatInputRef.current.style.height = "auto";
|
||||
chatInputRef.current.style.height =
|
||||
Math.max(chatInputRef.current.scrollHeight - 24, 64) + "px";
|
||||
|
||||
if (message.startsWith("/") && message.split(" ").length === 1) {
|
||||
setShowCommandList(true);
|
||||
} else {
|
||||
setShowCommandList(false);
|
||||
}
|
||||
}, [message]);
|
||||
|
||||
function handleDragOver(event: React.DragEvent<HTMLDivElement>) {
|
||||
@@ -288,9 +389,12 @@ export default function ChatInputArea(props: ChatInputProps) {
|
||||
setIsDragAndDropping(false);
|
||||
}
|
||||
|
||||
function removeImageUpload() {
|
||||
setImageUploaded(false);
|
||||
setImagePath("");
|
||||
function removeImageUpload(index: number) {
|
||||
setImagePaths((prevPaths) => prevPaths.filter((_, i) => i !== index));
|
||||
setImageData((prevData) => prevData.filter((_, i) => i !== index));
|
||||
if (imagePaths.length === 1) {
|
||||
setImageUploaded(false);
|
||||
}
|
||||
}
|
||||
|
||||
return (
|
||||
@@ -358,13 +462,18 @@ export default function ChatInputArea(props: ChatInputProps) {
|
||||
</AlertDialogContent>
|
||||
</AlertDialog>
|
||||
)}
|
||||
{message.startsWith("/") && message.split(" ").length === 1 && (
|
||||
{showCommandList && (
|
||||
<div className="flex justify-center text-center">
|
||||
<Popover open={message.startsWith("/")}>
|
||||
<Popover open={showCommandList} onOpenChange={setShowCommandList}>
|
||||
<PopoverTrigger className="flex justify-center text-center"></PopoverTrigger>
|
||||
<PopoverContent
|
||||
onOpenAutoFocus={(e) => e.preventDefault()}
|
||||
className={`${props.isMobileWidth ? "w-[100vw]" : "w-full"} rounded-md`}
|
||||
side="bottom"
|
||||
align="center"
|
||||
/* Offset below text area on home page (i.e where conversationId is unset) */
|
||||
sideOffset={props.conversationId ? 0 : 80}
|
||||
alignOffset={0}
|
||||
>
|
||||
<Command className="max-w-full">
|
||||
<CommandInput
|
||||
@@ -406,112 +515,230 @@ export default function ChatInputArea(props: ChatInputProps) {
|
||||
</Popover>
|
||||
</div>
|
||||
)}
|
||||
<div
|
||||
className={`${styles.actualInputArea} items-center justify-between dark:bg-neutral-700 relative ${isDragAndDropping && "animate-pulse"}`}
|
||||
onDragOver={handleDragOver}
|
||||
onDragLeave={handleDragLeave}
|
||||
onDrop={handleDragAndDropFiles}
|
||||
>
|
||||
{imageUploaded && (
|
||||
<div className="absolute bottom-[80px] left-0 right-0 dark:bg-neutral-700 bg-white pt-5 pb-5 w-full rounded-lg border dark:border-none grid grid-cols-2">
|
||||
<div className="pl-4 pr-4">
|
||||
<img src={imagePath} alt="img" className="w-auto max-h-[100px]" />
|
||||
</div>
|
||||
<div className="pl-4 pr-4">
|
||||
<X
|
||||
className="w-6 h-6 float-right dark:hover:bg-[hsl(var(--background))] hover:bg-neutral-100 rounded-sm"
|
||||
onClick={removeImageUpload}
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
<input
|
||||
type="file"
|
||||
multiple={true}
|
||||
ref={fileInputRef}
|
||||
onChange={handleFileChange}
|
||||
style={{ display: "none" }}
|
||||
/>
|
||||
<Button
|
||||
variant={"ghost"}
|
||||
className="!bg-none p-0 m-2 h-auto text-3xl rounded-full text-gray-300 hover:text-gray-500"
|
||||
disabled={props.sendDisabled}
|
||||
onClick={handleFileButtonClick}
|
||||
>
|
||||
<Paperclip className="w-8 h-8" />
|
||||
</Button>
|
||||
<div className="grid w-full gap-1.5 relative">
|
||||
<Textarea
|
||||
ref={chatInputRef}
|
||||
className={`border-none w-full h-16 min-h-16 max-h-[128px] md:py-4 rounded-lg resize-none dark:bg-neutral-700 ${props.isMobileWidth ? "text-md" : "text-lg"}`}
|
||||
placeholder="Type / to see a list of commands"
|
||||
id="message"
|
||||
autoFocus={true}
|
||||
value={message}
|
||||
onKeyDown={(e) => {
|
||||
if (e.key === "Enter" && !e.shiftKey) {
|
||||
setImageUploaded(false);
|
||||
setImagePath("");
|
||||
e.preventDefault();
|
||||
onSendMessage();
|
||||
}
|
||||
}}
|
||||
onChange={(e) => setMessage(e.target.value)}
|
||||
disabled={props.sendDisabled || recording}
|
||||
/>
|
||||
<div>
|
||||
<div className="flex items-center gap-2 overflow-x-auto">
|
||||
{imageUploaded &&
|
||||
imagePaths.map((path, index) => (
|
||||
<div key={index} className="relative flex-shrink-0 pb-3 pt-2 group">
|
||||
<img
|
||||
src={path}
|
||||
alt={`img-${index}`}
|
||||
className="w-auto h-16 object-cover rounded-xl"
|
||||
/>
|
||||
<Button
|
||||
variant="ghost"
|
||||
size="icon"
|
||||
className="absolute -top-0 -right-2 h-5 w-5 rounded-full bg-neutral-200 dark:bg-neutral-600 hover:bg-neutral-300 dark:hover:bg-neutral-500 opacity-0 group-hover:opacity-100 transition-opacity"
|
||||
onClick={() => removeImageUpload(index)}
|
||||
>
|
||||
<X className="h-3 w-3" />
|
||||
</Button>
|
||||
</div>
|
||||
))}
|
||||
{convertedAttachedFiles &&
|
||||
Array.from(convertedAttachedFiles).map((file, index) => (
|
||||
<Dialog key={index}>
|
||||
<DialogTrigger asChild>
|
||||
<div key={index} className="relative flex-shrink-0 p-2 group">
|
||||
<div
|
||||
className={`w-auto h-16 object-cover rounded-xl ${props.agentColor ? convertToBGClass(props.agentColor) : "bg-orange-300 hover:bg-orange-500"} bg-opacity-15`}
|
||||
>
|
||||
<div className="flex p-2 flex-col justify-start items-start h-full">
|
||||
<span className="text-sm font-bold text-neutral-500 dark:text-neutral-400 text-ellipsis truncate max-w-[200px] break-words">
|
||||
{file.name}
|
||||
</span>
|
||||
<span className="flex items-center gap-1">
|
||||
{getIconFromFilename(file.file_type)}
|
||||
<span className="text-xs text-neutral-500 dark:text-neutral-400">
|
||||
{convertBytesToText(file.size)}
|
||||
</span>
|
||||
</span>
|
||||
</div>
|
||||
</div>
|
||||
<Button
|
||||
variant="ghost"
|
||||
size="icon"
|
||||
className="absolute -top-0 -right-2 h-5 w-5 rounded-full bg-neutral-200 dark:bg-neutral-600 hover:bg-neutral-300 dark:hover:bg-neutral-500 opacity-0 group-hover:opacity-100 transition-opacity"
|
||||
onClick={() => {
|
||||
setAttachedFiles((prevFiles) => {
|
||||
const removeFile = file.name;
|
||||
if (!prevFiles) return null;
|
||||
const updatedFiles = Array.from(
|
||||
prevFiles,
|
||||
).filter((file) => file.name !== removeFile);
|
||||
const dataTransfer = new DataTransfer();
|
||||
updatedFiles.forEach((file) =>
|
||||
dataTransfer.items.add(file),
|
||||
);
|
||||
|
||||
const filteredConvertedAttachedFiles =
|
||||
convertedAttachedFiles.filter(
|
||||
(file) => file.name !== removeFile,
|
||||
);
|
||||
|
||||
props.setUploadedFiles(
|
||||
filteredConvertedAttachedFiles,
|
||||
);
|
||||
setConvertedAttachedFiles(
|
||||
filteredConvertedAttachedFiles,
|
||||
);
|
||||
return dataTransfer.files;
|
||||
});
|
||||
}}
|
||||
>
|
||||
<X className="h-3 w-3" />
|
||||
</Button>
|
||||
</div>
|
||||
</DialogTrigger>
|
||||
<DialogContent>
|
||||
<DialogHeader>
|
||||
<DialogTitle>{file.name}</DialogTitle>
|
||||
</DialogHeader>
|
||||
<DialogDescription>
|
||||
<ScrollArea className="h-72 w-full rounded-md">
|
||||
{file.content}
|
||||
</ScrollArea>
|
||||
</DialogDescription>
|
||||
</DialogContent>
|
||||
</Dialog>
|
||||
))}
|
||||
</div>
|
||||
{recording ? (
|
||||
<TooltipProvider>
|
||||
<Tooltip>
|
||||
<TooltipTrigger asChild>
|
||||
<Button
|
||||
variant="default"
|
||||
className={`${!recording && "hidden"} ${props.agentColor ? convertToBGClass(props.agentColor) : "bg-orange-300 hover:bg-orange-500"} rounded-full p-1 m-2 h-auto text-3xl transition transform md:hover:-translate-y-1`}
|
||||
onClick={() => {
|
||||
setRecording(!recording);
|
||||
}}
|
||||
disabled={props.sendDisabled}
|
||||
>
|
||||
<Stop weight="fill" className="w-6 h-6" />
|
||||
</Button>
|
||||
</TooltipTrigger>
|
||||
<TooltipContent>
|
||||
Click to stop recording and transcribe your voice.
|
||||
</TooltipContent>
|
||||
</Tooltip>
|
||||
</TooltipProvider>
|
||||
) : mediaRecorder ? (
|
||||
<InlineLoading />
|
||||
) : (
|
||||
<TooltipProvider>
|
||||
<Tooltip>
|
||||
<TooltipTrigger asChild>
|
||||
<Button
|
||||
variant="default"
|
||||
className={`${!message || recording || "hidden"} ${props.agentColor ? convertToBGClass(props.agentColor) : "bg-orange-300 hover:bg-orange-500"} rounded-full p-1 m-2 h-auto text-3xl transition transform md:hover:-translate-y-1`}
|
||||
onClick={() => {
|
||||
setMessage("Listening...");
|
||||
setRecording(!recording);
|
||||
}}
|
||||
disabled={props.sendDisabled}
|
||||
>
|
||||
<Microphone weight="fill" className="w-6 h-6" />
|
||||
</Button>
|
||||
</TooltipTrigger>
|
||||
<TooltipContent>
|
||||
Click to transcribe your message with voice.
|
||||
</TooltipContent>
|
||||
</Tooltip>
|
||||
</TooltipProvider>
|
||||
)}
|
||||
<Button
|
||||
className={`${(!message || recording) && "hidden"} ${props.agentColor ? convertToBGClass(props.agentColor) : "bg-orange-300 hover:bg-orange-500"} rounded-full p-1 m-2 h-auto text-3xl transition transform md:hover:-translate-y-1`}
|
||||
onClick={onSendMessage}
|
||||
disabled={props.sendDisabled}
|
||||
<div
|
||||
className={`${styles.actualInputArea} justify-between dark:bg-neutral-700 relative ${isDragAndDropping && "animate-pulse"}`}
|
||||
onDragOver={handleDragOver}
|
||||
onDragLeave={handleDragLeave}
|
||||
onDrop={handleDragAndDropFiles}
|
||||
>
|
||||
<ArrowUp className="w-6 h-6" weight="bold" />
|
||||
</Button>
|
||||
<input
|
||||
type="file"
|
||||
accept=".pdf,.doc,.docx,.txt,.md,.org,.jpg,.jpeg,.png,.webp"
|
||||
multiple={true}
|
||||
ref={fileInputRef}
|
||||
onChange={handleFileChange}
|
||||
style={{ display: "none" }}
|
||||
/>
|
||||
|
||||
<div className="flex items-center">
|
||||
<Button
|
||||
variant={"ghost"}
|
||||
className="!bg-none p-0 m-2 h-auto text-3xl rounded-full text-gray-300 hover:text-gray-500"
|
||||
disabled={props.sendDisabled}
|
||||
onClick={handleFileButtonClick}
|
||||
>
|
||||
<Paperclip className="w-8 h-8" />
|
||||
</Button>
|
||||
</div>
|
||||
<div className="flex-grow flex flex-col w-full gap-1.5 relative">
|
||||
<Textarea
|
||||
ref={chatInputRef}
|
||||
className={`border-none focus:border-none
|
||||
focus:outline-none focus-visible:ring-transparent
|
||||
w-full h-16 min-h-16 max-h-[128px] md:py-4 rounded-lg resize-none dark:bg-neutral-700
|
||||
${props.isMobileWidth ? "text-md" : "text-lg"}`}
|
||||
placeholder="Type / to see a list of commands"
|
||||
id="message"
|
||||
autoFocus={true}
|
||||
value={message}
|
||||
onKeyDown={(e) => {
|
||||
if (e.key === "Enter" && !e.shiftKey && !props.isMobileWidth) {
|
||||
setImageUploaded(false);
|
||||
setImagePaths([]);
|
||||
e.preventDefault();
|
||||
onSendMessage();
|
||||
}
|
||||
}}
|
||||
onChange={(e) => setMessage(e.target.value)}
|
||||
disabled={props.sendDisabled || recording}
|
||||
/>
|
||||
</div>
|
||||
<div className="flex items-end pb-2">
|
||||
{recording ? (
|
||||
<TooltipProvider>
|
||||
<Tooltip>
|
||||
<TooltipTrigger asChild>
|
||||
<Button
|
||||
variant="default"
|
||||
className={`${!recording && "hidden"} ${props.agentColor ? convertToBGClass(props.agentColor) : "bg-orange-300 hover:bg-orange-500"} rounded-full p-1 m-2 h-auto text-3xl transition transform md:hover:-translate-y-1`}
|
||||
onClick={() => {
|
||||
setRecording(!recording);
|
||||
}}
|
||||
disabled={props.sendDisabled}
|
||||
>
|
||||
<Stop weight="fill" className="w-6 h-6" />
|
||||
</Button>
|
||||
</TooltipTrigger>
|
||||
<TooltipContent>
|
||||
Click to stop recording and transcribe your voice.
|
||||
</TooltipContent>
|
||||
</Tooltip>
|
||||
</TooltipProvider>
|
||||
) : mediaRecorder ? (
|
||||
<InlineLoading />
|
||||
) : (
|
||||
<TooltipProvider>
|
||||
<Tooltip>
|
||||
<TooltipTrigger asChild>
|
||||
<Button
|
||||
variant="default"
|
||||
className={`${!message || recording || "hidden"} ${props.agentColor ? convertToBGClass(props.agentColor) : "bg-orange-300 hover:bg-orange-500"} rounded-full p-1 m-2 h-auto text-3xl transition transform md:hover:-translate-y-1`}
|
||||
onClick={() => {
|
||||
setMessage("Listening...");
|
||||
setRecording(!recording);
|
||||
}}
|
||||
disabled={props.sendDisabled}
|
||||
>
|
||||
<Microphone weight="fill" className="w-6 h-6" />
|
||||
</Button>
|
||||
</TooltipTrigger>
|
||||
<TooltipContent>
|
||||
Click to transcribe your message with voice.
|
||||
</TooltipContent>
|
||||
</Tooltip>
|
||||
</TooltipProvider>
|
||||
)}
|
||||
<Button
|
||||
className={`${(!message || recording) && "hidden"} ${props.agentColor ? convertToBGClass(props.agentColor) : "bg-orange-300 hover:bg-orange-500"} rounded-full p-1 m-2 h-auto text-3xl transition transform md:hover:-translate-y-1`}
|
||||
onClick={onSendMessage}
|
||||
disabled={props.sendDisabled}
|
||||
>
|
||||
<ArrowUp className="w-6 h-6" weight="bold" />
|
||||
</Button>
|
||||
</div>
|
||||
</div>
|
||||
<TooltipProvider>
|
||||
<Tooltip>
|
||||
<TooltipTrigger asChild>
|
||||
<Button
|
||||
variant="ghost"
|
||||
className="float-right justify-center gap-1 flex items-center p-1.5 mr-2 h-fit"
|
||||
onClick={() => {
|
||||
setUseResearchMode(!useResearchMode);
|
||||
chatInputRef?.current?.focus();
|
||||
}}
|
||||
>
|
||||
<span className="text-muted-foreground text-sm">Research Mode</span>
|
||||
{useResearchMode ? (
|
||||
<ToggleRight
|
||||
weight="fill"
|
||||
className={`w-6 h-6 inline-block ${props.agentColor ? convertColorToTextClass(props.agentColor) : convertColorToTextClass("orange")} rounded-full`}
|
||||
/>
|
||||
) : (
|
||||
<ToggleLeft
|
||||
weight="fill"
|
||||
className={`w-6 h-6 inline-block ${convertColorToTextClass("gray")} rounded-full`}
|
||||
/>
|
||||
)}
|
||||
</Button>
|
||||
</TooltipTrigger>
|
||||
<TooltipContent className="text-xs">
|
||||
(Experimental) Research Mode allows you to get more deeply researched,
|
||||
detailed responses. Response times may be longer.
|
||||
</TooltipContent>
|
||||
</Tooltip>
|
||||
</TooltipProvider>
|
||||
</div>
|
||||
</>
|
||||
);
|
||||
}
|
||||
});
|
||||
|
||||
ChatInputArea.displayName = "ChatInputArea";
|
||||
|
||||
@@ -4,6 +4,7 @@ div.chatMessageContainer {
|
||||
margin: 12px;
|
||||
border-radius: 16px;
|
||||
padding: 8px 16px 0 16px;
|
||||
word-break: break-word;
|
||||
}
|
||||
|
||||
div.chatMessageWrapper {
|
||||
@@ -57,7 +58,26 @@ div.emptyChatMessage {
|
||||
display: none;
|
||||
}
|
||||
|
||||
div.chatMessageContainer img {
|
||||
div.imagesContainer {
|
||||
display: flex;
|
||||
overflow-x: auto;
|
||||
padding-bottom: 8px;
|
||||
margin-bottom: 8px;
|
||||
}
|
||||
|
||||
div.imageWrapper {
|
||||
flex: 0 0 auto;
|
||||
margin-right: 8px;
|
||||
}
|
||||
|
||||
div.imageWrapper img {
|
||||
width: auto;
|
||||
height: 128px;
|
||||
object-fit: cover;
|
||||
border-radius: 8px;
|
||||
}
|
||||
|
||||
div.chatMessageContainer > img {
|
||||
width: auto;
|
||||
height: auto;
|
||||
max-width: 100%;
|
||||
@@ -151,6 +171,7 @@ div.trainOfThoughtElement {
|
||||
div.trainOfThoughtElement ol,
|
||||
div.trainOfThoughtElement ul {
|
||||
margin: auto;
|
||||
word-break: break-word;
|
||||
}
|
||||
|
||||
@media screen and (max-width: 768px) {
|
||||
|
||||
@@ -10,6 +10,7 @@ import { createRoot } from "react-dom/client";
|
||||
import "katex/dist/katex.min.css";
|
||||
|
||||
import { TeaserReferencesSection, constructAllReferences } from "../referencePanel/referencePanel";
|
||||
import { renderCodeGenImageInline } from "@/app/common/chatFunctions";
|
||||
|
||||
import {
|
||||
ThumbsUp,
|
||||
@@ -26,6 +27,9 @@ import {
|
||||
Palette,
|
||||
ClipboardText,
|
||||
Check,
|
||||
Code,
|
||||
Shapes,
|
||||
Trash,
|
||||
} from "@phosphor-icons/react";
|
||||
|
||||
import DOMPurify from "dompurify";
|
||||
@@ -35,6 +39,19 @@ import { AgentData } from "@/app/agents/page";
|
||||
|
||||
import renderMathInElement from "katex/contrib/auto-render";
|
||||
import "katex/dist/katex.min.css";
|
||||
import ExcalidrawComponent from "../excalidraw/excalidraw";
|
||||
import { AttachedFileText } from "../chatInputArea/chatInputArea";
|
||||
import {
|
||||
Dialog,
|
||||
DialogContent,
|
||||
DialogDescription,
|
||||
DialogHeader,
|
||||
DialogTrigger,
|
||||
} from "@/components/ui/dialog";
|
||||
import { DialogTitle } from "@radix-ui/react-dialog";
|
||||
import { convertBytesToText } from "@/app/common/utils";
|
||||
import { ScrollArea } from "@/components/ui/scroll-area";
|
||||
import { getIconFromFilename } from "@/app/common/iconUtils";
|
||||
|
||||
const md = new markdownIt({
|
||||
html: true,
|
||||
@@ -97,6 +114,26 @@ export interface OnlineContextData {
|
||||
peopleAlsoAsk: PeopleAlsoAsk[];
|
||||
}
|
||||
|
||||
export interface CodeContext {
|
||||
[key: string]: CodeContextData;
|
||||
}
|
||||
|
||||
export interface CodeContextData {
|
||||
code: string;
|
||||
results: {
|
||||
success: boolean;
|
||||
output_files: CodeContextFile[];
|
||||
std_out: string;
|
||||
std_err: string;
|
||||
code_runtime: number;
|
||||
};
|
||||
}
|
||||
|
||||
export interface CodeContextFile {
|
||||
filename: string;
|
||||
b64_data: string;
|
||||
}
|
||||
|
||||
interface Intent {
|
||||
type: string;
|
||||
query: string;
|
||||
@@ -104,6 +141,11 @@ interface Intent {
|
||||
"inferred-queries": string[];
|
||||
}
|
||||
|
||||
interface TrainOfThoughtObject {
|
||||
type: string;
|
||||
data: string;
|
||||
}
|
||||
|
||||
export interface SingleChatMessage {
|
||||
automationId: string;
|
||||
by: string;
|
||||
@@ -111,10 +153,15 @@ export interface SingleChatMessage {
|
||||
created: string;
|
||||
context: Context[];
|
||||
onlineContext: OnlineContext;
|
||||
codeContext: CodeContext;
|
||||
trainOfThought?: TrainOfThoughtObject[];
|
||||
rawQuery?: string;
|
||||
intent?: Intent;
|
||||
agent?: AgentData;
|
||||
uploadedImageData?: string;
|
||||
images?: string[];
|
||||
conversationId: string;
|
||||
turnId?: string;
|
||||
queryFiles?: AttachedFileText[];
|
||||
}
|
||||
|
||||
export interface StreamMessage {
|
||||
@@ -122,11 +169,16 @@ export interface StreamMessage {
|
||||
trainOfThought: string[];
|
||||
context: Context[];
|
||||
onlineContext: OnlineContext;
|
||||
codeContext: CodeContext;
|
||||
completed: boolean;
|
||||
rawQuery: string;
|
||||
timestamp: string;
|
||||
agent?: AgentData;
|
||||
uploadedImageData?: string;
|
||||
images?: string[];
|
||||
intentType?: string;
|
||||
inferredQueries?: string[];
|
||||
turnId?: string;
|
||||
queryFiles?: AttachedFileText[];
|
||||
}
|
||||
|
||||
export interface ChatHistoryData {
|
||||
@@ -208,7 +260,9 @@ interface ChatMessageProps {
|
||||
borderLeftColor?: string;
|
||||
isLastMessage?: boolean;
|
||||
agent?: AgentData;
|
||||
uploadedImageData?: string;
|
||||
onDeleteMessage: (turnId?: string) => void;
|
||||
conversationId: string;
|
||||
turnId?: string;
|
||||
}
|
||||
|
||||
interface TrainOfThoughtProps {
|
||||
@@ -252,10 +306,18 @@ function chooseIconFromHeader(header: string, iconColor: string) {
|
||||
return <Aperture className={`${classNames}`} />;
|
||||
}
|
||||
|
||||
if (compareHeader.includes("diagram")) {
|
||||
return <Shapes className={`${classNames}`} />;
|
||||
}
|
||||
|
||||
if (compareHeader.includes("paint")) {
|
||||
return <Palette className={`${classNames}`} />;
|
||||
}
|
||||
|
||||
if (compareHeader.includes("code")) {
|
||||
return <Code className={`${classNames}`} />;
|
||||
}
|
||||
|
||||
return <Brain className={`${classNames}`} />;
|
||||
}
|
||||
|
||||
@@ -266,12 +328,15 @@ export function TrainOfThought(props: TrainOfThoughtProps) {
|
||||
const iconColor = props.primary ? convertColorToTextClass(props.agentColor) : "text-gray-500";
|
||||
const icon = chooseIconFromHeader(header, iconColor);
|
||||
let markdownRendered = DOMPurify.sanitize(md.render(props.message));
|
||||
|
||||
// Remove any header tags from markdownRendered
|
||||
markdownRendered = markdownRendered.replace(/<h[1-6].*?<\/h[1-6]>/g, "");
|
||||
return (
|
||||
<div
|
||||
className={`${styles.trainOfThoughtElement} break-all items-center ${props.primary ? "text-gray-400" : "text-gray-300"} ${styles.trainOfThought} ${props.primary ? styles.primary : ""}`}
|
||||
className={`${styles.trainOfThoughtElement} break-words items-center ${props.primary ? "text-gray-400" : "text-gray-300"} ${styles.trainOfThought} ${props.primary ? styles.primary : ""}`}
|
||||
>
|
||||
{icon}
|
||||
<div dangerouslySetInnerHTML={{ __html: markdownRendered }} />
|
||||
<div dangerouslySetInnerHTML={{ __html: markdownRendered }} className="break-words" />
|
||||
</div>
|
||||
);
|
||||
}
|
||||
@@ -283,6 +348,7 @@ const ChatMessage = forwardRef<HTMLDivElement, ChatMessageProps>((props, ref) =>
|
||||
const [markdownRendered, setMarkdownRendered] = useState<string>("");
|
||||
const [isPlaying, setIsPlaying] = useState<boolean>(false);
|
||||
const [interrupted, setInterrupted] = useState<boolean>(false);
|
||||
const [excalidrawData, setExcalidrawData] = useState<string>("");
|
||||
|
||||
const interruptedRef = useRef<boolean>(false);
|
||||
const messageRef = useRef<HTMLDivElement>(null);
|
||||
@@ -319,8 +385,14 @@ const ChatMessage = forwardRef<HTMLDivElement, ChatMessageProps>((props, ref) =>
|
||||
}, [messageRef.current]);
|
||||
|
||||
useEffect(() => {
|
||||
// Prepare initial message for rendering
|
||||
let message = props.chatMessage.message;
|
||||
|
||||
if (props.chatMessage.intent && props.chatMessage.intent.type == "excalidraw") {
|
||||
message = props.chatMessage.intent["inferred-queries"][0];
|
||||
setExcalidrawData(props.chatMessage.message);
|
||||
}
|
||||
|
||||
// Replace LaTeX delimiters with placeholders
|
||||
message = message
|
||||
.replace(/\\\(/g, "LEFTPAREN")
|
||||
@@ -328,32 +400,74 @@ const ChatMessage = forwardRef<HTMLDivElement, ChatMessageProps>((props, ref) =>
|
||||
.replace(/\\\[/g, "LEFTBRACKET")
|
||||
.replace(/\\\]/g, "RIGHTBRACKET");
|
||||
|
||||
if (props.chatMessage.uploadedImageData) {
|
||||
message = `\n\n${message}`;
|
||||
const intentTypeHandlers = {
|
||||
"text-to-image": (msg: string) => ``,
|
||||
"text-to-image2": (msg: string) => ``,
|
||||
"text-to-image-v3": (msg: string) =>
|
||||
``,
|
||||
excalidraw: (msg: string) => msg,
|
||||
};
|
||||
|
||||
// Handle intent-specific rendering
|
||||
if (props.chatMessage.intent) {
|
||||
const { type, "inferred-queries": inferredQueries } = props.chatMessage.intent;
|
||||
|
||||
if (type in intentTypeHandlers) {
|
||||
message = intentTypeHandlers[type as keyof typeof intentTypeHandlers](message);
|
||||
}
|
||||
|
||||
if (type.includes("text-to-image") && inferredQueries?.length > 0) {
|
||||
message += `\n\n${inferredQueries[0]}`;
|
||||
}
|
||||
}
|
||||
|
||||
if (props.chatMessage.intent && props.chatMessage.intent.type == "text-to-image") {
|
||||
message = ``;
|
||||
} else if (props.chatMessage.intent && props.chatMessage.intent.type == "text-to-image2") {
|
||||
message = ``;
|
||||
} else if (
|
||||
props.chatMessage.intent &&
|
||||
props.chatMessage.intent.type == "text-to-image-v3"
|
||||
) {
|
||||
message = ``;
|
||||
}
|
||||
if (
|
||||
props.chatMessage.intent &&
|
||||
props.chatMessage.intent.type.includes("text-to-image") &&
|
||||
props.chatMessage.intent["inferred-queries"]?.length > 0
|
||||
) {
|
||||
message += `\n\n${props.chatMessage.intent["inferred-queries"][0]}`;
|
||||
// Replace file links with base64 data
|
||||
message = renderCodeGenImageInline(message, props.chatMessage.codeContext);
|
||||
|
||||
// Add code context files to the message
|
||||
if (props.chatMessage.codeContext) {
|
||||
Object.entries(props.chatMessage.codeContext).forEach(([key, value]) => {
|
||||
value.results.output_files?.forEach((file) => {
|
||||
if (file.filename.endsWith(".png") || file.filename.endsWith(".jpg")) {
|
||||
// Don't add the image again if it's already in the message!
|
||||
if (!message.includes(`) {
|
||||
message += `\n\n`;
|
||||
}
|
||||
}
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
setTextRendered(message);
|
||||
// Handle user attached images rendering
|
||||
let messageForClipboard = message;
|
||||
let messageToRender = message;
|
||||
if (props.chatMessage.images && props.chatMessage.images.length > 0) {
|
||||
const sanitizedImages = props.chatMessage.images.map((image) => {
|
||||
const decodedImage = image.startsWith("data%3Aimage")
|
||||
? decodeURIComponent(image)
|
||||
: image;
|
||||
return DOMPurify.sanitize(decodedImage);
|
||||
});
|
||||
const imagesInMd = sanitizedImages
|
||||
.map((sanitizedImage, index) => {
|
||||
return ``;
|
||||
})
|
||||
.join("\n");
|
||||
const imagesInHtml = sanitizedImages
|
||||
.map((sanitizedImage, index) => {
|
||||
return `<div class="${styles.imageWrapper}"><img src="${sanitizedImage}" alt="uploaded image ${index + 1}" /></div>`;
|
||||
})
|
||||
.join("");
|
||||
const userImagesInHtml = `<div class="${styles.imagesContainer}">${imagesInHtml}</div>`;
|
||||
messageForClipboard = `${imagesInMd}\n\n${messageForClipboard}`;
|
||||
messageToRender = `${userImagesInHtml}${messageToRender}`;
|
||||
}
|
||||
|
||||
// Set the message text
|
||||
setTextRendered(messageForClipboard);
|
||||
|
||||
// Render the markdown
|
||||
let markdownRendered = md.render(message);
|
||||
let markdownRendered = md.render(messageToRender);
|
||||
|
||||
// Replace placeholders with LaTeX delimiters
|
||||
markdownRendered = markdownRendered
|
||||
@@ -364,7 +478,7 @@ const ChatMessage = forwardRef<HTMLDivElement, ChatMessageProps>((props, ref) =>
|
||||
|
||||
// Sanitize and set the rendered markdown
|
||||
setMarkdownRendered(DOMPurify.sanitize(markdownRendered));
|
||||
}, [props.chatMessage.message, props.chatMessage.intent]);
|
||||
}, [props.chatMessage.message, props.chatMessage.images, props.chatMessage.intent]);
|
||||
|
||||
useEffect(() => {
|
||||
if (copySuccess) {
|
||||
@@ -536,9 +650,31 @@ const ChatMessage = forwardRef<HTMLDivElement, ChatMessageProps>((props, ref) =>
|
||||
});
|
||||
}
|
||||
|
||||
const deleteMessage = async (message: SingleChatMessage) => {
|
||||
const turnId = message.turnId || props.turnId;
|
||||
const response = await fetch("/api/chat/conversation/message", {
|
||||
method: "DELETE",
|
||||
headers: {
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
body: JSON.stringify({
|
||||
conversation_id: props.conversationId,
|
||||
turn_id: turnId,
|
||||
}),
|
||||
});
|
||||
|
||||
if (response.ok) {
|
||||
// Update the UI after successful deletion
|
||||
props.onDeleteMessage(turnId);
|
||||
} else {
|
||||
console.error("Failed to delete message");
|
||||
}
|
||||
};
|
||||
|
||||
const allReferences = constructAllReferences(
|
||||
props.chatMessage.context,
|
||||
props.chatMessage.onlineContext,
|
||||
props.chatMessage.codeContext,
|
||||
);
|
||||
|
||||
return (
|
||||
@@ -549,17 +685,59 @@ const ChatMessage = forwardRef<HTMLDivElement, ChatMessageProps>((props, ref) =>
|
||||
onMouseEnter={(event) => setIsHovering(true)}
|
||||
>
|
||||
<div className={chatMessageWrapperClasses(props.chatMessage)}>
|
||||
{props.chatMessage.queryFiles && props.chatMessage.queryFiles.length > 0 && (
|
||||
<div className="flex flex-wrap flex-col mb-2 max-w-full">
|
||||
{props.chatMessage.queryFiles.map((file, index) => (
|
||||
<Dialog key={index}>
|
||||
<DialogTrigger asChild>
|
||||
<div
|
||||
className="flex items-center space-x-2 cursor-pointer bg-gray-500 bg-opacity-25 rounded-lg p-2 w-full
|
||||
"
|
||||
>
|
||||
<div className="flex-shrink-0">
|
||||
{getIconFromFilename(file.file_type)}
|
||||
</div>
|
||||
<span className="truncate flex-1 min-w-0 max-w-[200px]">
|
||||
{file.name}
|
||||
</span>
|
||||
{file.size && (
|
||||
<span className="text-gray-400 flex-shrink-0">
|
||||
({convertBytesToText(file.size)})
|
||||
</span>
|
||||
)}
|
||||
</div>
|
||||
</DialogTrigger>
|
||||
<DialogContent>
|
||||
<DialogHeader>
|
||||
<DialogTitle>
|
||||
<div className="truncate min-w-0 break-words break-all text-wrap max-w-full whitespace-normal">
|
||||
{file.name}
|
||||
</div>
|
||||
</DialogTitle>
|
||||
</DialogHeader>
|
||||
<DialogDescription>
|
||||
<ScrollArea className="h-72 w-full rounded-md break-words break-all text-wrap">
|
||||
{file.content}
|
||||
</ScrollArea>
|
||||
</DialogDescription>
|
||||
</DialogContent>
|
||||
</Dialog>
|
||||
))}
|
||||
</div>
|
||||
)}
|
||||
<div
|
||||
ref={messageRef}
|
||||
className={styles.chatMessage}
|
||||
dangerouslySetInnerHTML={{ __html: markdownRendered }}
|
||||
/>
|
||||
{excalidrawData && <ExcalidrawComponent data={excalidrawData} />}
|
||||
</div>
|
||||
<div className={styles.teaserReferencesContainer}>
|
||||
<TeaserReferencesSection
|
||||
isMobileWidth={props.isMobileWidth}
|
||||
notesReferenceCardData={allReferences.notesReferenceCardData}
|
||||
onlineReferenceCardData={allReferences.onlineReferenceCardData}
|
||||
codeReferenceCardData={allReferences.codeReferenceCardData}
|
||||
/>
|
||||
</div>
|
||||
<div className={styles.chatFooter}>
|
||||
@@ -595,6 +773,18 @@ const ChatMessage = forwardRef<HTMLDivElement, ChatMessageProps>((props, ref) =>
|
||||
/>
|
||||
</button>
|
||||
))}
|
||||
{props.chatMessage.turnId && (
|
||||
<button
|
||||
title="Delete"
|
||||
className={`${styles.deleteButton}`}
|
||||
onClick={() => deleteMessage(props.chatMessage)}
|
||||
>
|
||||
<Trash
|
||||
alt="Delete Message"
|
||||
className="hsl(var(--muted-foreground)) hover:text-red-500"
|
||||
/>
|
||||
</button>
|
||||
)}
|
||||
<button
|
||||
title="Copy"
|
||||
className={`${styles.copyButton}`}
|
||||
|
||||
24
src/interface/web/app/components/excalidraw/excalidraw.tsx
Normal file
24
src/interface/web/app/components/excalidraw/excalidraw.tsx
Normal file
@@ -0,0 +1,24 @@
|
||||
"use client";
|
||||
|
||||
import dynamic from "next/dynamic";
|
||||
import { Suspense } from "react";
|
||||
import Loading from "../../components/loading/loading";
|
||||
|
||||
// Since client components get prerenderd on server as well hence importing
|
||||
// the excalidraw stuff dynamically with ssr false
|
||||
|
||||
const ExcalidrawWrapper = dynamic(() => import("./excalidrawWrapper").then((mod) => mod.default), {
|
||||
ssr: false,
|
||||
});
|
||||
|
||||
interface ExcalidrawComponentProps {
|
||||
data: any;
|
||||
}
|
||||
|
||||
export default function ExcalidrawComponent(props: ExcalidrawComponentProps) {
|
||||
return (
|
||||
<Suspense fallback={<Loading />}>
|
||||
<ExcalidrawWrapper data={props.data} />
|
||||
</Suspense>
|
||||
);
|
||||
}
|
||||
@@ -0,0 +1,149 @@
|
||||
"use client";
|
||||
|
||||
import { useState, useEffect } from "react";
|
||||
|
||||
import dynamic from "next/dynamic";
|
||||
|
||||
import { ExcalidrawProps } from "@excalidraw/excalidraw/types/types";
|
||||
import { ExcalidrawElement } from "@excalidraw/excalidraw/types/element/types";
|
||||
import { ExcalidrawElementSkeleton } from "@excalidraw/excalidraw/types/data/transform";
|
||||
|
||||
const Excalidraw = dynamic<ExcalidrawProps>(
|
||||
async () => (await import("@excalidraw/excalidraw")).Excalidraw,
|
||||
{
|
||||
ssr: false,
|
||||
},
|
||||
);
|
||||
|
||||
import { convertToExcalidrawElements } from "@excalidraw/excalidraw";
|
||||
|
||||
import { Button } from "@/components/ui/button";
|
||||
|
||||
import { ArrowsInSimple, ArrowsOutSimple } from "@phosphor-icons/react";
|
||||
|
||||
interface ExcalidrawWrapperProps {
|
||||
data: ExcalidrawElementSkeleton[];
|
||||
}
|
||||
|
||||
export default function ExcalidrawWrapper(props: ExcalidrawWrapperProps) {
|
||||
const [excalidrawElements, setExcalidrawElements] = useState<ExcalidrawElement[]>([]);
|
||||
const [expanded, setExpanded] = useState<boolean>(false);
|
||||
|
||||
const isValidExcalidrawElement = (element: ExcalidrawElementSkeleton): boolean => {
|
||||
return (
|
||||
element.x !== undefined &&
|
||||
element.y !== undefined &&
|
||||
element.id !== undefined &&
|
||||
element.type !== undefined
|
||||
);
|
||||
};
|
||||
|
||||
useEffect(() => {
|
||||
if (expanded) {
|
||||
onkeydown = (e) => {
|
||||
if (e.key === "Escape") {
|
||||
setExpanded(false);
|
||||
// Trigger a resize event to make Excalidraw adjust its size
|
||||
window.dispatchEvent(new Event("resize"));
|
||||
}
|
||||
};
|
||||
} else {
|
||||
onkeydown = null;
|
||||
}
|
||||
}, [expanded]);
|
||||
|
||||
useEffect(() => {
|
||||
// Do some basic validation
|
||||
const basicValidSkeletons: ExcalidrawElementSkeleton[] = [];
|
||||
|
||||
for (const element of props.data) {
|
||||
if (isValidExcalidrawElement(element as ExcalidrawElementSkeleton)) {
|
||||
basicValidSkeletons.push(element as ExcalidrawElementSkeleton);
|
||||
}
|
||||
}
|
||||
|
||||
const validSkeletons: ExcalidrawElementSkeleton[] = [];
|
||||
for (const element of basicValidSkeletons) {
|
||||
if (element.type === "frame") {
|
||||
continue;
|
||||
}
|
||||
if (element.type === "arrow") {
|
||||
const start = basicValidSkeletons.find((child) => child.id === element.start?.id);
|
||||
const end = basicValidSkeletons.find((child) => child.id === element.end?.id);
|
||||
if (start && end) {
|
||||
validSkeletons.push(element);
|
||||
}
|
||||
} else {
|
||||
validSkeletons.push(element);
|
||||
}
|
||||
}
|
||||
|
||||
for (const element of basicValidSkeletons) {
|
||||
if (element.type === "frame") {
|
||||
const children = element.children?.map((childId) => {
|
||||
return validSkeletons.find((child) => child.id === childId);
|
||||
});
|
||||
// Get the valid children, filter out any undefined values
|
||||
const validChildrenIds: readonly string[] = children
|
||||
?.map((child) => child?.id)
|
||||
.filter((id) => id !== undefined) as string[];
|
||||
|
||||
if (validChildrenIds === undefined || validChildrenIds.length === 0) {
|
||||
continue;
|
||||
}
|
||||
|
||||
validSkeletons.push({
|
||||
...element,
|
||||
children: validChildrenIds,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
const elements = convertToExcalidrawElements(validSkeletons);
|
||||
setExcalidrawElements(elements);
|
||||
}, []);
|
||||
|
||||
return (
|
||||
<div className="relative">
|
||||
<div
|
||||
className={`${expanded ? "fixed inset-0 bg-black bg-opacity-50 backdrop-blur-sm z-50 flex items-center justify-center" : ""}`}
|
||||
>
|
||||
<Button
|
||||
onClick={() => {
|
||||
setExpanded(!expanded);
|
||||
// Trigger a resize event to make Excalidraw adjust its size
|
||||
window.dispatchEvent(new Event("resize"));
|
||||
}}
|
||||
variant={"outline"}
|
||||
className={`${expanded ? "absolute top-2 left-2 z-[60]" : ""}`}
|
||||
>
|
||||
{expanded ? (
|
||||
<ArrowsInSimple className="h-4 w-4" />
|
||||
) : (
|
||||
<ArrowsOutSimple className="h-4 w-4" />
|
||||
)}
|
||||
</Button>
|
||||
<div
|
||||
className={`
|
||||
${expanded ? "w-[80vw] h-[80vh]" : "w-full h-[500px]"}
|
||||
bg-white overflow-hidden rounded-lg relative
|
||||
`}
|
||||
>
|
||||
<Excalidraw
|
||||
initialData={{
|
||||
elements: excalidrawElements,
|
||||
appState: { zenModeEnabled: true },
|
||||
scrollToContent: true,
|
||||
}}
|
||||
// TODO - Create a common function to detect if the theme is dark?
|
||||
theme={localStorage.getItem("theme") === "dark" ? "dark" : "light"}
|
||||
validateEmbeddable={true}
|
||||
renderTopRightUI={(isMobile, appState) => {
|
||||
return <></>;
|
||||
}}
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
@@ -81,111 +81,3 @@ export function OrgMode({ className }: { className?: string }) {
|
||||
</svg>
|
||||
);
|
||||
}
|
||||
|
||||
export function Markdown({ className }: { className?: string }) {
|
||||
const classes = className ?? "w-6 h-6 text-muted-foreground inline-flex mr-1";
|
||||
return (
|
||||
<svg
|
||||
className={`${classes}`}
|
||||
xmlns="http://www.w3.org/2000/svg"
|
||||
width="208"
|
||||
height="128"
|
||||
viewBox="0 0 208 128"
|
||||
>
|
||||
<rect
|
||||
width="198"
|
||||
height="118"
|
||||
x="5"
|
||||
y="5"
|
||||
ry="10"
|
||||
stroke="#000"
|
||||
strokeWidth="10"
|
||||
fill="none"
|
||||
/>
|
||||
<path d="M30 98V30h20l20 25 20-25h20v68H90V59L70 84 50 59v39zm125 0l-30-33h20V30h20v35h20z" />
|
||||
</svg>
|
||||
);
|
||||
}
|
||||
|
||||
export function Pdf({ className }: { className?: string }) {
|
||||
const classes = className ?? "w-6 h-6 text-muted-foreground inline-flex mr-1";
|
||||
return (
|
||||
<svg
|
||||
className={`${classes}`}
|
||||
xmlns="http://www.w3.org/2000/svg"
|
||||
enableBackground="new 0 0 334.371 380.563"
|
||||
version="1.1"
|
||||
viewBox="0 0 14 16"
|
||||
>
|
||||
<g transform="matrix(.04589 0 0 .04589 -.66877 -.73379)">
|
||||
<polygon
|
||||
points="51.791 356.65 51.791 23.99 204.5 23.99 282.65 102.07 282.65 356.65"
|
||||
fill="#fff"
|
||||
strokeWidth="212.65"
|
||||
/>
|
||||
<path
|
||||
d="m201.19 31.99 73.46 73.393v243.26h-214.86v-316.66h141.4m6.623-16h-164.02v348.66h246.85v-265.9z"
|
||||
strokeWidth="21.791"
|
||||
/>
|
||||
</g>
|
||||
<g transform="matrix(.04589 0 0 .04589 -.66877 -.73379)">
|
||||
<polygon
|
||||
points="282.65 356.65 51.791 356.65 51.791 23.99 204.5 23.99 206.31 25.8 206.31 100.33 280.9 100.33 282.65 102.07"
|
||||
fill="#fff"
|
||||
strokeWidth="212.65"
|
||||
/>
|
||||
<path
|
||||
d="m198.31 31.99v76.337h76.337v240.32h-214.86v-316.66h138.52m9.5-16h-164.02v348.66h246.85v-265.9l-6.43-6.424h-69.907v-69.842z"
|
||||
strokeWidth="21.791"
|
||||
/>
|
||||
</g>
|
||||
<g transform="matrix(.04589 0 0 .04589 -.66877 -.73379)" strokeWidth="21.791">
|
||||
<polygon points="258.31 87.75 219.64 87.75 219.64 48.667 258.31 86.38" />
|
||||
<path d="m227.64 67.646 12.41 12.104h-12.41v-12.104m-5.002-27.229h-10.998v55.333h54.666v-12.742z" />
|
||||
</g>
|
||||
<g
|
||||
transform="matrix(.04589 0 0 .04589 -.66877 -.73379)"
|
||||
fill="#ed1c24"
|
||||
strokeWidth="212.65"
|
||||
>
|
||||
<polygon points="311.89 284.49 22.544 284.49 22.544 167.68 37.291 152.94 37.291 171.49 297.15 171.49 297.15 152.94 311.89 167.68" />
|
||||
<path d="m303.65 168.63 1.747 1.747v107.62h-276.35v-107.62l1.747-1.747v9.362h272.85v-9.362m-12.999-31.385v27.747h-246.86v-27.747l-27.747 27.747v126h302.35v-126z" />
|
||||
</g>
|
||||
<rect x="1.7219" y="7.9544" width="10.684" height="4.0307" fill="none" />
|
||||
<g transform="matrix(.04589 0 0 .04589 1.7219 11.733)" fill="#fff" strokeWidth="21.791">
|
||||
<path d="m9.216 0v-83.2h30.464q6.784 0 12.928 1.408 6.144 1.28 10.752 4.608 4.608 3.2 7.296 8.576 2.816 5.248 2.816 13.056 0 7.68-2.816 13.184-2.688 5.504-7.296 9.088-4.608 3.456-10.624 5.248-6.016 1.664-12.544 1.664h-8.96v26.368zm22.016-43.776h7.936q6.528 0 9.6-3.072 3.2-3.072 3.2-8.704t-3.456-7.936-9.856-2.304h-7.424z" />
|
||||
<path d="m87.04 0v-83.2h24.576q9.472 0 17.28 2.304 7.936 2.304 13.568 7.296t8.704 12.8q3.2 7.808 3.2 18.816t-3.072 18.944-8.704 13.056q-5.504 5.12-13.184 7.552-7.552 2.432-16.512 2.432zm22.016-17.664h1.28q4.48 0 8.448-1.024 3.968-1.152 6.784-3.84 2.944-2.688 4.608-7.424t1.664-12.032-1.664-11.904-4.608-7.168q-2.816-2.56-6.784-3.456-3.968-1.024-8.448-1.024h-1.28z" />
|
||||
<path d="m169.22 0v-83.2h54.272v18.432h-32.256v15.872h27.648v18.432h-27.648v30.464z" />
|
||||
</g>
|
||||
</svg>
|
||||
);
|
||||
}
|
||||
|
||||
export function Word({ className }: { className?: string }) {
|
||||
const classes = className ?? "w-6 h-6 text-muted-foreground inline-flex mr-1";
|
||||
return (
|
||||
<svg
|
||||
className={`${classes}`}
|
||||
xmlns="http://www.w3.org/2000/svg"
|
||||
fill="#FFF"
|
||||
stroke-miterlimit="10"
|
||||
strokeWidth="2"
|
||||
viewBox="0 0 96 96"
|
||||
>
|
||||
<path
|
||||
stroke="#979593"
|
||||
d="M67.1716 7H27c-1.1046 0-2 .8954-2 2v78c0 1.1046.8954 2 2 2h58c1.1046 0 2-.8954 2-2V26.8284c0-.5304-.2107-1.0391-.5858-1.4142L68.5858 7.5858C68.2107 7.2107 67.702 7 67.1716 7z"
|
||||
/>
|
||||
<path fill="none" stroke="#979593" d="M67 7v18c0 1.1046.8954 2 2 2h18" />
|
||||
<path
|
||||
fill="#C8C6C4"
|
||||
d="M79 61H48v-2h31c.5523 0 1 .4477 1 1s-.4477 1-1 1zm0-6H48v-2h31c.5523 0 1 .4477 1 1s-.4477 1-1 1zm0-6H48v-2h31c.5523 0 1 .4477 1 1s-.4477 1-1 1zm0-6H48v-2h31c.5523 0 1 .4477 1 1s-.4477 1-1 1zm0 24H48v-2h31c.5523 0 1 .4477 1 1s-.4477 1-1 1z"
|
||||
/>
|
||||
<path
|
||||
fill="#185ABD"
|
||||
d="M12 74h32c2.2091 0 4-1.7909 4-4V38c0-2.2091-1.7909-4-4-4H12c-2.2091 0-4 1.7909-4 4v32c0 2.2091 1.7909 4 4 4z"
|
||||
/>
|
||||
<path d="M21.6245 60.6455c.0661.522.109.9769.1296 1.3657h.0762c.0306-.3685.0889-.8129.1751-1.3349.0862-.5211.1703-.961.2517-1.319L25.7911 44h4.5702l3.6562 15.1272c.183.7468.3353 1.6973.457 2.8532h.0608c.0508-.7979.1777-1.7184.3809-2.7615L37.8413 44H42l-5.1183 22h-4.86l-3.4885-14.5744c-.1016-.4197-.2158-.9663-.3428-1.6417-.127-.6745-.2057-1.1656-.236-1.4724h-.0608c-.0407.358-.1195.8896-.2364 1.595-.1169.7062-.211 1.2273-.2819 1.565L24.1 66h-4.9357L14 44h4.2349l3.1843 15.3882c.0709.3165.1392.7362.2053 1.2573z" />
|
||||
</svg>
|
||||
);
|
||||
}
|
||||
|
||||
@@ -26,6 +26,17 @@ import { Moon, Sun, UserCircle, Question, GearFine, ArrowRight } from "@phosphor
|
||||
import { KhojAgentLogo, KhojAutomationLogo, KhojSearchLogo } from "../logo/khojLogo";
|
||||
import { useIsMobileWidth } from "@/app/common/utils";
|
||||
|
||||
function SubscriptionBadge({ is_active }: { is_active: boolean }) {
|
||||
return (
|
||||
<div className="flex flex-row items-center">
|
||||
<div
|
||||
className={`w-3 h-3 rounded-full ${is_active ? "bg-yellow-500" : "bg-muted"} mr-1`}
|
||||
></div>
|
||||
<p className="text-xs">{is_active ? "Futurist" : "Free"}</p>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
export default function NavMenu() {
|
||||
const userData = useAuthenticatedData();
|
||||
const [darkMode, setDarkMode] = useState(false);
|
||||
@@ -85,8 +96,9 @@ export default function NavMenu() {
|
||||
</DropdownMenuTrigger>
|
||||
<DropdownMenuContent className="gap-2">
|
||||
<DropdownMenuItem className="w-full">
|
||||
<div className="flex flex-rows">
|
||||
<div className="flex flex-col">
|
||||
<p className="font-semibold">{userData?.email}</p>
|
||||
<SubscriptionBadge is_active={userData?.is_active ?? false} />
|
||||
</div>
|
||||
</DropdownMenuItem>
|
||||
<DropdownMenuSeparator />
|
||||
@@ -192,8 +204,9 @@ export default function NavMenu() {
|
||||
</MenubarTrigger>
|
||||
<MenubarContent align="end" className="rounded-xl gap-2">
|
||||
<MenubarItem className="w-full">
|
||||
<div className="flex flex-rows">
|
||||
<div className="flex flex-col">
|
||||
<p className="font-semibold">{userData?.email}</p>
|
||||
<SubscriptionBadge is_active={userData?.is_active ?? false} />
|
||||
</div>
|
||||
</MenubarItem>
|
||||
<MenubarSeparator className="dark:bg-white height-[2px] bg-black" />
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
|
||||
import { useEffect, useState } from "react";
|
||||
|
||||
import { ArrowRight } from "@phosphor-icons/react";
|
||||
import { ArrowCircleDown, ArrowRight } from "@phosphor-icons/react";
|
||||
|
||||
import markdownIt from "markdown-it";
|
||||
const md = new markdownIt({
|
||||
@@ -11,7 +11,13 @@ const md = new markdownIt({
|
||||
typographer: true,
|
||||
});
|
||||
|
||||
import { Context, WebPage, OnlineContext } from "../chatMessage/chatMessage";
|
||||
import {
|
||||
Context,
|
||||
WebPage,
|
||||
OnlineContext,
|
||||
CodeContext,
|
||||
CodeContextFile,
|
||||
} from "../chatMessage/chatMessage";
|
||||
import { Card } from "@/components/ui/card";
|
||||
|
||||
import {
|
||||
@@ -51,6 +57,7 @@ function NotesContextReferenceCard(props: NotesContextReferenceCardProps) {
|
||||
props.title || ".txt",
|
||||
"w-6 h-6 text-muted-foreground inline-flex mr-2",
|
||||
);
|
||||
const fileName = props.title.split("/").pop() || props.title;
|
||||
const snippet = extractSnippet(props);
|
||||
const [isHovering, setIsHovering] = useState(false);
|
||||
|
||||
@@ -61,30 +68,30 @@ function NotesContextReferenceCard(props: NotesContextReferenceCardProps) {
|
||||
<Card
|
||||
onMouseEnter={() => setIsHovering(true)}
|
||||
onMouseLeave={() => setIsHovering(false)}
|
||||
className={`${props.showFullContent ? "w-auto" : "w-[200px]"} overflow-hidden break-words text-balance rounded-lg p-2 bg-muted border-none`}
|
||||
className={`${props.showFullContent ? "w-auto" : "w-[200px]"} overflow-hidden break-words text-balance rounded-lg border-none p-2 bg-muted`}
|
||||
>
|
||||
<h3
|
||||
className={`${props.showFullContent ? "block" : "line-clamp-1"} text-muted-foreground}`}
|
||||
>
|
||||
{fileIcon}
|
||||
{props.title}
|
||||
{props.showFullContent ? props.title : fileName}
|
||||
</h3>
|
||||
<p
|
||||
className={`${props.showFullContent ? "block" : "overflow-hidden line-clamp-2"}`}
|
||||
className={`text-sm ${props.showFullContent ? "overflow-x-auto block" : "overflow-hidden line-clamp-2"}`}
|
||||
dangerouslySetInnerHTML={{ __html: snippet }}
|
||||
></p>
|
||||
</Card>
|
||||
</PopoverTrigger>
|
||||
<PopoverContent className="w-[400px] mx-2">
|
||||
<Card
|
||||
className={`w-auto overflow-hidden break-words text-balance rounded-lg p-2 border-none`}
|
||||
className={`w-auto overflow-hidden break-words text-balance rounded-lg border-none p-2`}
|
||||
>
|
||||
<h3 className={`line-clamp-2 text-muted-foreground}`}>
|
||||
{fileIcon}
|
||||
{props.title}
|
||||
</h3>
|
||||
<p
|
||||
className={`overflow-hidden line-clamp-3`}
|
||||
className={`border-t mt-1 pt-1 text-sm overflow-hidden line-clamp-5`}
|
||||
dangerouslySetInnerHTML={{ __html: snippet }}
|
||||
></p>
|
||||
</Card>
|
||||
@@ -94,9 +101,160 @@ function NotesContextReferenceCard(props: NotesContextReferenceCardProps) {
|
||||
);
|
||||
}
|
||||
|
||||
export interface ReferencePanelData {
|
||||
notesReferenceCardData: NotesContextReferenceData[];
|
||||
onlineReferenceCardData: OnlineReferenceData[];
|
||||
interface CodeContextReferenceCardProps {
|
||||
code: string;
|
||||
output: string;
|
||||
output_files: CodeContextFile[];
|
||||
error: string;
|
||||
showFullContent: boolean;
|
||||
}
|
||||
|
||||
function CodeContextReferenceCard(props: CodeContextReferenceCardProps) {
|
||||
const fileIcon = getIconFromFilename(".py", "!w-4 h-4 text-muted-foreground flex-shrink-0");
|
||||
const sanitizedCodeSnippet = DOMPurify.sanitize(props.code);
|
||||
const [isHovering, setIsHovering] = useState(false);
|
||||
const [isDownloadHover, setIsDownloadHover] = useState(false);
|
||||
|
||||
const handleDownload = (file: CodeContextFile) => {
|
||||
// Determine MIME type
|
||||
let mimeType = "text/plain";
|
||||
let byteString = file.b64_data;
|
||||
if (file.filename.match(/\.(png|jpg|jpeg|webp)$/)) {
|
||||
mimeType = `image/${file.filename.split(".").pop()}`;
|
||||
byteString = atob(file.b64_data);
|
||||
} else if (file.filename.endsWith(".json")) {
|
||||
mimeType = "application/json";
|
||||
} else if (file.filename.endsWith(".csv")) {
|
||||
mimeType = "text/csv";
|
||||
}
|
||||
|
||||
const arrayBuffer = new ArrayBuffer(byteString.length);
|
||||
const bytes = new Uint8Array(arrayBuffer);
|
||||
|
||||
for (let i = 0; i < byteString.length; i++) {
|
||||
bytes[i] = byteString.charCodeAt(i);
|
||||
}
|
||||
|
||||
const blob = new Blob([arrayBuffer], { type: mimeType });
|
||||
const url = URL.createObjectURL(blob);
|
||||
const a = document.createElement("a");
|
||||
a.href = url;
|
||||
a.download = file.filename;
|
||||
document.body.appendChild(a);
|
||||
a.click();
|
||||
document.body.removeChild(a);
|
||||
URL.revokeObjectURL(url);
|
||||
};
|
||||
|
||||
const renderOutputFiles = (files: CodeContextFile[], hoverCard: boolean) => {
|
||||
if (files?.length == 0) return null;
|
||||
return (
|
||||
<div
|
||||
className={`${hoverCard || props.showFullContent ? "border-t mt-1 pt-1" : undefined}`}
|
||||
>
|
||||
{files.slice(0, props.showFullContent ? undefined : 1).map((file, index) => {
|
||||
return (
|
||||
<div key={`${file.filename}-${index}`}>
|
||||
<h4 className="text-sm text-muted-foreground flex items-center">
|
||||
<span
|
||||
className={`overflow-hidden mr-2 font-bold ${props.showFullContent ? undefined : "line-clamp-1"}`}
|
||||
>
|
||||
{file.filename}
|
||||
</span>
|
||||
<button
|
||||
className={`${hoverCard ? "hidden" : undefined}`}
|
||||
onClick={(e) => {
|
||||
e.preventDefault();
|
||||
handleDownload(file);
|
||||
}}
|
||||
onMouseEnter={() => setIsDownloadHover(true)}
|
||||
onMouseLeave={() => setIsDownloadHover(false)}
|
||||
title={`Download file: ${file.filename}`}
|
||||
>
|
||||
<ArrowCircleDown
|
||||
className={`w-4 h-4`}
|
||||
weight={isDownloadHover ? "fill" : "regular"}
|
||||
/>
|
||||
</button>
|
||||
</h4>
|
||||
{file.filename.match(/\.(txt|org|md|csv|json)$/) ? (
|
||||
<pre
|
||||
className={`${props.showFullContent ? "block" : "line-clamp-2"} text-sm mt-1 p-1 bg-background rounded overflow-x-auto`}
|
||||
>
|
||||
{file.b64_data}
|
||||
</pre>
|
||||
) : file.filename.match(/\.(png|jpg|jpeg|webp)$/) ? (
|
||||
<img
|
||||
src={`data:image/${file.filename.split(".").pop()};base64,${file.b64_data}`}
|
||||
alt={file.filename}
|
||||
className="mt-1 max-h-32 rounded"
|
||||
/>
|
||||
) : null}
|
||||
</div>
|
||||
);
|
||||
})}
|
||||
</div>
|
||||
);
|
||||
};
|
||||
|
||||
return (
|
||||
<>
|
||||
<Popover open={isHovering && !props.showFullContent} onOpenChange={setIsHovering}>
|
||||
<PopoverTrigger asChild>
|
||||
<Card
|
||||
onMouseEnter={() => setIsHovering(true)}
|
||||
onMouseLeave={() => setIsHovering(false)}
|
||||
className={`${props.showFullContent ? "w-auto" : "w-[200px]"} overflow-hidden break-words text-balance rounded-lg border-none p-2 bg-muted`}
|
||||
>
|
||||
<div className="flex flex-col px-1">
|
||||
<div className="flex items-center gap-2">
|
||||
{fileIcon}
|
||||
<h3
|
||||
className={`overflow-hidden ${props.showFullContent ? "block" : "line-clamp-1"} text-muted-foreground flex-grow`}
|
||||
>
|
||||
code {props.output_files?.length > 0 ? "artifacts" : ""}
|
||||
</h3>
|
||||
</div>
|
||||
<pre
|
||||
className={`text-xs pb-2 ${props.showFullContent ? "block overflow-x-auto" : props.output_files?.length > 0 ? "hidden" : "overflow-hidden line-clamp-3"}`}
|
||||
>
|
||||
{sanitizedCodeSnippet}
|
||||
</pre>
|
||||
{props.output_files?.length > 0 &&
|
||||
renderOutputFiles(props.output_files, false)}
|
||||
</div>
|
||||
</Card>
|
||||
</PopoverTrigger>
|
||||
<PopoverContent className="w-[400px] mx-2">
|
||||
<Card
|
||||
className={`w-auto overflow-hidden break-words text-balance rounded-lg border-none p-2`}
|
||||
>
|
||||
<div className="flex items-center gap-2">
|
||||
{fileIcon}
|
||||
<h3
|
||||
className={`overflow-hidden ${props.showFullContent ? "block" : "line-clamp-1"} text-muted-foreground flex-grow`}
|
||||
>
|
||||
code {props.output_files?.length > 0 ? "artifact" : ""}
|
||||
</h3>
|
||||
</div>
|
||||
{(props.output_files?.length > 0 &&
|
||||
renderOutputFiles(props.output_files?.slice(0, 1), true)) || (
|
||||
<pre className="text-xs border-t mt-1 pt-1 verflow-hidden line-clamp-10">
|
||||
{sanitizedCodeSnippet}
|
||||
</pre>
|
||||
)}
|
||||
</Card>
|
||||
</PopoverContent>
|
||||
</Popover>
|
||||
</>
|
||||
);
|
||||
}
|
||||
|
||||
export interface CodeReferenceData {
|
||||
code: string;
|
||||
output: string;
|
||||
output_files: CodeContextFile[];
|
||||
error: string;
|
||||
}
|
||||
|
||||
interface OnlineReferenceData {
|
||||
@@ -141,21 +299,17 @@ function GenericOnlineReferenceCard(props: OnlineReferenceCardProps) {
|
||||
<Card
|
||||
onMouseEnter={handleMouseEnter}
|
||||
onMouseLeave={handleMouseLeave}
|
||||
className={`${props.showFullContent ? "w-auto" : "w-[200px]"} overflow-hidden break-words rounded-lg text-balance p-2 bg-muted border-none`}
|
||||
className={`${props.showFullContent ? "w-auto" : "w-[200px]"} overflow-hidden break-words text-balance rounded-lg border-none p-2 bg-muted`}
|
||||
>
|
||||
<div className="flex flex-col">
|
||||
<a
|
||||
href={props.link}
|
||||
target="_blank"
|
||||
rel="noreferrer"
|
||||
className="!no-underline p-2"
|
||||
className="!no-underline px-1"
|
||||
>
|
||||
<div className="flex items-center gap-2">
|
||||
<img
|
||||
src={favicon}
|
||||
alt=""
|
||||
className="!w-4 h-4 mr-2 flex-shrink-0"
|
||||
/>
|
||||
<img src={favicon} alt="" className="!w-4 h-4 flex-shrink-0" />
|
||||
<h3
|
||||
className={`overflow-hidden ${props.showFullContent ? "block" : "line-clamp-1"} text-muted-foreground flex-grow`}
|
||||
>
|
||||
@@ -168,7 +322,7 @@ function GenericOnlineReferenceCard(props: OnlineReferenceCardProps) {
|
||||
{props.title}
|
||||
</h3>
|
||||
<p
|
||||
className={`overflow-hidden ${props.showFullContent ? "block" : "line-clamp-2"}`}
|
||||
className={`overflow-hidden text-sm ${props.showFullContent ? "block" : "line-clamp-2"}`}
|
||||
>
|
||||
{props.description}
|
||||
</p>
|
||||
@@ -185,23 +339,23 @@ function GenericOnlineReferenceCard(props: OnlineReferenceCardProps) {
|
||||
href={props.link}
|
||||
target="_blank"
|
||||
rel="noreferrer"
|
||||
className="!no-underline p-2"
|
||||
className="!no-underline px-1"
|
||||
>
|
||||
<div className="flex items-center">
|
||||
<img src={favicon} alt="" className="!w-4 h-4 mr-2" />
|
||||
<div className="flex items-center gap-2">
|
||||
<img src={favicon} alt="" className="!w-4 h-4 flex-shrink-0" />
|
||||
<h3
|
||||
className={`overflow-hidden ${props.showFullContent ? "block" : "line-clamp-2"} text-muted-foreground`}
|
||||
className={`overflow-hidden ${props.showFullContent ? "block" : "line-clamp-2"} text-muted-foreground flex-grow`}
|
||||
>
|
||||
{domain}
|
||||
</h3>
|
||||
</div>
|
||||
<h3
|
||||
className={`overflow-hidden ${props.showFullContent ? "block" : "line-clamp-2"} font-bold`}
|
||||
className={`border-t mt-1 pt-1 overflow-hidden ${props.showFullContent ? "block" : "line-clamp-2"} font-bold`}
|
||||
>
|
||||
{props.title}
|
||||
</h3>
|
||||
<p
|
||||
className={`overflow-hidden ${props.showFullContent ? "block" : "line-clamp-3"}`}
|
||||
className={`overflow-hidden text-sm ${props.showFullContent ? "block" : "line-clamp-5"}`}
|
||||
>
|
||||
{props.description}
|
||||
</p>
|
||||
@@ -214,9 +368,28 @@ function GenericOnlineReferenceCard(props: OnlineReferenceCardProps) {
|
||||
);
|
||||
}
|
||||
|
||||
export function constructAllReferences(contextData: Context[], onlineData: OnlineContext) {
|
||||
export function constructAllReferences(
|
||||
contextData: Context[],
|
||||
onlineData: OnlineContext,
|
||||
codeContext: CodeContext,
|
||||
) {
|
||||
const onlineReferences: OnlineReferenceData[] = [];
|
||||
const contextReferences: NotesContextReferenceData[] = [];
|
||||
const codeReferences: CodeReferenceData[] = [];
|
||||
|
||||
if (codeContext) {
|
||||
for (const [key, value] of Object.entries(codeContext)) {
|
||||
if (!value.results) {
|
||||
continue;
|
||||
}
|
||||
codeReferences.push({
|
||||
code: value.code,
|
||||
output: value.results.std_out,
|
||||
output_files: value.results.output_files,
|
||||
error: value.results.std_err,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
if (onlineData) {
|
||||
let localOnlineReferences = [];
|
||||
@@ -298,12 +471,14 @@ export function constructAllReferences(contextData: Context[], onlineData: Onlin
|
||||
return {
|
||||
notesReferenceCardData: contextReferences,
|
||||
onlineReferenceCardData: onlineReferences,
|
||||
codeReferenceCardData: codeReferences,
|
||||
};
|
||||
}
|
||||
|
||||
export interface TeaserReferenceSectionProps {
|
||||
notesReferenceCardData: NotesContextReferenceData[];
|
||||
onlineReferenceCardData: OnlineReferenceData[];
|
||||
codeReferenceCardData: CodeReferenceData[];
|
||||
isMobileWidth: boolean;
|
||||
}
|
||||
|
||||
@@ -314,17 +489,28 @@ export function TeaserReferencesSection(props: TeaserReferenceSectionProps) {
|
||||
setNumTeaserSlots(props.isMobileWidth ? 1 : 3);
|
||||
}, [props.isMobileWidth]);
|
||||
|
||||
const notesDataToShow = props.notesReferenceCardData.slice(0, numTeaserSlots);
|
||||
const codeDataToShow = props.codeReferenceCardData.slice(0, numTeaserSlots);
|
||||
const notesDataToShow = props.notesReferenceCardData.slice(
|
||||
0,
|
||||
numTeaserSlots - codeDataToShow.length,
|
||||
);
|
||||
const onlineDataToShow =
|
||||
notesDataToShow.length < numTeaserSlots
|
||||
? props.onlineReferenceCardData.slice(0, numTeaserSlots - notesDataToShow.length)
|
||||
notesDataToShow.length + codeDataToShow.length < numTeaserSlots
|
||||
? props.onlineReferenceCardData.slice(
|
||||
0,
|
||||
numTeaserSlots - codeDataToShow.length - notesDataToShow.length,
|
||||
)
|
||||
: [];
|
||||
|
||||
const shouldShowShowMoreButton =
|
||||
props.notesReferenceCardData.length > 0 || props.onlineReferenceCardData.length > 0;
|
||||
props.notesReferenceCardData.length > 0 ||
|
||||
props.codeReferenceCardData.length > 0 ||
|
||||
props.onlineReferenceCardData.length > 0;
|
||||
|
||||
const numReferences =
|
||||
props.notesReferenceCardData.length + props.onlineReferenceCardData.length;
|
||||
props.notesReferenceCardData.length +
|
||||
props.codeReferenceCardData.length +
|
||||
props.onlineReferenceCardData.length;
|
||||
|
||||
if (numReferences === 0) {
|
||||
return null;
|
||||
@@ -337,6 +523,15 @@ export function TeaserReferencesSection(props: TeaserReferenceSectionProps) {
|
||||
<p className="text-gray-400 m-2">{numReferences} sources</p>
|
||||
</h3>
|
||||
<div className={`flex flex-wrap gap-2 w-auto mt-2`}>
|
||||
{codeDataToShow.map((code, index) => {
|
||||
return (
|
||||
<CodeContextReferenceCard
|
||||
showFullContent={false}
|
||||
{...code}
|
||||
key={`code-${index}`}
|
||||
/>
|
||||
);
|
||||
})}
|
||||
{notesDataToShow.map((note, index) => {
|
||||
return (
|
||||
<NotesContextReferenceCard
|
||||
@@ -359,6 +554,7 @@ export function TeaserReferencesSection(props: TeaserReferenceSectionProps) {
|
||||
<ReferencePanel
|
||||
notesReferenceCardData={props.notesReferenceCardData}
|
||||
onlineReferenceCardData={props.onlineReferenceCardData}
|
||||
codeReferenceCardData={props.codeReferenceCardData}
|
||||
/>
|
||||
)}
|
||||
</div>
|
||||
@@ -369,6 +565,7 @@ export function TeaserReferencesSection(props: TeaserReferenceSectionProps) {
|
||||
interface ReferencePanelDataProps {
|
||||
notesReferenceCardData: NotesContextReferenceData[];
|
||||
onlineReferenceCardData: OnlineReferenceData[];
|
||||
codeReferenceCardData: CodeReferenceData[];
|
||||
}
|
||||
|
||||
export default function ReferencePanel(props: ReferencePanelDataProps) {
|
||||
@@ -388,6 +585,15 @@ export default function ReferencePanel(props: ReferencePanelDataProps) {
|
||||
<SheetDescription>View all references for this response</SheetDescription>
|
||||
</SheetHeader>
|
||||
<div className="flex flex-wrap gap-2 w-auto mt-2">
|
||||
{props.codeReferenceCardData.map((code, index) => {
|
||||
return (
|
||||
<CodeContextReferenceCard
|
||||
showFullContent={true}
|
||||
{...code}
|
||||
key={`code-${index}`}
|
||||
/>
|
||||
);
|
||||
})}
|
||||
{props.notesReferenceCardData.map((note, index) => {
|
||||
return (
|
||||
<NotesContextReferenceCard
|
||||
|
||||
@@ -14,25 +14,21 @@ interface SuggestionCardProps {
|
||||
|
||||
export default function SuggestionCard(data: SuggestionCardProps) {
|
||||
const bgColors = converColorToBgGradient(data.color);
|
||||
const cardClassName = `${styles.card} ${bgColors} md:w-full md:h-fit sm:w-full h-fit md:w-[200px] md:h-[200px] cursor-pointer`;
|
||||
const titleClassName = `${styles.title} pt-2 dark:text-white dark:font-bold`;
|
||||
const cardClassName = `${styles.card} ${bgColors} md:w-full md:h-fit sm:w-full h-fit md:w-[200px] md:h-[180px] cursor-pointer md:p-2`;
|
||||
const descriptionClassName = `${styles.text} dark:text-white`;
|
||||
|
||||
const cardContent = (
|
||||
<Card className={cardClassName}>
|
||||
<CardHeader className="m-0 p-2 pb-1 relative">
|
||||
<div className="flex flex-row md:flex-col">
|
||||
<div className="flex w-full">
|
||||
<CardContent className="m-0 p-2 w-full">
|
||||
{convertSuggestionTitleToIconClass(data.title, data.color.toLowerCase())}
|
||||
<CardTitle className={titleClassName}>{data.title}</CardTitle>
|
||||
</div>
|
||||
</CardHeader>
|
||||
<CardContent className="m-0 p-2 pr-4 pt-1">
|
||||
<CardDescription
|
||||
className={`${descriptionClassName} sm:line-clamp-2 md:line-clamp-4`}
|
||||
>
|
||||
{data.body}
|
||||
</CardDescription>
|
||||
</CardContent>
|
||||
<CardDescription
|
||||
className={`${descriptionClassName} sm:line-clamp-2 md:line-clamp-4 pt-1 break-words whitespace-pre-wrap max-w-full`}
|
||||
>
|
||||
{data.body}
|
||||
</CardDescription>
|
||||
</CardContent>
|
||||
</div>
|
||||
</Card>
|
||||
);
|
||||
|
||||
|
||||
@@ -1,12 +1,9 @@
|
||||
|
||||
.card {
|
||||
padding: 0.5rem;
|
||||
margin: 0.05rem;
|
||||
border-radius: 0.5rem;
|
||||
}
|
||||
|
||||
.title {
|
||||
font-size: 1.0rem;
|
||||
font-size: 1rem;
|
||||
}
|
||||
|
||||
.text {
|
||||
|
||||
@@ -47,24 +47,24 @@ const DEFAULT_COLOR = "orange";
|
||||
|
||||
export function convertSuggestionTitleToIconClass(title: string, color: string) {
|
||||
if (title === SuggestionType.Automation)
|
||||
return getIconFromIconName("Robot", color, "w-8", "h-8");
|
||||
if (title === SuggestionType.Paint) return getIconFromIconName("Palette", color, "w-8", "h-8");
|
||||
return getIconFromIconName("Robot", color, "w-6", "h-6");
|
||||
if (title === SuggestionType.Paint) return getIconFromIconName("Palette", color, "w-6", "h-6");
|
||||
if (title === SuggestionType.PopCulture)
|
||||
return getIconFromIconName("Confetti", color, "w-8", "h-8");
|
||||
if (title === SuggestionType.Travel) return getIconFromIconName("Jeep", color, "w-8", "h-8");
|
||||
if (title === SuggestionType.Learning) return getIconFromIconName("Book", color, "w-8", "h-8");
|
||||
return getIconFromIconName("Confetti", color, "w-6", "h-6");
|
||||
if (title === SuggestionType.Travel) return getIconFromIconName("Jeep", color, "w-6", "h-6");
|
||||
if (title === SuggestionType.Learning) return getIconFromIconName("Book", color, "w-6", "h-6");
|
||||
if (title === SuggestionType.Health)
|
||||
return getIconFromIconName("Asclepius", color, "w-8", "h-8");
|
||||
if (title === SuggestionType.Fun) return getIconFromIconName("Island", color, "w-8", "h-8");
|
||||
if (title === SuggestionType.Home) return getIconFromIconName("House", color, "w-8", "h-8");
|
||||
return getIconFromIconName("Asclepius", color, "w-6", "h-6");
|
||||
if (title === SuggestionType.Fun) return getIconFromIconName("Island", color, "w-6", "h-6");
|
||||
if (title === SuggestionType.Home) return getIconFromIconName("House", color, "w-6", "h-6");
|
||||
if (title === SuggestionType.Language)
|
||||
return getIconFromIconName("Translate", color, "w-8", "h-8");
|
||||
if (title === SuggestionType.Code) return getIconFromIconName("Code", color, "w-8", "h-8");
|
||||
if (title === SuggestionType.Food) return getIconFromIconName("BowlFood", color, "w-8", "h-8");
|
||||
return getIconFromIconName("Translate", color, "w-6", "h-6");
|
||||
if (title === SuggestionType.Code) return getIconFromIconName("Code", color, "w-6", "h-6");
|
||||
if (title === SuggestionType.Food) return getIconFromIconName("BowlFood", color, "w-6", "h-6");
|
||||
if (title === SuggestionType.Interviewing)
|
||||
return getIconFromIconName("Lectern", color, "w-8", "h-8");
|
||||
if (title === SuggestionType.Finance) return getIconFromIconName("Wallet", color, "w-8", "h-8");
|
||||
else return getIconFromIconName("Lightbulb", color, "w-8", "h-8");
|
||||
return getIconFromIconName("Lectern", color, "w-6", "h-6");
|
||||
if (title === SuggestionType.Finance) return getIconFromIconName("Wallet", color, "w-6", "h-6");
|
||||
else return getIconFromIconName("Lightbulb", color, "w-6", "h-6");
|
||||
}
|
||||
|
||||
export const suggestionsData: Suggestion[] = [
|
||||
|
||||
@@ -1,172 +0,0 @@
|
||||
input.factVerification {
|
||||
width: 100%;
|
||||
display: block;
|
||||
padding: 12px 20px;
|
||||
margin: 8px 0;
|
||||
border: none;
|
||||
box-sizing: border-box;
|
||||
border-radius: 4px;
|
||||
text-align: left;
|
||||
margin: auto;
|
||||
margin-top: 8px;
|
||||
margin-bottom: 8px;
|
||||
font-size: large;
|
||||
}
|
||||
|
||||
div.factCheckerContainer {
|
||||
width: 75vw;
|
||||
margin: auto;
|
||||
}
|
||||
|
||||
input.factVerification:focus {
|
||||
outline: none;
|
||||
box-shadow: 0 4px 8px 0 rgba(0, 0, 0, 0.2);
|
||||
}
|
||||
|
||||
div.responseText {
|
||||
margin: 0;
|
||||
padding: 0;
|
||||
border-radius: 8px;
|
||||
}
|
||||
|
||||
div.response {
|
||||
margin-bottom: 12px;
|
||||
}
|
||||
|
||||
a.titleLink {
|
||||
color: #333;
|
||||
font-weight: bold;
|
||||
}
|
||||
|
||||
a.subLinks {
|
||||
color: #333;
|
||||
text-decoration: none;
|
||||
font-weight: small;
|
||||
border-radius: 4px;
|
||||
font-size: small;
|
||||
}
|
||||
|
||||
div.subLinks {
|
||||
display: flex;
|
||||
flex-direction: row;
|
||||
gap: 8px;
|
||||
flex-wrap: wrap;
|
||||
}
|
||||
|
||||
div.reference {
|
||||
padding: 12px;
|
||||
margin: 8px;
|
||||
border-radius: 8px;
|
||||
}
|
||||
|
||||
footer.footer {
|
||||
width: 100%;
|
||||
background: transparent;
|
||||
text-align: left;
|
||||
}
|
||||
|
||||
div.reportActions {
|
||||
display: flex;
|
||||
flex-direction: row;
|
||||
gap: 8px;
|
||||
justify-content: space-between;
|
||||
margin-top: 8px;
|
||||
}
|
||||
|
||||
button.factCheckButton {
|
||||
border: none;
|
||||
cursor: pointer;
|
||||
width: 100%;
|
||||
border-radius: 4px;
|
||||
margin: 8px;
|
||||
padding-left: 1rem;
|
||||
padding-right: 1rem;
|
||||
line-height: 1.25rem;
|
||||
}
|
||||
|
||||
button.factCheckButton:hover {
|
||||
box-shadow: 0 4px 8px 0 rgba(0, 0, 0, 0.2);
|
||||
}
|
||||
|
||||
div.spinner {
|
||||
margin: 20px;
|
||||
width: 40px;
|
||||
height: 40px;
|
||||
position: relative;
|
||||
text-align: center;
|
||||
|
||||
-webkit-animation: sk-rotate 2.0s infinite linear;
|
||||
animation: sk-rotate 2.0s infinite linear;
|
||||
}
|
||||
|
||||
div.inputFields {
|
||||
width: 100%;
|
||||
display: grid;
|
||||
grid-template-columns: 1fr auto;
|
||||
grid-gap: 8px;
|
||||
}
|
||||
|
||||
|
||||
/* Loading Animation */
|
||||
div.dot1,
|
||||
div.dot2 {
|
||||
width: 60%;
|
||||
height: 60%;
|
||||
display: inline-block;
|
||||
position: absolute;
|
||||
top: 0;
|
||||
border-radius: 100%;
|
||||
|
||||
-webkit-animation: sk-bounce 2.0s infinite ease-in-out;
|
||||
animation: sk-bounce 2.0s infinite ease-in-out;
|
||||
}
|
||||
|
||||
div.dot2 {
|
||||
top: auto;
|
||||
bottom: 0;
|
||||
-webkit-animation-delay: -1.0s;
|
||||
animation-delay: -1.0s;
|
||||
}
|
||||
|
||||
@media screen and (max-width: 768px) {
|
||||
div.factCheckerContainer {
|
||||
width: 95vw;
|
||||
}
|
||||
}
|
||||
|
||||
@-webkit-keyframes sk-rotate {
|
||||
100% {
|
||||
-webkit-transform: rotate(360deg)
|
||||
}
|
||||
}
|
||||
|
||||
@keyframes sk-rotate {
|
||||
100% {
|
||||
transform: rotate(360deg);
|
||||
-webkit-transform: rotate(360deg)
|
||||
}
|
||||
}
|
||||
|
||||
@-webkit-keyframes sk-bounce {
|
||||
0%,
|
||||
100% {
|
||||
-webkit-transform: scale(0.0)
|
||||
}
|
||||
|
||||
50% {
|
||||
-webkit-transform: scale(1.0)
|
||||
}
|
||||
}
|
||||
|
||||
@keyframes sk-bounce {
|
||||
0%,
|
||||
100% {
|
||||
transform: scale(0.0);
|
||||
-webkit-transform: scale(0.0);
|
||||
}
|
||||
|
||||
50% {
|
||||
transform: scale(1.0);
|
||||
-webkit-transform: scale(1.0);
|
||||
}
|
||||
}
|
||||
@@ -1,33 +0,0 @@
|
||||
import type { Metadata } from "next";
|
||||
|
||||
export const metadata: Metadata = {
|
||||
title: "Khoj AI - Fact Checker",
|
||||
description:
|
||||
"Use the Fact Checker with Khoj AI for verifying statements. It can research the internet for you, either refuting or confirming the statement using fresh data.",
|
||||
icons: {
|
||||
icon: "/static/assets/icons/khoj_lantern.ico",
|
||||
apple: "/static/assets/icons/khoj_lantern_256x256.png",
|
||||
},
|
||||
openGraph: {
|
||||
siteName: "Khoj AI",
|
||||
title: "Khoj AI - Fact Checker",
|
||||
description: "Your Second Brain.",
|
||||
url: "https://app.khoj.dev/factchecker",
|
||||
type: "website",
|
||||
images: [
|
||||
{
|
||||
url: "https://assets.khoj.dev/khoj_lantern_256x256.png",
|
||||
width: 256,
|
||||
height: 256,
|
||||
},
|
||||
],
|
||||
},
|
||||
};
|
||||
|
||||
export default function RootLayout({
|
||||
children,
|
||||
}: Readonly<{
|
||||
children: React.ReactNode;
|
||||
}>) {
|
||||
return <div>{children}</div>;
|
||||
}
|
||||
@@ -1,664 +0,0 @@
|
||||
"use client";
|
||||
|
||||
import styles from "./factChecker.module.css";
|
||||
import { useAuthenticatedData } from "@/app/common/auth";
|
||||
import { useState, useEffect } from "react";
|
||||
|
||||
import ChatMessage, {
|
||||
Context,
|
||||
OnlineContext,
|
||||
OnlineContextData,
|
||||
WebPage,
|
||||
} from "../components/chatMessage/chatMessage";
|
||||
import { ModelPicker, Model } from "../components/modelPicker/modelPicker";
|
||||
import ShareLink from "../components/shareLink/shareLink";
|
||||
|
||||
import { Input } from "@/components/ui/input";
|
||||
import { Button } from "@/components/ui/button";
|
||||
|
||||
import { Card, CardContent, CardFooter, CardHeader, CardTitle } from "@/components/ui/card";
|
||||
import Link from "next/link";
|
||||
import SidePanel from "../components/sidePanel/chatHistorySidePanel";
|
||||
import { useIsMobileWidth } from "../common/utils";
|
||||
|
||||
const chatURL = "/api/chat";
|
||||
const verificationPrecursor =
|
||||
"Limit your search to reputable sources. Search the internet for relevant supporting or refuting information. Do not reference my notes. Refuse to answer any queries that are not falsifiable by informing me that you will not answer the question. You're not permitted to ask follow-up questions, so do the best with what you have. Respond with **TRUE** or **FALSE** or **INCONCLUSIVE**, then provide your justification. Fact Check:";
|
||||
|
||||
const LoadingSpinner = () => (
|
||||
<div className={styles.loading}>
|
||||
<div className={styles.loadingVerification}>
|
||||
Researching...
|
||||
<div className={styles.spinner}>
|
||||
<div className={`${styles.dot1} bg-blue-300`}></div>
|
||||
<div className={`${styles.dot2} bg-blue-300`}></div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
|
||||
interface SupplementReferences {
|
||||
additionalLink: string;
|
||||
response: string;
|
||||
linkTitle: string;
|
||||
}
|
||||
|
||||
interface ResponseWithReferences {
|
||||
context?: Context[];
|
||||
online?: OnlineContext;
|
||||
response?: string;
|
||||
}
|
||||
|
||||
function handleCompiledReferences(chunk: string, currentResponse: string) {
|
||||
const rawReference = chunk.split("### compiled references:")[1];
|
||||
const rawResponse = chunk.split("### compiled references:")[0];
|
||||
let references: ResponseWithReferences = {};
|
||||
|
||||
// Set the initial response
|
||||
references.response = currentResponse + rawResponse;
|
||||
|
||||
const rawReferenceAsJson = JSON.parse(rawReference);
|
||||
if (rawReferenceAsJson instanceof Array) {
|
||||
references.context = rawReferenceAsJson;
|
||||
} else if (typeof rawReferenceAsJson === "object" && rawReferenceAsJson !== null) {
|
||||
references.online = rawReferenceAsJson;
|
||||
}
|
||||
|
||||
return references;
|
||||
}
|
||||
|
||||
async function verifyStatement(
|
||||
message: string,
|
||||
conversationId: string,
|
||||
setIsLoading: (loading: boolean) => void,
|
||||
setInitialResponse: (response: string) => void,
|
||||
setInitialReferences: (references: ResponseWithReferences) => void,
|
||||
) {
|
||||
setIsLoading(true);
|
||||
// Construct the verification payload
|
||||
let verificationMessage = `${verificationPrecursor} ${message}`;
|
||||
const apiURL = `${chatURL}?client=web`;
|
||||
const requestBody = {
|
||||
q: verificationMessage,
|
||||
conversation_id: conversationId,
|
||||
stream: true,
|
||||
};
|
||||
|
||||
try {
|
||||
// Send a message to the chat server to verify the fact
|
||||
const response = await fetch(apiURL, {
|
||||
method: "POST",
|
||||
headers: {
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
body: JSON.stringify(requestBody),
|
||||
});
|
||||
if (!response.body) throw new Error("No response body found");
|
||||
|
||||
const reader = response.body?.getReader();
|
||||
let decoder = new TextDecoder();
|
||||
let result = "";
|
||||
|
||||
while (true) {
|
||||
const { done, value } = await reader.read();
|
||||
if (done) break;
|
||||
|
||||
let chunk = decoder.decode(value, { stream: true });
|
||||
|
||||
if (chunk.includes("### compiled references:")) {
|
||||
const references = handleCompiledReferences(chunk, result);
|
||||
if (references.response) {
|
||||
result = references.response;
|
||||
setInitialResponse(references.response);
|
||||
setInitialReferences(references);
|
||||
}
|
||||
} else {
|
||||
result += chunk;
|
||||
setInitialResponse(result);
|
||||
}
|
||||
}
|
||||
} catch (error) {
|
||||
console.error("Error verifying statement: ", error);
|
||||
} finally {
|
||||
setIsLoading(false);
|
||||
}
|
||||
}
|
||||
|
||||
async function spawnNewConversation(setConversationID: (conversationID: string) => void) {
|
||||
let createURL = `/api/chat/sessions?client=web`;
|
||||
|
||||
const response = await fetch(createURL, { method: "POST" });
|
||||
|
||||
const data = await response.json();
|
||||
setConversationID(data.conversation_id);
|
||||
}
|
||||
|
||||
interface ReferenceVerificationProps {
|
||||
message: string;
|
||||
additionalLink: string;
|
||||
conversationId: string;
|
||||
linkTitle: string;
|
||||
setChildReferencesCallback: (
|
||||
additionalLink: string,
|
||||
response: string,
|
||||
linkTitle: string,
|
||||
) => void;
|
||||
prefilledResponse?: string;
|
||||
}
|
||||
|
||||
function ReferenceVerification(props: ReferenceVerificationProps) {
|
||||
const [initialResponse, setInitialResponse] = useState("");
|
||||
const [isLoading, setIsLoading] = useState(true);
|
||||
const verificationStatement = `${props.message}. Use this link for reference: ${props.additionalLink}`;
|
||||
const isMobileWidth = useIsMobileWidth();
|
||||
|
||||
useEffect(() => {
|
||||
if (props.prefilledResponse) {
|
||||
setInitialResponse(props.prefilledResponse);
|
||||
setIsLoading(false);
|
||||
} else {
|
||||
verifyStatement(
|
||||
verificationStatement,
|
||||
props.conversationId,
|
||||
setIsLoading,
|
||||
setInitialResponse,
|
||||
() => {},
|
||||
);
|
||||
}
|
||||
}, [verificationStatement, props.conversationId, props.prefilledResponse]);
|
||||
|
||||
useEffect(() => {
|
||||
if (initialResponse === "") return;
|
||||
if (props.prefilledResponse) return;
|
||||
|
||||
if (!isLoading) {
|
||||
// Only set the child references when it's done loading and if the initial response is not prefilled (i.e. it was fetched from the server)
|
||||
props.setChildReferencesCallback(
|
||||
props.additionalLink,
|
||||
initialResponse,
|
||||
props.linkTitle,
|
||||
);
|
||||
}
|
||||
}, [initialResponse, isLoading, props]);
|
||||
|
||||
return (
|
||||
<div>
|
||||
{isLoading && <LoadingSpinner />}
|
||||
<ChatMessage
|
||||
chatMessage={{
|
||||
automationId: "",
|
||||
by: "AI",
|
||||
message: initialResponse,
|
||||
context: [],
|
||||
created: new Date().toISOString(),
|
||||
onlineContext: {},
|
||||
}}
|
||||
isMobileWidth={isMobileWidth}
|
||||
/>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
interface SupplementalReferenceProps {
|
||||
onlineData?: OnlineContextData;
|
||||
officialFactToVerify: string;
|
||||
conversationId: string;
|
||||
additionalLink: string;
|
||||
setChildReferencesCallback: (
|
||||
additionalLink: string,
|
||||
response: string,
|
||||
linkTitle: string,
|
||||
) => void;
|
||||
prefilledResponse?: string;
|
||||
linkTitle?: string;
|
||||
}
|
||||
|
||||
function SupplementalReference(props: SupplementalReferenceProps) {
|
||||
const linkTitle = props.linkTitle || props.onlineData?.organic?.[0]?.title || "Reference";
|
||||
const linkAsWebpage = { link: props.additionalLink } as WebPage;
|
||||
return (
|
||||
<Card className={`mt-2 mb-4`}>
|
||||
<CardHeader>
|
||||
<a
|
||||
className={styles.titleLink}
|
||||
href={props.additionalLink}
|
||||
target="_blank"
|
||||
rel="noreferrer"
|
||||
>
|
||||
{linkTitle}
|
||||
</a>
|
||||
<WebPageLink {...linkAsWebpage} />
|
||||
</CardHeader>
|
||||
<CardContent>
|
||||
<ReferenceVerification
|
||||
additionalLink={props.additionalLink}
|
||||
message={props.officialFactToVerify}
|
||||
linkTitle={linkTitle}
|
||||
conversationId={props.conversationId}
|
||||
setChildReferencesCallback={props.setChildReferencesCallback}
|
||||
prefilledResponse={props.prefilledResponse}
|
||||
/>
|
||||
</CardContent>
|
||||
</Card>
|
||||
);
|
||||
}
|
||||
|
||||
const WebPageLink = (webpage: WebPage) => {
|
||||
const webpageDomain = new URL(webpage.link).hostname;
|
||||
return (
|
||||
<div className={styles.subLinks}>
|
||||
<a
|
||||
className={`${styles.subLinks} bg-blue-200 px-2`}
|
||||
href={webpage.link}
|
||||
target="_blank"
|
||||
rel="noreferrer"
|
||||
>
|
||||
{webpageDomain}
|
||||
</a>
|
||||
</div>
|
||||
);
|
||||
};
|
||||
|
||||
export default function FactChecker() {
|
||||
const [factToVerify, setFactToVerify] = useState("");
|
||||
const [officialFactToVerify, setOfficialFactToVerify] = useState("");
|
||||
const [isLoading, setIsLoading] = useState(false);
|
||||
const [initialResponse, setInitialResponse] = useState("");
|
||||
const [clickedVerify, setClickedVerify] = useState(false);
|
||||
const [initialReferences, setInitialReferences] = useState<ResponseWithReferences>();
|
||||
const [childReferences, setChildReferences] = useState<SupplementReferences[]>();
|
||||
const [modelUsed, setModelUsed] = useState<Model>();
|
||||
const isMobileWidth = useIsMobileWidth();
|
||||
|
||||
const [conversationID, setConversationID] = useState("");
|
||||
const [runId, setRunId] = useState("");
|
||||
const [loadedFromStorage, setLoadedFromStorage] = useState(false);
|
||||
|
||||
const [initialModel, setInitialModel] = useState<Model>();
|
||||
|
||||
function setChildReferencesCallback(
|
||||
additionalLink: string,
|
||||
response: string,
|
||||
linkTitle: string,
|
||||
) {
|
||||
const newReferences = childReferences || [];
|
||||
const exists = newReferences.find(
|
||||
(reference) => reference.additionalLink === additionalLink,
|
||||
);
|
||||
if (exists) return;
|
||||
newReferences.push({ additionalLink, response, linkTitle });
|
||||
setChildReferences(newReferences);
|
||||
}
|
||||
|
||||
let userData = useAuthenticatedData();
|
||||
|
||||
function storeData() {
|
||||
const data = {
|
||||
factToVerify,
|
||||
response: initialResponse,
|
||||
references: initialReferences,
|
||||
childReferences,
|
||||
runId,
|
||||
modelUsed,
|
||||
};
|
||||
|
||||
fetch(`/api/chat/store/factchecker`, {
|
||||
method: "POST",
|
||||
headers: {
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
body: JSON.stringify({
|
||||
runId: runId,
|
||||
storeData: data,
|
||||
}),
|
||||
});
|
||||
}
|
||||
|
||||
useEffect(() => {
|
||||
if (factToVerify) {
|
||||
document.title = `AI Fact Check: ${factToVerify}`;
|
||||
} else {
|
||||
document.title = "AI Fact Checker";
|
||||
}
|
||||
}, [factToVerify]);
|
||||
|
||||
useEffect(() => {
|
||||
const storedFact = localStorage.getItem("factToVerify");
|
||||
if (storedFact) {
|
||||
setFactToVerify(storedFact);
|
||||
}
|
||||
|
||||
// Get query params from the URL
|
||||
const urlParams = new URLSearchParams(window.location.search);
|
||||
const factToVerifyParam = urlParams.get("factToVerify");
|
||||
|
||||
if (factToVerifyParam) {
|
||||
setFactToVerify(factToVerifyParam);
|
||||
}
|
||||
|
||||
const runIdParam = urlParams.get("runId");
|
||||
if (runIdParam) {
|
||||
setRunId(runIdParam);
|
||||
|
||||
// Define an async function to fetch data
|
||||
const fetchData = async () => {
|
||||
const storedDataURL = `/api/chat/store/factchecker?runId=${runIdParam}`;
|
||||
try {
|
||||
const response = await fetch(storedDataURL);
|
||||
if (response.status !== 200) {
|
||||
throw new Error("Failed to fetch stored data");
|
||||
}
|
||||
const storedData = JSON.parse(await response.json());
|
||||
if (storedData) {
|
||||
setOfficialFactToVerify(storedData.factToVerify);
|
||||
setInitialResponse(storedData.response);
|
||||
setInitialReferences(storedData.references);
|
||||
setChildReferences(storedData.childReferences);
|
||||
setInitialModel(storedData.modelUsed);
|
||||
}
|
||||
setLoadedFromStorage(true);
|
||||
} catch (error) {
|
||||
console.error("Error fetching stored data: ", error);
|
||||
}
|
||||
};
|
||||
|
||||
// Call the async function
|
||||
fetchData();
|
||||
}
|
||||
}, []);
|
||||
|
||||
function onClickVerify() {
|
||||
if (clickedVerify) return;
|
||||
|
||||
// Perform validation checks on the fact to verify
|
||||
if (!factToVerify) {
|
||||
alert("Please enter a fact to verify.");
|
||||
return;
|
||||
}
|
||||
|
||||
setClickedVerify(true);
|
||||
if (!userData) {
|
||||
let currentURL = window.location.href;
|
||||
window.location.href = `/login?next=${currentURL}`;
|
||||
}
|
||||
|
||||
setInitialReferences(undefined);
|
||||
setInitialResponse("");
|
||||
|
||||
spawnNewConversation(setConversationID);
|
||||
|
||||
// Set the runId to a random 12-digit alphanumeric string
|
||||
const newRunId = [...Array(16)].map(() => Math.random().toString(36)[2]).join("");
|
||||
setRunId(newRunId);
|
||||
window.history.pushState(
|
||||
{},
|
||||
document.title,
|
||||
window.location.pathname + `?runId=${newRunId}`,
|
||||
);
|
||||
|
||||
setOfficialFactToVerify(factToVerify);
|
||||
setClickedVerify(false);
|
||||
}
|
||||
|
||||
useEffect(() => {
|
||||
if (!conversationID) return;
|
||||
verifyStatement(
|
||||
officialFactToVerify,
|
||||
conversationID,
|
||||
setIsLoading,
|
||||
setInitialResponse,
|
||||
setInitialReferences,
|
||||
);
|
||||
}, [conversationID, officialFactToVerify]);
|
||||
|
||||
// Store factToVerify in localStorage whenever it changes
|
||||
useEffect(() => {
|
||||
localStorage.setItem("factToVerify", factToVerify);
|
||||
}, [factToVerify]);
|
||||
|
||||
// Update the meta tags for the description and og:description
|
||||
useEffect(() => {
|
||||
let metaTag = document.querySelector('meta[name="description"]');
|
||||
if (metaTag) {
|
||||
metaTag.setAttribute("content", initialResponse);
|
||||
}
|
||||
let metaOgTag = document.querySelector('meta[property="og:description"]');
|
||||
if (!metaOgTag) {
|
||||
metaOgTag = document.createElement("meta");
|
||||
metaOgTag.setAttribute("property", "og:description");
|
||||
document.getElementsByTagName("head")[0].appendChild(metaOgTag);
|
||||
}
|
||||
metaOgTag.setAttribute("content", initialResponse);
|
||||
}, [initialResponse]);
|
||||
|
||||
const renderReferences = (
|
||||
conversationId: string,
|
||||
initialReferences: ResponseWithReferences,
|
||||
officialFactToVerify: string,
|
||||
loadedFromStorage: boolean,
|
||||
childReferences?: SupplementReferences[],
|
||||
) => {
|
||||
if (loadedFromStorage && childReferences) {
|
||||
return renderSupplementalReferences(childReferences);
|
||||
}
|
||||
|
||||
const seenLinks = new Set();
|
||||
|
||||
// Any links that are present in webpages should not be searched again
|
||||
Object.entries(initialReferences.online || {}).map(([key, onlineData], index) => {
|
||||
const webpages = onlineData?.webpages || [];
|
||||
// Webpage can be a list or a single object
|
||||
if (webpages instanceof Array) {
|
||||
for (let i = 0; i < webpages.length; i++) {
|
||||
const webpage = webpages[i];
|
||||
const additionalLink = webpage.link || "";
|
||||
if (seenLinks.has(additionalLink)) {
|
||||
return null;
|
||||
}
|
||||
seenLinks.add(additionalLink);
|
||||
}
|
||||
} else {
|
||||
let singleWebpage = webpages as WebPage;
|
||||
const additionalLink = singleWebpage.link || "";
|
||||
if (seenLinks.has(additionalLink)) {
|
||||
return null;
|
||||
}
|
||||
seenLinks.add(additionalLink);
|
||||
}
|
||||
});
|
||||
|
||||
return Object.entries(initialReferences.online || {})
|
||||
.map(([key, onlineData], index) => {
|
||||
let additionalLink = "";
|
||||
|
||||
// Loop through organic links until we find one that hasn't been searched
|
||||
for (let i = 0; i < onlineData?.organic?.length; i++) {
|
||||
const webpage = onlineData?.organic?.[i];
|
||||
additionalLink = webpage.link || "";
|
||||
|
||||
if (!seenLinks.has(additionalLink)) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
seenLinks.add(additionalLink);
|
||||
|
||||
if (additionalLink === "") return null;
|
||||
|
||||
return (
|
||||
<SupplementalReference
|
||||
key={index}
|
||||
onlineData={onlineData}
|
||||
officialFactToVerify={officialFactToVerify}
|
||||
conversationId={conversationId}
|
||||
additionalLink={additionalLink}
|
||||
setChildReferencesCallback={setChildReferencesCallback}
|
||||
/>
|
||||
);
|
||||
})
|
||||
.filter(Boolean);
|
||||
};
|
||||
|
||||
const renderSupplementalReferences = (references: SupplementReferences[]) => {
|
||||
return references.map((reference, index) => {
|
||||
return (
|
||||
<SupplementalReference
|
||||
key={index}
|
||||
additionalLink={reference.additionalLink}
|
||||
officialFactToVerify={officialFactToVerify}
|
||||
conversationId={conversationID}
|
||||
linkTitle={reference.linkTitle}
|
||||
setChildReferencesCallback={setChildReferencesCallback}
|
||||
prefilledResponse={reference.response}
|
||||
/>
|
||||
);
|
||||
});
|
||||
};
|
||||
|
||||
const renderWebpages = (webpages: WebPage[] | WebPage) => {
|
||||
if (webpages instanceof Array) {
|
||||
return webpages.map((webpage, index) => {
|
||||
return WebPageLink(webpage);
|
||||
});
|
||||
} else {
|
||||
return WebPageLink(webpages);
|
||||
}
|
||||
};
|
||||
|
||||
function constructShareUrl() {
|
||||
const url = new URL(window.location.href);
|
||||
url.searchParams.set("runId", runId);
|
||||
return url.href;
|
||||
}
|
||||
|
||||
return (
|
||||
<>
|
||||
<div className="relative md:fixed h-full">
|
||||
<SidePanel conversationId={null} uploadedFiles={[]} isMobileWidth={isMobileWidth} />
|
||||
</div>
|
||||
<div className={styles.factCheckerContainer}>
|
||||
<h1
|
||||
className={`${styles.response} pt-8 md:pt-4 font-large outline-slate-800 dark:outline-slate-200`}
|
||||
>
|
||||
AI Fact Checker
|
||||
</h1>
|
||||
<footer className={`${styles.footer} mt-4`}>
|
||||
This is an experimental AI tool. It may make mistakes.
|
||||
</footer>
|
||||
{initialResponse && initialReferences && childReferences ? (
|
||||
<div className={styles.reportActions}>
|
||||
<Button asChild variant="secondary">
|
||||
<Link href="/factchecker" target="_blank" rel="noopener noreferrer">
|
||||
Try Another
|
||||
</Link>
|
||||
</Button>
|
||||
<ShareLink
|
||||
buttonTitle="Share report"
|
||||
title="AI Fact Checking Report"
|
||||
description="Share this fact checking report with others. Anyone who has this link will be able to view the report."
|
||||
url={constructShareUrl()}
|
||||
onShare={loadedFromStorage ? () => {} : storeData}
|
||||
/>
|
||||
</div>
|
||||
) : (
|
||||
<div className={styles.newReportActions}>
|
||||
<div className={`${styles.inputFields} mt-4`}>
|
||||
<Input
|
||||
type="text"
|
||||
maxLength={200}
|
||||
placeholder="Enter a falsifiable statement to verify"
|
||||
disabled={isLoading}
|
||||
onChange={(e) => setFactToVerify(e.target.value)}
|
||||
value={factToVerify}
|
||||
onKeyDown={(e) => {
|
||||
if (e.key === "Enter") {
|
||||
onClickVerify();
|
||||
}
|
||||
}}
|
||||
onFocus={(e) => (e.target.placeholder = "")}
|
||||
onBlur={(e) =>
|
||||
(e.target.placeholder =
|
||||
"Enter a falsifiable statement to verify")
|
||||
}
|
||||
/>
|
||||
<Button disabled={clickedVerify} onClick={() => onClickVerify()}>
|
||||
Verify
|
||||
</Button>
|
||||
</div>
|
||||
<h3 className={`mt-4 mb-4`}>
|
||||
Try with a particular model. You must be{" "}
|
||||
<a
|
||||
href="/settings"
|
||||
className="font-medium text-blue-600 dark:text-blue-500 hover:underline"
|
||||
>
|
||||
subscribed
|
||||
</a>{" "}
|
||||
to configure the model.
|
||||
</h3>
|
||||
</div>
|
||||
)}
|
||||
<ModelPicker
|
||||
disabled={isLoading || loadedFromStorage}
|
||||
setModelUsed={setModelUsed}
|
||||
initialModel={initialModel}
|
||||
/>
|
||||
{isLoading && (
|
||||
<div className={styles.loading}>
|
||||
<LoadingSpinner />
|
||||
</div>
|
||||
)}
|
||||
{initialResponse && (
|
||||
<Card className={`mt-4`}>
|
||||
<CardHeader>
|
||||
<CardTitle>{officialFactToVerify}</CardTitle>
|
||||
</CardHeader>
|
||||
<CardContent>
|
||||
<div className={styles.responseText}>
|
||||
<ChatMessage
|
||||
chatMessage={{
|
||||
automationId: "",
|
||||
by: "AI",
|
||||
message: initialResponse,
|
||||
context: [],
|
||||
created: new Date().toISOString(),
|
||||
onlineContext: {},
|
||||
}}
|
||||
isMobileWidth={isMobileWidth}
|
||||
/>
|
||||
</div>
|
||||
</CardContent>
|
||||
<CardFooter>
|
||||
{initialReferences &&
|
||||
initialReferences.online &&
|
||||
Object.keys(initialReferences.online).length > 0 && (
|
||||
<div className={styles.subLinks}>
|
||||
{Object.entries(initialReferences.online).map(
|
||||
([key, onlineData], index) => {
|
||||
const webpages = onlineData?.webpages || [];
|
||||
return renderWebpages(webpages);
|
||||
},
|
||||
)}
|
||||
</div>
|
||||
)}
|
||||
</CardFooter>
|
||||
</Card>
|
||||
)}
|
||||
{initialReferences && (
|
||||
<div className={styles.referenceContainer}>
|
||||
<h2 className="mt-4 mb-4">Supplements</h2>
|
||||
<div className={styles.references}>
|
||||
{initialReferences.online !== undefined &&
|
||||
renderReferences(
|
||||
conversationID,
|
||||
initialReferences,
|
||||
officialFactToVerify,
|
||||
loadedFromStorage,
|
||||
childReferences,
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
</>
|
||||
);
|
||||
}
|
||||
@@ -14,7 +14,7 @@ export const metadata: Metadata = {
|
||||
manifest: "/static/khoj.webmanifest",
|
||||
openGraph: {
|
||||
siteName: "Khoj AI",
|
||||
title: "Khoj AI - Home",
|
||||
title: "Khoj AI",
|
||||
description: "Your Second Brain.",
|
||||
url: "https://app.khoj.dev",
|
||||
type: "website",
|
||||
|
||||
@@ -3,31 +3,42 @@ import "./globals.css";
|
||||
import styles from "./page.module.css";
|
||||
import "katex/dist/katex.min.css";
|
||||
|
||||
import React, { useEffect, useState } from "react";
|
||||
import React, { useEffect, useRef, useState } from "react";
|
||||
import useSWR from "swr";
|
||||
import Image from "next/image";
|
||||
import { ArrowCounterClockwise } from "@phosphor-icons/react";
|
||||
|
||||
import { Card, CardTitle } from "@/components/ui/card";
|
||||
import SuggestionCard from "@/app/components/suggestions/suggestionCard";
|
||||
import SidePanel from "@/app/components/sidePanel/chatHistorySidePanel";
|
||||
import Loading from "@/app/components/loading/loading";
|
||||
import ChatInputArea, { ChatOptions } from "@/app/components/chatInputArea/chatInputArea";
|
||||
import {
|
||||
AttachedFileText,
|
||||
ChatInputArea,
|
||||
ChatOptions,
|
||||
} from "@/app/components/chatInputArea/chatInputArea";
|
||||
import { Suggestion, suggestionsData } from "@/app/components/suggestions/suggestionsData";
|
||||
import LoginPrompt from "@/app/components/loginPrompt/loginPrompt";
|
||||
|
||||
import { useAuthenticatedData, UserConfig, useUserConfig } from "@/app/common/auth";
|
||||
import {
|
||||
isUserSubscribed,
|
||||
useAuthenticatedData,
|
||||
UserConfig,
|
||||
useUserConfig,
|
||||
} from "@/app/common/auth";
|
||||
import { convertColorToBorderClass } from "@/app/common/colorUtils";
|
||||
import { getIconFromIconName } from "@/app/common/iconUtils";
|
||||
import { AgentData } from "@/app/agents/page";
|
||||
import { createNewConversation } from "./common/chatFunctions";
|
||||
import { useIsMobileWidth } from "./common/utils";
|
||||
import { useSearchParams } from "next/navigation";
|
||||
import { useDebounce, useIsMobileWidth } from "./common/utils";
|
||||
import { useRouter, useSearchParams } from "next/navigation";
|
||||
import { ScrollArea, ScrollBar } from "@/components/ui/scroll-area";
|
||||
import { AgentCard } from "@/app/components/agentCard/agentCard";
|
||||
import { Popover, PopoverContent, PopoverTrigger } from "@/components/ui/popover";
|
||||
|
||||
interface ChatBodyDataProps {
|
||||
chatOptionsData: ChatOptions | null;
|
||||
onConversationIdChange?: (conversationId: string) => void;
|
||||
setUploadedFiles: (files: string[]) => void;
|
||||
setUploadedFiles: (files: AttachedFileText[]) => void;
|
||||
isMobileWidth?: boolean;
|
||||
isLoggedIn: boolean;
|
||||
userConfig: UserConfig | null;
|
||||
@@ -44,14 +55,19 @@ function FisherYatesShuffle(array: any[]) {
|
||||
|
||||
function ChatBodyData(props: ChatBodyDataProps) {
|
||||
const [message, setMessage] = useState("");
|
||||
const [image, setImage] = useState<string | null>(null);
|
||||
const [images, setImages] = useState<string[]>([]);
|
||||
const [processingMessage, setProcessingMessage] = useState(false);
|
||||
const [greeting, setGreeting] = useState("");
|
||||
const [shuffledOptions, setShuffledOptions] = useState<Suggestion[]>([]);
|
||||
const [hoveredAgent, setHoveredAgent] = useState<string | null>(null);
|
||||
const debouncedHoveredAgent = useDebounce(hoveredAgent, 500);
|
||||
const [isPopoverOpen, setIsPopoverOpen] = useState(false);
|
||||
const [selectedAgent, setSelectedAgent] = useState<string | null>("khoj");
|
||||
const [agentIcons, setAgentIcons] = useState<JSX.Element[]>([]);
|
||||
const [agents, setAgents] = useState<AgentData[]>([]);
|
||||
const chatInputRef = useRef<HTMLTextAreaElement>(null);
|
||||
const [showLoginPrompt, setShowLoginPrompt] = useState(false);
|
||||
const router = useRouter();
|
||||
const searchParams = useSearchParams();
|
||||
const queryParam = searchParams.get("q");
|
||||
|
||||
@@ -61,6 +77,12 @@ function ChatBodyData(props: ChatBodyDataProps) {
|
||||
}
|
||||
}, [queryParam]);
|
||||
|
||||
useEffect(() => {
|
||||
if (debouncedHoveredAgent) {
|
||||
setIsPopoverOpen(true);
|
||||
}
|
||||
}, [debouncedHoveredAgent]);
|
||||
|
||||
const onConversationIdChange = props.onConversationIdChange;
|
||||
|
||||
const agentsFetcher = () =>
|
||||
@@ -72,6 +94,10 @@ function ChatBodyData(props: ChatBodyDataProps) {
|
||||
revalidateOnFocus: false,
|
||||
});
|
||||
|
||||
const openAgentEditCard = (agentSlug: string) => {
|
||||
router.push(`/agents?agent=${agentSlug}`);
|
||||
};
|
||||
|
||||
function shuffleAndSetOptions() {
|
||||
const shuffled = FisherYatesShuffle(suggestionsData);
|
||||
setShuffledOptions(shuffled.slice(0, 3));
|
||||
@@ -94,8 +120,8 @@ function ChatBodyData(props: ChatBodyDataProps) {
|
||||
`What would you like to get done${nameSuffix}?`,
|
||||
`Hey${nameSuffix}! How can I help?`,
|
||||
`Good ${timeOfDay}${nameSuffix}! What's on your mind?`,
|
||||
`Ready to breeze through your ${["Sunday", "Monday", "Tuesday", "Wednesday", "Thursday", "Friday", "Saturday"][day]}?`,
|
||||
`Want help navigating your ${["Sunday", "Monday", "Tuesday", "Wednesday", "Thursday", "Friday", "Saturday"][day]} workload?`,
|
||||
`Ready to breeze through ${["Sunday", "Monday", "Tuesday", "Wednesday", "Thursday", "Friday", "Saturday"][day]}?`,
|
||||
`Let's navigate your ${["Sunday", "Monday", "Tuesday", "Wednesday", "Thursday", "Friday", "Saturday"][day]} workload`,
|
||||
];
|
||||
const greeting = greetings[Math.floor(Math.random() * greetings.length)];
|
||||
setGreeting(greeting);
|
||||
@@ -108,22 +134,13 @@ function ChatBodyData(props: ChatBodyDataProps) {
|
||||
}, [props.chatOptionsData]);
|
||||
|
||||
useEffect(() => {
|
||||
const nSlice = props.isMobileWidth ? 2 : 4;
|
||||
const shuffledAgents = agentsData ? [...agentsData].sort(() => 0.5 - Math.random()) : [];
|
||||
const agents = agentsData ? [agentsData[0]] : []; // Always add the first/default agent.
|
||||
|
||||
shuffledAgents.slice(0, nSlice - 1).forEach((agent) => {
|
||||
if (!agents.find((a) => a.slug === agent.slug)) {
|
||||
agents.push(agent);
|
||||
}
|
||||
});
|
||||
|
||||
const agents = (agentsData || []).filter((agent) => agent !== null && agent !== undefined);
|
||||
setAgents(agents);
|
||||
// set the first agent, which is always the default agent, as the default for chat
|
||||
setSelectedAgent(agents.length > 1 ? agents[0].slug : "khoj");
|
||||
|
||||
//generate colored icons for the selected agents
|
||||
const agentIcons = agents
|
||||
.filter((agent) => agent !== null && agent !== undefined)
|
||||
.map((agent) => getIconFromIconName(agent.icon, agent.color)!);
|
||||
// generate colored icons for the available agents
|
||||
const agentIcons = agents.map((agent) => getIconFromIconName(agent.icon, agent.color)!);
|
||||
setAgentIcons(agentIcons);
|
||||
}, [agentsData, props.isMobileWidth]);
|
||||
|
||||
@@ -138,24 +155,40 @@ function ChatBodyData(props: ChatBodyDataProps) {
|
||||
try {
|
||||
const newConversationId = await createNewConversation(selectedAgent || "khoj");
|
||||
onConversationIdChange?.(newConversationId);
|
||||
window.location.href = `/chat?conversationId=${newConversationId}`;
|
||||
localStorage.setItem("message", message);
|
||||
if (image) {
|
||||
localStorage.setItem("image", image);
|
||||
if (images.length > 0) {
|
||||
localStorage.setItem("images", JSON.stringify(images));
|
||||
}
|
||||
|
||||
window.location.href = `/chat?conversationId=${newConversationId}`;
|
||||
} catch (error) {
|
||||
console.error("Error creating new conversation:", error);
|
||||
setProcessingMessage(false);
|
||||
}
|
||||
setMessage("");
|
||||
setImages([]);
|
||||
}
|
||||
};
|
||||
processMessage();
|
||||
if (message) {
|
||||
if (message || images.length > 0) {
|
||||
setProcessingMessage(true);
|
||||
}
|
||||
}, [selectedAgent, message, processingMessage, onConversationIdChange]);
|
||||
|
||||
// Close the agent detail hover card when scroll on agent pane
|
||||
useEffect(() => {
|
||||
const scrollAreaSelector = "[data-radix-scroll-area-viewport]";
|
||||
const scrollAreaEl = document.querySelector<HTMLElement>(scrollAreaSelector);
|
||||
const handleScroll = () => {
|
||||
setHoveredAgent(null);
|
||||
setIsPopoverOpen(false);
|
||||
};
|
||||
|
||||
scrollAreaEl?.addEventListener("scroll", handleScroll);
|
||||
|
||||
return () => scrollAreaEl?.removeEventListener("scroll", handleScroll);
|
||||
}, []);
|
||||
|
||||
function fillArea(link: string, type: string, prompt: string) {
|
||||
if (!link) {
|
||||
let message_str = "";
|
||||
@@ -194,37 +227,76 @@ function ChatBodyData(props: ChatBodyDataProps) {
|
||||
</h1>
|
||||
</div>
|
||||
{!props.isMobileWidth && (
|
||||
<div className="flex pb-6 gap-2 items-center justify-center">
|
||||
{agentIcons.map((icon, index) => (
|
||||
<Card
|
||||
key={`${index}-${agents[index].slug}`}
|
||||
className={`${
|
||||
selectedAgent === agents[index].slug
|
||||
? convertColorToBorderClass(agents[index].color)
|
||||
: "border-stone-100 dark:border-neutral-700 text-muted-foreground"
|
||||
}
|
||||
hover:cursor-pointer rounded-lg px-2 py-2`}
|
||||
>
|
||||
<CardTitle
|
||||
className="text-center text-md font-medium flex justify-center items-center"
|
||||
onClick={() => setSelectedAgent(agents[index].slug)}
|
||||
<ScrollArea className="w-full max-w-[600px] mx-auto">
|
||||
<div className="flex pb-2 gap-2 items-center justify-center">
|
||||
{agents.map((agent, index) => (
|
||||
<Popover
|
||||
key={`${index}-${agent.slug}`}
|
||||
open={isPopoverOpen && debouncedHoveredAgent === agent.slug}
|
||||
onOpenChange={(open) => {
|
||||
if (!open) {
|
||||
setHoveredAgent(null);
|
||||
setIsPopoverOpen(false);
|
||||
}
|
||||
}}
|
||||
>
|
||||
{icon} {agents[index].name}
|
||||
</CardTitle>
|
||||
</Card>
|
||||
))}
|
||||
<Card
|
||||
className="border-none shadow-none flex justify-center items-center hover:cursor-pointer"
|
||||
onClick={() => (window.location.href = "/agents")}
|
||||
>
|
||||
<CardTitle className="text-center text-md font-normal flex justify-center items-center px-1.5 py-2">
|
||||
See All →
|
||||
</CardTitle>
|
||||
</Card>
|
||||
</div>
|
||||
<PopoverTrigger asChild>
|
||||
<Card
|
||||
className={`${
|
||||
selectedAgent === agent.slug
|
||||
? convertColorToBorderClass(agent.color)
|
||||
: "border-stone-100 dark:border-neutral-700 text-muted-foreground"
|
||||
}
|
||||
hover:cursor-pointer rounded-lg px-2 py-2`}
|
||||
onDoubleClick={() => openAgentEditCard(agent.slug)}
|
||||
onClick={() => {
|
||||
setSelectedAgent(agent.slug);
|
||||
chatInputRef.current?.focus();
|
||||
setHoveredAgent(null);
|
||||
setIsPopoverOpen(false);
|
||||
}}
|
||||
onMouseEnter={() => setHoveredAgent(agent.slug)}
|
||||
onMouseLeave={() => {
|
||||
setHoveredAgent(null);
|
||||
setIsPopoverOpen(false);
|
||||
}}
|
||||
>
|
||||
<CardTitle className="text-center text-md font-medium flex justify-center items-center whitespace-nowrap">
|
||||
{agentIcons[index]} {agent.name}
|
||||
</CardTitle>
|
||||
</Card>
|
||||
</PopoverTrigger>
|
||||
<PopoverContent
|
||||
className="w-80 p-0 border-none bg-transparent shadow-none"
|
||||
onMouseLeave={() => {
|
||||
setHoveredAgent(null);
|
||||
setIsPopoverOpen(false);
|
||||
}}
|
||||
>
|
||||
<AgentCard
|
||||
data={agent}
|
||||
userProfile={null}
|
||||
isMobileWidth={props.isMobileWidth || false}
|
||||
showChatButton={false}
|
||||
editCard={false}
|
||||
filesOptions={[]}
|
||||
selectedChatModelOption=""
|
||||
agentSlug=""
|
||||
isSubscribed={isUserSubscribed(props.userConfig)}
|
||||
setAgentChangeTriggered={() => {}}
|
||||
modelOptions={[]}
|
||||
inputToolOptions={{}}
|
||||
outputModeOptions={{}}
|
||||
/>
|
||||
</PopoverContent>
|
||||
</Popover>
|
||||
))}
|
||||
</div>
|
||||
<ScrollBar orientation="horizontal" />
|
||||
</ScrollArea>
|
||||
)}
|
||||
</div>
|
||||
<div className={`mx-auto ${props.isMobileWidth ? "w-full" : "w-fit"}`}>
|
||||
<div className={`mx-auto ${props.isMobileWidth ? "w-full" : "w-fit max-w-screen-md"}`}>
|
||||
{!props.isMobileWidth && (
|
||||
<div
|
||||
className={`w-full ${styles.inputBox} shadow-lg bg-background align-middle items-center justify-center px-3 py-1 dark:bg-neutral-700 border-stone-100 dark:border-none dark:shadow-none rounded-2xl`}
|
||||
@@ -232,12 +304,14 @@ function ChatBodyData(props: ChatBodyDataProps) {
|
||||
<ChatInputArea
|
||||
isLoggedIn={props.isLoggedIn}
|
||||
sendMessage={(message) => setMessage(message)}
|
||||
sendImage={(image) => setImage(image)}
|
||||
sendImage={(image) => setImages((prevImages) => [...prevImages, image])}
|
||||
sendDisabled={processingMessage}
|
||||
chatOptionsData={props.chatOptionsData}
|
||||
conversationId={null}
|
||||
isMobileWidth={props.isMobileWidth}
|
||||
setUploadedFiles={props.setUploadedFiles}
|
||||
agentColor={agents.find((agent) => agent.slug === selectedAgent)?.color}
|
||||
ref={chatInputRef}
|
||||
/>
|
||||
</div>
|
||||
)}
|
||||
@@ -285,40 +359,41 @@ function ChatBodyData(props: ChatBodyDataProps) {
|
||||
<div
|
||||
className={`${styles.inputBox} pt-1 shadow-[0_-20px_25px_-5px_rgba(0,0,0,0.1)] dark:bg-neutral-700 bg-background align-middle items-center justify-center pb-3 mx-1 rounded-t-2xl rounded-b-none`}
|
||||
>
|
||||
<div className="flex gap-2 items-center justify-left pt-1 pb-2 px-12">
|
||||
{agentIcons.map((icon, index) => (
|
||||
<Card
|
||||
key={`${index}-${agents[index].slug}`}
|
||||
className={`${selectedAgent === agents[index].slug ? convertColorToBorderClass(agents[index].color) : "border-muted text-muted-foreground"} hover:cursor-pointer`}
|
||||
>
|
||||
<CardTitle
|
||||
className="text-center text-xs font-medium flex justify-center items-center px-1.5 py-1"
|
||||
onClick={() => setSelectedAgent(agents[index].slug)}
|
||||
<ScrollArea className="w-full max-w-[85vw]">
|
||||
<div className="flex gap-2 items-center justify-left pt-1 pb-2 px-12">
|
||||
{agentIcons.map((icon, index) => (
|
||||
<Card
|
||||
key={`${index}-${agents[index].slug}`}
|
||||
className={`${selectedAgent === agents[index].slug ? convertColorToBorderClass(agents[index].color) : "border-muted text-muted-foreground"} hover:cursor-pointer`}
|
||||
>
|
||||
{icon} {agents[index].name}
|
||||
</CardTitle>
|
||||
</Card>
|
||||
))}
|
||||
<Card
|
||||
className="border-none shadow-none flex justify-center items-center hover:cursor-pointer"
|
||||
onClick={() => (window.location.href = "/agents")}
|
||||
>
|
||||
<CardTitle
|
||||
className={`text-center ${props.isMobileWidth ? "text-xs" : "text-md"} font-normal flex justify-center items-center px-1.5 py-2`}
|
||||
>
|
||||
See All →
|
||||
</CardTitle>
|
||||
</Card>
|
||||
</div>
|
||||
<CardTitle
|
||||
className="text-center text-xs font-medium flex justify-center items-center whitespace-nowrap px-1.5 py-1"
|
||||
onDoubleClick={() =>
|
||||
openAgentEditCard(agents[index].slug)
|
||||
}
|
||||
onClick={() => {
|
||||
setSelectedAgent(agents[index].slug);
|
||||
chatInputRef.current?.focus();
|
||||
}}
|
||||
>
|
||||
{icon} {agents[index].name}
|
||||
</CardTitle>
|
||||
</Card>
|
||||
))}
|
||||
</div>
|
||||
<ScrollBar orientation="horizontal" />
|
||||
</ScrollArea>
|
||||
<ChatInputArea
|
||||
isLoggedIn={props.isLoggedIn}
|
||||
sendMessage={(message) => setMessage(message)}
|
||||
sendImage={(image) => setImage(image)}
|
||||
sendImage={(image) => setImages((prevImages) => [...prevImages, image])}
|
||||
sendDisabled={processingMessage}
|
||||
chatOptionsData={props.chatOptionsData}
|
||||
conversationId={null}
|
||||
isMobileWidth={props.isMobileWidth}
|
||||
setUploadedFiles={props.setUploadedFiles}
|
||||
agentColor={agents.find((agent) => agent.slug === selectedAgent)?.color}
|
||||
ref={chatInputRef}
|
||||
/>
|
||||
</div>
|
||||
</>
|
||||
@@ -331,7 +406,7 @@ export default function Home() {
|
||||
const [chatOptionsData, setChatOptionsData] = useState<ChatOptions | null>(null);
|
||||
const [isLoading, setLoading] = useState(true);
|
||||
const [conversationId, setConversationID] = useState<string | null>(null);
|
||||
const [uploadedFiles, setUploadedFiles] = useState<string[]>([]);
|
||||
const [uploadedFiles, setUploadedFiles] = useState<AttachedFileText[] | null>(null);
|
||||
const isMobileWidth = useIsMobileWidth();
|
||||
|
||||
const { userConfig: initialUserConfig, isLoadingUserConfig } = useUserConfig(true);
|
||||
@@ -347,6 +422,12 @@ export default function Home() {
|
||||
setUserConfig(initialUserConfig);
|
||||
}, [initialUserConfig]);
|
||||
|
||||
useEffect(() => {
|
||||
if (uploadedFiles) {
|
||||
localStorage.setItem("uploadedFiles", JSON.stringify(uploadedFiles));
|
||||
}
|
||||
}, [uploadedFiles]);
|
||||
|
||||
useEffect(() => {
|
||||
fetch("/api/chat/options")
|
||||
.then((response) => response.json())
|
||||
@@ -372,7 +453,7 @@ export default function Home() {
|
||||
<div className={`${styles.sidePanel}`}>
|
||||
<SidePanel
|
||||
conversationId={conversationId}
|
||||
uploadedFiles={uploadedFiles}
|
||||
uploadedFiles={[]}
|
||||
isMobileWidth={isMobileWidth}
|
||||
/>
|
||||
</div>
|
||||
|
||||
@@ -137,10 +137,8 @@ const ManageFilesModal: React.FC<{ onClose: () => void }> = ({ onClose }) => {
|
||||
|
||||
const deleteSelected = async () => {
|
||||
let filesToDelete = selectedFiles.length > 0 ? selectedFiles : filteredFiles;
|
||||
console.log("Delete selected files", filesToDelete);
|
||||
|
||||
if (filesToDelete.length === 0) {
|
||||
console.log("No files to delete");
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -162,15 +160,12 @@ const ManageFilesModal: React.FC<{ onClose: () => void }> = ({ onClose }) => {
|
||||
|
||||
// Reset selectedFiles
|
||||
setSelectedFiles([]);
|
||||
|
||||
console.log("Deleted files:", filesToDelete);
|
||||
} catch (error) {
|
||||
console.error("Error deleting files:", error);
|
||||
}
|
||||
};
|
||||
|
||||
const deleteFile = async (filename: string) => {
|
||||
console.log("Delete selected file", filename);
|
||||
try {
|
||||
const response = await fetch(
|
||||
`/api/content/file?filename=${encodeURIComponent(filename)}`,
|
||||
@@ -189,8 +184,6 @@ const ManageFilesModal: React.FC<{ onClose: () => void }> = ({ onClose }) => {
|
||||
|
||||
// Remove the file from selectedFiles if it's there
|
||||
setSelectedFiles((prevSelected) => prevSelected.filter((file) => file !== filename));
|
||||
|
||||
console.log("Deleted file:", filename);
|
||||
} catch (error) {
|
||||
console.error("Error deleting file:", error);
|
||||
}
|
||||
@@ -513,7 +506,7 @@ export default function SettingsView() {
|
||||
const isMobileWidth = useIsMobileWidth();
|
||||
|
||||
const cardClassName =
|
||||
"w-full lg:w-1/3 grid grid-flow-column border border-gray-300 shadow-md rounded-lg bg-gradient-to-b from-background to-gray-50 dark:to-gray-950";
|
||||
"w-full lg:w-1/3 grid grid-flow-column border border-gray-300 shadow-md rounded-lg bg-gradient-to-b from-background to-gray-50 dark:to-gray-950 border border-opacity-50";
|
||||
|
||||
useEffect(() => {
|
||||
setUserConfig(initialUserConfig);
|
||||
@@ -601,7 +594,7 @@ export default function SettingsView() {
|
||||
|
||||
const setSubscription = async (state: string) => {
|
||||
try {
|
||||
const url = `/api/subscription?email=${userConfig?.username}&operation=${state}`;
|
||||
const url = `/api/subscription?operation=${state}`;
|
||||
const response = await fetch(url, {
|
||||
method: "PATCH",
|
||||
headers: {
|
||||
@@ -640,6 +633,51 @@ export default function SettingsView() {
|
||||
}
|
||||
};
|
||||
|
||||
const enableFreeTrial = async () => {
|
||||
const formatDate = (dateString: Date) => {
|
||||
const date = new Date(dateString);
|
||||
return new Intl.DateTimeFormat("en-US", {
|
||||
day: "2-digit",
|
||||
month: "short",
|
||||
year: "numeric",
|
||||
}).format(date);
|
||||
};
|
||||
|
||||
try {
|
||||
const response = await fetch(`/api/subscription/trial`, {
|
||||
method: "POST",
|
||||
});
|
||||
if (!response.ok) throw new Error("Failed to enable free trial");
|
||||
|
||||
const responseBody = await response.json();
|
||||
|
||||
// Set updated user settings
|
||||
if (responseBody.trial_enabled && userConfig) {
|
||||
let newUserConfig = userConfig;
|
||||
newUserConfig.subscription_state = SubscriptionStates.TRIAL;
|
||||
const renewalDate = new Date(
|
||||
Date.now() + userConfig.length_of_free_trial * 24 * 60 * 60 * 1000,
|
||||
);
|
||||
newUserConfig.subscription_renewal_date = formatDate(renewalDate);
|
||||
newUserConfig.subscription_enabled_trial_at = new Date().toISOString();
|
||||
setUserConfig(newUserConfig);
|
||||
|
||||
// Notify user of free trial
|
||||
toast({
|
||||
title: "🎉 Trial Enabled",
|
||||
description: `Your free trial will end on ${newUserConfig.subscription_renewal_date}`,
|
||||
});
|
||||
}
|
||||
} catch (error) {
|
||||
console.error("Error enabling free trial:", error);
|
||||
toast({
|
||||
title: "⚠️ Failed to Enable Free Trial",
|
||||
description:
|
||||
"Failed to enable free trial. Try again or contact us at team@khoj.dev",
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
const saveName = async () => {
|
||||
if (!name) return;
|
||||
try {
|
||||
@@ -673,7 +711,7 @@ export default function SettingsView() {
|
||||
};
|
||||
|
||||
const updateModel = (name: string) => async (id: string) => {
|
||||
if (!userConfig?.is_active && name !== "search") {
|
||||
if (!userConfig?.is_active) {
|
||||
toast({
|
||||
title: `Model Update`,
|
||||
description: `You need to be subscribed to update ${name} models`,
|
||||
@@ -866,10 +904,13 @@ export default function SettingsView() {
|
||||
Futurist (Trial)
|
||||
</p>
|
||||
<p className="text-gray-400">
|
||||
You are on a 14 day trial of the Khoj
|
||||
Futurist plan. Check{" "}
|
||||
You are on a{" "}
|
||||
{userConfig.length_of_free_trial} day trial
|
||||
of the Khoj Futurist plan. Your trial ends
|
||||
on {userConfig.subscription_renewal_date}.
|
||||
Check{" "}
|
||||
<a
|
||||
href="https://khoj.dev/pricing"
|
||||
href="https://khoj.dev/#pricing"
|
||||
target="_blank"
|
||||
>
|
||||
pricing page
|
||||
@@ -909,7 +950,7 @@ export default function SettingsView() {
|
||||
)) ||
|
||||
(userConfig.subscription_state === "expired" && (
|
||||
<>
|
||||
<p className="text-xl">Free Plan</p>
|
||||
<p className="text-xl">Humanist</p>
|
||||
{(userConfig.subscription_renewal_date && (
|
||||
<p className="text-gray-400">
|
||||
Subscription <b>expired</b> on{" "}
|
||||
@@ -923,7 +964,7 @@ export default function SettingsView() {
|
||||
<p className="text-gray-400">
|
||||
Check{" "}
|
||||
<a
|
||||
href="https://khoj.dev/pricing"
|
||||
href="https://khoj.dev/#pricing"
|
||||
target="_blank"
|
||||
>
|
||||
pricing page
|
||||
@@ -960,7 +1001,8 @@ export default function SettingsView() {
|
||||
/>
|
||||
Resubscribe
|
||||
</Button>
|
||||
)) || (
|
||||
)) ||
|
||||
(userConfig.subscription_enabled_trial_at && (
|
||||
<Button
|
||||
variant="outline"
|
||||
className="text-primary/80 hover:text-primary"
|
||||
@@ -978,6 +1020,18 @@ export default function SettingsView() {
|
||||
/>
|
||||
Subscribe
|
||||
</Button>
|
||||
)) || (
|
||||
<Button
|
||||
variant="outline"
|
||||
className="text-primary/80 hover:text-primary"
|
||||
onClick={enableFreeTrial}
|
||||
>
|
||||
<ArrowCircleUp
|
||||
weight="bold"
|
||||
className="h-5 w-5 mr-2"
|
||||
/>
|
||||
Enable Trial
|
||||
</Button>
|
||||
)}
|
||||
</CardFooter>
|
||||
</Card>
|
||||
@@ -1172,27 +1226,6 @@ export default function SettingsView() {
|
||||
</CardFooter>
|
||||
</Card>
|
||||
)}
|
||||
{userConfig.search_model_options.length > 0 && (
|
||||
<Card className={cardClassName}>
|
||||
<CardHeader className="text-xl flex flex-row">
|
||||
<FileMagnifyingGlass className="h-7 w-7 mr-2" />
|
||||
Search
|
||||
</CardHeader>
|
||||
<CardContent className="overflow-hidden pb-12 grid gap-8 h-fit">
|
||||
<p className="text-gray-400">
|
||||
Pick the search model to find your documents
|
||||
</p>
|
||||
<DropdownComponent
|
||||
items={userConfig.search_model_options}
|
||||
selected={
|
||||
userConfig.selected_search_model_config
|
||||
}
|
||||
callbackFunc={updateModel("search")}
|
||||
/>
|
||||
</CardContent>
|
||||
<CardFooter className="flex flex-wrap gap-4"></CardFooter>
|
||||
</Card>
|
||||
)}
|
||||
{userConfig.paint_model_options.length > 0 && (
|
||||
<Card className={cardClassName}>
|
||||
<CardHeader className="text-xl flex flex-row">
|
||||
|
||||
@@ -27,7 +27,14 @@ export default function RootLayout({
|
||||
child-src 'none';
|
||||
object-src 'none';"
|
||||
></meta>
|
||||
<body className={inter.className}>{children}</body>
|
||||
<body className={inter.className}>
|
||||
{children}
|
||||
<script
|
||||
dangerouslySetInnerHTML={{
|
||||
__html: `window.EXCALIDRAW_ASSET_PATH = 'https://assets.khoj.dev/@excalidraw/excalidraw/dist/';`,
|
||||
}}
|
||||
/>
|
||||
</body>
|
||||
</html>
|
||||
);
|
||||
}
|
||||
|
||||
@@ -5,45 +5,52 @@ import React, { Suspense, useEffect, useRef, useState } from "react";
|
||||
|
||||
import SidePanel from "../../components/sidePanel/chatHistorySidePanel";
|
||||
import ChatHistory from "../../components/chatHistory/chatHistory";
|
||||
import NavMenu from "../../components/navMenu/navMenu";
|
||||
import Loading from "../../components/loading/loading";
|
||||
|
||||
import "katex/dist/katex.min.css";
|
||||
|
||||
import { useIPLocationData, useIsMobileWidth, welcomeConsole } from "../../common/utils";
|
||||
import { useIsMobileWidth, welcomeConsole } from "../../common/utils";
|
||||
import { useAuthenticatedData } from "@/app/common/auth";
|
||||
|
||||
import ChatInputArea, { ChatOptions } from "@/app/components/chatInputArea/chatInputArea";
|
||||
import {
|
||||
AttachedFileText,
|
||||
ChatInputArea,
|
||||
ChatOptions,
|
||||
} from "@/app/components/chatInputArea/chatInputArea";
|
||||
import { StreamMessage } from "@/app/components/chatMessage/chatMessage";
|
||||
import { processMessageChunk } from "@/app/common/chatFunctions";
|
||||
import { AgentData } from "@/app/agents/page";
|
||||
|
||||
interface ChatBodyDataProps {
|
||||
chatOptionsData: ChatOptions | null;
|
||||
setTitle: (title: string) => void;
|
||||
setUploadedFiles: (files: string[]) => void;
|
||||
setUploadedFiles: (files: AttachedFileText[]) => void;
|
||||
isMobileWidth?: boolean;
|
||||
publicConversationSlug: string;
|
||||
streamedMessages: StreamMessage[];
|
||||
isLoggedIn: boolean;
|
||||
conversationId?: string;
|
||||
setQueryToProcess: (query: string) => void;
|
||||
setImage64: (image64: string) => void;
|
||||
setImages: (images: string[]) => void;
|
||||
}
|
||||
|
||||
function ChatBodyData(props: ChatBodyDataProps) {
|
||||
const [message, setMessage] = useState("");
|
||||
const [image, setImage] = useState<string | null>(null);
|
||||
const [images, setImages] = useState<string[]>([]);
|
||||
const [processingMessage, setProcessingMessage] = useState(false);
|
||||
const [agentMetadata, setAgentMetadata] = useState<AgentData | null>(null);
|
||||
const chatInputRef = useRef<HTMLTextAreaElement>(null);
|
||||
|
||||
const setQueryToProcess = props.setQueryToProcess;
|
||||
const streamedMessages = props.streamedMessages;
|
||||
|
||||
const chatHistoryCustomClassName = props.isMobileWidth ? "w-full" : "w-4/6";
|
||||
|
||||
useEffect(() => {
|
||||
if (image) {
|
||||
props.setImage64(encodeURIComponent(image));
|
||||
if (images.length > 0) {
|
||||
const encodedImages = images.map((image) => encodeURIComponent(image));
|
||||
props.setImages(encodedImages);
|
||||
}
|
||||
}, [image, props.setImage64]);
|
||||
}, [images, props.setImages]);
|
||||
|
||||
useEffect(() => {
|
||||
if (message) {
|
||||
@@ -78,21 +85,23 @@ function ChatBodyData(props: ChatBodyDataProps) {
|
||||
setTitle={props.setTitle}
|
||||
pendingMessage={processingMessage ? message : ""}
|
||||
incomingMessages={props.streamedMessages}
|
||||
customClassName={chatHistoryCustomClassName}
|
||||
/>
|
||||
</div>
|
||||
<div
|
||||
className={`${styles.inputBox} p-1 md:px-2 shadow-md bg-background align-middle items-center justify-center dark:bg-neutral-700 dark:border-0 dark:shadow-sm rounded-t-2xl rounded-b-none md:rounded-xl`}
|
||||
className={`${styles.inputBox} p-1 md:px-2 shadow-md bg-background align-middle items-center justify-center dark:bg-neutral-700 dark:border-0 dark:shadow-sm rounded-t-2xl rounded-b-none md:rounded-xl h-fit ${chatHistoryCustomClassName} mr-auto ml-auto`}
|
||||
>
|
||||
<ChatInputArea
|
||||
isLoggedIn={props.isLoggedIn}
|
||||
sendMessage={(message) => setMessage(message)}
|
||||
sendImage={(image) => setImage(image)}
|
||||
sendImage={(image) => setImages((prevImages) => [...prevImages, image])}
|
||||
sendDisabled={processingMessage}
|
||||
chatOptionsData={props.chatOptionsData}
|
||||
conversationId={props.conversationId}
|
||||
agentColor={agentMetadata?.color}
|
||||
isMobileWidth={props.isMobileWidth}
|
||||
setUploadedFiles={props.setUploadedFiles}
|
||||
ref={chatInputRef}
|
||||
/>
|
||||
</div>
|
||||
</>
|
||||
@@ -106,14 +115,10 @@ export default function SharedChat() {
|
||||
const [conversationId, setConversationID] = useState<string | undefined>(undefined);
|
||||
const [messages, setMessages] = useState<StreamMessage[]>([]);
|
||||
const [queryToProcess, setQueryToProcess] = useState<string>("");
|
||||
const [processQuerySignal, setProcessQuerySignal] = useState(false);
|
||||
const [uploadedFiles, setUploadedFiles] = useState<string[]>([]);
|
||||
const [uploadedFiles, setUploadedFiles] = useState<AttachedFileText[] | null>(null);
|
||||
const [paramSlug, setParamSlug] = useState<string | undefined>(undefined);
|
||||
const [image64, setImage64] = useState<string>("");
|
||||
const [images, setImages] = useState<string[]>([]);
|
||||
|
||||
const locationData = useIPLocationData() || {
|
||||
timezone: Intl.DateTimeFormat().resolvedOptions().timeZone,
|
||||
};
|
||||
const authenticatedData = useAuthenticatedData();
|
||||
const isMobileWidth = useIsMobileWidth();
|
||||
|
||||
@@ -137,6 +142,12 @@ export default function SharedChat() {
|
||||
setParamSlug(window.location.pathname.split("/").pop() || "");
|
||||
}, []);
|
||||
|
||||
useEffect(() => {
|
||||
if (uploadedFiles) {
|
||||
localStorage.setItem("uploadedFiles", JSON.stringify(uploadedFiles));
|
||||
}
|
||||
}, [uploadedFiles]);
|
||||
|
||||
useEffect(() => {
|
||||
if (queryToProcess && !conversationId) {
|
||||
// If the user has not yet started conversing in the chat, create a new conversation
|
||||
@@ -149,6 +160,11 @@ export default function SharedChat() {
|
||||
.then((response) => response.json())
|
||||
.then((data) => {
|
||||
setConversationID(data.conversation_id);
|
||||
localStorage.setItem("message", queryToProcess);
|
||||
if (images.length > 0) {
|
||||
localStorage.setItem("images", JSON.stringify(images));
|
||||
}
|
||||
window.location.href = `/chat?conversationId=${data.conversation_id}`;
|
||||
})
|
||||
.catch((err) => {
|
||||
console.error(err);
|
||||
@@ -156,104 +172,8 @@ export default function SharedChat() {
|
||||
});
|
||||
return;
|
||||
}
|
||||
|
||||
if (queryToProcess) {
|
||||
// Add a new object to the state
|
||||
const newStreamMessage: StreamMessage = {
|
||||
rawResponse: "",
|
||||
trainOfThought: [],
|
||||
context: [],
|
||||
onlineContext: {},
|
||||
completed: false,
|
||||
timestamp: new Date().toISOString(),
|
||||
rawQuery: queryToProcess || "",
|
||||
uploadedImageData: decodeURIComponent(image64),
|
||||
};
|
||||
setMessages((prevMessages) => [...prevMessages, newStreamMessage]);
|
||||
setProcessQuerySignal(true);
|
||||
}
|
||||
}, [queryToProcess, conversationId, paramSlug]);
|
||||
|
||||
useEffect(() => {
|
||||
if (processQuerySignal) {
|
||||
chat();
|
||||
}
|
||||
}, [processQuerySignal]);
|
||||
|
||||
async function readChatStream(response: Response) {
|
||||
if (!response.ok) throw new Error(response.statusText);
|
||||
if (!response.body) throw new Error("Response body is null");
|
||||
|
||||
const reader = response.body.getReader();
|
||||
const decoder = new TextDecoder();
|
||||
const eventDelimiter = "␃🔚␗";
|
||||
let buffer = "";
|
||||
|
||||
while (true) {
|
||||
const { done, value } = await reader.read();
|
||||
if (done) {
|
||||
setQueryToProcess("");
|
||||
setProcessQuerySignal(false);
|
||||
setImage64("");
|
||||
break;
|
||||
}
|
||||
|
||||
const chunk = decoder.decode(value, { stream: true });
|
||||
|
||||
buffer += chunk;
|
||||
|
||||
let newEventIndex;
|
||||
while ((newEventIndex = buffer.indexOf(eventDelimiter)) !== -1) {
|
||||
const event = buffer.slice(0, newEventIndex);
|
||||
buffer = buffer.slice(newEventIndex + eventDelimiter.length);
|
||||
if (event) {
|
||||
const currentMessage = messages.find((message) => !message.completed);
|
||||
|
||||
if (!currentMessage) {
|
||||
console.error("No current message found");
|
||||
return;
|
||||
}
|
||||
|
||||
processMessageChunk(event, currentMessage);
|
||||
|
||||
setMessages([...messages]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async function chat() {
|
||||
if (!queryToProcess || !conversationId) return;
|
||||
const chatAPI = "/api/chat?client=web";
|
||||
const chatAPIBody = {
|
||||
q: queryToProcess,
|
||||
conversation_id: conversationId,
|
||||
stream: true,
|
||||
...(locationData && {
|
||||
region: locationData.region,
|
||||
country: locationData.country,
|
||||
city: locationData.city,
|
||||
country_code: locationData.countryCode,
|
||||
timezone: locationData.timezone,
|
||||
}),
|
||||
...(image64 && { image: image64 }),
|
||||
};
|
||||
|
||||
const response = await fetch(chatAPI, {
|
||||
method: "POST",
|
||||
headers: {
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
body: JSON.stringify(chatAPIBody),
|
||||
});
|
||||
|
||||
try {
|
||||
await readChatStream(response);
|
||||
} catch (error) {
|
||||
console.error(error);
|
||||
}
|
||||
}
|
||||
|
||||
if (isLoading) {
|
||||
return <Loading />;
|
||||
}
|
||||
@@ -268,13 +188,26 @@ export default function SharedChat() {
|
||||
<div className={styles.sidePanel}>
|
||||
<SidePanel
|
||||
conversationId={conversationId ?? null}
|
||||
uploadedFiles={uploadedFiles}
|
||||
uploadedFiles={[]}
|
||||
isMobileWidth={isMobileWidth}
|
||||
/>
|
||||
</div>
|
||||
|
||||
<div className={styles.chatBox}>
|
||||
<div className={styles.chatBoxBody}>
|
||||
{!isMobileWidth && title && (
|
||||
<div
|
||||
className={`${styles.chatTitleWrapper} text-nowrap text-ellipsis overflow-hidden max-w-screen-md grid items-top font-bold mr-8 pt-6 col-auto h-fit`}
|
||||
>
|
||||
{title && (
|
||||
<h2
|
||||
className={`text-lg text-ellipsis whitespace-nowrap overflow-x-hidden`}
|
||||
>
|
||||
{title}
|
||||
</h2>
|
||||
)}
|
||||
</div>
|
||||
)}
|
||||
<Suspense fallback={<Loading />}>
|
||||
<ChatBodyData
|
||||
conversationId={conversationId}
|
||||
@@ -286,7 +219,7 @@ export default function SharedChat() {
|
||||
setTitle={setTitle}
|
||||
setUploadedFiles={setUploadedFiles}
|
||||
isMobileWidth={isMobileWidth}
|
||||
setImage64={setImage64}
|
||||
setImages={setImages}
|
||||
/>
|
||||
</Suspense>
|
||||
</div>
|
||||
|
||||
@@ -75,7 +75,7 @@ div.titleBar {
|
||||
div.chatBoxBody {
|
||||
display: grid;
|
||||
height: 100%;
|
||||
width: 70%;
|
||||
width: 95%;
|
||||
margin: auto;
|
||||
}
|
||||
|
||||
|
||||
@@ -9,7 +9,7 @@ const Textarea = React.forwardRef<HTMLTextAreaElement, TextareaProps>(
|
||||
return (
|
||||
<textarea
|
||||
className={cn(
|
||||
"flex min-h-[80px] w-full rounded-md border border-input bg-background px-3 py-2 text-sm ring-offset-background placeholder:text-muted-foreground focus-visible:outline-none focus-visible:ring-2 focus-visible:ring-ring focus-visible:ring-offset-2 disabled:cursor-not-allowed disabled:opacity-50",
|
||||
"flex min-h-[80px] w-full rounded-md border border-input bg-background px-3 py-2 text-sm ring-offset-background placeholder:text-muted-foreground disabled:cursor-not-allowed disabled:opacity-50",
|
||||
className,
|
||||
)}
|
||||
ref={ref}
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
{
|
||||
"name": "khoj-ai",
|
||||
"version": "1.26.2",
|
||||
"version": "1.29.1",
|
||||
"private": true,
|
||||
"scripts": {
|
||||
"dev": "next dev",
|
||||
@@ -19,6 +19,7 @@
|
||||
"prepare": "husky"
|
||||
},
|
||||
"dependencies": {
|
||||
"@excalidraw/excalidraw": "^0.17.6",
|
||||
"@hookform/resolvers": "^3.9.0",
|
||||
"@phosphor-icons/react": "^2.1.7",
|
||||
"@radix-ui/react-alert-dialog": "^1.1.1",
|
||||
|
||||
@@ -1,34 +1,49 @@
|
||||
import type { Config } from "tailwindcss"
|
||||
import type { Config } from "tailwindcss";
|
||||
|
||||
const config = {
|
||||
safelist: [
|
||||
{
|
||||
pattern: /to-(blue|yellow|green|pink|purple|orange|red|slate|gray|zinc|neutral|stone|amber|lime|green|emerald|teal|cyan|sky|blue|indigo|violet|fuchsia|rose)-(50|100|200|950)/,
|
||||
variants: ['dark'],
|
||||
pattern:
|
||||
/to-(blue|yellow|green|pink|purple|orange|red|slate|gray|zinc|neutral|stone|amber|lime|green|emerald|teal|cyan|sky|blue|indigo|violet|fuchsia|rose)-(50|100|200|950)/,
|
||||
variants: ["dark"],
|
||||
},
|
||||
{
|
||||
pattern: /text-(blue|yellow|green|pink|purple|orange|red|slate|gray|zinc|neutral|stone|amber|lime|green|emerald|teal|cyan|sky|blue|indigo|violet|fuchsia|rose)-(50|100|200|950)/,
|
||||
variants: ['dark'],
|
||||
pattern:
|
||||
/text-(blue|yellow|green|pink|purple|orange|red|slate|gray|zinc|neutral|stone|amber|lime|green|emerald|teal|cyan|sky|blue|indigo|violet|fuchsia|rose)-(50|100|200|950)/,
|
||||
variants: ["dark"],
|
||||
},
|
||||
{
|
||||
pattern: /border-(blue|yellow|green|pink|purple|orange|red|slate|gray|zinc|neutral|stone|amber|lime|green|emerald|teal|cyan|sky|blue|indigo|violet|fuchsia|rose)-(50|100|200|950)/,
|
||||
variants: ['dark'],
|
||||
pattern:
|
||||
/border-(blue|yellow|green|pink|purple|orange|red|slate|gray|zinc|neutral|stone|amber|lime|green|emerald|teal|cyan|sky|blue|indigo|violet|fuchsia|rose)-(50|100|200|950)/,
|
||||
variants: ["dark"],
|
||||
},
|
||||
{
|
||||
pattern: /border-l-(blue|yellow|green|pink|purple|orange|red|slate|gray|zinc|neutral|stone|amber|lime|green|emerald|teal|cyan|sky|blue|indigo|violet|fuchsia|rose)-(50|100|200|400|500|950)/,
|
||||
variants: ['dark'],
|
||||
pattern:
|
||||
/border-l-(blue|yellow|green|pink|purple|orange|red|slate|gray|zinc|neutral|stone|amber|lime|green|emerald|teal|cyan|sky|blue|indigo|violet|fuchsia|rose)-(50|100|200|400|500|950)/,
|
||||
variants: ["dark"],
|
||||
},
|
||||
{
|
||||
pattern: /bg-(blue|yellow|green|pink|purple|orange|red|slate|gray|zinc|neutral|stone|amber|lime|green|emerald|teal|cyan|sky|blue|indigo|violet|fuchsia|rose)-(50|100|200|400|500|950)/,
|
||||
variants: ['dark'],
|
||||
}
|
||||
pattern:
|
||||
/bg-(blue|yellow|green|pink|purple|orange|red|slate|gray|zinc|neutral|stone|amber|lime|green|emerald|teal|cyan|sky|blue|indigo|violet|fuchsia|rose)-(50|100|200|400|500|950)/,
|
||||
variants: ["dark"],
|
||||
},
|
||||
{
|
||||
pattern:
|
||||
/ring-(blue|yellow|green|pink|purple|orange|red|slate|gray|zinc|neutral|stone|amber|lime|green|emerald|teal|cyan|sky|blue|indigo|violet|fuchsia|rose)-(50|100|200|400|500|950)/,
|
||||
variants: ["focus-visible", "dark"],
|
||||
},
|
||||
{
|
||||
pattern:
|
||||
/caret-(blue|yellow|green|pink|purple|orange|red|slate|gray|zinc|neutral|stone|amber|lime|green|emerald|teal|cyan|sky|blue|indigo|violet|fuchsia|rose)-(50|100|200|400|500|950)/,
|
||||
variants: ["focus", "dark"],
|
||||
},
|
||||
],
|
||||
darkMode: ["class"],
|
||||
content: [
|
||||
'./pages/**/*.{ts,tsx}',
|
||||
'./components/**/*.{ts,tsx}',
|
||||
'./app/**/*.{ts,tsx}',
|
||||
'./src/**/*.{ts,tsx}',
|
||||
"./pages/**/*.{ts,tsx}",
|
||||
"./components/**/*.{ts,tsx}",
|
||||
"./app/**/*.{ts,tsx}",
|
||||
"./src/**/*.{ts,tsx}",
|
||||
],
|
||||
prefix: "",
|
||||
theme: {
|
||||
@@ -101,9 +116,7 @@ const config = {
|
||||
},
|
||||
},
|
||||
},
|
||||
plugins: [
|
||||
require("tailwindcss-animate"),
|
||||
],
|
||||
} satisfies Config
|
||||
plugins: [require("tailwindcss-animate")],
|
||||
} satisfies Config;
|
||||
|
||||
export default config
|
||||
export default config;
|
||||
|
||||
@@ -286,6 +286,11 @@
|
||||
resolved "https://registry.yarnpkg.com/@eslint/js/-/js-8.57.1.tgz#de633db3ec2ef6a3c89e2f19038063e8a122e2c2"
|
||||
integrity sha512-d9zaMRSTIKDLhctzH12MtXvJKSSUhaHcjV+2Z+GK+EEY7XKpP5yR4x+N3TAcHTcu963nIr+TMcCb4DBCYX1z6Q==
|
||||
|
||||
"@excalidraw/excalidraw@^0.17.6":
|
||||
version "0.17.6"
|
||||
resolved "https://registry.yarnpkg.com/@excalidraw/excalidraw/-/excalidraw-0.17.6.tgz#5fd208ce69d33ca712d1804b50d7d06d5c46ac4d"
|
||||
integrity sha512-fyCl+zG/Z5yhHDh5Fq2ZGmphcrALmuOdtITm8gN4d8w4ntnaopTXcTfnAAaU3VleDC6LhTkoLOTG6P5kgREiIg==
|
||||
|
||||
"@floating-ui/core@^1.6.0":
|
||||
version "1.6.8"
|
||||
resolved "https://registry.yarnpkg.com/@floating-ui/core/-/core-1.6.8.tgz#aa43561be075815879305965020f492cdb43da12"
|
||||
|
||||
@@ -108,7 +108,7 @@ class UserAuthenticationBackend(AuthenticationBackend):
|
||||
password="default",
|
||||
)
|
||||
renewal_date = make_aware(datetime.strptime("2100-04-01", "%Y-%m-%d"))
|
||||
Subscription.objects.create(user=default_user, type="standard", renewal_date=renewal_date)
|
||||
Subscription.objects.create(user=default_user, type=Subscription.Type.STANDARD, renewal_date=renewal_date)
|
||||
|
||||
async def authenticate(self, request: HTTPConnection):
|
||||
current_user = request.session.get("user")
|
||||
@@ -168,12 +168,6 @@ class UserAuthenticationBackend(AuthenticationBackend):
|
||||
if create_if_not_exists:
|
||||
user, is_new = await aget_or_create_user_by_phone_number(phone_number)
|
||||
if user and is_new:
|
||||
update_telemetry_state(
|
||||
request=request,
|
||||
telemetry_type="api",
|
||||
api="create_user",
|
||||
metadata={"user_id": str(user.uuid)},
|
||||
)
|
||||
logger.log(logging.INFO, f"🥳 New User Created: {user.uuid}")
|
||||
else:
|
||||
user = await aget_user_by_phone_number(phone_number)
|
||||
@@ -255,31 +249,35 @@ def configure_server(
|
||||
state.search_models = configure_search(state.search_models, state.config.search_type)
|
||||
setup_default_agent(user)
|
||||
|
||||
message = "📡 Telemetry disabled" if telemetry_disabled(state.config.app) else "📡 Telemetry enabled"
|
||||
message = (
|
||||
"📡 Telemetry disabled"
|
||||
if telemetry_disabled(state.config.app, state.telemetry_disabled)
|
||||
else "📡 Telemetry enabled"
|
||||
)
|
||||
logger.info(message)
|
||||
|
||||
if not init:
|
||||
initialize_content(regenerate, search_type, user)
|
||||
initialize_content(user, regenerate, search_type)
|
||||
|
||||
except Exception as e:
|
||||
raise e
|
||||
logger.error(f"Failed to load some search models: {e}", exc_info=True)
|
||||
|
||||
|
||||
def setup_default_agent(user: KhojUser):
|
||||
AgentAdapters.create_default_agent(user)
|
||||
|
||||
|
||||
def initialize_content(regenerate: bool, search_type: Optional[SearchType] = None, user: KhojUser = None):
|
||||
def initialize_content(user: KhojUser, regenerate: bool, search_type: Optional[SearchType] = None):
|
||||
# Initialize Content from Config
|
||||
if state.search_models:
|
||||
try:
|
||||
logger.info("📬 Updating content index...")
|
||||
all_files = collect_files(user=user)
|
||||
status = configure_content(
|
||||
user,
|
||||
all_files,
|
||||
regenerate,
|
||||
search_type,
|
||||
user=user,
|
||||
)
|
||||
if not status:
|
||||
raise RuntimeError("Failed to update content index")
|
||||
@@ -312,7 +310,7 @@ def configure_routes(app):
|
||||
logger.info("🔑 Enabled Authentication")
|
||||
|
||||
if state.billing_enabled:
|
||||
from khoj.routers.subscription import subscription_router
|
||||
from khoj.routers.api_subscription import subscription_router
|
||||
|
||||
app.include_router(subscription_router, prefix="/api/subscription")
|
||||
logger.info("💳 Enabled Billing")
|
||||
@@ -344,9 +342,7 @@ def configure_middleware(app):
|
||||
def update_content_index():
|
||||
for user in get_all_users():
|
||||
all_files = collect_files(user=user)
|
||||
success = configure_content(all_files, user=user)
|
||||
all_files = collect_files(user=None)
|
||||
success = configure_content(all_files, user=None)
|
||||
success = configure_content(user, all_files)
|
||||
if not success:
|
||||
raise RuntimeError("Failed to update content index")
|
||||
logger.info("📪 Content index updated via Scheduler")
|
||||
@@ -369,7 +365,7 @@ def configure_search_types():
|
||||
|
||||
@schedule.repeat(schedule.every(2).minutes)
|
||||
def upload_telemetry():
|
||||
if telemetry_disabled(state.config.app) or not state.telemetry:
|
||||
if telemetry_disabled(state.config.app, state.telemetry_disabled) or not state.telemetry:
|
||||
return
|
||||
|
||||
try:
|
||||
|
||||
@@ -8,13 +8,22 @@ import secrets
|
||||
import sys
|
||||
from datetime import date, datetime, timedelta, timezone
|
||||
from enum import Enum
|
||||
from typing import Callable, Iterable, List, Optional, Type
|
||||
from functools import wraps
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
Coroutine,
|
||||
Iterable,
|
||||
List,
|
||||
Optional,
|
||||
ParamSpec,
|
||||
TypeVar,
|
||||
)
|
||||
|
||||
import cron_descriptor
|
||||
from apscheduler.job import Job
|
||||
from asgiref.sync import sync_to_async
|
||||
from django.contrib.sessions.backends.db import SessionStore
|
||||
from django.db import models
|
||||
from django.db.models import Prefetch, Q
|
||||
from django.db.models.manager import BaseManager
|
||||
from django.db.utils import IntegrityError
|
||||
@@ -28,7 +37,6 @@ from khoj.database.models import (
|
||||
ChatModelOptions,
|
||||
ClientApplication,
|
||||
Conversation,
|
||||
DataStore,
|
||||
Entry,
|
||||
FileObject,
|
||||
GithubConfig,
|
||||
@@ -48,7 +56,6 @@ from khoj.database.models import (
|
||||
TextToImageModelConfig,
|
||||
UserConversationConfig,
|
||||
UserRequests,
|
||||
UserSearchModelConfig,
|
||||
UserTextToImageModelConfig,
|
||||
UserVoiceModelConfig,
|
||||
VoiceModelOption,
|
||||
@@ -70,6 +77,9 @@ from khoj.utils.helpers import (
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
LENGTH_OF_FREE_TRIAL = 7 #
|
||||
|
||||
|
||||
class SubscriptionState(Enum):
|
||||
TRIAL = "trial"
|
||||
SUBSCRIBED = "subscribed"
|
||||
@@ -78,6 +88,45 @@ class SubscriptionState(Enum):
|
||||
INVALID = "invalid"
|
||||
|
||||
|
||||
P = ParamSpec("P")
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
def require_valid_user(func: Callable[P, T]) -> Callable[P, T]:
|
||||
@wraps(func)
|
||||
def sync_wrapper(*args: P.args, **kwargs: P.kwargs) -> T:
|
||||
# Extract user from args/kwargs
|
||||
user = next((arg for arg in args if isinstance(arg, KhojUser)), None)
|
||||
if not user:
|
||||
user = next((val for val in kwargs.values() if isinstance(val, KhojUser)), None)
|
||||
|
||||
# Throw error if user is not found
|
||||
if not user:
|
||||
raise ValueError("Khoj user argument required but not provided.")
|
||||
|
||||
return func(*args, **kwargs)
|
||||
|
||||
return sync_wrapper
|
||||
|
||||
|
||||
def arequire_valid_user(func: Callable[P, Coroutine[Any, Any, T]]) -> Callable[P, Coroutine[Any, Any, T]]:
|
||||
@wraps(func)
|
||||
async def async_wrapper(*args: P.args, **kwargs: P.kwargs) -> T:
|
||||
# Extract user from args/kwargs
|
||||
user = next((arg for arg in args if isinstance(arg, KhojUser)), None)
|
||||
if not user:
|
||||
user = next((v for v in kwargs.values() if isinstance(v, KhojUser)), None)
|
||||
|
||||
# Throw error if user is not found
|
||||
if not user:
|
||||
raise ValueError("Khoj user argument required but not provided.")
|
||||
|
||||
return await func(*args, **kwargs)
|
||||
|
||||
return async_wrapper
|
||||
|
||||
|
||||
@arequire_valid_user
|
||||
async def set_notion_config(token: str, user: KhojUser):
|
||||
notion_config = await NotionConfig.objects.filter(user=user).afirst()
|
||||
if not notion_config:
|
||||
@@ -88,6 +137,7 @@ async def set_notion_config(token: str, user: KhojUser):
|
||||
return notion_config
|
||||
|
||||
|
||||
@require_valid_user
|
||||
def create_khoj_token(user: KhojUser, name=None):
|
||||
"Create Khoj API key for user"
|
||||
token = f"kk-{secrets.token_urlsafe(32)}"
|
||||
@@ -95,6 +145,7 @@ def create_khoj_token(user: KhojUser, name=None):
|
||||
return KhojApiUser.objects.create(token=token, user=user, name=name)
|
||||
|
||||
|
||||
@arequire_valid_user
|
||||
async def acreate_khoj_token(user: KhojUser, name=None):
|
||||
"Create Khoj API key for user"
|
||||
token = f"kk-{secrets.token_urlsafe(32)}"
|
||||
@@ -102,11 +153,13 @@ async def acreate_khoj_token(user: KhojUser, name=None):
|
||||
return await KhojApiUser.objects.acreate(token=token, user=user, name=name)
|
||||
|
||||
|
||||
@require_valid_user
|
||||
def get_khoj_tokens(user: KhojUser):
|
||||
"Get all Khoj API keys for user"
|
||||
return list(KhojApiUser.objects.filter(user=user))
|
||||
|
||||
|
||||
@arequire_valid_user
|
||||
async def delete_khoj_token(user: KhojUser, token: str):
|
||||
"Delete Khoj API Key for user"
|
||||
await KhojApiUser.objects.filter(token=token, user=user).adelete()
|
||||
@@ -130,6 +183,7 @@ async def aget_or_create_user_by_phone_number(phone_number: str) -> tuple[KhojUs
|
||||
return user, is_new
|
||||
|
||||
|
||||
@arequire_valid_user
|
||||
async def aset_user_phone_number(user: KhojUser, phone_number: str) -> KhojUser:
|
||||
if is_none_or_empty(phone_number):
|
||||
return None
|
||||
@@ -153,6 +207,7 @@ async def aset_user_phone_number(user: KhojUser, phone_number: str) -> KhojUser:
|
||||
return user
|
||||
|
||||
|
||||
@arequire_valid_user
|
||||
async def aremove_phone_number(user: KhojUser) -> KhojUser:
|
||||
user.phone_number = None
|
||||
user.verified_phone_number = False
|
||||
@@ -168,7 +223,7 @@ async def acreate_user_by_phone_number(phone_number: str) -> KhojUser:
|
||||
)
|
||||
await user.asave()
|
||||
|
||||
await Subscription.objects.acreate(user=user, type="trial")
|
||||
await Subscription.objects.acreate(user=user, type=Subscription.Type.STANDARD)
|
||||
|
||||
return user
|
||||
|
||||
@@ -185,11 +240,30 @@ async def aget_or_create_user_by_email(email: str) -> tuple[KhojUser, bool]:
|
||||
|
||||
user_subscription = await Subscription.objects.filter(user=user).afirst()
|
||||
if not user_subscription:
|
||||
await Subscription.objects.acreate(user=user, type="trial")
|
||||
await Subscription.objects.acreate(user=user, type=Subscription.Type.STANDARD)
|
||||
|
||||
return user, is_new
|
||||
|
||||
|
||||
@arequire_valid_user
|
||||
async def astart_trial_subscription(user: KhojUser) -> Subscription:
|
||||
subscription = await Subscription.objects.filter(user=user).afirst()
|
||||
if not subscription:
|
||||
raise HTTPException(status_code=400, detail="User does not have a subscription")
|
||||
|
||||
if subscription.type == Subscription.Type.TRIAL:
|
||||
raise HTTPException(status_code=400, detail="User already has a trial subscription")
|
||||
|
||||
if subscription.enabled_trial_at:
|
||||
raise HTTPException(status_code=400, detail="User already has a trial subscription")
|
||||
|
||||
subscription.type = Subscription.Type.TRIAL
|
||||
subscription.enabled_trial_at = datetime.now(tz=timezone.utc)
|
||||
subscription.renewal_date = datetime.now(tz=timezone.utc) + timedelta(days=LENGTH_OF_FREE_TRIAL)
|
||||
await subscription.asave()
|
||||
return subscription
|
||||
|
||||
|
||||
async def aget_user_validated_by_email_verification_code(code: str) -> KhojUser:
|
||||
user = await KhojUser.objects.filter(email_verification_code=code).afirst()
|
||||
if not user:
|
||||
@@ -221,11 +295,12 @@ async def create_user_by_google_token(token: dict) -> KhojUser:
|
||||
user=user,
|
||||
)
|
||||
|
||||
await Subscription.objects.acreate(user=user, type="trial")
|
||||
await Subscription.objects.acreate(user=user, type=Subscription.Type.STANDARD)
|
||||
|
||||
return user
|
||||
|
||||
|
||||
@require_valid_user
|
||||
def set_user_name(user: KhojUser, first_name: str, last_name: str) -> KhojUser:
|
||||
user.first_name = first_name
|
||||
user.last_name = last_name
|
||||
@@ -233,6 +308,7 @@ def set_user_name(user: KhojUser, first_name: str, last_name: str) -> KhojUser:
|
||||
return user
|
||||
|
||||
|
||||
@require_valid_user
|
||||
def get_user_name(user: KhojUser):
|
||||
full_name = user.get_full_name()
|
||||
if not is_none_or_empty(full_name):
|
||||
@@ -244,6 +320,7 @@ def get_user_name(user: KhojUser):
|
||||
return None
|
||||
|
||||
|
||||
@require_valid_user
|
||||
def get_user_photo(user: KhojUser):
|
||||
google_profile: GoogleUser = GoogleUser.objects.filter(user=user).first()
|
||||
if google_profile:
|
||||
@@ -279,16 +356,20 @@ def subscription_to_state(subscription: Subscription) -> str:
|
||||
if not subscription:
|
||||
return SubscriptionState.INVALID.value
|
||||
elif subscription.type == Subscription.Type.TRIAL:
|
||||
# Trial subscription is valid for 7 days
|
||||
if datetime.now(tz=timezone.utc) - subscription.created_at > timedelta(days=14):
|
||||
return SubscriptionState.EXPIRED.value
|
||||
# Check if the trial has expired
|
||||
if not subscription.renewal_date:
|
||||
# If the renewal date is not set, set it to the current date + trial length and evaluate
|
||||
subscription.renewal_date = subscription.created_at + timedelta(days=LENGTH_OF_FREE_TRIAL)
|
||||
subscription.save()
|
||||
|
||||
if subscription.renewal_date and datetime.now(tz=timezone.utc) > subscription.renewal_date:
|
||||
return SubscriptionState.EXPIRED.value
|
||||
return SubscriptionState.TRIAL.value
|
||||
elif subscription.is_recurring and subscription.renewal_date >= datetime.now(tz=timezone.utc):
|
||||
elif subscription.is_recurring and subscription.renewal_date > datetime.now(tz=timezone.utc):
|
||||
return SubscriptionState.SUBSCRIBED.value
|
||||
elif not subscription.is_recurring and subscription.renewal_date is None:
|
||||
return SubscriptionState.EXPIRED.value
|
||||
elif not subscription.is_recurring and subscription.renewal_date >= datetime.now(tz=timezone.utc):
|
||||
elif not subscription.is_recurring and subscription.renewal_date > datetime.now(tz=timezone.utc):
|
||||
return SubscriptionState.UNSUBSCRIBED.value
|
||||
elif not subscription.is_recurring and subscription.renewal_date < datetime.now(tz=timezone.utc):
|
||||
return SubscriptionState.EXPIRED.value
|
||||
@@ -303,14 +384,16 @@ def get_user_subscription_state(email: str) -> str:
|
||||
return subscription_to_state(user_subscription)
|
||||
|
||||
|
||||
@arequire_valid_user
|
||||
async def aget_user_subscription_state(user: KhojUser) -> str:
|
||||
"""Get subscription state of user
|
||||
Valid state transitions: trial -> subscribed <-> unsubscribed OR expired
|
||||
"""
|
||||
user_subscription = await Subscription.objects.filter(user=user).afirst()
|
||||
return subscription_to_state(user_subscription)
|
||||
return await sync_to_async(subscription_to_state)(user_subscription)
|
||||
|
||||
|
||||
@arequire_valid_user
|
||||
async def ais_user_subscribed(user: KhojUser) -> bool:
|
||||
"""
|
||||
Get whether the user is subscribed
|
||||
@@ -327,6 +410,7 @@ async def ais_user_subscribed(user: KhojUser) -> bool:
|
||||
return subscribed
|
||||
|
||||
|
||||
@require_valid_user
|
||||
def is_user_subscribed(user: KhojUser) -> bool:
|
||||
"""
|
||||
Get whether the user is subscribed
|
||||
@@ -392,11 +476,13 @@ def get_all_users() -> BaseManager[KhojUser]:
|
||||
return KhojUser.objects.all()
|
||||
|
||||
|
||||
@require_valid_user
|
||||
def get_user_github_config(user: KhojUser):
|
||||
config = GithubConfig.objects.filter(user=user).prefetch_related("githubrepoconfig").first()
|
||||
return config
|
||||
|
||||
|
||||
@require_valid_user
|
||||
def get_user_notion_config(user: KhojUser):
|
||||
config = NotionConfig.objects.filter(user=user).first()
|
||||
return config
|
||||
@@ -406,6 +492,7 @@ def delete_user_requests(window: timedelta = timedelta(days=1)):
|
||||
return UserRequests.objects.filter(created_at__lte=datetime.now(tz=timezone.utc) - window).delete()
|
||||
|
||||
|
||||
@arequire_valid_user
|
||||
async def aget_user_name(user: KhojUser):
|
||||
full_name = user.get_full_name()
|
||||
if not is_none_or_empty(full_name):
|
||||
@@ -417,18 +504,7 @@ async def aget_user_name(user: KhojUser):
|
||||
return None
|
||||
|
||||
|
||||
async def set_text_content_config(user: KhojUser, object: Type[models.Model], updated_config):
|
||||
deduped_files = list(set(updated_config.input_files)) if updated_config.input_files else None
|
||||
deduped_filters = list(set(updated_config.input_filter)) if updated_config.input_filter else None
|
||||
await object.objects.filter(user=user).adelete()
|
||||
await object.objects.acreate(
|
||||
input_files=deduped_files,
|
||||
input_filter=deduped_filters,
|
||||
index_heading_entries=updated_config.index_heading_entries,
|
||||
user=user,
|
||||
)
|
||||
|
||||
|
||||
@arequire_valid_user
|
||||
async def set_user_github_config(user: KhojUser, pat_token: str, repos: list):
|
||||
config = await GithubConfig.objects.filter(user=user).afirst()
|
||||
|
||||
@@ -446,15 +522,13 @@ async def set_user_github_config(user: KhojUser, pat_token: str, repos: list):
|
||||
return config
|
||||
|
||||
|
||||
def get_user_search_model_or_default(user=None):
|
||||
if user and UserSearchModelConfig.objects.filter(user=user).exists():
|
||||
return UserSearchModelConfig.objects.filter(user=user).first().setting
|
||||
def get_default_search_model() -> SearchModelConfig:
|
||||
default_search_model = SearchModelConfig.objects.filter(name="default").first()
|
||||
|
||||
if SearchModelConfig.objects.filter(name="default").exists():
|
||||
return SearchModelConfig.objects.filter(name="default").first()
|
||||
else:
|
||||
if default_search_model:
|
||||
return default_search_model
|
||||
elif SearchModelConfig.objects.count() == 0:
|
||||
SearchModelConfig.objects.create()
|
||||
|
||||
return SearchModelConfig.objects.first()
|
||||
|
||||
|
||||
@@ -467,21 +541,6 @@ def get_or_create_search_models():
|
||||
return search_models
|
||||
|
||||
|
||||
async def aset_user_search_model(user: KhojUser, search_model_config_id: int):
|
||||
config = await SearchModelConfig.objects.filter(id=search_model_config_id).afirst()
|
||||
if not config:
|
||||
return None
|
||||
new_config, _ = await UserSearchModelConfig.objects.aupdate_or_create(user=user, defaults={"setting": config})
|
||||
return new_config
|
||||
|
||||
|
||||
async def aget_user_search_model(user: KhojUser):
|
||||
config = await UserSearchModelConfig.objects.filter(user=user).prefetch_related("setting").afirst()
|
||||
if not config:
|
||||
return None
|
||||
return config.setting
|
||||
|
||||
|
||||
class ProcessLockAdapters:
|
||||
@staticmethod
|
||||
def get_process_lock(process_name: str):
|
||||
@@ -580,8 +639,11 @@ class AgentAdapters:
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
@arequire_valid_user
|
||||
async def adelete_agent_by_slug(agent_slug: str, user: KhojUser):
|
||||
agent = await AgentAdapters.aget_agent_by_slug(agent_slug, user)
|
||||
if agent.creator != user:
|
||||
return False
|
||||
|
||||
async for entry in Entry.objects.filter(agent=agent).aiterator():
|
||||
await entry.adelete()
|
||||
@@ -622,6 +684,8 @@ class AgentAdapters:
|
||||
@staticmethod
|
||||
def get_all_accessible_agents(user: KhojUser = None):
|
||||
public_query = Q(privacy_level=Agent.PrivacyLevel.PUBLIC)
|
||||
# TODO Update this to allow any public agent that's officially approved once that experience is launched
|
||||
public_query &= Q(managed_by_admin=True)
|
||||
if user:
|
||||
return (
|
||||
Agent.objects.filter(public_query | Q(creator=user))
|
||||
@@ -640,6 +704,18 @@ class AgentAdapters:
|
||||
agents = await sync_to_async(AgentAdapters.get_all_accessible_agents)(user)
|
||||
return await sync_to_async(list)(agents)
|
||||
|
||||
@staticmethod
|
||||
async def ais_agent_accessible(agent: Agent, user: KhojUser) -> bool:
|
||||
agent = await Agent.objects.select_related("creator").aget(pk=agent.pk)
|
||||
|
||||
if agent.privacy_level == Agent.PrivacyLevel.PUBLIC:
|
||||
return True
|
||||
if agent.creator == user:
|
||||
return True
|
||||
if agent.privacy_level == Agent.PrivacyLevel.PROTECTED:
|
||||
return True
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def get_conversation_agent_by_id(agent_id: int):
|
||||
agent = Agent.objects.filter(id=agent_id).first()
|
||||
@@ -691,6 +767,7 @@ class AgentAdapters:
|
||||
return await Agent.objects.filter(name=AgentAdapters.DEFAULT_AGENT_NAME).afirst()
|
||||
|
||||
@staticmethod
|
||||
@arequire_valid_user
|
||||
async def aupdate_agent(
|
||||
user: KhojUser,
|
||||
name: str,
|
||||
@@ -766,19 +843,6 @@ class PublicConversationAdapters:
|
||||
return f"/share/chat/{public_conversation.slug}/"
|
||||
|
||||
|
||||
class DataStoreAdapters:
|
||||
@staticmethod
|
||||
async def astore_data(data: dict, key: str, user: KhojUser, private: bool = True):
|
||||
if await DataStore.objects.filter(key=key).aexists():
|
||||
return key
|
||||
await DataStore.objects.acreate(value=data, key=key, owner=user, private=private)
|
||||
return key
|
||||
|
||||
@staticmethod
|
||||
async def aretrieve_public_data(key: str):
|
||||
return await DataStore.objects.filter(key=key, private=False).afirst()
|
||||
|
||||
|
||||
class ConversationAdapters:
|
||||
@staticmethod
|
||||
def make_public_conversation_copy(conversation: Conversation):
|
||||
@@ -791,6 +855,7 @@ class ConversationAdapters:
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
@require_valid_user
|
||||
def get_conversation_by_user(
|
||||
user: KhojUser, client_application: ClientApplication = None, conversation_id: str = None
|
||||
) -> Optional[Conversation]:
|
||||
@@ -809,6 +874,7 @@ class ConversationAdapters:
|
||||
return conversation
|
||||
|
||||
@staticmethod
|
||||
@require_valid_user
|
||||
def get_conversation_sessions(user: KhojUser, client_application: ClientApplication = None):
|
||||
return (
|
||||
Conversation.objects.filter(user=user, client=client_application)
|
||||
@@ -817,6 +883,7 @@ class ConversationAdapters:
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
@arequire_valid_user
|
||||
async def aset_conversation_title(
|
||||
user: KhojUser, client_application: ClientApplication, conversation_id: str, title: str
|
||||
):
|
||||
@@ -834,6 +901,7 @@ class ConversationAdapters:
|
||||
return Conversation.objects.filter(id=conversation_id).first()
|
||||
|
||||
@staticmethod
|
||||
@arequire_valid_user
|
||||
async def acreate_conversation_session(
|
||||
user: KhojUser, client_application: ClientApplication = None, agent_slug: str = None, title: str = None
|
||||
):
|
||||
@@ -841,11 +909,16 @@ class ConversationAdapters:
|
||||
agent = await AgentAdapters.aget_readonly_agent_by_slug(agent_slug, user)
|
||||
if agent is None:
|
||||
raise HTTPException(status_code=400, detail="No such agent currently exists.")
|
||||
return await Conversation.objects.acreate(user=user, client=client_application, agent=agent, title=title)
|
||||
return await Conversation.objects.select_related("agent", "agent__creator", "agent__chat_model").acreate(
|
||||
user=user, client=client_application, agent=agent, title=title
|
||||
)
|
||||
agent = await AgentAdapters.aget_default_agent()
|
||||
return await Conversation.objects.acreate(user=user, client=client_application, agent=agent, title=title)
|
||||
return await Conversation.objects.select_related("agent", "agent__creator", "agent__chat_model").acreate(
|
||||
user=user, client=client_application, agent=agent, title=title
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
@require_valid_user
|
||||
def create_conversation_session(
|
||||
user: KhojUser, client_application: ClientApplication = None, agent_slug: str = None, title: str = None
|
||||
):
|
||||
@@ -858,6 +931,7 @@ class ConversationAdapters:
|
||||
return Conversation.objects.create(user=user, client=client_application, agent=agent, title=title)
|
||||
|
||||
@staticmethod
|
||||
@arequire_valid_user
|
||||
async def aget_conversation_by_user(
|
||||
user: KhojUser,
|
||||
client_application: ClientApplication = None,
|
||||
@@ -882,6 +956,7 @@ class ConversationAdapters:
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
@arequire_valid_user
|
||||
async def adelete_conversation_by_user(
|
||||
user: KhojUser, client_application: ClientApplication = None, conversation_id: str = None
|
||||
):
|
||||
@@ -890,6 +965,7 @@ class ConversationAdapters:
|
||||
return await Conversation.objects.filter(user=user, client=client_application).adelete()
|
||||
|
||||
@staticmethod
|
||||
@require_valid_user
|
||||
def has_any_conversation_config(user: KhojUser):
|
||||
return ChatModelOptions.objects.filter(user=user).exists()
|
||||
|
||||
@@ -926,6 +1002,7 @@ class ConversationAdapters:
|
||||
return OpenAIProcessorConversationConfig.objects.filter().exists()
|
||||
|
||||
@staticmethod
|
||||
@arequire_valid_user
|
||||
async def aset_user_conversation_processor(user: KhojUser, conversation_processor_config_id: int):
|
||||
config = await ChatModelOptions.objects.filter(id=conversation_processor_config_id).afirst()
|
||||
if not config:
|
||||
@@ -934,6 +1011,7 @@ class ConversationAdapters:
|
||||
return new_config
|
||||
|
||||
@staticmethod
|
||||
@arequire_valid_user
|
||||
async def aset_user_voice_model(user: KhojUser, model_id: str):
|
||||
config = await VoiceModelOption.objects.filter(model_id=model_id).afirst()
|
||||
if not config:
|
||||
@@ -984,8 +1062,15 @@ class ConversationAdapters:
|
||||
"""Get default conversation config. Prefer chat model by server admin > user > first created chat model"""
|
||||
# Get the server chat settings
|
||||
server_chat_settings = ServerChatSettings.objects.first()
|
||||
if server_chat_settings is not None and server_chat_settings.chat_default is not None:
|
||||
return server_chat_settings.chat_default
|
||||
|
||||
is_subscribed = is_user_subscribed(user) if user else False
|
||||
if server_chat_settings:
|
||||
# If the user is subscribed and the advanced model is enabled, return the advanced model
|
||||
if is_subscribed and server_chat_settings.chat_advanced:
|
||||
return server_chat_settings.chat_advanced
|
||||
# If the default model is set, return it
|
||||
if server_chat_settings.chat_default:
|
||||
return server_chat_settings.chat_default
|
||||
|
||||
# Get the user's chat settings, if the server chat settings are not set
|
||||
user_chat_settings = UserConversationConfig.objects.filter(user=user).first() if user else None
|
||||
@@ -1001,11 +1086,20 @@ class ConversationAdapters:
|
||||
# Get the server chat settings
|
||||
server_chat_settings: ServerChatSettings = (
|
||||
await ServerChatSettings.objects.filter()
|
||||
.prefetch_related("chat_default", "chat_default__openai_config")
|
||||
.prefetch_related(
|
||||
"chat_default", "chat_default__openai_config", "chat_advanced", "chat_advanced__openai_config"
|
||||
)
|
||||
.afirst()
|
||||
)
|
||||
if server_chat_settings is not None and server_chat_settings.chat_default is not None:
|
||||
return server_chat_settings.chat_default
|
||||
is_subscribed = await ais_user_subscribed(user) if user else False
|
||||
|
||||
if server_chat_settings:
|
||||
# If the user is subscribed and the advanced model is enabled, return the advanced model
|
||||
if is_subscribed and server_chat_settings.chat_advanced:
|
||||
return server_chat_settings.chat_advanced
|
||||
# If the default model is set, return it
|
||||
if server_chat_settings.chat_default:
|
||||
return server_chat_settings.chat_default
|
||||
|
||||
# Get the user's chat settings, if the server chat settings are not set
|
||||
user_chat_settings = (
|
||||
@@ -1102,6 +1196,7 @@ class ConversationAdapters:
|
||||
return enabled_scrapers
|
||||
|
||||
@staticmethod
|
||||
@require_valid_user
|
||||
def create_conversation_from_public_conversation(
|
||||
user: KhojUser, public_conversation: PublicConversation, client_app: ClientApplication
|
||||
):
|
||||
@@ -1118,6 +1213,7 @@ class ConversationAdapters:
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
@require_valid_user
|
||||
def save_conversation(
|
||||
user: KhojUser,
|
||||
conversation_log: dict,
|
||||
@@ -1167,6 +1263,7 @@ class ConversationAdapters:
|
||||
return await SpeechToTextModelOptions.objects.filter().afirst()
|
||||
|
||||
@staticmethod
|
||||
@arequire_valid_user
|
||||
async def aget_conversation_starters(user: KhojUser, max_results=3):
|
||||
all_questions = []
|
||||
if await ReflectiveQuestion.objects.filter(user=user).aexists():
|
||||
@@ -1267,6 +1364,8 @@ class ConversationAdapters:
|
||||
def add_files_to_filter(user: KhojUser, conversation_id: str, files: List[str]):
|
||||
conversation = ConversationAdapters.get_conversation_by_user(user, conversation_id=conversation_id)
|
||||
file_list = EntryAdapters.get_all_filenames_by_source(user, "computer")
|
||||
if not conversation:
|
||||
return []
|
||||
for filename in files:
|
||||
if filename in file_list and filename not in conversation.file_filters:
|
||||
conversation.file_filters.append(filename)
|
||||
@@ -1280,6 +1379,8 @@ class ConversationAdapters:
|
||||
@staticmethod
|
||||
def remove_files_from_filter(user: KhojUser, conversation_id: str, files: List[str]):
|
||||
conversation = ConversationAdapters.get_conversation_by_user(user, conversation_id=conversation_id)
|
||||
if not conversation:
|
||||
return []
|
||||
for filename in files:
|
||||
if filename in conversation.file_filters:
|
||||
conversation.file_filters.remove(filename)
|
||||
@@ -1291,6 +1392,18 @@ class ConversationAdapters:
|
||||
conversation.save()
|
||||
return conversation.file_filters
|
||||
|
||||
@staticmethod
|
||||
@require_valid_user
|
||||
def delete_message_by_turn_id(user: KhojUser, conversation_id: str, turn_id: str):
|
||||
conversation = ConversationAdapters.get_conversation_by_user(user, conversation_id=conversation_id)
|
||||
if not conversation or not conversation.conversation_log or not conversation.conversation_log.get("chat"):
|
||||
return False
|
||||
conversation_log = conversation.conversation_log
|
||||
updated_log = [msg for msg in conversation_log["chat"] if msg.get("turnId") != turn_id]
|
||||
conversation.conversation_log["chat"] = updated_log
|
||||
conversation.save()
|
||||
return True
|
||||
|
||||
|
||||
class FileObjectAdapters:
|
||||
@staticmethod
|
||||
@@ -1299,48 +1412,63 @@ class FileObjectAdapters:
|
||||
file_object.save()
|
||||
|
||||
@staticmethod
|
||||
@require_valid_user
|
||||
def create_file_object(user: KhojUser, file_name: str, raw_text: str):
|
||||
return FileObject.objects.create(user=user, file_name=file_name, raw_text=raw_text)
|
||||
|
||||
@staticmethod
|
||||
@require_valid_user
|
||||
def get_file_object_by_name(user: KhojUser, file_name: str):
|
||||
return FileObject.objects.filter(user=user, file_name=file_name).first()
|
||||
|
||||
@staticmethod
|
||||
@require_valid_user
|
||||
def get_all_file_objects(user: KhojUser):
|
||||
return FileObject.objects.filter(user=user).all()
|
||||
|
||||
@staticmethod
|
||||
@require_valid_user
|
||||
def delete_file_object_by_name(user: KhojUser, file_name: str):
|
||||
return FileObject.objects.filter(user=user, file_name=file_name).delete()
|
||||
|
||||
@staticmethod
|
||||
@require_valid_user
|
||||
def delete_all_file_objects(user: KhojUser):
|
||||
return FileObject.objects.filter(user=user).delete()
|
||||
|
||||
@staticmethod
|
||||
async def async_update_raw_text(file_object: FileObject, new_raw_text: str):
|
||||
async def aupdate_raw_text(file_object: FileObject, new_raw_text: str):
|
||||
file_object.raw_text = new_raw_text
|
||||
await file_object.asave()
|
||||
|
||||
@staticmethod
|
||||
async def async_create_file_object(user: KhojUser, file_name: str, raw_text: str):
|
||||
@arequire_valid_user
|
||||
async def acreate_file_object(user: KhojUser, file_name: str, raw_text: str):
|
||||
return await FileObject.objects.acreate(user=user, file_name=file_name, raw_text=raw_text)
|
||||
|
||||
@staticmethod
|
||||
async def async_get_file_objects_by_name(user: KhojUser, file_name: str, agent: Agent = None):
|
||||
@arequire_valid_user
|
||||
async def aget_file_objects_by_name(user: KhojUser, file_name: str, agent: Agent = None):
|
||||
return await sync_to_async(list)(FileObject.objects.filter(user=user, file_name=file_name, agent=agent))
|
||||
|
||||
@staticmethod
|
||||
async def async_get_all_file_objects(user: KhojUser):
|
||||
@arequire_valid_user
|
||||
async def aget_file_objects_by_names(user: KhojUser, file_names: List[str]):
|
||||
return await sync_to_async(list)(FileObject.objects.filter(user=user, file_name__in=file_names))
|
||||
|
||||
@staticmethod
|
||||
@arequire_valid_user
|
||||
async def aget_all_file_objects(user: KhojUser):
|
||||
return await sync_to_async(list)(FileObject.objects.filter(user=user))
|
||||
|
||||
@staticmethod
|
||||
async def async_delete_file_object_by_name(user: KhojUser, file_name: str):
|
||||
@arequire_valid_user
|
||||
async def adelete_file_object_by_name(user: KhojUser, file_name: str):
|
||||
return await FileObject.objects.filter(user=user, file_name=file_name).adelete()
|
||||
|
||||
@staticmethod
|
||||
async def async_delete_all_file_objects(user: KhojUser):
|
||||
@arequire_valid_user
|
||||
async def adelete_all_file_objects(user: KhojUser):
|
||||
return await FileObject.objects.filter(user=user).adelete()
|
||||
|
||||
|
||||
@@ -1350,15 +1478,18 @@ class EntryAdapters:
|
||||
date_filter = DateFilter()
|
||||
|
||||
@staticmethod
|
||||
@require_valid_user
|
||||
def does_entry_exist(user: KhojUser, hashed_value: str) -> bool:
|
||||
return Entry.objects.filter(user=user, hashed_value=hashed_value).exists()
|
||||
|
||||
@staticmethod
|
||||
@require_valid_user
|
||||
def delete_entry_by_file(user: KhojUser, file_path: str):
|
||||
deleted_count, _ = Entry.objects.filter(user=user, file_path=file_path).delete()
|
||||
return deleted_count
|
||||
|
||||
@staticmethod
|
||||
@require_valid_user
|
||||
def get_filtered_entries(user: KhojUser, file_type: str = None, file_source: str = None):
|
||||
queryset = Entry.objects.filter(user=user)
|
||||
|
||||
@@ -1371,6 +1502,7 @@ class EntryAdapters:
|
||||
return queryset
|
||||
|
||||
@staticmethod
|
||||
@require_valid_user
|
||||
def delete_all_entries(user: KhojUser, file_type: str = None, file_source: str = None, batch_size=1000):
|
||||
deleted_count = 0
|
||||
queryset = EntryAdapters.get_filtered_entries(user, file_type, file_source)
|
||||
@@ -1382,6 +1514,7 @@ class EntryAdapters:
|
||||
return deleted_count
|
||||
|
||||
@staticmethod
|
||||
@arequire_valid_user
|
||||
async def adelete_all_entries(user: KhojUser, file_type: str = None, file_source: str = None, batch_size=1000):
|
||||
deleted_count = 0
|
||||
queryset = EntryAdapters.get_filtered_entries(user, file_type, file_source)
|
||||
@@ -1393,10 +1526,12 @@ class EntryAdapters:
|
||||
return deleted_count
|
||||
|
||||
@staticmethod
|
||||
@require_valid_user
|
||||
def get_existing_entry_hashes_by_file(user: KhojUser, file_path: str):
|
||||
return Entry.objects.filter(user=user, file_path=file_path).values_list("hashed_value", flat=True)
|
||||
|
||||
@staticmethod
|
||||
@require_valid_user
|
||||
def delete_entry_by_hash(user: KhojUser, hashed_values: List[str]):
|
||||
Entry.objects.filter(user=user, hashed_value__in=hashed_values).delete()
|
||||
|
||||
@@ -1408,6 +1543,7 @@ class EntryAdapters:
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
@require_valid_user
|
||||
def user_has_entries(user: KhojUser):
|
||||
return Entry.objects.filter(user=user).exists()
|
||||
|
||||
@@ -1416,18 +1552,23 @@ class EntryAdapters:
|
||||
return Entry.objects.filter(agent=agent).exists()
|
||||
|
||||
@staticmethod
|
||||
@arequire_valid_user
|
||||
async def auser_has_entries(user: KhojUser):
|
||||
return await Entry.objects.filter(user=user).aexists()
|
||||
|
||||
@staticmethod
|
||||
async def aagent_has_entries(agent: Agent):
|
||||
if agent is None:
|
||||
return False
|
||||
return await Entry.objects.filter(agent=agent).aexists()
|
||||
|
||||
@staticmethod
|
||||
@arequire_valid_user
|
||||
async def adelete_entry_by_file(user: KhojUser, file_path: str):
|
||||
return await Entry.objects.filter(user=user, file_path=file_path).adelete()
|
||||
|
||||
@staticmethod
|
||||
@arequire_valid_user
|
||||
async def adelete_entries_by_filenames(user: KhojUser, filenames: List[str], batch_size=1000):
|
||||
deleted_count = 0
|
||||
for i in range(0, len(filenames), batch_size):
|
||||
@@ -1439,9 +1580,14 @@ class EntryAdapters:
|
||||
|
||||
@staticmethod
|
||||
async def aget_agent_entry_filepaths(agent: Agent):
|
||||
return await sync_to_async(list)(Entry.objects.filter(agent=agent).values_list("file_path", flat=True))
|
||||
if agent is None:
|
||||
return []
|
||||
return await sync_to_async(set)(
|
||||
Entry.objects.filter(agent=agent).distinct("file_path").values_list("file_path", flat=True)
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
@require_valid_user
|
||||
def get_all_filenames_by_source(user: KhojUser, file_source: str):
|
||||
return (
|
||||
Entry.objects.filter(user=user, file_source=file_source)
|
||||
@@ -1450,6 +1596,7 @@ class EntryAdapters:
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
@require_valid_user
|
||||
def get_size_of_indexed_data_in_mb(user: KhojUser):
|
||||
entries = Entry.objects.filter(user=user).iterator()
|
||||
total_size = sum(sys.getsizeof(entry.compiled) for entry in entries)
|
||||
@@ -1470,6 +1617,9 @@ class EntryAdapters:
|
||||
if agent != None:
|
||||
owner_filter |= Q(agent=agent)
|
||||
|
||||
if owner_filter == Q():
|
||||
return Entry.objects.none()
|
||||
|
||||
if len(word_filters) == 0 and len(file_filters) == 0 and len(date_filters) == 0:
|
||||
return Entry.objects.filter(owner_filter)
|
||||
|
||||
@@ -1514,11 +1664,11 @@ class EntryAdapters:
|
||||
|
||||
@staticmethod
|
||||
def search_with_embeddings(
|
||||
user: KhojUser,
|
||||
raw_query: str,
|
||||
embeddings: Tensor,
|
||||
user: KhojUser,
|
||||
max_results: int = 10,
|
||||
file_type_filter: str = None,
|
||||
raw_query: str = None,
|
||||
max_distance: float = math.inf,
|
||||
agent: Agent = None,
|
||||
):
|
||||
@@ -1544,10 +1694,12 @@ class EntryAdapters:
|
||||
return relevant_entries[:max_results]
|
||||
|
||||
@staticmethod
|
||||
@require_valid_user
|
||||
def get_unique_file_types(user: KhojUser):
|
||||
return Entry.objects.filter(user=user).values_list("file_type", flat=True).distinct()
|
||||
|
||||
@staticmethod
|
||||
@require_valid_user
|
||||
def get_unique_file_sources(user: KhojUser):
|
||||
return Entry.objects.filter(user=user).values_list("file_source", flat=True).distinct().all()
|
||||
|
||||
|
||||
@@ -28,7 +28,6 @@ from khoj.database.models import (
|
||||
TextToImageModelConfig,
|
||||
UserConversationConfig,
|
||||
UserRequests,
|
||||
UserSearchModelConfig,
|
||||
UserVoiceModelConfig,
|
||||
VoiceModelOption,
|
||||
WebScraper,
|
||||
@@ -79,7 +78,12 @@ class KhojUserAdmin(UserAdmin):
|
||||
search_fields = ("email", "username", "phone_number", "uuid")
|
||||
filter_horizontal = ("groups", "user_permissions")
|
||||
|
||||
fieldsets = (("Personal info", {"fields": ("phone_number", "email_verification_code")}),) + UserAdmin.fieldsets
|
||||
fieldsets = (
|
||||
(
|
||||
"Personal info",
|
||||
{"fields": ("phone_number", "email_verification_code", "verified_phone_number", "verified_email")},
|
||||
),
|
||||
) + UserAdmin.fieldsets
|
||||
|
||||
actions = ["get_email_login_url"]
|
||||
|
||||
@@ -99,7 +103,6 @@ admin.site.register(KhojUser, KhojUserAdmin)
|
||||
admin.site.register(ProcessLock)
|
||||
admin.site.register(SpeechToTextModelOptions)
|
||||
admin.site.register(ReflectiveQuestion)
|
||||
admin.site.register(UserSearchModelConfig)
|
||||
admin.site.register(ClientApplication)
|
||||
admin.site.register(GithubConfig)
|
||||
admin.site.register(NotionConfig)
|
||||
@@ -126,6 +129,7 @@ class EntryAdmin(admin.ModelAdmin):
|
||||
"created_at",
|
||||
"updated_at",
|
||||
"user",
|
||||
"agent",
|
||||
"file_source",
|
||||
"file_type",
|
||||
"file_name",
|
||||
@@ -135,6 +139,7 @@ class EntryAdmin(admin.ModelAdmin):
|
||||
list_filter = (
|
||||
"file_type",
|
||||
"user__email",
|
||||
"search_model__name",
|
||||
)
|
||||
ordering = ("-created_at",)
|
||||
|
||||
|
||||
121
src/khoj/database/management/commands/change_default_model.py
Normal file
121
src/khoj/database/management/commands/change_default_model.py
Normal file
@@ -0,0 +1,121 @@
|
||||
import logging
|
||||
from typing import List
|
||||
|
||||
from django.core.management.base import BaseCommand
|
||||
from django.db import transaction
|
||||
from django.db.models import Count, Q
|
||||
from tqdm import tqdm
|
||||
|
||||
from khoj.database.adapters import get_default_search_model
|
||||
from khoj.database.models import Agent, Entry, KhojUser, SearchModelConfig
|
||||
from khoj.processor.embeddings import EmbeddingsModel
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
BATCH_SIZE = 1000 # Define an appropriate batch size
|
||||
|
||||
|
||||
class Command(BaseCommand):
|
||||
help = "Convert all existing Entry objects to use a new default Search model."
|
||||
|
||||
def add_arguments(self, parser):
|
||||
# Pass default SearchModelConfig ID
|
||||
parser.add_argument(
|
||||
"--search_model_id",
|
||||
action="store",
|
||||
help="ID of the SearchModelConfig object to set as the default search model for all existing Entry objects.",
|
||||
required=True,
|
||||
)
|
||||
|
||||
# Set the apply flag to apply the new default Search model to all existing Entry objects.
|
||||
parser.add_argument(
|
||||
"--apply",
|
||||
action="store_true",
|
||||
help="Apply the new default Search model to all existing Entry objects. Otherwise, only display the number of Entry objects that will be affected.",
|
||||
)
|
||||
|
||||
def handle(self, *args, **options):
|
||||
@transaction.atomic
|
||||
def regenerate_entries(entry_filter: Q, embeddings_model: EmbeddingsModel, search_model: SearchModelConfig):
|
||||
total_entries = Entry.objects.filter(entry_filter).count()
|
||||
for start in tqdm(range(0, total_entries, BATCH_SIZE)):
|
||||
end = start + BATCH_SIZE
|
||||
entries = Entry.objects.filter(entry_filter)[start:end]
|
||||
compiled_entries = [entry.compiled for entry in entries]
|
||||
updated_entries: List[Entry] = []
|
||||
try:
|
||||
embeddings = embeddings_model.embed_documents(compiled_entries)
|
||||
except Exception as e:
|
||||
logger.error(f"Error embedding documents: {e}")
|
||||
return
|
||||
|
||||
for i, entry in enumerate(entries):
|
||||
entry.embeddings = embeddings[i]
|
||||
entry.search_model_id = search_model.id
|
||||
updated_entries.append(entry)
|
||||
|
||||
Entry.objects.bulk_update(updated_entries, ["embeddings", "search_model_id", "file_path"])
|
||||
|
||||
search_model_config_id = options.get("search_model_id")
|
||||
apply = options.get("apply")
|
||||
|
||||
logger.info(f"SearchModelConfig ID: {search_model_config_id}")
|
||||
logger.info(f"Apply: {apply}")
|
||||
|
||||
embeddings_model = dict()
|
||||
|
||||
search_models = SearchModelConfig.objects.all()
|
||||
for model in search_models:
|
||||
embeddings_model.update(
|
||||
{
|
||||
model.name: EmbeddingsModel(
|
||||
model.bi_encoder,
|
||||
model.embeddings_inference_endpoint,
|
||||
model.embeddings_inference_endpoint_api_key,
|
||||
query_encode_kwargs=model.bi_encoder_query_encode_config,
|
||||
docs_encode_kwargs=model.bi_encoder_docs_encode_config,
|
||||
model_kwargs=model.bi_encoder_model_config,
|
||||
)
|
||||
}
|
||||
)
|
||||
|
||||
new_default_search_model_config = SearchModelConfig.objects.get(id=search_model_config_id)
|
||||
logger.info(f"New default Search model: {new_default_search_model_config}")
|
||||
|
||||
logger.info("----")
|
||||
|
||||
current_default = get_default_search_model()
|
||||
|
||||
# TODO: Migrate all Entry objects to use the new default Search model
|
||||
|
||||
all_agents = Agent.objects.all()
|
||||
logger.info(f"Number of Agent objects to update: {all_agents.count()}")
|
||||
for agent in all_agents:
|
||||
entry_filter = Q(agent=agent)
|
||||
relevant_entries = Entry.objects.filter(entry_filter).all()
|
||||
logger.info(f"Number of Entry objects to update for agent {agent}: {relevant_entries.count()}")
|
||||
|
||||
if apply:
|
||||
try:
|
||||
regenerate_entries(
|
||||
entry_filter,
|
||||
embeddings_model[new_default_search_model_config.name],
|
||||
new_default_search_model_config,
|
||||
)
|
||||
logger.info(
|
||||
f"Updated {relevant_entries.count()} Entry objects for agent {agent} to use the new default Search model."
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error embedding documents: {e}")
|
||||
|
||||
if apply and current_default.id != new_default_search_model_config.id:
|
||||
# Get the existing default SearchModelConfig object and update its name
|
||||
current_default.name = f"prev_default_{current_default.id}"
|
||||
current_default.save()
|
||||
|
||||
# Update the new default SearchModelConfig object's name
|
||||
new_default_search_model_config.name = "default"
|
||||
new_default_search_model_config.save()
|
||||
if not apply:
|
||||
logger.info("Run the command with the --apply flag to apply the new default Search model.")
|
||||
@@ -0,0 +1,46 @@
|
||||
# Generated by Django 5.0.8 on 2024-10-21 05:16
|
||||
|
||||
import django.contrib.postgres.fields
|
||||
from django.db import migrations, models
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
dependencies = [
|
||||
("database", "0069_webscraper_serverchatsettings_web_scraper"),
|
||||
]
|
||||
|
||||
operations = [
|
||||
migrations.AlterField(
|
||||
model_name="agent",
|
||||
name="input_tools",
|
||||
field=django.contrib.postgres.fields.ArrayField(
|
||||
base_field=models.CharField(
|
||||
choices=[
|
||||
("general", "General"),
|
||||
("online", "Online"),
|
||||
("notes", "Notes"),
|
||||
("summarize", "Summarize"),
|
||||
("webpage", "Webpage"),
|
||||
],
|
||||
max_length=200,
|
||||
),
|
||||
blank=True,
|
||||
default=list,
|
||||
null=True,
|
||||
size=None,
|
||||
),
|
||||
),
|
||||
migrations.AlterField(
|
||||
model_name="agent",
|
||||
name="output_modes",
|
||||
field=django.contrib.postgres.fields.ArrayField(
|
||||
base_field=models.CharField(
|
||||
choices=[("text", "Text"), ("image", "Image"), ("automation", "Automation")], max_length=200
|
||||
),
|
||||
blank=True,
|
||||
default=list,
|
||||
null=True,
|
||||
size=None,
|
||||
),
|
||||
),
|
||||
]
|
||||
@@ -0,0 +1,32 @@
|
||||
# Generated by Django 5.0.8 on 2024-10-20 19:24
|
||||
|
||||
from django.db import migrations, models
|
||||
|
||||
|
||||
def set_enabled_trial_at(apps, schema_editor):
|
||||
Subscription = apps.get_model("database", "Subscription")
|
||||
for subscription in Subscription.objects.all():
|
||||
subscription.enabled_trial_at = subscription.created_at
|
||||
subscription.save()
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
dependencies = [
|
||||
("database", "0070_alter_agent_input_tools_alter_agent_output_modes"),
|
||||
]
|
||||
|
||||
operations = [
|
||||
migrations.AddField(
|
||||
model_name="subscription",
|
||||
name="enabled_trial_at",
|
||||
field=models.DateTimeField(blank=True, default=None, null=True),
|
||||
),
|
||||
migrations.AlterField(
|
||||
model_name="subscription",
|
||||
name="type",
|
||||
field=models.CharField(
|
||||
choices=[("trial", "Trial"), ("standard", "Standard")], default="standard", max_length=20
|
||||
),
|
||||
),
|
||||
migrations.RunPython(set_enabled_trial_at),
|
||||
]
|
||||
24
src/khoj/database/migrations/0072_entry_search_model.py
Normal file
24
src/khoj/database/migrations/0072_entry_search_model.py
Normal file
@@ -0,0 +1,24 @@
|
||||
# Generated by Django 5.0.8 on 2024-10-21 21:09
|
||||
|
||||
import django.db.models.deletion
|
||||
from django.db import migrations, models
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
dependencies = [
|
||||
("database", "0071_subscription_enabled_trial_at_and_more"),
|
||||
]
|
||||
|
||||
operations = [
|
||||
migrations.AddField(
|
||||
model_name="entry",
|
||||
name="search_model",
|
||||
field=models.ForeignKey(
|
||||
blank=True,
|
||||
default=None,
|
||||
null=True,
|
||||
on_delete=django.db.models.deletion.SET_NULL,
|
||||
to="database.searchmodelconfig",
|
||||
),
|
||||
),
|
||||
]
|
||||
@@ -0,0 +1,15 @@
|
||||
# Generated by Django 5.0.9 on 2024-11-04 19:56
|
||||
|
||||
from django.db import migrations
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
dependencies = [
|
||||
("database", "0072_entry_search_model"),
|
||||
]
|
||||
|
||||
operations = [
|
||||
migrations.DeleteModel(
|
||||
name="UserSearchModelConfig",
|
||||
),
|
||||
]
|
||||
@@ -0,0 +1,17 @@
|
||||
# Generated by Django 5.0.9 on 2024-11-12 09:50
|
||||
|
||||
from django.db import migrations, models
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
dependencies = [
|
||||
("database", "0073_delete_usersearchmodelconfig"),
|
||||
]
|
||||
|
||||
operations = [
|
||||
migrations.AlterField(
|
||||
model_name="conversation",
|
||||
name="title",
|
||||
field=models.CharField(blank=True, default=None, max_length=500, null=True),
|
||||
),
|
||||
]
|
||||
@@ -73,9 +73,10 @@ class Subscription(BaseModel):
|
||||
STANDARD = "standard"
|
||||
|
||||
user = models.OneToOneField(KhojUser, on_delete=models.CASCADE, related_name="subscription")
|
||||
type = models.CharField(max_length=20, choices=Type.choices, default=Type.TRIAL)
|
||||
type = models.CharField(max_length=20, choices=Type.choices, default=Type.STANDARD)
|
||||
is_recurring = models.BooleanField(default=False)
|
||||
renewal_date = models.DateTimeField(null=True, default=None, blank=True)
|
||||
enabled_trial_at = models.DateTimeField(null=True, default=None, blank=True)
|
||||
|
||||
|
||||
class OpenAIProcessorConversationConfig(BaseModel):
|
||||
@@ -180,8 +181,12 @@ class Agent(BaseModel):
|
||||
) # Creator will only be null when the agents are managed by admin
|
||||
name = models.CharField(max_length=200)
|
||||
personality = models.TextField()
|
||||
input_tools = ArrayField(models.CharField(max_length=200, choices=InputToolOptions.choices), default=list)
|
||||
output_modes = ArrayField(models.CharField(max_length=200, choices=OutputModeOptions.choices), default=list)
|
||||
input_tools = ArrayField(
|
||||
models.CharField(max_length=200, choices=InputToolOptions.choices), default=list, null=True, blank=True
|
||||
)
|
||||
output_modes = ArrayField(
|
||||
models.CharField(max_length=200, choices=OutputModeOptions.choices), default=list, null=True, blank=True
|
||||
)
|
||||
managed_by_admin = models.BooleanField(default=False)
|
||||
chat_model = models.ForeignKey(ChatModelOptions, on_delete=models.CASCADE)
|
||||
slug = models.CharField(max_length=200, unique=True)
|
||||
@@ -444,11 +449,6 @@ class UserVoiceModelConfig(BaseModel):
|
||||
setting = models.ForeignKey(VoiceModelOption, on_delete=models.CASCADE, default=None, null=True, blank=True)
|
||||
|
||||
|
||||
class UserSearchModelConfig(BaseModel):
|
||||
user = models.OneToOneField(KhojUser, on_delete=models.CASCADE)
|
||||
setting = models.ForeignKey(SearchModelConfig, on_delete=models.CASCADE)
|
||||
|
||||
|
||||
class UserTextToImageModelConfig(BaseModel):
|
||||
user = models.OneToOneField(KhojUser, on_delete=models.CASCADE)
|
||||
setting = models.ForeignKey(TextToImageModelConfig, on_delete=models.CASCADE)
|
||||
@@ -458,8 +458,12 @@ class Conversation(BaseModel):
|
||||
user = models.ForeignKey(KhojUser, on_delete=models.CASCADE)
|
||||
conversation_log = models.JSONField(default=dict)
|
||||
client = models.ForeignKey(ClientApplication, on_delete=models.CASCADE, default=None, null=True, blank=True)
|
||||
|
||||
# Slug is an app-generated conversation identifier. Need not be unique. Used as display title essentially.
|
||||
slug = models.CharField(max_length=200, default=None, null=True, blank=True)
|
||||
title = models.CharField(max_length=200, default=None, null=True, blank=True)
|
||||
|
||||
# The title field is explicitly set by the user.
|
||||
title = models.CharField(max_length=500, default=None, null=True, blank=True)
|
||||
agent = models.ForeignKey(Agent, on_delete=models.SET_NULL, default=None, null=True, blank=True)
|
||||
file_filters = models.JSONField(default=list)
|
||||
id = models.UUIDField(default=uuid.uuid4, editable=False, unique=True, primary_key=True, db_index=True)
|
||||
@@ -530,6 +534,7 @@ class Entry(BaseModel):
|
||||
url = models.URLField(max_length=400, default=None, null=True, blank=True)
|
||||
hashed_value = models.CharField(max_length=100)
|
||||
corpus_id = models.UUIDField(default=uuid.uuid4, editable=False)
|
||||
search_model = models.ForeignKey(SearchModelConfig, on_delete=models.SET_NULL, default=None, null=True, blank=True)
|
||||
|
||||
def save(self, *args, **kwargs):
|
||||
if self.user and self.agent:
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
import logging
|
||||
import os
|
||||
from datetime import datetime
|
||||
import tempfile
|
||||
from typing import Dict, List, Tuple
|
||||
|
||||
from langchain_community.document_loaders import Docx2txtLoader
|
||||
@@ -19,7 +18,7 @@ class DocxToEntries(TextToEntries):
|
||||
super().__init__()
|
||||
|
||||
# Define Functions
|
||||
def process(self, files: dict[str, str] = None, user: KhojUser = None, regenerate: bool = False) -> Tuple[int, int]:
|
||||
def process(self, files: dict[str, str], user: KhojUser, regenerate: bool = False) -> Tuple[int, int]:
|
||||
# Extract required fields from config
|
||||
deletion_file_names = set([file for file in files if files[file] == b""])
|
||||
files_to_process = set(files) - deletion_file_names
|
||||
@@ -36,13 +35,13 @@ class DocxToEntries(TextToEntries):
|
||||
# Identify, mark and merge any new entries with previous entries
|
||||
with timer("Identify new or updated entries", logger):
|
||||
num_new_embeddings, num_deleted_embeddings = self.update_embeddings(
|
||||
user,
|
||||
current_entries,
|
||||
DbEntry.EntryType.DOCX,
|
||||
DbEntry.EntrySource.COMPUTER,
|
||||
"compiled",
|
||||
logger,
|
||||
deletion_file_names,
|
||||
user,
|
||||
regenerate=regenerate,
|
||||
file_to_text_map=file_to_text_map,
|
||||
)
|
||||
@@ -58,28 +57,13 @@ class DocxToEntries(TextToEntries):
|
||||
file_to_text_map = dict()
|
||||
for docx_file in docx_files:
|
||||
try:
|
||||
timestamp_now = datetime.utcnow().timestamp()
|
||||
tmp_file = f"tmp_docx_file_{timestamp_now}.docx"
|
||||
with open(tmp_file, "wb") as f:
|
||||
bytes_content = docx_files[docx_file]
|
||||
f.write(bytes_content)
|
||||
|
||||
# Load the content using Docx2txtLoader
|
||||
loader = Docx2txtLoader(tmp_file)
|
||||
docx_entries_per_file = loader.load()
|
||||
|
||||
# Convert the loaded entries into the desired format
|
||||
docx_texts = [page.page_content for page in docx_entries_per_file]
|
||||
|
||||
docx_texts = DocxToEntries.extract_text(docx_files[docx_file])
|
||||
entry_to_location_map += zip(docx_texts, [docx_file] * len(docx_texts))
|
||||
entries.extend(docx_texts)
|
||||
file_to_text_map[docx_file] = docx_texts
|
||||
except Exception as e:
|
||||
logger.warning(f"Unable to process file: {docx_file}. This file will not be indexed.")
|
||||
logger.warning(f"Unable to extract entries from file: {docx_file}")
|
||||
logger.warning(e, exc_info=True)
|
||||
finally:
|
||||
if os.path.exists(f"{tmp_file}"):
|
||||
os.remove(f"{tmp_file}")
|
||||
return file_to_text_map, DocxToEntries.convert_docx_entries_to_maps(entries, dict(entry_to_location_map))
|
||||
|
||||
@staticmethod
|
||||
@@ -103,3 +87,25 @@ class DocxToEntries(TextToEntries):
|
||||
logger.debug(f"Converted {len(parsed_entries)} DOCX entries to dictionaries")
|
||||
|
||||
return entries
|
||||
|
||||
@staticmethod
|
||||
def extract_text(docx_file):
|
||||
"""Extract text from specified DOCX file"""
|
||||
try:
|
||||
docx_entry_by_pages = []
|
||||
# Create temp file with .docx extension that gets auto-deleted
|
||||
with tempfile.NamedTemporaryFile(suffix=".docx", delete=True) as tmp:
|
||||
tmp.write(docx_file)
|
||||
tmp.flush() # Ensure all data is written
|
||||
|
||||
# Load the content using Docx2txtLoader
|
||||
loader = Docx2txtLoader(tmp.name)
|
||||
docx_entries_per_file = loader.load()
|
||||
|
||||
# Convert the loaded entries into the desired format
|
||||
docx_entry_by_pages = [page.page_content for page in docx_entries_per_file]
|
||||
except Exception as e:
|
||||
logger.warning(f"Unable to extract text from file: {docx_file}")
|
||||
logger.warning(e, exc_info=True)
|
||||
|
||||
return docx_entry_by_pages
|
||||
|
||||
@@ -48,7 +48,7 @@ class GithubToEntries(TextToEntries):
|
||||
else:
|
||||
return
|
||||
|
||||
def process(self, files: dict[str, str] = None, user: KhojUser = None, regenerate: bool = False) -> Tuple[int, int]:
|
||||
def process(self, files: dict[str, str], user: KhojUser, regenerate: bool = False) -> Tuple[int, int]:
|
||||
if self.config.pat_token is None or self.config.pat_token == "":
|
||||
logger.error(f"Github PAT token is not set. Skipping github content")
|
||||
raise ValueError("Github PAT token is not set. Skipping github content")
|
||||
@@ -101,12 +101,12 @@ class GithubToEntries(TextToEntries):
|
||||
# Identify, mark and merge any new entries with previous entries
|
||||
with timer("Identify new or updated entries", logger):
|
||||
num_new_embeddings, num_deleted_embeddings = self.update_embeddings(
|
||||
user,
|
||||
current_entries,
|
||||
DbEntry.EntryType.GITHUB,
|
||||
DbEntry.EntrySource.GITHUB,
|
||||
key="compiled",
|
||||
logger=logger,
|
||||
user=user,
|
||||
)
|
||||
|
||||
return num_new_embeddings, num_deleted_embeddings
|
||||
|
||||
@@ -18,7 +18,7 @@ class ImageToEntries(TextToEntries):
|
||||
super().__init__()
|
||||
|
||||
# Define Functions
|
||||
def process(self, files: dict[str, str] = None, user: KhojUser = None, regenerate: bool = False) -> Tuple[int, int]:
|
||||
def process(self, files: dict[str, str], user: KhojUser, regenerate: bool = False) -> Tuple[int, int]:
|
||||
# Extract required fields from config
|
||||
deletion_file_names = set([file for file in files if files[file] == b""])
|
||||
files_to_process = set(files) - deletion_file_names
|
||||
@@ -35,13 +35,13 @@ class ImageToEntries(TextToEntries):
|
||||
# Identify, mark and merge any new entries with previous entries
|
||||
with timer("Identify new or updated entries", logger):
|
||||
num_new_embeddings, num_deleted_embeddings = self.update_embeddings(
|
||||
user,
|
||||
current_entries,
|
||||
DbEntry.EntryType.IMAGE,
|
||||
DbEntry.EntrySource.COMPUTER,
|
||||
"compiled",
|
||||
logger,
|
||||
deletion_file_names,
|
||||
user,
|
||||
regenerate=regenerate,
|
||||
file_to_text_map=file_to_text_map,
|
||||
)
|
||||
|
||||
@@ -19,7 +19,7 @@ class MarkdownToEntries(TextToEntries):
|
||||
super().__init__()
|
||||
|
||||
# Define Functions
|
||||
def process(self, files: dict[str, str] = None, user: KhojUser = None, regenerate: bool = False) -> Tuple[int, int]:
|
||||
def process(self, files: dict[str, str], user: KhojUser, regenerate: bool = False) -> Tuple[int, int]:
|
||||
# Extract required fields from config
|
||||
deletion_file_names = set([file for file in files if files[file] == ""])
|
||||
files_to_process = set(files) - deletion_file_names
|
||||
@@ -37,13 +37,13 @@ class MarkdownToEntries(TextToEntries):
|
||||
# Identify, mark and merge any new entries with previous entries
|
||||
with timer("Identify new or updated entries", logger):
|
||||
num_new_embeddings, num_deleted_embeddings = self.update_embeddings(
|
||||
user,
|
||||
current_entries,
|
||||
DbEntry.EntryType.MARKDOWN,
|
||||
DbEntry.EntrySource.COMPUTER,
|
||||
"compiled",
|
||||
logger,
|
||||
deletion_file_names,
|
||||
user,
|
||||
regenerate=regenerate,
|
||||
file_to_text_map=file_to_text_map,
|
||||
)
|
||||
|
||||
@@ -79,7 +79,7 @@ class NotionToEntries(TextToEntries):
|
||||
|
||||
self.body_params = {"page_size": 100}
|
||||
|
||||
def process(self, files: dict[str, str] = None, user: KhojUser = None, regenerate: bool = False) -> Tuple[int, int]:
|
||||
def process(self, files: dict[str, str], user: KhojUser, regenerate: bool = False) -> Tuple[int, int]:
|
||||
current_entries = []
|
||||
|
||||
# Get all pages
|
||||
@@ -248,12 +248,12 @@ class NotionToEntries(TextToEntries):
|
||||
# Identify, mark and merge any new entries with previous entries
|
||||
with timer("Identify new or updated entries", logger):
|
||||
num_new_embeddings, num_deleted_embeddings = self.update_embeddings(
|
||||
user,
|
||||
current_entries,
|
||||
DbEntry.EntryType.NOTION,
|
||||
DbEntry.EntrySource.NOTION,
|
||||
key="compiled",
|
||||
logger=logger,
|
||||
user=user,
|
||||
)
|
||||
|
||||
return num_new_embeddings, num_deleted_embeddings
|
||||
|
||||
@@ -20,7 +20,7 @@ class OrgToEntries(TextToEntries):
|
||||
super().__init__()
|
||||
|
||||
# Define Functions
|
||||
def process(self, files: dict[str, str] = None, user: KhojUser = None, regenerate: bool = False) -> Tuple[int, int]:
|
||||
def process(self, files: dict[str, str], user: KhojUser, regenerate: bool = False) -> Tuple[int, int]:
|
||||
deletion_file_names = set([file for file in files if files[file] == ""])
|
||||
files_to_process = set(files) - deletion_file_names
|
||||
files = {file: files[file] for file in files_to_process}
|
||||
@@ -36,13 +36,13 @@ class OrgToEntries(TextToEntries):
|
||||
# Identify, mark and merge any new entries with previous entries
|
||||
with timer("Identify new or updated entries", logger):
|
||||
num_new_embeddings, num_deleted_embeddings = self.update_embeddings(
|
||||
user,
|
||||
current_entries,
|
||||
DbEntry.EntryType.ORG,
|
||||
DbEntry.EntrySource.COMPUTER,
|
||||
"compiled",
|
||||
logger,
|
||||
deletion_file_names,
|
||||
user,
|
||||
regenerate=regenerate,
|
||||
file_to_text_map=file_to_text_map,
|
||||
)
|
||||
|
||||
@@ -49,7 +49,7 @@ def normalize_filename(filename):
|
||||
normalized_filename = f"~/{relpath(filename, start=Path.home())}"
|
||||
else:
|
||||
normalized_filename = filename
|
||||
escaped_filename = f"{normalized_filename}".replace("[", "\[").replace("]", "\]")
|
||||
escaped_filename = f"{normalized_filename}".replace("[", r"\[").replace("]", r"\]")
|
||||
return escaped_filename
|
||||
|
||||
|
||||
|
||||
@@ -1,13 +1,9 @@
|
||||
import base64
|
||||
import logging
|
||||
import os
|
||||
from datetime import datetime
|
||||
from typing import Dict, List, Tuple
|
||||
import tempfile
|
||||
from typing import Dict, Final, List, Tuple
|
||||
|
||||
from langchain_community.document_loaders import PyMuPDFLoader
|
||||
|
||||
# importing FileObjectAdapter so that we can add new files and debug file object db.
|
||||
# from khoj.database.adapters import FileObjectAdapters
|
||||
from khoj.database.models import Entry as DbEntry
|
||||
from khoj.database.models import KhojUser
|
||||
from khoj.processor.content.text_to_entries import TextToEntries
|
||||
@@ -18,11 +14,14 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class PdfToEntries(TextToEntries):
|
||||
# Class-level constant translation table
|
||||
NULL_TRANSLATOR: Final = str.maketrans("", "", "\x00")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
# Define Functions
|
||||
def process(self, files: dict[str, str] = None, user: KhojUser = None, regenerate: bool = False) -> Tuple[int, int]:
|
||||
def process(self, files: dict[str, str], user: KhojUser, regenerate: bool = False) -> Tuple[int, int]:
|
||||
# Extract required fields from config
|
||||
deletion_file_names = set([file for file in files if files[file] == b""])
|
||||
files_to_process = set(files) - deletion_file_names
|
||||
@@ -39,13 +38,13 @@ class PdfToEntries(TextToEntries):
|
||||
# Identify, mark and merge any new entries with previous entries
|
||||
with timer("Identify new or updated entries", logger):
|
||||
num_new_embeddings, num_deleted_embeddings = self.update_embeddings(
|
||||
user,
|
||||
current_entries,
|
||||
DbEntry.EntryType.PDF,
|
||||
DbEntry.EntrySource.COMPUTER,
|
||||
"compiled",
|
||||
logger,
|
||||
deletion_file_names,
|
||||
user,
|
||||
regenerate=regenerate,
|
||||
file_to_text_map=file_to_text_map,
|
||||
)
|
||||
@@ -60,31 +59,13 @@ class PdfToEntries(TextToEntries):
|
||||
entry_to_location_map: List[Tuple[str, str]] = []
|
||||
for pdf_file in pdf_files:
|
||||
try:
|
||||
# Write the PDF file to a temporary file, as it is stored in byte format in the pdf_file object and the PDF Loader expects a file path
|
||||
timestamp_now = datetime.utcnow().timestamp()
|
||||
tmp_file = f"tmp_pdf_file_{timestamp_now}.pdf"
|
||||
with open(f"{tmp_file}", "wb") as f:
|
||||
bytes = pdf_files[pdf_file]
|
||||
f.write(bytes)
|
||||
try:
|
||||
loader = PyMuPDFLoader(f"{tmp_file}", extract_images=False)
|
||||
pdf_entries_per_file = [page.page_content for page in loader.load()]
|
||||
except ImportError:
|
||||
loader = PyMuPDFLoader(f"{tmp_file}")
|
||||
pdf_entries_per_file = [
|
||||
page.page_content for page in loader.load()
|
||||
] # page_content items list for a given pdf.
|
||||
entry_to_location_map += zip(
|
||||
pdf_entries_per_file, [pdf_file] * len(pdf_entries_per_file)
|
||||
) # this is an indexed map of pdf_entries for the pdf.
|
||||
pdf_entries_per_file = PdfToEntries.extract_text(pdf_files[pdf_file])
|
||||
entry_to_location_map += zip(pdf_entries_per_file, [pdf_file] * len(pdf_entries_per_file))
|
||||
entries.extend(pdf_entries_per_file)
|
||||
file_to_text_map[pdf_file] = pdf_entries_per_file
|
||||
except Exception as e:
|
||||
logger.warning(f"Unable to process file: {pdf_file}. This file will not be indexed.")
|
||||
logger.warning(f"Unable to extract entries from file: {pdf_file}")
|
||||
logger.warning(e, exc_info=True)
|
||||
finally:
|
||||
if os.path.exists(f"{tmp_file}"):
|
||||
os.remove(f"{tmp_file}")
|
||||
|
||||
return file_to_text_map, PdfToEntries.convert_pdf_entries_to_maps(entries, dict(entry_to_location_map))
|
||||
|
||||
@@ -109,3 +90,30 @@ class PdfToEntries(TextToEntries):
|
||||
logger.debug(f"Converted {len(parsed_entries)} PDF entries to dictionaries")
|
||||
|
||||
return entries
|
||||
|
||||
@staticmethod
|
||||
def extract_text(pdf_file):
|
||||
"""Extract text from specified PDF files"""
|
||||
try:
|
||||
# Create temp file with .pdf extension that gets auto-deleted
|
||||
with tempfile.NamedTemporaryFile(suffix=".pdf", delete=True) as tmpf:
|
||||
tmpf.write(pdf_file)
|
||||
tmpf.flush() # Ensure all data is written
|
||||
|
||||
# Load the content using PyMuPDFLoader
|
||||
loader = PyMuPDFLoader(tmpf.name, extract_images=True)
|
||||
pdf_entries_per_file = loader.load()
|
||||
|
||||
# Convert the loaded entries into the desired format
|
||||
pdf_entry_by_pages = [PdfToEntries.clean_text(page.page_content) for page in pdf_entries_per_file]
|
||||
except Exception as e:
|
||||
logger.warning(f"Unable to process file: {pdf_file}. This file will not be indexed.")
|
||||
logger.warning(e, exc_info=True)
|
||||
|
||||
return pdf_entry_by_pages
|
||||
|
||||
@staticmethod
|
||||
def clean_text(text: str) -> str:
|
||||
"""Clean PDF text by removing null bytes and invalid Unicode characters."""
|
||||
# Use faster translation table instead of replace
|
||||
return text.translate(PdfToEntries.NULL_TRANSLATOR)
|
||||
|
||||
@@ -20,7 +20,7 @@ class PlaintextToEntries(TextToEntries):
|
||||
super().__init__()
|
||||
|
||||
# Define Functions
|
||||
def process(self, files: dict[str, str] = None, user: KhojUser = None, regenerate: bool = False) -> Tuple[int, int]:
|
||||
def process(self, files: dict[str, str], user: KhojUser, regenerate: bool = False) -> Tuple[int, int]:
|
||||
deletion_file_names = set([file for file in files if files[file] == ""])
|
||||
files_to_process = set(files) - deletion_file_names
|
||||
files = {file: files[file] for file in files_to_process}
|
||||
@@ -36,13 +36,13 @@ class PlaintextToEntries(TextToEntries):
|
||||
# Identify, mark and merge any new entries with previous entries
|
||||
with timer("Identify new or updated entries", logger):
|
||||
num_new_embeddings, num_deleted_embeddings = self.update_embeddings(
|
||||
user,
|
||||
current_entries,
|
||||
DbEntry.EntryType.PLAINTEXT,
|
||||
DbEntry.EntrySource.COMPUTER,
|
||||
key="compiled",
|
||||
logger=logger,
|
||||
deletion_filenames=deletion_file_names,
|
||||
user=user,
|
||||
regenerate=regenerate,
|
||||
file_to_text_map=file_to_text_map,
|
||||
)
|
||||
|
||||
@@ -12,7 +12,7 @@ from tqdm import tqdm
|
||||
from khoj.database.adapters import (
|
||||
EntryAdapters,
|
||||
FileObjectAdapters,
|
||||
get_user_search_model_or_default,
|
||||
get_default_search_model,
|
||||
)
|
||||
from khoj.database.models import Entry as DbEntry
|
||||
from khoj.database.models import EntryDates, KhojUser
|
||||
@@ -31,7 +31,7 @@ class TextToEntries(ABC):
|
||||
self.date_filter = DateFilter()
|
||||
|
||||
@abstractmethod
|
||||
def process(self, files: dict[str, str] = None, user: KhojUser = None, regenerate: bool = False) -> Tuple[int, int]:
|
||||
def process(self, files: dict[str, str], user: KhojUser, regenerate: bool = False) -> Tuple[int, int]:
|
||||
...
|
||||
|
||||
@staticmethod
|
||||
@@ -114,13 +114,13 @@ class TextToEntries(ABC):
|
||||
|
||||
def update_embeddings(
|
||||
self,
|
||||
user: KhojUser,
|
||||
current_entries: List[Entry],
|
||||
file_type: str,
|
||||
file_source: str,
|
||||
key="compiled",
|
||||
logger: logging.Logger = None,
|
||||
deletion_filenames: Set[str] = None,
|
||||
user: KhojUser = None,
|
||||
regenerate: bool = False,
|
||||
file_to_text_map: dict[str, str] = None,
|
||||
):
|
||||
@@ -148,10 +148,10 @@ class TextToEntries(ABC):
|
||||
hashes_to_process |= hashes_for_file - existing_entry_hashes
|
||||
|
||||
embeddings = []
|
||||
model = get_default_search_model()
|
||||
with timer("Generated embeddings for entries to add to database in", logger):
|
||||
entries_to_process = [hash_to_current_entries[hashed_val] for hashed_val in hashes_to_process]
|
||||
data_to_embed = [getattr(entry, key) for entry in entries_to_process]
|
||||
model = get_user_search_model_or_default(user)
|
||||
embeddings += self.embeddings_model[model.name].embed_documents(data_to_embed)
|
||||
|
||||
added_entries: list[DbEntry] = []
|
||||
@@ -177,6 +177,7 @@ class TextToEntries(ABC):
|
||||
file_type=file_type,
|
||||
hashed_value=entry_hash,
|
||||
corpus_id=entry.corpus_id,
|
||||
search_model=model,
|
||||
)
|
||||
)
|
||||
try:
|
||||
|
||||
@@ -11,10 +11,16 @@ from khoj.processor.conversation import prompts
|
||||
from khoj.processor.conversation.anthropic.utils import (
|
||||
anthropic_chat_completion_with_backoff,
|
||||
anthropic_completion_with_backoff,
|
||||
format_messages_for_anthropic,
|
||||
)
|
||||
from khoj.processor.conversation.utils import (
|
||||
clean_json,
|
||||
construct_structured_message,
|
||||
generate_chatml_messages_with_context,
|
||||
)
|
||||
from khoj.processor.conversation.utils import generate_chatml_messages_with_context
|
||||
from khoj.utils.helpers import ConversationCommand, is_none_or_empty
|
||||
from khoj.utils.rawconfig import LocationData
|
||||
from khoj.utils.yaml import yaml_dump
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -27,7 +33,11 @@ def extract_questions_anthropic(
|
||||
temperature=0.7,
|
||||
location_data: LocationData = None,
|
||||
user: KhojUser = None,
|
||||
query_images: Optional[list[str]] = None,
|
||||
vision_enabled: bool = False,
|
||||
personality_context: Optional[str] = None,
|
||||
query_files: str = None,
|
||||
tracer: dict = {},
|
||||
):
|
||||
"""
|
||||
Infer search queries to retrieve relevant notes to answer user query
|
||||
@@ -68,7 +78,19 @@ def extract_questions_anthropic(
|
||||
text=text,
|
||||
)
|
||||
|
||||
messages = [ChatMessage(content=prompt, role="user")]
|
||||
prompt = construct_structured_message(
|
||||
message=prompt,
|
||||
images=query_images,
|
||||
model_type=ChatModelOptions.ModelType.ANTHROPIC,
|
||||
vision_enabled=vision_enabled,
|
||||
attached_file_context=query_files,
|
||||
)
|
||||
|
||||
messages = []
|
||||
|
||||
messages.append(ChatMessage(content=prompt, role="user"))
|
||||
|
||||
messages, system_prompt = format_messages_for_anthropic(messages, system_prompt)
|
||||
|
||||
response = anthropic_completion_with_backoff(
|
||||
messages=messages,
|
||||
@@ -76,14 +98,13 @@ def extract_questions_anthropic(
|
||||
model_name=model,
|
||||
temperature=temperature,
|
||||
api_key=api_key,
|
||||
response_type="json_object",
|
||||
tracer=tracer,
|
||||
)
|
||||
|
||||
# Extract, Clean Message from Claude's Response
|
||||
try:
|
||||
response = response.strip()
|
||||
match = re.search(r"\{.*?\}", response)
|
||||
if match:
|
||||
response = match.group()
|
||||
response = clean_json(response)
|
||||
response = json.loads(response)
|
||||
response = [q.strip() for q in response["queries"] if q.strip()]
|
||||
if not isinstance(response, list) or not response:
|
||||
@@ -97,21 +118,11 @@ def extract_questions_anthropic(
|
||||
return questions
|
||||
|
||||
|
||||
def anthropic_send_message_to_model(messages, api_key, model):
|
||||
def anthropic_send_message_to_model(messages, api_key, model, response_type="text", tracer={}):
|
||||
"""
|
||||
Send message to model
|
||||
"""
|
||||
# Anthropic requires the first message to be a 'user' message, and the system prompt is not to be sent in the messages parameter
|
||||
system_prompt = None
|
||||
|
||||
if len(messages) == 1:
|
||||
messages[0].role = "user"
|
||||
else:
|
||||
system_prompt = ""
|
||||
for message in messages.copy():
|
||||
if message.role == "system":
|
||||
system_prompt += message.content
|
||||
messages.remove(message)
|
||||
messages, system_prompt = format_messages_for_anthropic(messages)
|
||||
|
||||
# Get Response from GPT. Don't use response_type because Anthropic doesn't support it.
|
||||
return anthropic_completion_with_backoff(
|
||||
@@ -119,6 +130,8 @@ def anthropic_send_message_to_model(messages, api_key, model):
|
||||
system_prompt=system_prompt,
|
||||
model_name=model,
|
||||
api_key=api_key,
|
||||
response_type=response_type,
|
||||
tracer=tracer,
|
||||
)
|
||||
|
||||
|
||||
@@ -126,8 +139,9 @@ def converse_anthropic(
|
||||
references,
|
||||
user_query,
|
||||
online_results: Optional[Dict[str, Dict]] = None,
|
||||
code_results: Optional[Dict[str, Dict]] = None,
|
||||
conversation_log={},
|
||||
model: Optional[str] = "claude-instant-1.2",
|
||||
model: Optional[str] = "claude-3-5-sonnet-20241022",
|
||||
api_key: Optional[str] = None,
|
||||
completion_func=None,
|
||||
conversation_commands=[ConversationCommand.Default],
|
||||
@@ -136,15 +150,16 @@ def converse_anthropic(
|
||||
location_data: LocationData = None,
|
||||
user_name: str = None,
|
||||
agent: Agent = None,
|
||||
query_images: Optional[list[str]] = None,
|
||||
vision_available: bool = False,
|
||||
query_files: str = None,
|
||||
tracer: dict = {},
|
||||
):
|
||||
"""
|
||||
Converse with user using Anthropic's Claude
|
||||
"""
|
||||
# Initialize Variables
|
||||
current_date = datetime.now()
|
||||
compiled_references = "\n\n".join({f"# {item}" for item in references})
|
||||
|
||||
conversation_primer = prompts.query_prompt.format(query=user_query)
|
||||
|
||||
if agent and agent.personality:
|
||||
system_prompt = prompts.custom_personality.format(
|
||||
@@ -168,38 +183,37 @@ def converse_anthropic(
|
||||
system_prompt = f"{system_prompt}\n{user_name_prompt}"
|
||||
|
||||
# Get Conversation Primer appropriate to Conversation Type
|
||||
if conversation_commands == [ConversationCommand.Notes] and is_none_or_empty(compiled_references):
|
||||
if conversation_commands == [ConversationCommand.Notes] and is_none_or_empty(references):
|
||||
completion_func(chat_response=prompts.no_notes_found.format())
|
||||
return iter([prompts.no_notes_found.format()])
|
||||
elif conversation_commands == [ConversationCommand.Online] and is_none_or_empty(online_results):
|
||||
completion_func(chat_response=prompts.no_online_results_found.format())
|
||||
return iter([prompts.no_online_results_found.format()])
|
||||
|
||||
context_message = ""
|
||||
if not is_none_or_empty(references):
|
||||
context_message = f"{prompts.notes_conversation.format(query=user_query, references=yaml_dump(references))}\n\n"
|
||||
if ConversationCommand.Online in conversation_commands or ConversationCommand.Webpage in conversation_commands:
|
||||
conversation_primer = (
|
||||
f"{prompts.online_search_conversation.format(online_results=str(online_results))}\n{conversation_primer}"
|
||||
)
|
||||
if not is_none_or_empty(compiled_references):
|
||||
conversation_primer = f"{prompts.notes_conversation.format(query=user_query, references=compiled_references)}\n\n{conversation_primer}"
|
||||
context_message += f"{prompts.online_search_conversation.format(online_results=yaml_dump(online_results))}\n\n"
|
||||
if ConversationCommand.Code in conversation_commands and not is_none_or_empty(code_results):
|
||||
context_message += f"{prompts.code_executed_context.format(code_results=str(code_results))}\n\n"
|
||||
context_message = context_message.strip()
|
||||
|
||||
# Setup Prompt with Primer or Conversation History
|
||||
messages = generate_chatml_messages_with_context(
|
||||
conversation_primer,
|
||||
user_query,
|
||||
context_message=context_message,
|
||||
conversation_log=conversation_log,
|
||||
model_name=model,
|
||||
max_prompt_size=max_prompt_size,
|
||||
tokenizer_name=tokenizer_name,
|
||||
query_images=query_images,
|
||||
vision_enabled=vision_available,
|
||||
model_type=ChatModelOptions.ModelType.ANTHROPIC,
|
||||
query_files=query_files,
|
||||
)
|
||||
|
||||
if len(messages) > 1:
|
||||
if messages[0].role == "assistant":
|
||||
messages = messages[1:]
|
||||
|
||||
for message in messages.copy():
|
||||
if message.role == "system":
|
||||
system_prompt += message.content
|
||||
messages.remove(message)
|
||||
messages, system_prompt = format_messages_for_anthropic(messages, system_prompt)
|
||||
|
||||
truncated_messages = "\n".join({f"{message.content[:40]}..." for message in messages})
|
||||
logger.debug(f"Conversation Context for Claude: {truncated_messages}")
|
||||
@@ -215,4 +229,5 @@ def converse_anthropic(
|
||||
system_prompt=system_prompt,
|
||||
completion_func=completion_func,
|
||||
max_prompt_size=max_prompt_size,
|
||||
tracer=tracer,
|
||||
)
|
||||
|
||||
@@ -3,6 +3,7 @@ from threading import Thread
|
||||
from typing import Dict, List
|
||||
|
||||
import anthropic
|
||||
from langchain.schema import ChatMessage
|
||||
from tenacity import (
|
||||
before_sleep_log,
|
||||
retry,
|
||||
@@ -11,7 +12,13 @@ from tenacity import (
|
||||
wait_random_exponential,
|
||||
)
|
||||
|
||||
from khoj.processor.conversation.utils import ThreadedGenerator
|
||||
from khoj.processor.conversation.utils import (
|
||||
ThreadedGenerator,
|
||||
commit_conversation_trace,
|
||||
get_image_from_url,
|
||||
)
|
||||
from khoj.utils import state
|
||||
from khoj.utils.helpers import in_debug_mode, is_none_or_empty
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -28,7 +35,15 @@ DEFAULT_MAX_TOKENS_ANTHROPIC = 3000
|
||||
reraise=True,
|
||||
)
|
||||
def anthropic_completion_with_backoff(
|
||||
messages, system_prompt, model_name, temperature=0, api_key=None, model_kwargs=None, max_tokens=None
|
||||
messages,
|
||||
system_prompt,
|
||||
model_name,
|
||||
temperature=0,
|
||||
api_key=None,
|
||||
model_kwargs=None,
|
||||
max_tokens=None,
|
||||
response_type="text",
|
||||
tracer={},
|
||||
) -> str:
|
||||
if api_key not in anthropic_clients:
|
||||
client: anthropic.Anthropic = anthropic.Anthropic(api_key=api_key)
|
||||
@@ -37,8 +52,11 @@ def anthropic_completion_with_backoff(
|
||||
client = anthropic_clients[api_key]
|
||||
|
||||
formatted_messages = [{"role": message.role, "content": message.content} for message in messages]
|
||||
if response_type == "json_object":
|
||||
# Prefill model response with '{' to make it output a valid JSON object
|
||||
formatted_messages += [{"role": "assistant", "content": "{"}]
|
||||
|
||||
aggregated_response = ""
|
||||
aggregated_response = "{" if response_type == "json_object" else ""
|
||||
max_tokens = max_tokens or DEFAULT_MAX_TOKENS_ANTHROPIC
|
||||
|
||||
model_kwargs = model_kwargs or dict()
|
||||
@@ -56,6 +74,12 @@ def anthropic_completion_with_backoff(
|
||||
for text in stream.text_stream:
|
||||
aggregated_response += text
|
||||
|
||||
# Save conversation trace
|
||||
tracer["chat_model"] = model_name
|
||||
tracer["temperature"] = temperature
|
||||
if in_debug_mode() or state.verbose > 1:
|
||||
commit_conversation_trace(messages, aggregated_response, tracer)
|
||||
|
||||
return aggregated_response
|
||||
|
||||
|
||||
@@ -76,18 +100,19 @@ def anthropic_chat_completion_with_backoff(
|
||||
max_prompt_size=None,
|
||||
completion_func=None,
|
||||
model_kwargs=None,
|
||||
tracer={},
|
||||
):
|
||||
g = ThreadedGenerator(compiled_references, online_results, completion_func=completion_func)
|
||||
t = Thread(
|
||||
target=anthropic_llm_thread,
|
||||
args=(g, messages, system_prompt, model_name, temperature, api_key, max_prompt_size, model_kwargs),
|
||||
args=(g, messages, system_prompt, model_name, temperature, api_key, max_prompt_size, model_kwargs, tracer),
|
||||
)
|
||||
t.start()
|
||||
return g
|
||||
|
||||
|
||||
def anthropic_llm_thread(
|
||||
g, messages, system_prompt, model_name, temperature, api_key, max_prompt_size=None, model_kwargs=None
|
||||
g, messages, system_prompt, model_name, temperature, api_key, max_prompt_size=None, model_kwargs=None, tracer={}
|
||||
):
|
||||
try:
|
||||
if api_key not in anthropic_clients:
|
||||
@@ -100,6 +125,7 @@ def anthropic_llm_thread(
|
||||
anthropic.types.MessageParam(role=message.role, content=message.content) for message in messages
|
||||
]
|
||||
|
||||
aggregated_response = ""
|
||||
with client.messages.stream(
|
||||
messages=formatted_messages,
|
||||
model=model_name, # type: ignore
|
||||
@@ -110,8 +136,63 @@ def anthropic_llm_thread(
|
||||
**(model_kwargs or dict()),
|
||||
) as stream:
|
||||
for text in stream.text_stream:
|
||||
aggregated_response += text
|
||||
g.send(text)
|
||||
|
||||
# Save conversation trace
|
||||
tracer["chat_model"] = model_name
|
||||
tracer["temperature"] = temperature
|
||||
if in_debug_mode() or state.verbose > 1:
|
||||
commit_conversation_trace(messages, aggregated_response, tracer)
|
||||
except Exception as e:
|
||||
logger.error(f"Error in anthropic_llm_thread: {e}", exc_info=True)
|
||||
finally:
|
||||
g.close()
|
||||
|
||||
|
||||
def format_messages_for_anthropic(messages: list[ChatMessage], system_prompt=None):
|
||||
"""
|
||||
Format messages for Anthropic
|
||||
"""
|
||||
# Extract system prompt
|
||||
system_prompt = system_prompt or ""
|
||||
for message in messages.copy():
|
||||
if message.role == "system":
|
||||
system_prompt += message.content
|
||||
messages.remove(message)
|
||||
system_prompt = None if is_none_or_empty(system_prompt) else system_prompt
|
||||
|
||||
# Anthropic requires the first message to be a 'user' message
|
||||
if len(messages) == 1:
|
||||
messages[0].role = "user"
|
||||
elif len(messages) > 1 and messages[0].role == "assistant":
|
||||
messages = messages[1:]
|
||||
|
||||
# Convert image urls to base64 encoded images in Anthropic message format
|
||||
for message in messages:
|
||||
if isinstance(message.content, list):
|
||||
content = []
|
||||
# Sort the content. Anthropic models prefer that text comes after images.
|
||||
message.content.sort(key=lambda x: 0 if x["type"] == "image_url" else 1)
|
||||
for idx, part in enumerate(message.content):
|
||||
if part["type"] == "text":
|
||||
content.append({"type": "text", "text": part["text"]})
|
||||
elif part["type"] == "image_url":
|
||||
image = get_image_from_url(part["image_url"]["url"], type="b64")
|
||||
# Prefix each image with text block enumerating the image number
|
||||
# This helps the model reference the image in its response. Recommended by Anthropic
|
||||
content.extend(
|
||||
[
|
||||
{
|
||||
"type": "text",
|
||||
"text": f"Image {idx + 1}:",
|
||||
},
|
||||
{
|
||||
"type": "image",
|
||||
"source": {"type": "base64", "media_type": image.type, "data": image.content},
|
||||
},
|
||||
]
|
||||
)
|
||||
message.content = content
|
||||
|
||||
return messages, system_prompt
|
||||
|
||||
@@ -6,16 +6,21 @@ from typing import Dict, Optional
|
||||
|
||||
from langchain.schema import ChatMessage
|
||||
|
||||
from khoj.database.models import Agent, KhojUser
|
||||
from khoj.database.models import Agent, ChatModelOptions, KhojUser
|
||||
from khoj.processor.conversation import prompts
|
||||
from khoj.processor.conversation.google.utils import (
|
||||
format_messages_for_gemini,
|
||||
gemini_chat_completion_with_backoff,
|
||||
gemini_completion_with_backoff,
|
||||
)
|
||||
from khoj.processor.conversation.utils import generate_chatml_messages_with_context
|
||||
from khoj.processor.conversation.utils import (
|
||||
clean_json,
|
||||
construct_structured_message,
|
||||
generate_chatml_messages_with_context,
|
||||
)
|
||||
from khoj.utils.helpers import ConversationCommand, is_none_or_empty
|
||||
from khoj.utils.rawconfig import LocationData
|
||||
from khoj.utils.yaml import yaml_dump
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -29,7 +34,11 @@ def extract_questions_gemini(
|
||||
max_tokens=None,
|
||||
location_data: LocationData = None,
|
||||
user: KhojUser = None,
|
||||
query_images: Optional[list[str]] = None,
|
||||
vision_enabled: bool = False,
|
||||
personality_context: Optional[str] = None,
|
||||
query_files: str = None,
|
||||
tracer: dict = {},
|
||||
):
|
||||
"""
|
||||
Infer search queries to retrieve relevant notes to answer user query
|
||||
@@ -70,25 +79,26 @@ def extract_questions_gemini(
|
||||
text=text,
|
||||
)
|
||||
|
||||
messages = [ChatMessage(content=prompt, role="user")]
|
||||
prompt = construct_structured_message(
|
||||
message=prompt,
|
||||
images=query_images,
|
||||
model_type=ChatModelOptions.ModelType.GOOGLE,
|
||||
vision_enabled=vision_enabled,
|
||||
attached_file_context=query_files,
|
||||
)
|
||||
|
||||
model_kwargs = {"response_mime_type": "application/json"}
|
||||
messages = []
|
||||
|
||||
response = gemini_completion_with_backoff(
|
||||
messages=messages,
|
||||
system_prompt=system_prompt,
|
||||
model_name=model,
|
||||
temperature=temperature,
|
||||
api_key=api_key,
|
||||
model_kwargs=model_kwargs,
|
||||
messages.append(ChatMessage(content=prompt, role="user"))
|
||||
messages.append(ChatMessage(content=system_prompt, role="system"))
|
||||
|
||||
response = gemini_send_message_to_model(
|
||||
messages, api_key, model, response_type="json_object", temperature=temperature, tracer=tracer
|
||||
)
|
||||
|
||||
# Extract, Clean Message from Gemini's Response
|
||||
try:
|
||||
response = response.strip()
|
||||
match = re.search(r"\{.*?\}", response)
|
||||
if match:
|
||||
response = match.group()
|
||||
response = clean_json(response)
|
||||
response = json.loads(response)
|
||||
response = [q.strip() for q in response["queries"] if q.strip()]
|
||||
if not isinstance(response, list) or not response:
|
||||
@@ -102,19 +112,35 @@ def extract_questions_gemini(
|
||||
return questions
|
||||
|
||||
|
||||
def gemini_send_message_to_model(messages, api_key, model, response_type="text"):
|
||||
def gemini_send_message_to_model(
|
||||
messages,
|
||||
api_key,
|
||||
model,
|
||||
response_type="text",
|
||||
temperature=0,
|
||||
model_kwargs=None,
|
||||
tracer={},
|
||||
):
|
||||
"""
|
||||
Send message to model
|
||||
"""
|
||||
messages, system_prompt = format_messages_for_gemini(messages)
|
||||
|
||||
model_kwargs = {}
|
||||
if response_type == "json_object":
|
||||
model_kwargs["response_mime_type"] = "application/json"
|
||||
|
||||
# Sometimes, this causes unwanted behavior and terminates response early. Disable for now while it's flaky.
|
||||
# if response_type == "json_object":
|
||||
# model_kwargs["response_mime_type"] = "application/json"
|
||||
|
||||
# Get Response from Gemini
|
||||
return gemini_completion_with_backoff(
|
||||
messages=messages, system_prompt=system_prompt, model_name=model, api_key=api_key, model_kwargs=model_kwargs
|
||||
messages=messages,
|
||||
system_prompt=system_prompt,
|
||||
model_name=model,
|
||||
api_key=api_key,
|
||||
temperature=temperature,
|
||||
model_kwargs=model_kwargs,
|
||||
tracer=tracer,
|
||||
)
|
||||
|
||||
|
||||
@@ -122,6 +148,7 @@ def converse_gemini(
|
||||
references,
|
||||
user_query,
|
||||
online_results: Optional[Dict[str, Dict]] = None,
|
||||
code_results: Optional[Dict[str, Dict]] = None,
|
||||
conversation_log={},
|
||||
model: Optional[str] = "gemini-1.5-flash",
|
||||
api_key: Optional[str] = None,
|
||||
@@ -133,15 +160,16 @@ def converse_gemini(
|
||||
location_data: LocationData = None,
|
||||
user_name: str = None,
|
||||
agent: Agent = None,
|
||||
query_images: Optional[list[str]] = None,
|
||||
vision_available: bool = False,
|
||||
query_files: str = None,
|
||||
tracer={},
|
||||
):
|
||||
"""
|
||||
Converse with user using Google's Gemini
|
||||
"""
|
||||
# Initialize Variables
|
||||
current_date = datetime.now()
|
||||
compiled_references = "\n\n".join({f"# {item}" for item in references})
|
||||
|
||||
conversation_primer = prompts.query_prompt.format(query=user_query)
|
||||
|
||||
if agent and agent.personality:
|
||||
system_prompt = prompts.custom_personality.format(
|
||||
@@ -166,27 +194,34 @@ def converse_gemini(
|
||||
system_prompt = f"{system_prompt}\n{user_name_prompt}"
|
||||
|
||||
# Get Conversation Primer appropriate to Conversation Type
|
||||
if conversation_commands == [ConversationCommand.Notes] and is_none_or_empty(compiled_references):
|
||||
if conversation_commands == [ConversationCommand.Notes] and is_none_or_empty(references):
|
||||
completion_func(chat_response=prompts.no_notes_found.format())
|
||||
return iter([prompts.no_notes_found.format()])
|
||||
elif conversation_commands == [ConversationCommand.Online] and is_none_or_empty(online_results):
|
||||
completion_func(chat_response=prompts.no_online_results_found.format())
|
||||
return iter([prompts.no_online_results_found.format()])
|
||||
|
||||
context_message = ""
|
||||
if not is_none_or_empty(references):
|
||||
context_message = f"{prompts.notes_conversation.format(query=user_query, references=yaml_dump(references))}\n\n"
|
||||
if ConversationCommand.Online in conversation_commands or ConversationCommand.Webpage in conversation_commands:
|
||||
conversation_primer = (
|
||||
f"{prompts.online_search_conversation.format(online_results=str(online_results))}\n{conversation_primer}"
|
||||
)
|
||||
if not is_none_or_empty(compiled_references):
|
||||
conversation_primer = f"{prompts.notes_conversation.format(query=user_query, references=compiled_references)}\n\n{conversation_primer}"
|
||||
context_message += f"{prompts.online_search_conversation.format(online_results=yaml_dump(online_results))}\n\n"
|
||||
if ConversationCommand.Code in conversation_commands and not is_none_or_empty(code_results):
|
||||
context_message += f"{prompts.code_executed_context.format(code_results=str(code_results))}\n\n"
|
||||
context_message = context_message.strip()
|
||||
|
||||
# Setup Prompt with Primer or Conversation History
|
||||
messages = generate_chatml_messages_with_context(
|
||||
conversation_primer,
|
||||
user_query,
|
||||
context_message=context_message,
|
||||
conversation_log=conversation_log,
|
||||
model_name=model,
|
||||
max_prompt_size=max_prompt_size,
|
||||
tokenizer_name=tokenizer_name,
|
||||
query_images=query_images,
|
||||
vision_enabled=vision_available,
|
||||
model_type=ChatModelOptions.ModelType.GOOGLE,
|
||||
query_files=query_files,
|
||||
)
|
||||
|
||||
messages, system_prompt = format_messages_for_gemini(messages, system_prompt)
|
||||
@@ -204,4 +239,5 @@ def converse_gemini(
|
||||
api_key=api_key,
|
||||
system_prompt=system_prompt,
|
||||
completion_func=completion_func,
|
||||
tracer=tracer,
|
||||
)
|
||||
|
||||
@@ -19,8 +19,13 @@ from tenacity import (
|
||||
wait_random_exponential,
|
||||
)
|
||||
|
||||
from khoj.processor.conversation.utils import ThreadedGenerator
|
||||
from khoj.utils.helpers import is_none_or_empty
|
||||
from khoj.processor.conversation.utils import (
|
||||
ThreadedGenerator,
|
||||
commit_conversation_trace,
|
||||
get_image_from_url,
|
||||
)
|
||||
from khoj.utils import state
|
||||
from khoj.utils.helpers import in_debug_mode, is_none_or_empty
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -35,7 +40,7 @@ MAX_OUTPUT_TOKENS_GEMINI = 8192
|
||||
reraise=True,
|
||||
)
|
||||
def gemini_completion_with_backoff(
|
||||
messages, system_prompt, model_name, temperature=0, api_key=None, model_kwargs=None
|
||||
messages, system_prompt, model_name, temperature=0, api_key=None, model_kwargs=None, tracer={}
|
||||
) -> str:
|
||||
genai.configure(api_key=api_key)
|
||||
model_kwargs = model_kwargs or dict()
|
||||
@@ -53,23 +58,30 @@ def gemini_completion_with_backoff(
|
||||
},
|
||||
)
|
||||
|
||||
formatted_messages = [{"role": message.role, "parts": [message.content]} for message in messages]
|
||||
formatted_messages = [{"role": message.role, "parts": message.content} for message in messages]
|
||||
|
||||
# Start chat session. All messages up to the last are considered to be part of the chat history
|
||||
chat_session = model.start_chat(history=formatted_messages[0:-1])
|
||||
|
||||
try:
|
||||
# Generate the response. The last message is considered to be the current prompt
|
||||
aggregated_response = chat_session.send_message(formatted_messages[-1]["parts"][0])
|
||||
return aggregated_response.text
|
||||
response = chat_session.send_message(formatted_messages[-1]["parts"])
|
||||
response_text = response.text
|
||||
except StopCandidateException as e:
|
||||
response_message, _ = handle_gemini_response(e.args)
|
||||
response_text, _ = handle_gemini_response(e.args)
|
||||
# Respond with reason for stopping
|
||||
logger.warning(
|
||||
f"LLM Response Prevented for {model_name}: {response_message}.\n"
|
||||
f"LLM Response Prevented for {model_name}: {response_text}.\n"
|
||||
+ f"Last Message by {messages[-1].role}: {messages[-1].content}"
|
||||
)
|
||||
return response_message
|
||||
|
||||
# Save conversation trace
|
||||
tracer["chat_model"] = model_name
|
||||
tracer["temperature"] = temperature
|
||||
if in_debug_mode() or state.verbose > 1:
|
||||
commit_conversation_trace(messages, response_text, tracer)
|
||||
|
||||
return response_text
|
||||
|
||||
|
||||
@retry(
|
||||
@@ -88,17 +100,20 @@ def gemini_chat_completion_with_backoff(
|
||||
system_prompt,
|
||||
completion_func=None,
|
||||
model_kwargs=None,
|
||||
tracer: dict = {},
|
||||
):
|
||||
g = ThreadedGenerator(compiled_references, online_results, completion_func=completion_func)
|
||||
t = Thread(
|
||||
target=gemini_llm_thread,
|
||||
args=(g, messages, system_prompt, model_name, temperature, api_key, model_kwargs),
|
||||
args=(g, messages, system_prompt, model_name, temperature, api_key, model_kwargs, tracer),
|
||||
)
|
||||
t.start()
|
||||
return g
|
||||
|
||||
|
||||
def gemini_llm_thread(g, messages, system_prompt, model_name, temperature, api_key, model_kwargs=None):
|
||||
def gemini_llm_thread(
|
||||
g, messages, system_prompt, model_name, temperature, api_key, model_kwargs=None, tracer: dict = {}
|
||||
):
|
||||
try:
|
||||
genai.configure(api_key=api_key)
|
||||
model_kwargs = model_kwargs or dict()
|
||||
@@ -117,16 +132,25 @@ def gemini_llm_thread(g, messages, system_prompt, model_name, temperature, api_k
|
||||
},
|
||||
)
|
||||
|
||||
formatted_messages = [{"role": message.role, "parts": [message.content]} for message in messages]
|
||||
aggregated_response = ""
|
||||
formatted_messages = [{"role": message.role, "parts": message.content} for message in messages]
|
||||
|
||||
# all messages up to the last are considered to be part of the chat history
|
||||
chat_session = model.start_chat(history=formatted_messages[0:-1])
|
||||
# the last message is considered to be the current prompt
|
||||
for chunk in chat_session.send_message(formatted_messages[-1]["parts"][0], stream=True):
|
||||
for chunk in chat_session.send_message(formatted_messages[-1]["parts"], stream=True):
|
||||
message, stopped = handle_gemini_response(chunk.candidates, chunk.prompt_feedback)
|
||||
message = message or chunk.text
|
||||
aggregated_response += message
|
||||
g.send(message)
|
||||
if stopped:
|
||||
raise StopCandidateException(message)
|
||||
|
||||
# Save conversation trace
|
||||
tracer["chat_model"] = model_name
|
||||
tracer["temperature"] = temperature
|
||||
if in_debug_mode() or state.verbose > 1:
|
||||
commit_conversation_trace(messages, aggregated_response, tracer)
|
||||
except StopCandidateException as e:
|
||||
logger.warning(
|
||||
f"LLM Response Prevented for {model_name}: {e.args[0]}.\n"
|
||||
@@ -191,14 +215,6 @@ def generate_safety_response(safety_ratings):
|
||||
|
||||
|
||||
def format_messages_for_gemini(messages: list[ChatMessage], system_prompt: str = None) -> tuple[list[str], str]:
|
||||
if len(messages) == 1:
|
||||
messages[0].role = "user"
|
||||
return messages, system_prompt
|
||||
|
||||
for message in messages:
|
||||
if message.role == "assistant":
|
||||
message.role = "model"
|
||||
|
||||
# Extract system message
|
||||
system_prompt = system_prompt or ""
|
||||
for message in messages.copy():
|
||||
@@ -207,4 +223,23 @@ def format_messages_for_gemini(messages: list[ChatMessage], system_prompt: str =
|
||||
messages.remove(message)
|
||||
system_prompt = None if is_none_or_empty(system_prompt) else system_prompt
|
||||
|
||||
for message in messages:
|
||||
# Convert message content to string list from chatml dictionary list
|
||||
if isinstance(message.content, list):
|
||||
# Convert image_urls to PIL.Image and place them at beginning of list (better for Gemini)
|
||||
message.content = [
|
||||
get_image_from_url(item["image_url"]["url"]).content
|
||||
if item["type"] == "image_url"
|
||||
else item.get("text", "")
|
||||
for item in sorted(message.content, key=lambda x: 0 if x["type"] == "image_url" else 1)
|
||||
]
|
||||
elif isinstance(message.content, str):
|
||||
message.content = [message.content]
|
||||
|
||||
if message.role == "assistant":
|
||||
message.role = "model"
|
||||
|
||||
if len(messages) == 1:
|
||||
messages[0].role = "user"
|
||||
|
||||
return messages, system_prompt
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from datetime import datetime, timedelta
|
||||
from threading import Thread
|
||||
from typing import Any, Iterator, List, Optional, Union
|
||||
@@ -12,12 +13,14 @@ from khoj.processor.conversation import prompts
|
||||
from khoj.processor.conversation.offline.utils import download_model
|
||||
from khoj.processor.conversation.utils import (
|
||||
ThreadedGenerator,
|
||||
commit_conversation_trace,
|
||||
generate_chatml_messages_with_context,
|
||||
)
|
||||
from khoj.utils import state
|
||||
from khoj.utils.constants import empty_escape_sequences
|
||||
from khoj.utils.helpers import ConversationCommand, is_none_or_empty
|
||||
from khoj.utils.helpers import ConversationCommand, in_debug_mode, is_none_or_empty
|
||||
from khoj.utils.rawconfig import LocationData
|
||||
from khoj.utils.yaml import yaml_dump
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -34,6 +37,8 @@ def extract_questions_offline(
|
||||
max_prompt_size: int = None,
|
||||
temperature: float = 0.7,
|
||||
personality_context: Optional[str] = None,
|
||||
query_files: str = None,
|
||||
tracer: dict = {},
|
||||
) -> List[str]:
|
||||
"""
|
||||
Infer search queries to retrieve relevant notes to answer user query
|
||||
@@ -83,6 +88,7 @@ def extract_questions_offline(
|
||||
loaded_model=offline_chat_model,
|
||||
max_prompt_size=max_prompt_size,
|
||||
model_type=ChatModelOptions.ModelType.OFFLINE,
|
||||
query_files=query_files,
|
||||
)
|
||||
|
||||
state.chat_lock.acquire()
|
||||
@@ -94,6 +100,7 @@ def extract_questions_offline(
|
||||
max_prompt_size=max_prompt_size,
|
||||
temperature=temperature,
|
||||
response_type="json_object",
|
||||
tracer=tracer,
|
||||
)
|
||||
finally:
|
||||
state.chat_lock.release()
|
||||
@@ -135,7 +142,8 @@ def filter_questions(questions: List[str]):
|
||||
def converse_offline(
|
||||
user_query,
|
||||
references=[],
|
||||
online_results=[],
|
||||
online_results={},
|
||||
code_results={},
|
||||
conversation_log={},
|
||||
model: str = "bartowski/Meta-Llama-3.1-8B-Instruct-GGUF",
|
||||
loaded_model: Union[Any, None] = None,
|
||||
@@ -146,6 +154,8 @@ def converse_offline(
|
||||
location_data: LocationData = None,
|
||||
user_name: str = None,
|
||||
agent: Agent = None,
|
||||
query_files: str = None,
|
||||
tracer: dict = {},
|
||||
) -> Union[ThreadedGenerator, Iterator[str]]:
|
||||
"""
|
||||
Converse with user using Llama
|
||||
@@ -153,8 +163,7 @@ def converse_offline(
|
||||
# Initialize Variables
|
||||
assert loaded_model is None or isinstance(loaded_model, Llama), "loaded_model must be of type Llama, if configured"
|
||||
offline_chat_model = loaded_model or download_model(model, max_tokens=max_prompt_size)
|
||||
compiled_references_message = "\n\n".join({f"{item['compiled']}" for item in references})
|
||||
|
||||
tracer["chat_model"] = model
|
||||
current_date = datetime.now()
|
||||
|
||||
if agent and agent.personality:
|
||||
@@ -170,8 +179,6 @@ def converse_offline(
|
||||
day_of_week=current_date.strftime("%A"),
|
||||
)
|
||||
|
||||
conversation_primer = prompts.query_prompt.format(query=user_query)
|
||||
|
||||
if location_data:
|
||||
location_prompt = prompts.user_location.format(location=f"{location_data}")
|
||||
system_prompt = f"{system_prompt}\n{location_prompt}"
|
||||
@@ -181,45 +188,52 @@ def converse_offline(
|
||||
system_prompt = f"{system_prompt}\n{user_name_prompt}"
|
||||
|
||||
# Get Conversation Primer appropriate to Conversation Type
|
||||
if conversation_commands == [ConversationCommand.Notes] and is_none_or_empty(compiled_references_message):
|
||||
if conversation_commands == [ConversationCommand.Notes] and is_none_or_empty(references):
|
||||
return iter([prompts.no_notes_found.format()])
|
||||
elif conversation_commands == [ConversationCommand.Online] and is_none_or_empty(online_results):
|
||||
completion_func(chat_response=prompts.no_online_results_found.format())
|
||||
return iter([prompts.no_online_results_found.format()])
|
||||
|
||||
if ConversationCommand.Online in conversation_commands:
|
||||
context_message = ""
|
||||
if not is_none_or_empty(references):
|
||||
context_message = f"{prompts.notes_conversation_offline.format(references=yaml_dump(references))}\n\n"
|
||||
if ConversationCommand.Online in conversation_commands or ConversationCommand.Webpage in conversation_commands:
|
||||
simplified_online_results = online_results.copy()
|
||||
for result in online_results:
|
||||
if online_results[result].get("webpages"):
|
||||
simplified_online_results[result] = online_results[result]["webpages"]
|
||||
|
||||
conversation_primer = f"{prompts.online_search_conversation_offline.format(online_results=str(simplified_online_results))}\n{conversation_primer}"
|
||||
if not is_none_or_empty(compiled_references_message):
|
||||
conversation_primer = f"{prompts.notes_conversation_offline.format(references=compiled_references_message)}\n\n{conversation_primer}"
|
||||
context_message += f"{prompts.online_search_conversation_offline.format(online_results=yaml_dump(simplified_online_results))}\n\n"
|
||||
if ConversationCommand.Code in conversation_commands and not is_none_or_empty(code_results):
|
||||
context_message += f"{prompts.code_executed_context.format(code_results=str(code_results))}\n\n"
|
||||
context_message = context_message.strip()
|
||||
|
||||
# Setup Prompt with Primer or Conversation History
|
||||
messages = generate_chatml_messages_with_context(
|
||||
conversation_primer,
|
||||
user_query,
|
||||
system_prompt,
|
||||
conversation_log,
|
||||
context_message=context_message,
|
||||
model_name=model,
|
||||
loaded_model=offline_chat_model,
|
||||
max_prompt_size=max_prompt_size,
|
||||
tokenizer_name=tokenizer_name,
|
||||
model_type=ChatModelOptions.ModelType.OFFLINE,
|
||||
query_files=query_files,
|
||||
)
|
||||
|
||||
truncated_messages = "\n".join({f"{message.content[:70]}..." for message in messages})
|
||||
logger.debug(f"Conversation Context for {model}: {truncated_messages}")
|
||||
|
||||
g = ThreadedGenerator(references, online_results, completion_func=completion_func)
|
||||
t = Thread(target=llm_thread, args=(g, messages, offline_chat_model, max_prompt_size))
|
||||
t = Thread(target=llm_thread, args=(g, messages, offline_chat_model, max_prompt_size, tracer))
|
||||
t.start()
|
||||
return g
|
||||
|
||||
|
||||
def llm_thread(g, messages: List[ChatMessage], model: Any, max_prompt_size: int = None):
|
||||
def llm_thread(g, messages: List[ChatMessage], model: Any, max_prompt_size: int = None, tracer: dict = {}):
|
||||
stop_phrases = ["<s>", "INST]", "Notes:"]
|
||||
aggregated_response = ""
|
||||
|
||||
state.chat_lock.acquire()
|
||||
try:
|
||||
@@ -227,7 +241,14 @@ def llm_thread(g, messages: List[ChatMessage], model: Any, max_prompt_size: int
|
||||
messages, loaded_model=model, stop=stop_phrases, max_prompt_size=max_prompt_size, streaming=True
|
||||
)
|
||||
for response in response_iterator:
|
||||
g.send(response["choices"][0]["delta"].get("content", ""))
|
||||
response_delta = response["choices"][0]["delta"].get("content", "")
|
||||
aggregated_response += response_delta
|
||||
g.send(response_delta)
|
||||
|
||||
# Save conversation trace
|
||||
if in_debug_mode() or state.verbose > 1:
|
||||
commit_conversation_trace(messages, aggregated_response, tracer)
|
||||
|
||||
finally:
|
||||
state.chat_lock.release()
|
||||
g.close()
|
||||
@@ -242,14 +263,31 @@ def send_message_to_model_offline(
|
||||
stop=[],
|
||||
max_prompt_size: int = None,
|
||||
response_type: str = "text",
|
||||
tracer: dict = {},
|
||||
):
|
||||
assert loaded_model is None or isinstance(loaded_model, Llama), "loaded_model must be of type Llama, if configured"
|
||||
offline_chat_model = loaded_model or download_model(model, max_tokens=max_prompt_size)
|
||||
messages_dict = [{"role": message.role, "content": message.content} for message in messages]
|
||||
seed = int(os.getenv("KHOJ_LLM_SEED")) if os.getenv("KHOJ_LLM_SEED") else None
|
||||
response = offline_chat_model.create_chat_completion(
|
||||
messages_dict, stop=stop, stream=streaming, temperature=temperature, response_format={"type": response_type}
|
||||
messages_dict,
|
||||
stop=stop,
|
||||
stream=streaming,
|
||||
temperature=temperature,
|
||||
response_format={"type": response_type},
|
||||
seed=seed,
|
||||
)
|
||||
|
||||
if streaming:
|
||||
return response
|
||||
else:
|
||||
return response["choices"][0]["message"].get("content", "")
|
||||
|
||||
response_text = response["choices"][0]["message"].get("content", "")
|
||||
|
||||
# Save conversation trace for non-streaming responses
|
||||
# Streamed responses need to be saved by the calling function
|
||||
tracer["chat_model"] = model
|
||||
tracer["temperature"] = temperature
|
||||
if in_debug_mode() or state.verbose > 1:
|
||||
commit_conversation_trace(messages, response_text, tracer)
|
||||
|
||||
return response_text
|
||||
|
||||
@@ -12,12 +12,13 @@ from khoj.processor.conversation.openai.utils import (
|
||||
completion_with_backoff,
|
||||
)
|
||||
from khoj.processor.conversation.utils import (
|
||||
clean_json,
|
||||
construct_structured_message,
|
||||
generate_chatml_messages_with_context,
|
||||
remove_json_codeblock,
|
||||
)
|
||||
from khoj.utils.helpers import ConversationCommand, is_none_or_empty
|
||||
from khoj.utils.rawconfig import LocationData
|
||||
from khoj.utils.yaml import yaml_dump
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -30,9 +31,11 @@ def extract_questions(
|
||||
api_base_url=None,
|
||||
location_data: LocationData = None,
|
||||
user: KhojUser = None,
|
||||
uploaded_image_url: Optional[str] = None,
|
||||
query_images: Optional[list[str]] = None,
|
||||
vision_enabled: bool = False,
|
||||
personality_context: Optional[str] = None,
|
||||
query_files: str = None,
|
||||
tracer: dict = {},
|
||||
):
|
||||
"""
|
||||
Infer search queries to retrieve relevant notes to answer user query
|
||||
@@ -74,21 +77,28 @@ def extract_questions(
|
||||
|
||||
prompt = construct_structured_message(
|
||||
message=prompt,
|
||||
image_url=uploaded_image_url,
|
||||
images=query_images,
|
||||
model_type=ChatModelOptions.ModelType.OPENAI,
|
||||
vision_enabled=vision_enabled,
|
||||
attached_file_context=query_files,
|
||||
)
|
||||
|
||||
messages = [ChatMessage(content=prompt, role="user")]
|
||||
messages = []
|
||||
messages.append(ChatMessage(content=prompt, role="user"))
|
||||
|
||||
response = send_message_to_model(
|
||||
messages, api_key, model, response_type="json_object", api_base_url=api_base_url, temperature=temperature
|
||||
messages,
|
||||
api_key,
|
||||
model,
|
||||
response_type="json_object",
|
||||
api_base_url=api_base_url,
|
||||
temperature=temperature,
|
||||
tracer=tracer,
|
||||
)
|
||||
|
||||
# Extract, Clean Message from GPT's Response
|
||||
try:
|
||||
response = response.strip()
|
||||
response = remove_json_codeblock(response)
|
||||
response = clean_json(response)
|
||||
response = json.loads(response)
|
||||
response = [q.strip() for q in response["queries"] if q.strip()]
|
||||
if not isinstance(response, list) or not response:
|
||||
@@ -103,7 +113,9 @@ def extract_questions(
|
||||
return questions
|
||||
|
||||
|
||||
def send_message_to_model(messages, api_key, model, response_type="text", api_base_url=None, temperature=0):
|
||||
def send_message_to_model(
|
||||
messages, api_key, model, response_type="text", api_base_url=None, temperature=0, tracer: dict = {}
|
||||
):
|
||||
"""
|
||||
Send message to model
|
||||
"""
|
||||
@@ -116,6 +128,7 @@ def send_message_to_model(messages, api_key, model, response_type="text", api_ba
|
||||
temperature=temperature,
|
||||
api_base_url=api_base_url,
|
||||
model_kwargs={"response_format": {"type": response_type}},
|
||||
tracer=tracer,
|
||||
)
|
||||
|
||||
|
||||
@@ -123,6 +136,7 @@ def converse(
|
||||
references,
|
||||
user_query,
|
||||
online_results: Optional[Dict[str, Dict]] = None,
|
||||
code_results: Optional[Dict[str, Dict]] = None,
|
||||
conversation_log={},
|
||||
model: str = "gpt-4o-mini",
|
||||
api_key: Optional[str] = None,
|
||||
@@ -135,17 +149,16 @@ def converse(
|
||||
location_data: LocationData = None,
|
||||
user_name: str = None,
|
||||
agent: Agent = None,
|
||||
image_url: Optional[str] = None,
|
||||
query_images: Optional[list[str]] = None,
|
||||
vision_available: bool = False,
|
||||
query_files: str = None,
|
||||
tracer: dict = {},
|
||||
):
|
||||
"""
|
||||
Converse with user using OpenAI's ChatGPT
|
||||
"""
|
||||
# Initialize Variables
|
||||
current_date = datetime.now()
|
||||
compiled_references = "\n\n".join({f"# {item['compiled']}" for item in references})
|
||||
|
||||
conversation_primer = prompts.query_prompt.format(query=user_query)
|
||||
|
||||
if agent and agent.personality:
|
||||
system_prompt = prompts.custom_personality.format(
|
||||
@@ -169,31 +182,35 @@ def converse(
|
||||
system_prompt = f"{system_prompt}\n{user_name_prompt}"
|
||||
|
||||
# Get Conversation Primer appropriate to Conversation Type
|
||||
if conversation_commands == [ConversationCommand.Notes] and is_none_or_empty(compiled_references):
|
||||
if conversation_commands == [ConversationCommand.Notes] and is_none_or_empty(references):
|
||||
completion_func(chat_response=prompts.no_notes_found.format())
|
||||
return iter([prompts.no_notes_found.format()])
|
||||
elif conversation_commands == [ConversationCommand.Online] and is_none_or_empty(online_results):
|
||||
completion_func(chat_response=prompts.no_online_results_found.format())
|
||||
return iter([prompts.no_online_results_found.format()])
|
||||
|
||||
context_message = ""
|
||||
if not is_none_or_empty(references):
|
||||
context_message = f"{prompts.notes_conversation.format(references=yaml_dump(references))}\n\n"
|
||||
if not is_none_or_empty(online_results):
|
||||
conversation_primer = (
|
||||
f"{prompts.online_search_conversation.format(online_results=str(online_results))}\n{conversation_primer}"
|
||||
)
|
||||
if not is_none_or_empty(compiled_references):
|
||||
conversation_primer = f"{prompts.notes_conversation.format(query=user_query, references=compiled_references)}\n\n{conversation_primer}"
|
||||
context_message += f"{prompts.online_search_conversation.format(online_results=yaml_dump(online_results))}\n\n"
|
||||
if not is_none_or_empty(code_results):
|
||||
context_message += f"{prompts.code_executed_context.format(code_results=str(code_results))}\n\n"
|
||||
context_message = context_message.strip()
|
||||
|
||||
# Setup Prompt with Primer or Conversation History
|
||||
messages = generate_chatml_messages_with_context(
|
||||
conversation_primer,
|
||||
user_query,
|
||||
system_prompt,
|
||||
conversation_log,
|
||||
context_message=context_message,
|
||||
model_name=model,
|
||||
max_prompt_size=max_prompt_size,
|
||||
tokenizer_name=tokenizer_name,
|
||||
uploaded_image_url=image_url,
|
||||
query_images=query_images,
|
||||
vision_enabled=vision_available,
|
||||
model_type=ChatModelOptions.ModelType.OPENAI,
|
||||
query_files=query_files,
|
||||
)
|
||||
truncated_messages = "\n".join({f"{message.content[:70]}..." for message in messages})
|
||||
logger.debug(f"Conversation Context for GPT: {truncated_messages}")
|
||||
@@ -209,4 +226,5 @@ def converse(
|
||||
api_base_url=api_base_url,
|
||||
completion_func=completion_func,
|
||||
model_kwargs={"stop": ["Notes:\n["]},
|
||||
tracer=tracer,
|
||||
)
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import logging
|
||||
import os
|
||||
from threading import Thread
|
||||
from typing import Dict
|
||||
|
||||
@@ -12,7 +13,12 @@ from tenacity import (
|
||||
wait_random_exponential,
|
||||
)
|
||||
|
||||
from khoj.processor.conversation.utils import ThreadedGenerator
|
||||
from khoj.processor.conversation.utils import (
|
||||
ThreadedGenerator,
|
||||
commit_conversation_trace,
|
||||
)
|
||||
from khoj.utils import state
|
||||
from khoj.utils.helpers import in_debug_mode
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -33,7 +39,7 @@ openai_clients: Dict[str, openai.OpenAI] = {}
|
||||
reraise=True,
|
||||
)
|
||||
def completion_with_backoff(
|
||||
messages, model, temperature=0, openai_api_key=None, api_base_url=None, model_kwargs=None
|
||||
messages, model, temperature=0, openai_api_key=None, api_base_url=None, model_kwargs=None, tracer: dict = {}
|
||||
) -> str:
|
||||
client_key = f"{openai_api_key}--{api_base_url}"
|
||||
client: openai.OpenAI | None = openai_clients.get(client_key)
|
||||
@@ -55,6 +61,9 @@ def completion_with_backoff(
|
||||
model_kwargs.pop("stop", None)
|
||||
model_kwargs.pop("response_format", None)
|
||||
|
||||
if os.getenv("KHOJ_LLM_SEED"):
|
||||
model_kwargs["seed"] = int(os.getenv("KHOJ_LLM_SEED"))
|
||||
|
||||
chat = client.chat.completions.create(
|
||||
stream=stream,
|
||||
messages=formatted_messages, # type: ignore
|
||||
@@ -77,6 +86,12 @@ def completion_with_backoff(
|
||||
elif delta_chunk.content:
|
||||
aggregated_response += delta_chunk.content
|
||||
|
||||
# Save conversation trace
|
||||
tracer["chat_model"] = model
|
||||
tracer["temperature"] = temperature
|
||||
if in_debug_mode() or state.verbose > 1:
|
||||
commit_conversation_trace(messages, aggregated_response, tracer)
|
||||
|
||||
return aggregated_response
|
||||
|
||||
|
||||
@@ -103,26 +118,37 @@ def chat_completion_with_backoff(
|
||||
api_base_url=None,
|
||||
completion_func=None,
|
||||
model_kwargs=None,
|
||||
tracer: dict = {},
|
||||
):
|
||||
g = ThreadedGenerator(compiled_references, online_results, completion_func=completion_func)
|
||||
t = Thread(
|
||||
target=llm_thread, args=(g, messages, model_name, temperature, openai_api_key, api_base_url, model_kwargs)
|
||||
target=llm_thread,
|
||||
args=(g, messages, model_name, temperature, openai_api_key, api_base_url, model_kwargs, tracer),
|
||||
)
|
||||
t.start()
|
||||
return g
|
||||
|
||||
|
||||
def llm_thread(g, messages, model_name, temperature, openai_api_key=None, api_base_url=None, model_kwargs=None):
|
||||
def llm_thread(
|
||||
g,
|
||||
messages,
|
||||
model_name,
|
||||
temperature,
|
||||
openai_api_key=None,
|
||||
api_base_url=None,
|
||||
model_kwargs=None,
|
||||
tracer: dict = {},
|
||||
):
|
||||
try:
|
||||
client_key = f"{openai_api_key}--{api_base_url}"
|
||||
if client_key not in openai_clients:
|
||||
client: openai.OpenAI = openai.OpenAI(
|
||||
client = openai.OpenAI(
|
||||
api_key=openai_api_key,
|
||||
base_url=api_base_url,
|
||||
)
|
||||
openai_clients[client_key] = client
|
||||
else:
|
||||
client: openai.OpenAI = openai_clients[client_key]
|
||||
client = openai_clients[client_key]
|
||||
|
||||
formatted_messages = [{"role": message.role, "content": message.content} for message in messages]
|
||||
stream = True
|
||||
@@ -135,6 +161,9 @@ def llm_thread(g, messages, model_name, temperature, openai_api_key=None, api_ba
|
||||
model_kwargs.pop("stop", None)
|
||||
model_kwargs.pop("response_format", None)
|
||||
|
||||
if os.getenv("KHOJ_LLM_SEED"):
|
||||
model_kwargs["seed"] = int(os.getenv("KHOJ_LLM_SEED"))
|
||||
|
||||
chat = client.chat.completions.create(
|
||||
stream=stream,
|
||||
messages=formatted_messages,
|
||||
@@ -144,17 +173,29 @@ def llm_thread(g, messages, model_name, temperature, openai_api_key=None, api_ba
|
||||
**(model_kwargs or dict()),
|
||||
)
|
||||
|
||||
aggregated_response = ""
|
||||
if not stream:
|
||||
g.send(chat.choices[0].message.content)
|
||||
aggregated_response = chat.choices[0].message.content
|
||||
g.send(aggregated_response)
|
||||
else:
|
||||
for chunk in chat:
|
||||
if len(chunk.choices) == 0:
|
||||
continue
|
||||
delta_chunk = chunk.choices[0].delta
|
||||
text_chunk = ""
|
||||
if isinstance(delta_chunk, str):
|
||||
g.send(delta_chunk)
|
||||
text_chunk = delta_chunk
|
||||
elif delta_chunk.content:
|
||||
g.send(delta_chunk.content)
|
||||
text_chunk = delta_chunk.content
|
||||
if text_chunk:
|
||||
aggregated_response += text_chunk
|
||||
g.send(text_chunk)
|
||||
|
||||
# Save conversation trace
|
||||
tracer["chat_model"] = model_name
|
||||
tracer["temperature"] = temperature
|
||||
if in_debug_mode() or state.verbose > 1:
|
||||
commit_conversation_trace(messages, aggregated_response, tracer)
|
||||
except Exception as e:
|
||||
logger.error(f"Error in llm_thread: {e}", exc_info=True)
|
||||
finally:
|
||||
|
||||
@@ -118,6 +118,7 @@ Use my personal notes and our past conversations to inform your response.
|
||||
Ask crisp follow-up questions to get additional context, when a helpful response cannot be provided from the provided notes or past conversations.
|
||||
|
||||
User's Notes:
|
||||
-----
|
||||
{references}
|
||||
""".strip()
|
||||
)
|
||||
@@ -127,6 +128,7 @@ notes_conversation_offline = PromptTemplate.from_template(
|
||||
Use my personal notes and our past conversations to inform your response.
|
||||
|
||||
User's Notes:
|
||||
-----
|
||||
{references}
|
||||
""".strip()
|
||||
)
|
||||
@@ -176,6 +178,134 @@ Improved Prompt:
|
||||
""".strip()
|
||||
)
|
||||
|
||||
## Diagram Generation
|
||||
## --
|
||||
|
||||
improve_diagram_description_prompt = PromptTemplate.from_template(
|
||||
"""
|
||||
you are an architect working with a novice artist using a diagramming tool.
|
||||
{personality_context}
|
||||
|
||||
you need to convert the user's query to a description format that the novice artist can use very well. you are allowed to use primitives like
|
||||
- text
|
||||
- rectangle
|
||||
- diamond
|
||||
- ellipse
|
||||
- line
|
||||
- arrow
|
||||
|
||||
use these primitives to describe what sort of diagram the drawer should create. the artist must recreate the diagram every time, so include all relevant prior information in your description.
|
||||
|
||||
use simple, concise language.
|
||||
|
||||
Today's Date: {current_date}
|
||||
User's Location: {location}
|
||||
|
||||
User's Notes:
|
||||
{references}
|
||||
|
||||
Online References:
|
||||
{online_results}
|
||||
|
||||
Conversation Log:
|
||||
{chat_history}
|
||||
|
||||
Query: {query}
|
||||
|
||||
|
||||
""".strip()
|
||||
)
|
||||
|
||||
excalidraw_diagram_generation_prompt = PromptTemplate.from_template(
|
||||
"""
|
||||
You are a program manager with the ability to describe diagrams to compose in professional, fine detail.
|
||||
{personality_context}
|
||||
|
||||
You need to create a declarative description of the diagram and relevant components, using this base schema. Use the `label` property to specify the text to be rendered in the respective elements. Always use light colors for the `backgroundColor` property, like white, or light blue, green, red. "type", "x", "y", "id", are required properties for all elements.
|
||||
|
||||
{{
|
||||
type: string,
|
||||
x: number,
|
||||
y: number,
|
||||
strokeColor: string,
|
||||
backgroundColor: string,
|
||||
width: number,
|
||||
height: number,
|
||||
id: string,
|
||||
label: {{
|
||||
text: string,
|
||||
}}
|
||||
}}
|
||||
|
||||
Valid types:
|
||||
- text
|
||||
- rectangle
|
||||
- diamond
|
||||
- ellipse
|
||||
- line
|
||||
- arrow
|
||||
|
||||
For arrows and lines, you can use the `points` property to specify the start and end points of the arrow. You may also use the `label` property to specify the text to be rendered. You may use the `start` and `end` properties to connect the linear elements to other elements. The start and end point can either be the ID to map to an existing object, or the `type` to create a new object. Mapping to an existing object is useful if you want to connect it to multiple objects. Lines and arrows can only start and end at rectangle, text, diamond, or ellipse elements.
|
||||
|
||||
{{
|
||||
type: "arrow",
|
||||
id: string,
|
||||
x: number,
|
||||
y: number,
|
||||
width: number,
|
||||
height: number,
|
||||
strokeColor: string,
|
||||
start: {{
|
||||
id: string,
|
||||
type: string,
|
||||
}},
|
||||
end: {{
|
||||
id: string,
|
||||
type: string,
|
||||
}},
|
||||
label: {{
|
||||
text: string,
|
||||
}}
|
||||
points: [
|
||||
[number, number],
|
||||
[number, number],
|
||||
]
|
||||
}}
|
||||
|
||||
For text, you must use the `text` property to specify the text to be rendered. You may also use `fontSize` property to specify the font size of the text. Only use the `text` element for titles, subtitles, and overviews. For labels, use the `label` property in the respective elements.
|
||||
|
||||
{{
|
||||
type: "text",
|
||||
id: string,
|
||||
x: number,
|
||||
y: number,
|
||||
fontSize: number,
|
||||
text: string,
|
||||
}}
|
||||
|
||||
Here's an example of a valid diagram:
|
||||
|
||||
Design Description: Create a diagram describing a circular development process with 3 stages: design, implementation and feedback. The design stage is connected to the implementation stage and the implementation stage is connected to the feedback stage and the feedback stage is connected to the design stage. Each stage should be labeled with the stage name.
|
||||
|
||||
Response:
|
||||
|
||||
[
|
||||
{{"type":"text","x":-150,"y":50,"width":300,"height":40,"id":"title_text","text":"Circular Development Process","fontSize":24}},
|
||||
{{"type":"ellipse","x":-169,"y":113,"width":188,"height":202,"id":"design_ellipse", "label": {{"text": "Design"}}}},
|
||||
{{"type":"ellipse","x":62,"y":394,"width":186,"height":188,"id":"implement_ellipse", "label": {{"text": "Implement"}}}},
|
||||
{{"type":"ellipse","x":-348,"y":430,"width":184,"height":170,"id":"feedback_ellipse", "label": {{"text": "Feedback"}}}},
|
||||
{{"type":"arrow","x":21,"y":273,"id":"design_to_implement_arrow","points":[[0,0],[86,105]],"start":{{"id":"design_ellipse"}}, "end":{{"id":"implement_ellipse"}}}},
|
||||
{{"type":"arrow","x":50,"y":519,"id":"implement_to_feedback_arrow","points":[[0,0],[-198,-6]],"start":{{"id":"implement_ellipse"}}, "end":{{"id":"feedback_ellipse"}}}},
|
||||
{{"type":"arrow","x":-228,"y":417,"id":"feedback_to_design_arrow","points":[[0,0],[85,-123]],"start":{{"id":"feedback_ellipse"}}, "end":{{"id":"design_ellipse"}}}},
|
||||
]
|
||||
|
||||
Create a detailed diagram from the provided context and user prompt below. Return a valid JSON object:
|
||||
|
||||
Diagram Description: {query}
|
||||
|
||||
""".strip()
|
||||
)
|
||||
|
||||
## Online Search Conversation
|
||||
## --
|
||||
online_search_conversation = PromptTemplate.from_template(
|
||||
@@ -184,6 +314,7 @@ Use this up-to-date information from the internet to inform your response.
|
||||
Ask crisp follow-up questions to get additional context, when a helpful response cannot be provided from the online data or past conversations.
|
||||
|
||||
Information from the internet:
|
||||
-----
|
||||
{online_results}
|
||||
""".strip()
|
||||
)
|
||||
@@ -193,6 +324,7 @@ online_search_conversation_offline = PromptTemplate.from_template(
|
||||
Use this up-to-date information from the internet to inform your response.
|
||||
|
||||
Information from the internet:
|
||||
-----
|
||||
{online_results}
|
||||
""".strip()
|
||||
)
|
||||
@@ -262,21 +394,23 @@ Q: {query}
|
||||
|
||||
extract_questions = PromptTemplate.from_template(
|
||||
"""
|
||||
You are Khoj, an extremely smart and helpful document search assistant with only the ability to retrieve information from the user's notes. Disregard online search requests.
|
||||
You are Khoj, an extremely smart and helpful document search assistant with only the ability to retrieve information from the user's notes and documents.
|
||||
Construct search queries to retrieve relevant information to answer the user's question.
|
||||
- You will be provided past questions(Q) and answers(A) for context.
|
||||
- You will be provided example and actual past user questions(Q), search queries(Khoj) and answers(A) for context.
|
||||
- Add as much context from the previous questions and answers as required into your search queries.
|
||||
- Break messages into multiple search queries when required to retrieve the relevant information.
|
||||
- Break your search down into multiple search queries from a diverse set of lenses to retrieve all related documents.
|
||||
- Add date filters to your search queries from questions and answers when required to retrieve the relevant information.
|
||||
- When asked a meta, vague or random questions, search for a variety of broad topics to answer the user's question.
|
||||
{personality_context}
|
||||
What searches will you perform to answer the users question? Respond with search queries as list of strings in a JSON object.
|
||||
What searches will you perform to answer the user's question? Respond with search queries as list of strings in a JSON object.
|
||||
Current Date: {day_of_week}, {current_date}
|
||||
User's Location: {location}
|
||||
{username}
|
||||
|
||||
Examples
|
||||
---
|
||||
Q: How was my trip to Cambodia?
|
||||
Khoj: {{"queries": ["How was my trip to Cambodia?"]}}
|
||||
Khoj: {{"queries": ["How was my trip to Cambodia?", "Angkor Wat temple visit", "Flight to Phnom Penh", "Expenses in Cambodia", "Stay in Cambodia"]}}
|
||||
A: The trip was amazing. You went to the Angkor Wat temple and it was beautiful.
|
||||
|
||||
Q: Who did i visit that temple with?
|
||||
@@ -311,6 +445,8 @@ Q: Who all did I meet here yesterday?
|
||||
Khoj: {{"queries": ["Met in {location} on {yesterday_date} dt>='{yesterday_date}' dt<'{current_date}'"]}}
|
||||
A: Yesterday's note mentions your visit to your local beach with Ram and Shyam.
|
||||
|
||||
Actual
|
||||
---
|
||||
{chat_history}
|
||||
Q: {text}
|
||||
Khoj:
|
||||
@@ -319,11 +455,11 @@ Khoj:
|
||||
|
||||
extract_questions_anthropic_system_prompt = PromptTemplate.from_template(
|
||||
"""
|
||||
You are Khoj, an extremely smart and helpful document search assistant with only the ability to retrieve information from the user's notes. Disregard online search requests.
|
||||
You are Khoj, an extremely smart and helpful document search assistant with only the ability to retrieve information from the user's notes.
|
||||
Construct search queries to retrieve relevant information to answer the user's question.
|
||||
- You will be provided past questions(User), extracted queries(Assistant) and answers(A) for context.
|
||||
- You will be provided past questions(User), search queries(Assistant) and answers(A) for context.
|
||||
- Add as much context from the previous questions and answers as required into your search queries.
|
||||
- Break messages into multiple search queries when required to retrieve the relevant information.
|
||||
- Break your search down into multiple search queries from a diverse set of lenses to retrieve all related documents.
|
||||
- Add date filters to your search queries from questions and answers when required to retrieve the relevant information.
|
||||
- When asked a meta, vague or random questions, search for a variety of broad topics to answer the user's question.
|
||||
{personality_context}
|
||||
@@ -336,7 +472,7 @@ User's Location: {location}
|
||||
Here are some examples of how you can construct search queries to answer the user's question:
|
||||
|
||||
User: How was my trip to Cambodia?
|
||||
Assistant: {{"queries": ["How was my trip to Cambodia?"]}}
|
||||
Assistant: {{"queries": ["How was my trip to Cambodia?", "Angkor Wat temple visit", "Flight to Phnom Penh", "Expenses in Cambodia", "Stay in Cambodia"]}}
|
||||
A: The trip was amazing. You went to the Angkor Wat temple and it was beautiful.
|
||||
|
||||
User: What national parks did I go to last year?
|
||||
@@ -369,17 +505,14 @@ Assistant:
|
||||
)
|
||||
|
||||
system_prompt_extract_relevant_information = """
|
||||
As a professional analyst, create a comprehensive report of the most relevant information from a web page in response to a user's query.
|
||||
The text provided is directly from within the web page.
|
||||
The report you create should be multiple paragraphs, and it should represent the content of the website.
|
||||
Tell the user exactly what the website says in response to their query, while adhering to these guidelines:
|
||||
As a professional analyst, your job is to extract all pertinent information from documents to help answer user's query.
|
||||
You will be provided raw text directly from within the document.
|
||||
Adhere to these guidelines while extracting information from the provided documents:
|
||||
|
||||
1. Answer the user's query as specifically as possible. Include many supporting details from the website.
|
||||
2. Craft a report that is detailed, thorough, in-depth, and complex, while maintaining clarity.
|
||||
3. Rely strictly on the provided text, without including external information.
|
||||
4. Format the report in multiple paragraphs with a clear structure.
|
||||
5. Be as specific as possible in your answer to the user's query.
|
||||
6. Reproduce as much of the provided text as possible, while maintaining readability.
|
||||
1. Extract all relevant text and links from the document that can assist with further research or answer the user's query.
|
||||
2. Craft a comprehensive but compact report with all the necessary data from the document to generate an informed response.
|
||||
3. Rely strictly on the provided text to generate your summary, without including external information.
|
||||
4. Provide specific, important snippets from the document in your report to establish trust in your summary.
|
||||
""".strip()
|
||||
|
||||
extract_relevant_information = PromptTemplate.from_template(
|
||||
@@ -387,10 +520,10 @@ extract_relevant_information = PromptTemplate.from_template(
|
||||
{personality_context}
|
||||
Target Query: {query}
|
||||
|
||||
Web Pages:
|
||||
Document:
|
||||
{corpus}
|
||||
|
||||
Collate only relevant information from the website to answer the target query.
|
||||
Collate only relevant information from the document to answer the target query.
|
||||
""".strip()
|
||||
)
|
||||
|
||||
@@ -475,7 +608,7 @@ AI: It's currently 28°C and partly cloudy in Bali.
|
||||
Q: Share a painting using the weather for Bali every morning.
|
||||
Khoj: {{"output": "automation"}}
|
||||
|
||||
Now it's your turn to pick the mode you would like to use to answer the user's question. Provide your response as a JSON.
|
||||
Now it's your turn to pick the mode you would like to use to answer the user's question. Provide your response as a JSON. Do not say anything else.
|
||||
|
||||
Chat History:
|
||||
{chat_history}
|
||||
@@ -485,6 +618,67 @@ Khoj:
|
||||
""".strip()
|
||||
)
|
||||
|
||||
plan_function_execution = PromptTemplate.from_template(
|
||||
"""
|
||||
You are Khoj, a smart, creative and methodical researcher. Use the provided tool AIs to investigate information to answer query.
|
||||
Create a multi-step plan and intelligently iterate on the plan based on the retrieved information to find the requested information.
|
||||
{personality_context}
|
||||
|
||||
# Instructions
|
||||
- Ask highly diverse, detailed queries to the tool AIs, one tool AI at a time, to discover required information or run calculations. Their response will be shown to you in the next iteration.
|
||||
- Break down your research process into independent, self-contained steps that can be executed sequentially using the available tool AIs to answer the user's query. Write your step-by-step plan in the scratchpad.
|
||||
- Always ask a new query that was not asked to the tool AI in a previous iteration. Build on the results of the previous iterations.
|
||||
- Ensure that all required context is passed to the tool AIs for successful execution. They only know the context provided in your query.
|
||||
- Think step by step to come up with creative strategies when the previous iteration did not yield useful results.
|
||||
- You are allowed upto {max_iterations} iterations to use the help of the provided tool AIs to answer the user's question.
|
||||
- Stop when you have the required information by returning a JSON object with an empty "tool" field. E.g., {{scratchpad: "I have all I need", tool: "", query: ""}}
|
||||
|
||||
# Examples
|
||||
Assuming you can search the user's notes and the internet.
|
||||
- When the user asks for the population of their hometown
|
||||
1. Try look up their hometown in their notes. Ask the note search AI to search for their birth certificate, childhood memories, school, resume etc.
|
||||
2. If not found in their notes, try infer their hometown from their online social media profiles. Ask the online search AI to look for {username}'s biography, school, resume on linkedin, facebook, website etc.
|
||||
3. Only then try find the latest population of their hometown by reading official websites with the help of the online search and web page reading AI.
|
||||
- When the user asks for their computer's specs
|
||||
1. Try find their computer model in their notes.
|
||||
2. Now find webpages with their computer model's spec online.
|
||||
3. Ask the the webpage tool AI to extract the required information from the relevant webpages.
|
||||
- When the user asks what clothes to carry for their upcoming trip
|
||||
1. Find the itinerary of their upcoming trip in their notes.
|
||||
2. Next find the weather forecast at the destination online.
|
||||
3. Then find if they mentioned what clothes they own in their notes.
|
||||
|
||||
# Background Context
|
||||
- Current Date: {day_of_week}, {current_date}
|
||||
- User Location: {location}
|
||||
- User Name: {username}
|
||||
|
||||
# Available Tool AIs
|
||||
Which of the tool AIs listed below would you use to answer the user's question? You **only** have access to the following tool AIs:
|
||||
|
||||
{tools}
|
||||
|
||||
# Previous Iterations
|
||||
{previous_iterations}
|
||||
|
||||
# Chat History:
|
||||
{chat_history}
|
||||
|
||||
Return the next tool AI to use and the query to ask it. Your response should always be a valid JSON object. Do not say anything else.
|
||||
Response format:
|
||||
{{"scratchpad": "<your_scratchpad_to_reason_about_which_tool_to_use>", "query": "<your_detailed_query_for_the_tool_ai>", "tool": "<name_of_tool_ai>"}}
|
||||
""".strip()
|
||||
)
|
||||
|
||||
previous_iteration = PromptTemplate.from_template(
|
||||
"""
|
||||
## Iteration {index}:
|
||||
- tool: {tool}
|
||||
- query: {query}
|
||||
- result: {result}
|
||||
"""
|
||||
)
|
||||
|
||||
pick_relevant_information_collection_tools = PromptTemplate.from_template(
|
||||
"""
|
||||
You are Khoj, an extremely smart and helpful search assistant.
|
||||
@@ -604,8 +798,8 @@ Khoj:
|
||||
online_search_conversation_subqueries = PromptTemplate.from_template(
|
||||
"""
|
||||
You are Khoj, an advanced web search assistant. You are tasked with constructing **up to three** google search queries to answer the user's question.
|
||||
- You will receive the conversation history as context.
|
||||
- Add as much context from the previous questions and answers as required into your search queries.
|
||||
- You will receive the actual chat history as context.
|
||||
- Add as much context from the chat history as required into your search queries.
|
||||
- Break messages into multiple search queries when required to retrieve the relevant information.
|
||||
- Use site: google search operator when appropriate
|
||||
- You have access to the the whole internet to retrieve information.
|
||||
@@ -618,62 +812,122 @@ User's Location: {location}
|
||||
{username}
|
||||
|
||||
Here are some examples:
|
||||
History:
|
||||
Example Chat History:
|
||||
User: I like to use Hacker News to get my tech news.
|
||||
Khoj: {{queries: ["what is Hacker News?", "Hacker News website for tech news"]}}
|
||||
AI: Hacker News is an online forum for sharing and discussing the latest tech news. It is a great place to learn about new technologies and startups.
|
||||
|
||||
Q: Summarize the top posts on HackerNews
|
||||
User: Summarize the top posts on HackerNews
|
||||
Khoj: {{"queries": ["top posts on HackerNews"]}}
|
||||
|
||||
History:
|
||||
|
||||
Q: Tell me the latest news about the farmers protest in Colombia and China on Reuters
|
||||
Example Chat History:
|
||||
User: Tell me the latest news about the farmers protest in Colombia and China on Reuters
|
||||
Khoj: {{"queries": ["site:reuters.com farmers protest Colombia", "site:reuters.com farmers protest China"]}}
|
||||
|
||||
History:
|
||||
Example Chat History:
|
||||
User: I'm currently living in New York but I'm thinking about moving to San Francisco.
|
||||
Khoj: {{"queries": ["New York city vs San Francisco life", "San Francisco living cost", "New York city living cost"]}}
|
||||
AI: New York is a great city to live in. It has a lot of great restaurants and museums. San Francisco is also a great city to live in. It has good access to nature and a great tech scene.
|
||||
|
||||
Q: What is the climate like in those cities?
|
||||
Khoj: {{"queries": ["climate in new york city", "climate in san francisco"]}}
|
||||
User: What is the climate like in those cities?
|
||||
Khoj: {{"queries": ["climate in New York city", "climate in San Francisco"]}}
|
||||
|
||||
History:
|
||||
AI: Hey, how is it going?
|
||||
User: Going well. Ananya is in town tonight!
|
||||
Example Chat History:
|
||||
User: Hey, Ananya is in town tonight!
|
||||
Khoj: {{"queries": ["events in {location} tonight", "best restaurants in {location}", "places to visit in {location}"]}}
|
||||
AI: Oh that's awesome! What are your plans for the evening?
|
||||
|
||||
Q: She wants to see a movie. Any decent sci-fi movies playing at the local theater?
|
||||
User: She wants to see a movie. Any decent sci-fi movies playing at the local theater?
|
||||
Khoj: {{"queries": ["new sci-fi movies in theaters near {location}"]}}
|
||||
|
||||
History:
|
||||
Example Chat History:
|
||||
User: Can I chat with you over WhatsApp?
|
||||
Khoj: {{"queries": ["site:khoj.dev chat with Khoj on Whatsapp"]}}
|
||||
AI: Yes, you can chat with me using WhatsApp.
|
||||
|
||||
Q: How
|
||||
Khoj: {{"queries": ["site:khoj.dev chat with Khoj on Whatsapp"]}}
|
||||
|
||||
History:
|
||||
|
||||
|
||||
Q: How do I share my files with you?
|
||||
Example Chat History:
|
||||
User: How do I share my files with Khoj?
|
||||
Khoj: {{"queries": ["site:khoj.dev sync files with Khoj"]}}
|
||||
|
||||
History:
|
||||
Example Chat History:
|
||||
User: I need to transport a lot of oranges to the moon. Are there any rockets that can fit a lot of oranges?
|
||||
Khoj: {{"queries": ["current rockets with large cargo capacity", "rocket rideshare cost by cargo capacity"]}}
|
||||
AI: NASA's Saturn V rocket frequently makes lunar trips and has a large cargo capacity.
|
||||
|
||||
Q: How many oranges would fit in NASA's Saturn V rocket?
|
||||
Khoj: {{"queries": ["volume of an orange", "volume of saturn v rocket"]}}
|
||||
User: How many oranges would fit in NASA's Saturn V rocket?
|
||||
Khoj: {{"queries": ["volume of an orange", "volume of Saturn V rocket"]}}
|
||||
|
||||
Now it's your turn to construct Google search queries to answer the user's question. Provide them as a list of strings in a JSON object. Do not say anything else.
|
||||
History:
|
||||
Actual Chat History:
|
||||
{chat_history}
|
||||
|
||||
Q: {query}
|
||||
User: {query}
|
||||
Khoj:
|
||||
""".strip()
|
||||
)
|
||||
|
||||
# Code Generation
|
||||
# --
|
||||
python_code_generation_prompt = PromptTemplate.from_template(
|
||||
"""
|
||||
You are Khoj, an advanced python programmer. You are tasked with constructing a python program to best answer the user query.
|
||||
- The python program will run in a pyodide python sandbox with no network access.
|
||||
- You can write programs to run complex calculations, analyze data, create charts, generate documents to meticulously answer the query.
|
||||
- The sandbox has access to the standard library, matplotlib, panda, numpy, scipy, bs4, sympy, brotli, cryptography, fast-parquet.
|
||||
- List known file paths to required user documents in "input_files" and known links to required documents from the web in the "input_links" field.
|
||||
- The python program should be self-contained. It can only read data generated by the program itself and from provided input_files, input_links by their basename (i.e filename excluding file path).
|
||||
- Do not try display images or plots in the code directly. The code should save the image or plot to a file instead.
|
||||
- Write any document, charts etc. to be shared with the user to file. These files can be seen by the user.
|
||||
- Use as much context from the previous questions and answers as required to generate your code.
|
||||
{personality_context}
|
||||
What code will you need to write to answer the user's question?
|
||||
|
||||
Current Date: {current_date}
|
||||
User's Location: {location}
|
||||
{username}
|
||||
|
||||
The response JSON schema is of the form {{"code": "<python_code>", "input_files": ["file_path_1", "file_path_2"], "input_links": ["link_1", "link_2"]}}
|
||||
Examples:
|
||||
---
|
||||
{{
|
||||
"code": "# Input values\\nprincipal = 43235\\nrate = 5.24\\nyears = 5\\n\\n# Convert rate to decimal\\nrate_decimal = rate / 100\\n\\n# Calculate final amount\\nfinal_amount = principal * (1 + rate_decimal) ** years\\n\\n# Calculate interest earned\\ninterest_earned = final_amount - principal\\n\\n# Print results with formatting\\nprint(f"Interest Earned: ${{interest_earned:,.2f}}")\\nprint(f"Final Amount: ${{final_amount:,.2f}}")"
|
||||
}}
|
||||
|
||||
{{
|
||||
"code": "import re\\n\\n# Read org file\\nfile_path = 'tasks.org'\\nwith open(file_path, 'r') as f:\\n content = f.read()\\n\\n# Get today's date in YYYY-MM-DD format\\ntoday = datetime.now().strftime('%Y-%m-%d')\\npattern = r'\*+\s+.*\\n.*SCHEDULED:\s+<' + today + r'.*>'\\n\\n# Find all matches using multiline mode\\nmatches = re.findall(pattern, content, re.MULTILINE)\\ncount = len(matches)\\n\\n# Display count\\nprint(f'Count of scheduled tasks for today: {{count}}')",
|
||||
"input_files": ["/home/linux/tasks.org"]
|
||||
}}
|
||||
|
||||
{{
|
||||
"code": "import pandas as pd\\nimport matplotlib.pyplot as plt\\n\\n# Load the CSV file\\ndf = pd.read_csv('world_population_by_year.csv')\\n\\n# Plot the data\\nplt.figure(figsize=(10, 6))\\nplt.plot(df['Year'], df['Population'], marker='o')\\n\\n# Add titles and labels\\nplt.title('Population by Year')\\nplt.xlabel('Year')\\nplt.ylabel('Population')\\n\\n# Save the plot to a file\\nplt.savefig('population_by_year_plot.png')",
|
||||
"input_links": ["https://population.un.org/world_population_by_year.csv"]
|
||||
}}
|
||||
|
||||
Now it's your turn to construct a python program to answer the user's question. Provide the code, required input files and input links in a JSON object. Do not say anything else.
|
||||
Context:
|
||||
---
|
||||
{context}
|
||||
|
||||
Chat History:
|
||||
---
|
||||
{chat_history}
|
||||
|
||||
User: {query}
|
||||
Khoj:
|
||||
""".strip()
|
||||
)
|
||||
|
||||
code_executed_context = PromptTemplate.from_template(
|
||||
"""
|
||||
Use the provided code executions to inform your response.
|
||||
Ask crisp follow-up questions to get additional context, when a helpful response cannot be provided from the provided code execution results or past conversations.
|
||||
|
||||
Code Execution Results:
|
||||
{code_results}
|
||||
""".strip()
|
||||
)
|
||||
|
||||
|
||||
# Automations
|
||||
# --
|
||||
crontime_prompt = PromptTemplate.from_template(
|
||||
@@ -749,16 +1003,27 @@ You are an extremely smart and helpful title generator assistant. Given a user q
|
||||
|
||||
# Examples:
|
||||
User: Show a new Calvin and Hobbes quote every morning at 9am. My Current Location: Shanghai, China
|
||||
Khoj: Your daily Calvin and Hobbes Quote
|
||||
Assistant: Your daily Calvin and Hobbes Quote
|
||||
|
||||
User: Notify me when version 2.0.0 of the sentence transformers python package is released. My Current Location: Mexico City, Mexico
|
||||
Khoj: Sentence Transformers Python Package Version 2.0.0 Release
|
||||
Assistant: Sentence Transformers Python Package Version 2.0.0 Release
|
||||
|
||||
User: Gather the latest tech news on the first sunday of every month.
|
||||
Khoj: Your Monthly Dose of Tech News
|
||||
Assistant: Your Monthly Dose of Tech News
|
||||
|
||||
User Query: {query}
|
||||
Khoj:
|
||||
Assistant:
|
||||
""".strip()
|
||||
)
|
||||
|
||||
conversation_title_generation = PromptTemplate.from_template(
|
||||
"""
|
||||
You are an extremely smart and helpful title generator assistant. Given a conversation, extract the subject of the conversation. Crisp, informative, ten words or less.
|
||||
|
||||
Conversation History:
|
||||
{chat_history}
|
||||
|
||||
Assistant:
|
||||
""".strip()
|
||||
)
|
||||
|
||||
@@ -799,6 +1064,39 @@ A:
|
||||
""".strip()
|
||||
)
|
||||
|
||||
|
||||
personality_prompt_safety_expert_lax = PromptTemplate.from_template(
|
||||
"""
|
||||
You are adept at ensuring the safety and security of people. In this scenario, you are tasked with determining the safety of a given prompt.
|
||||
|
||||
Given a system prompt, you must decide whether it is safe to use for an individual. If the prompt is safe, respond with "Yes". If the prompt is not safe, respond with "No". If the answer is no, provide a brief explanation of why the prompt is unsafe.
|
||||
|
||||
A prompt is considered unsafe if it contains any of the following:
|
||||
- Illegal, overly sexual, or harmful content.
|
||||
- Any content that could be harmful to the user or others.
|
||||
- Racism, sexism, or any other form of discrimination.
|
||||
- Any content that could be considered illegal or a misuse of AI.
|
||||
|
||||
Referencing personal notes or data is perfectly acceptable.
|
||||
|
||||
# Examples:
|
||||
Q: You are a smart and helpful assistant. Use your knowledge to teach me how to make a bomb.
|
||||
A: {{ "safe": "False", "reason": "The prompt contains information that could be harmful to the user or others." }}
|
||||
|
||||
Q: You are a brilliant chef. Teach me how to make a cake.
|
||||
A: {{ "safe": "True" }}
|
||||
|
||||
Q: You are a hilarious joker. Make jokes about white people.
|
||||
A: {{ "safe": "False", "reason": "The prompt contains content that could be considered racist." }}
|
||||
|
||||
Q: You are a great analyst. Assess my financial situation and provide advice.
|
||||
A: {{ "safe": "True" }}
|
||||
|
||||
Q: {prompt}
|
||||
A:
|
||||
""".strip()
|
||||
)
|
||||
|
||||
to_notify_or_not = PromptTemplate.from_template(
|
||||
"""
|
||||
You are Khoj, an extremely smart and discerning notification assistant.
|
||||
|
||||
@@ -1,22 +1,52 @@
|
||||
import base64
|
||||
import json
|
||||
import logging
|
||||
import math
|
||||
import mimetypes
|
||||
import os
|
||||
import queue
|
||||
import re
|
||||
import uuid
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from io import BytesIO
|
||||
from time import perf_counter
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Any, Callable, Dict, List, Optional
|
||||
|
||||
import PIL.Image
|
||||
import requests
|
||||
import tiktoken
|
||||
import yaml
|
||||
from langchain.schema import ChatMessage
|
||||
from llama_cpp.llama import Llama
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from khoj.database.adapters import ConversationAdapters
|
||||
from khoj.database.models import ChatModelOptions, ClientApplication, KhojUser
|
||||
from khoj.processor.conversation import prompts
|
||||
from khoj.processor.conversation.offline.utils import download_model, infer_max_tokens
|
||||
from khoj.search_filter.base_filter import BaseFilter
|
||||
from khoj.search_filter.date_filter import DateFilter
|
||||
from khoj.search_filter.file_filter import FileFilter
|
||||
from khoj.search_filter.word_filter import WordFilter
|
||||
from khoj.utils import state
|
||||
from khoj.utils.helpers import is_none_or_empty, merge_dicts
|
||||
from khoj.utils.helpers import (
|
||||
ConversationCommand,
|
||||
in_debug_mode,
|
||||
is_none_or_empty,
|
||||
merge_dicts,
|
||||
)
|
||||
from khoj.utils.rawconfig import FileAttachment
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
try:
|
||||
from git import Repo
|
||||
except ImportError:
|
||||
if in_debug_mode():
|
||||
logger.warning("GitPython not installed. `pip install gitpython` to enable prompt tracer.")
|
||||
|
||||
model_to_prompt_size = {
|
||||
# OpenAI Models
|
||||
"gpt-3.5-turbo": 12000,
|
||||
@@ -75,8 +105,122 @@ class ThreadedGenerator:
|
||||
self.queue.put(StopIteration)
|
||||
|
||||
|
||||
class InformationCollectionIteration:
|
||||
def __init__(
|
||||
self,
|
||||
tool: str,
|
||||
query: str,
|
||||
context: list = None,
|
||||
onlineContext: dict = None,
|
||||
codeContext: dict = None,
|
||||
summarizedResult: str = None,
|
||||
warning: str = None,
|
||||
):
|
||||
self.tool = tool
|
||||
self.query = query
|
||||
self.context = context
|
||||
self.onlineContext = onlineContext
|
||||
self.codeContext = codeContext
|
||||
self.summarizedResult = summarizedResult
|
||||
self.warning = warning
|
||||
|
||||
|
||||
def construct_iteration_history(
|
||||
previous_iterations: List[InformationCollectionIteration], previous_iteration_prompt: str
|
||||
) -> str:
|
||||
previous_iterations_history = ""
|
||||
for idx, iteration in enumerate(previous_iterations):
|
||||
iteration_data = previous_iteration_prompt.format(
|
||||
tool=iteration.tool,
|
||||
query=iteration.query,
|
||||
result=iteration.summarizedResult,
|
||||
index=idx + 1,
|
||||
)
|
||||
|
||||
previous_iterations_history += iteration_data
|
||||
return previous_iterations_history
|
||||
|
||||
|
||||
def construct_chat_history(conversation_history: dict, n: int = 4, agent_name="AI") -> str:
|
||||
chat_history = ""
|
||||
for chat in conversation_history.get("chat", [])[-n:]:
|
||||
if chat["by"] == "khoj" and chat["intent"].get("type") in ["remember", "reminder", "summarize"]:
|
||||
chat_history += f"User: {chat['intent']['query']}\n"
|
||||
|
||||
if chat["intent"].get("inferred-queries"):
|
||||
chat_history += f'{agent_name}: {{"queries": {chat["intent"].get("inferred-queries")}}}\n'
|
||||
|
||||
chat_history += f"{agent_name}: {chat['message']}\n\n"
|
||||
elif chat["by"] == "khoj" and ("text-to-image" in chat["intent"].get("type")):
|
||||
chat_history += f"User: {chat['intent']['query']}\n"
|
||||
chat_history += f"{agent_name}: [generated image redacted for space]\n"
|
||||
elif chat["by"] == "khoj" and ("excalidraw" in chat["intent"].get("type")):
|
||||
chat_history += f"User: {chat['intent']['query']}\n"
|
||||
chat_history += f"{agent_name}: {chat['intent']['inferred-queries'][0]}\n"
|
||||
elif chat["by"] == "you":
|
||||
raw_query_files = chat.get("queryFiles")
|
||||
if raw_query_files:
|
||||
query_files: Dict[str, str] = {}
|
||||
for file in raw_query_files:
|
||||
query_files[file["name"]] = file["content"]
|
||||
|
||||
query_file_context = gather_raw_query_files(query_files)
|
||||
chat_history += f"User: {query_file_context}\n"
|
||||
|
||||
return chat_history
|
||||
|
||||
|
||||
def construct_tool_chat_history(
|
||||
previous_iterations: List[InformationCollectionIteration], tool: ConversationCommand = None
|
||||
) -> Dict[str, list]:
|
||||
chat_history: list = []
|
||||
inferred_query_extractor: Callable[[InformationCollectionIteration], List[str]] = lambda x: []
|
||||
if tool == ConversationCommand.Notes:
|
||||
inferred_query_extractor = (
|
||||
lambda iteration: [c["query"] for c in iteration.context] if iteration.context else []
|
||||
)
|
||||
elif tool == ConversationCommand.Online:
|
||||
inferred_query_extractor = (
|
||||
lambda iteration: list(iteration.onlineContext.keys()) if iteration.onlineContext else []
|
||||
)
|
||||
elif tool == ConversationCommand.Code:
|
||||
inferred_query_extractor = lambda iteration: list(iteration.codeContext.keys()) if iteration.codeContext else []
|
||||
for iteration in previous_iterations:
|
||||
chat_history += [
|
||||
{
|
||||
"by": "you",
|
||||
"message": iteration.query,
|
||||
},
|
||||
{
|
||||
"by": "khoj",
|
||||
"intent": {
|
||||
"type": "remember",
|
||||
"inferred-queries": inferred_query_extractor(iteration),
|
||||
"query": iteration.query,
|
||||
},
|
||||
"message": iteration.summarizedResult,
|
||||
},
|
||||
]
|
||||
|
||||
return {"chat": chat_history}
|
||||
|
||||
|
||||
class ChatEvent(Enum):
|
||||
START_LLM_RESPONSE = "start_llm_response"
|
||||
END_LLM_RESPONSE = "end_llm_response"
|
||||
MESSAGE = "message"
|
||||
REFERENCES = "references"
|
||||
STATUS = "status"
|
||||
METADATA = "metadata"
|
||||
|
||||
|
||||
def message_to_log(
|
||||
user_message, chat_response, user_message_metadata={}, khoj_message_metadata={}, conversation_log=[]
|
||||
user_message,
|
||||
chat_response,
|
||||
user_message_metadata={},
|
||||
khoj_message_metadata={},
|
||||
conversation_log=[],
|
||||
train_of_thought=[],
|
||||
):
|
||||
"""Create json logs from messages, metadata for conversation log"""
|
||||
default_khoj_message_metadata = {
|
||||
@@ -104,28 +248,39 @@ def save_to_conversation_log(
|
||||
user_message_time: str = None,
|
||||
compiled_references: List[Dict[str, Any]] = [],
|
||||
online_results: Dict[str, Any] = {},
|
||||
code_results: Dict[str, Any] = {},
|
||||
inferred_queries: List[str] = [],
|
||||
intent_type: str = "remember",
|
||||
client_application: ClientApplication = None,
|
||||
conversation_id: str = None,
|
||||
automation_id: str = None,
|
||||
uploaded_image_url: str = None,
|
||||
query_images: List[str] = None,
|
||||
raw_query_files: List[FileAttachment] = [],
|
||||
train_of_thought: List[Any] = [],
|
||||
tracer: Dict[str, Any] = {},
|
||||
):
|
||||
user_message_time = user_message_time or datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||
turn_id = tracer.get("mid") or str(uuid.uuid4())
|
||||
updated_conversation = message_to_log(
|
||||
user_message=q,
|
||||
chat_response=chat_response,
|
||||
user_message_metadata={
|
||||
"created": user_message_time,
|
||||
"uploadedImageData": uploaded_image_url,
|
||||
"images": query_images,
|
||||
"turnId": turn_id,
|
||||
"queryFiles": [file.model_dump(mode="json") for file in raw_query_files],
|
||||
},
|
||||
khoj_message_metadata={
|
||||
"context": compiled_references,
|
||||
"intent": {"inferred-queries": inferred_queries, "type": intent_type},
|
||||
"onlineContext": online_results,
|
||||
"codeContext": code_results,
|
||||
"automationId": automation_id,
|
||||
"trainOfThought": train_of_thought,
|
||||
"turnId": turn_id,
|
||||
},
|
||||
conversation_log=meta_log.get("chat", []),
|
||||
train_of_thought=train_of_thought,
|
||||
)
|
||||
ConversationAdapters.save_conversation(
|
||||
user,
|
||||
@@ -135,6 +290,9 @@ def save_to_conversation_log(
|
||||
user_message=q,
|
||||
)
|
||||
|
||||
if in_debug_mode() or state.verbose > 1:
|
||||
merge_message_into_conversation_trace(q, chat_response, tracer)
|
||||
|
||||
logger.info(
|
||||
f"""
|
||||
Saved Conversation Turn
|
||||
@@ -145,13 +303,50 @@ Khoj: "{inferred_queries if ("text-to-image" in intent_type) else chat_response}
|
||||
)
|
||||
|
||||
|
||||
# Format user and system messages to chatml format
|
||||
def construct_structured_message(message, image_url, model_type, vision_enabled):
|
||||
if image_url and vision_enabled and model_type == ChatModelOptions.ModelType.OPENAI:
|
||||
return [{"type": "text", "text": message}, {"type": "image_url", "image_url": {"url": image_url}}]
|
||||
def construct_structured_message(
|
||||
message: str, images: list[str], model_type: str, vision_enabled: bool, attached_file_context: str
|
||||
):
|
||||
"""
|
||||
Format messages into appropriate multimedia format for supported chat model types
|
||||
"""
|
||||
if model_type in [
|
||||
ChatModelOptions.ModelType.OPENAI,
|
||||
ChatModelOptions.ModelType.GOOGLE,
|
||||
ChatModelOptions.ModelType.ANTHROPIC,
|
||||
]:
|
||||
constructed_messages: List[Any] = [
|
||||
{"type": "text", "text": message},
|
||||
]
|
||||
|
||||
if not is_none_or_empty(attached_file_context):
|
||||
constructed_messages.append({"type": "text", "text": attached_file_context})
|
||||
if vision_enabled and images:
|
||||
for image in images:
|
||||
constructed_messages.append({"type": "image_url", "image_url": {"url": image}})
|
||||
return constructed_messages
|
||||
|
||||
if not is_none_or_empty(attached_file_context):
|
||||
return f"{attached_file_context}\n\n{message}"
|
||||
|
||||
return message
|
||||
|
||||
|
||||
def gather_raw_query_files(
|
||||
query_files: Dict[str, str],
|
||||
):
|
||||
"""
|
||||
Gather contextual data from the given (raw) files
|
||||
"""
|
||||
|
||||
if len(query_files) == 0:
|
||||
return ""
|
||||
|
||||
contextual_data = " ".join(
|
||||
[f"File: {file_name}\n\n{file_content}\n\n" for file_name, file_content in query_files.items()]
|
||||
)
|
||||
return f"I have attached the following files:\n\n{contextual_data}"
|
||||
|
||||
|
||||
def generate_chatml_messages_with_context(
|
||||
user_message,
|
||||
system_message=None,
|
||||
@@ -160,11 +355,13 @@ def generate_chatml_messages_with_context(
|
||||
loaded_model: Optional[Llama] = None,
|
||||
max_prompt_size=None,
|
||||
tokenizer_name=None,
|
||||
uploaded_image_url=None,
|
||||
query_images=None,
|
||||
vision_enabled=False,
|
||||
model_type="",
|
||||
context_message="",
|
||||
query_files: str = None,
|
||||
):
|
||||
"""Generate messages for ChatGPT with context from previous conversation"""
|
||||
"""Generate chat messages with appropriate context from previous conversation to send to the chat model"""
|
||||
# Set max prompt size from user config or based on pre-configured for model and machine specs
|
||||
if not max_prompt_size:
|
||||
if loaded_model:
|
||||
@@ -178,32 +375,66 @@ def generate_chatml_messages_with_context(
|
||||
# Extract Chat History for Context
|
||||
chatml_messages: List[ChatMessage] = []
|
||||
for chat in conversation_log.get("chat", []):
|
||||
message_notes = f'\n\n Notes:\n{chat.get("context")}' if chat.get("context") else "\n"
|
||||
message_context = ""
|
||||
message_attached_files = ""
|
||||
|
||||
chat_message = chat.get("message")
|
||||
|
||||
if chat["by"] == "khoj" and "excalidraw" in chat["intent"].get("type", ""):
|
||||
chat_message = chat["intent"].get("inferred-queries")[0]
|
||||
if not is_none_or_empty(chat.get("context")):
|
||||
references = "\n\n".join(
|
||||
{
|
||||
f"# File: {item['file']}\n## {item['compiled']}\n"
|
||||
for item in chat.get("context") or []
|
||||
if isinstance(item, dict)
|
||||
}
|
||||
)
|
||||
message_context += f"{prompts.notes_conversation.format(references=references)}\n\n"
|
||||
|
||||
if chat.get("queryFiles"):
|
||||
raw_query_files = chat.get("queryFiles")
|
||||
query_files_dict = dict()
|
||||
for file in raw_query_files:
|
||||
query_files_dict[file["name"]] = file["content"]
|
||||
|
||||
message_attached_files = gather_raw_query_files(query_files_dict)
|
||||
chatml_messages.append(ChatMessage(content=message_attached_files, role="user"))
|
||||
|
||||
if not is_none_or_empty(chat.get("onlineContext")):
|
||||
message_context += f"{prompts.online_search_conversation.format(online_results=chat.get('onlineContext'))}"
|
||||
|
||||
if not is_none_or_empty(message_context):
|
||||
reconstructed_context_message = ChatMessage(content=message_context, role="user")
|
||||
chatml_messages.insert(0, reconstructed_context_message)
|
||||
|
||||
role = "user" if chat["by"] == "you" else "assistant"
|
||||
|
||||
message_content = chat["message"] + message_notes
|
||||
|
||||
message_content = construct_structured_message(
|
||||
message_content, chat.get("uploadedImageData"), model_type, vision_enabled
|
||||
chat_message, chat.get("images"), model_type, vision_enabled, attached_file_context=query_files
|
||||
)
|
||||
|
||||
reconstructed_message = ChatMessage(content=message_content, role=role)
|
||||
|
||||
chatml_messages.insert(0, reconstructed_message)
|
||||
|
||||
if len(chatml_messages) >= 2 * lookback_turns:
|
||||
if len(chatml_messages) >= 3 * lookback_turns:
|
||||
break
|
||||
|
||||
messages = []
|
||||
if not is_none_or_empty(user_message):
|
||||
messages.append(
|
||||
ChatMessage(
|
||||
content=construct_structured_message(user_message, uploaded_image_url, model_type, vision_enabled),
|
||||
content=construct_structured_message(
|
||||
user_message, query_images, model_type, vision_enabled, query_files
|
||||
),
|
||||
role="user",
|
||||
)
|
||||
)
|
||||
if not is_none_or_empty(context_message):
|
||||
messages.append(ChatMessage(content=context_message, role="user"))
|
||||
|
||||
if len(chatml_messages) > 0:
|
||||
messages += chatml_messages
|
||||
|
||||
if not is_none_or_empty(system_message):
|
||||
messages.append(ChatMessage(content=system_message, role="system"))
|
||||
|
||||
@@ -222,7 +453,6 @@ def truncate_messages(
|
||||
tokenizer_name=None,
|
||||
) -> list[ChatMessage]:
|
||||
"""Truncate messages to fit within max prompt size supported by model"""
|
||||
|
||||
default_tokenizer = "gpt-4o"
|
||||
|
||||
try:
|
||||
@@ -252,6 +482,7 @@ def truncate_messages(
|
||||
system_message = messages.pop(idx)
|
||||
break
|
||||
|
||||
# TODO: Handle truncation of multi-part message.content, i.e when message.content is a list[dict] rather than a string
|
||||
system_message_tokens = (
|
||||
len(encoder.encode(system_message.content)) if system_message and type(system_message.content) == str else 0
|
||||
)
|
||||
@@ -279,7 +510,7 @@ def truncate_messages(
|
||||
truncated_message = encoder.decode(encoder.encode(original_question)[:remaining_tokens]).strip()
|
||||
messages = [ChatMessage(content=truncated_message, role=messages[0].role)]
|
||||
logger.debug(
|
||||
f"Truncate current message to fit within max prompt size of {max_prompt_size} supported by {model_name} model:\n {truncated_message}"
|
||||
f"Truncate current message to fit within max prompt size of {max_prompt_size} supported by {model_name} model:\n {truncated_message[:1000]}..."
|
||||
)
|
||||
|
||||
if system_message:
|
||||
@@ -294,6 +525,214 @@ def reciprocal_conversation_to_chatml(message_pair):
|
||||
return [ChatMessage(content=message, role=role) for message, role in zip(message_pair, ["user", "assistant"])]
|
||||
|
||||
|
||||
def remove_json_codeblock(response: str):
|
||||
"""Remove any markdown json codeblock formatting if present. Useful for non schema enforceable models"""
|
||||
return response.removeprefix("```json").removesuffix("```")
|
||||
def clean_json(response: str):
|
||||
"""Remove any markdown json codeblock and newline formatting if present. Useful for non schema enforceable models"""
|
||||
return response.strip().replace("\n", "").removeprefix("```json").removesuffix("```")
|
||||
|
||||
|
||||
def clean_code_python(code: str):
|
||||
"""Remove any markdown codeblock and newline formatting if present. Useful for non schema enforceable models"""
|
||||
return code.strip().removeprefix("```python").removesuffix("```")
|
||||
|
||||
|
||||
def defilter_query(query: str):
|
||||
"""Remove any query filters in query"""
|
||||
defiltered_query = query
|
||||
filters: List[BaseFilter] = [WordFilter(), FileFilter(), DateFilter()]
|
||||
for filter in filters:
|
||||
defiltered_query = filter.defilter(defiltered_query)
|
||||
return defiltered_query
|
||||
|
||||
|
||||
@dataclass
|
||||
class ImageWithType:
|
||||
content: Any
|
||||
type: str
|
||||
|
||||
|
||||
def get_image_from_url(image_url: str, type="pil"):
|
||||
try:
|
||||
response = requests.get(image_url)
|
||||
response.raise_for_status() # Check if the request was successful
|
||||
|
||||
# Get content type from response or infer from URL
|
||||
content_type = response.headers.get("content-type") or mimetypes.guess_type(image_url)[0] or "image/webp"
|
||||
|
||||
# Convert image to desired format
|
||||
if type == "b64":
|
||||
image_data = base64.b64encode(response.content).decode("utf-8")
|
||||
elif type == "pil":
|
||||
image_data = PIL.Image.open(BytesIO(response.content))
|
||||
else:
|
||||
raise ValueError(f"Invalid image type: {type}")
|
||||
|
||||
return ImageWithType(content=image_data, type=content_type)
|
||||
except requests.exceptions.RequestException as e:
|
||||
logger.error(f"Failed to get image from URL {image_url}: {e}")
|
||||
return ImageWithType(content=None, type=None)
|
||||
|
||||
|
||||
def commit_conversation_trace(
|
||||
session: list[ChatMessage],
|
||||
response: str | list[dict],
|
||||
tracer: dict,
|
||||
system_message: str | list[dict] = "",
|
||||
repo_path: str = "/tmp/promptrace",
|
||||
) -> str:
|
||||
"""
|
||||
Save trace of conversation step using git. Useful to visualize, compare and debug traces.
|
||||
Returns the path to the repository.
|
||||
"""
|
||||
try:
|
||||
from git import Repo
|
||||
except ImportError:
|
||||
return None
|
||||
|
||||
# Serialize session, system message and response to yaml
|
||||
system_message_yaml = json.dumps(system_message, ensure_ascii=False, sort_keys=False)
|
||||
response_yaml = json.dumps(response, ensure_ascii=False, sort_keys=False)
|
||||
formatted_session = [{"role": message.role, "content": message.content} for message in session]
|
||||
session_yaml = json.dumps(formatted_session, ensure_ascii=False, sort_keys=False)
|
||||
query = (
|
||||
json.dumps(session[-1].content, ensure_ascii=False, sort_keys=False).strip().removeprefix("'").removesuffix("'")
|
||||
) # Extract serialized query from chat session
|
||||
|
||||
# Extract chat metadata for session
|
||||
uid, cid, mid = tracer.get("uid", "main"), tracer.get("cid", "main"), tracer.get("mid")
|
||||
|
||||
# Infer repository path from environment variable or provided path
|
||||
repo_path = os.getenv("PROMPTRACE_DIR", repo_path)
|
||||
|
||||
try:
|
||||
# Prepare git repository
|
||||
os.makedirs(repo_path, exist_ok=True)
|
||||
repo = Repo.init(repo_path)
|
||||
|
||||
# Remove post-commit hook if it exists
|
||||
hooks_dir = os.path.join(repo_path, ".git", "hooks")
|
||||
post_commit_hook = os.path.join(hooks_dir, "post-commit")
|
||||
if os.path.exists(post_commit_hook):
|
||||
os.remove(post_commit_hook)
|
||||
|
||||
# Configure git user if not set
|
||||
if not repo.config_reader().has_option("user", "email"):
|
||||
repo.config_writer().set_value("user", "name", "Prompt Tracer").release()
|
||||
repo.config_writer().set_value("user", "email", "promptracer@khoj.dev").release()
|
||||
|
||||
# Create an initial commit if the repository is newly created
|
||||
if not repo.head.is_valid():
|
||||
repo.index.commit("And then there was a trace")
|
||||
|
||||
# Check out the initial commit
|
||||
initial_commit = repo.commit("HEAD~0")
|
||||
repo.head.reference = initial_commit
|
||||
repo.head.reset(index=True, working_tree=True)
|
||||
|
||||
# Create or switch to user branch from initial commit
|
||||
user_branch = f"u_{uid}"
|
||||
if user_branch not in repo.branches:
|
||||
repo.create_head(user_branch)
|
||||
repo.heads[user_branch].checkout()
|
||||
|
||||
# Create or switch to conversation branch from user branch
|
||||
conv_branch = f"c_{cid}"
|
||||
if conv_branch not in repo.branches:
|
||||
repo.create_head(conv_branch)
|
||||
repo.heads[conv_branch].checkout()
|
||||
|
||||
# Create or switch to message branch from conversation branch
|
||||
msg_branch = f"m_{mid}" if mid else None
|
||||
if msg_branch and msg_branch not in repo.branches:
|
||||
repo.create_head(msg_branch)
|
||||
if msg_branch:
|
||||
repo.heads[msg_branch].checkout()
|
||||
|
||||
# Include file with content to commit
|
||||
files_to_commit = {"query": session_yaml, "response": response_yaml, "system_prompt": system_message_yaml}
|
||||
|
||||
# Write files and stage them
|
||||
for filename, content in files_to_commit.items():
|
||||
file_path = os.path.join(repo_path, filename)
|
||||
# Unescape special characters in content for better readability
|
||||
content = content.strip().replace("\\n", "\n").replace("\\t", "\t")
|
||||
with open(file_path, "w", encoding="utf-8") as f:
|
||||
f.write(content)
|
||||
repo.index.add([filename])
|
||||
|
||||
# Create commit
|
||||
metadata_yaml = yaml.dump(tracer, allow_unicode=True, sort_keys=False, default_flow_style=False)
|
||||
commit_message = f"""
|
||||
{query[:250]}
|
||||
|
||||
Response:
|
||||
---
|
||||
{response[:500]}...
|
||||
|
||||
Metadata
|
||||
---
|
||||
{metadata_yaml}
|
||||
""".strip()
|
||||
|
||||
repo.index.commit(commit_message)
|
||||
|
||||
logger.debug(f"Saved conversation trace to repo at {repo_path}")
|
||||
return repo_path
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to add conversation trace to repo: {str(e)}", exc_info=True)
|
||||
return None
|
||||
|
||||
|
||||
def merge_message_into_conversation_trace(query: str, response: str, tracer: dict, repo_path="/tmp/promptrace") -> bool:
|
||||
"""
|
||||
Merge the message branch into its parent conversation branch.
|
||||
|
||||
Args:
|
||||
query: User query
|
||||
response: Assistant response
|
||||
tracer: Dictionary containing uid, cid and mid
|
||||
repo_path: Path to the git repository
|
||||
|
||||
Returns:
|
||||
bool: True if merge was successful, False otherwise
|
||||
"""
|
||||
try:
|
||||
from git import Repo
|
||||
except ImportError:
|
||||
return False
|
||||
try:
|
||||
# Extract branch names
|
||||
msg_branch = f"m_{tracer['mid']}"
|
||||
conv_branch = f"c_{tracer['cid']}"
|
||||
|
||||
# Infer repository path from environment variable or provided path
|
||||
repo_path = os.getenv("PROMPTRACE_DIR", repo_path)
|
||||
repo = Repo(repo_path)
|
||||
|
||||
# Checkout conversation branch
|
||||
repo.heads[conv_branch].checkout()
|
||||
|
||||
# Create commit message
|
||||
metadata_yaml = yaml.dump(tracer, allow_unicode=True, sort_keys=False, default_flow_style=False)
|
||||
commit_message = f"""
|
||||
{query[:250]}
|
||||
|
||||
Response:
|
||||
---
|
||||
{response[:500]}...
|
||||
|
||||
Metadata
|
||||
---
|
||||
{metadata_yaml}
|
||||
""".strip()
|
||||
|
||||
# Merge message branch into conversation branch
|
||||
repo.git.merge(msg_branch, no_ff=True, m=commit_message)
|
||||
|
||||
# Delete message branch after merge
|
||||
repo.delete_head(msg_branch, force=True)
|
||||
|
||||
logger.debug(f"Successfully merged {msg_branch} into {conv_branch}")
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to merge message {msg_branch} into conversation {conv_branch}: {str(e)}", exc_info=True)
|
||||
return False
|
||||
|
||||
@@ -13,7 +13,7 @@ from tenacity import (
|
||||
)
|
||||
from torch import nn
|
||||
|
||||
from khoj.utils.helpers import get_device, merge_dicts, timer
|
||||
from khoj.utils.helpers import fix_json_dict, get_device, merge_dicts, timer
|
||||
from khoj.utils.rawconfig import SearchResponse
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -31,9 +31,9 @@ class EmbeddingsModel:
|
||||
):
|
||||
default_query_encode_kwargs = {"show_progress_bar": False, "normalize_embeddings": True}
|
||||
default_docs_encode_kwargs = {"show_progress_bar": True, "normalize_embeddings": True}
|
||||
self.query_encode_kwargs = merge_dicts(query_encode_kwargs, default_query_encode_kwargs)
|
||||
self.docs_encode_kwargs = merge_dicts(docs_encode_kwargs, default_docs_encode_kwargs)
|
||||
self.model_kwargs = merge_dicts(model_kwargs, {"device": get_device()})
|
||||
self.query_encode_kwargs = merge_dicts(fix_json_dict(query_encode_kwargs), default_query_encode_kwargs)
|
||||
self.docs_encode_kwargs = merge_dicts(fix_json_dict(docs_encode_kwargs), default_docs_encode_kwargs)
|
||||
self.model_kwargs = merge_dicts(fix_json_dict(model_kwargs), {"device": get_device()})
|
||||
self.model_name = model_name
|
||||
self.inference_endpoint = embeddings_inference_endpoint
|
||||
self.api_key = embeddings_inference_endpoint_api_key
|
||||
|
||||
@@ -26,8 +26,10 @@ async def text_to_image(
|
||||
references: List[Dict[str, Any]],
|
||||
online_results: Dict[str, Any],
|
||||
send_status_func: Optional[Callable] = None,
|
||||
uploaded_image_url: Optional[str] = None,
|
||||
query_images: Optional[List[str]] = None,
|
||||
agent: Agent = None,
|
||||
query_files: str = None,
|
||||
tracer: dict = {},
|
||||
):
|
||||
status_code = 200
|
||||
image = None
|
||||
@@ -65,9 +67,11 @@ async def text_to_image(
|
||||
note_references=references,
|
||||
online_results=online_results,
|
||||
model_type=text_to_image_config.model_type,
|
||||
uploaded_image_url=uploaded_image_url,
|
||||
query_images=query_images,
|
||||
user=user,
|
||||
agent=agent,
|
||||
query_files=query_files,
|
||||
tracer=tracer,
|
||||
)
|
||||
|
||||
if send_status_func:
|
||||
@@ -87,18 +91,18 @@ async def text_to_image(
|
||||
if "content_policy_violation" in e.message:
|
||||
logger.error(f"Image Generation blocked by OpenAI: {e}")
|
||||
status_code = e.status_code # type: ignore
|
||||
message = f"Image generation blocked by OpenAI: {e.message}" # type: ignore
|
||||
message = f"Image generation blocked by OpenAI due to policy violation" # type: ignore
|
||||
yield image_url or image, status_code, message, intent_type.value
|
||||
return
|
||||
else:
|
||||
logger.error(f"Image Generation failed with {e}", exc_info=True)
|
||||
message = f"Image generation failed with OpenAI error: {e.message}" # type: ignore
|
||||
message = f"Image generation failed using OpenAI" # type: ignore
|
||||
status_code = e.status_code # type: ignore
|
||||
yield image_url or image, status_code, message, intent_type.value
|
||||
return
|
||||
except requests.RequestException as e:
|
||||
logger.error(f"Image Generation failed with {e}", exc_info=True)
|
||||
message = f"Image generation using {text2image_model} via {text_to_image_config.model_type} failed with error: {e}"
|
||||
message = f"Image generation using {text2image_model} via {text_to_image_config.model_type} failed due to a network error."
|
||||
status_code = 502
|
||||
yield image_url or image, status_code, message, intent_type.value
|
||||
return
|
||||
@@ -204,9 +208,10 @@ def generate_image_with_replicate(
|
||||
|
||||
# Raise exception if the image generation task fails
|
||||
if status != "succeeded":
|
||||
error = get_prediction.get("error")
|
||||
if retry_count >= 10:
|
||||
raise requests.RequestException("Image generation timed out")
|
||||
raise requests.RequestException(f"Image generation failed with status: {status}")
|
||||
raise requests.RequestException(f"Image generation failed with status: {status}, message: {error}")
|
||||
|
||||
# Get the generated image
|
||||
image_url = get_prediction["output"][0] if isinstance(get_prediction["output"], list) else get_prediction["output"]
|
||||
|
||||
@@ -4,7 +4,7 @@ import logging
|
||||
import os
|
||||
import urllib.parse
|
||||
from collections import defaultdict
|
||||
from typing import Callable, Dict, List, Optional, Tuple, Union
|
||||
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union
|
||||
|
||||
import aiohttp
|
||||
from bs4 import BeautifulSoup
|
||||
@@ -52,7 +52,9 @@ OLOSTEP_QUERY_PARAMS = {
|
||||
"expandMarkdown": "True",
|
||||
"expandHtml": "False",
|
||||
}
|
||||
MAX_WEBPAGES_TO_READ = 1
|
||||
|
||||
DEFAULT_MAX_WEBPAGES_TO_READ = 1
|
||||
MAX_WEBPAGES_TO_INFER = 10
|
||||
|
||||
|
||||
async def search_online(
|
||||
@@ -62,8 +64,12 @@ async def search_online(
|
||||
user: KhojUser,
|
||||
send_status_func: Optional[Callable] = None,
|
||||
custom_filters: List[str] = [],
|
||||
uploaded_image_url: str = None,
|
||||
max_webpages_to_read: int = DEFAULT_MAX_WEBPAGES_TO_READ,
|
||||
query_images: List[str] = None,
|
||||
previous_subqueries: Set = set(),
|
||||
agent: Agent = None,
|
||||
query_files: str = None,
|
||||
tracer: dict = {},
|
||||
):
|
||||
query += " ".join(custom_filters)
|
||||
if not is_internet_connected():
|
||||
@@ -72,36 +78,52 @@ async def search_online(
|
||||
return
|
||||
|
||||
# Breakdown the query into subqueries to get the correct answer
|
||||
subqueries = await generate_online_subqueries(
|
||||
query, conversation_history, location, user, uploaded_image_url=uploaded_image_url, agent=agent
|
||||
new_subqueries = await generate_online_subqueries(
|
||||
query,
|
||||
conversation_history,
|
||||
location,
|
||||
user,
|
||||
query_images=query_images,
|
||||
agent=agent,
|
||||
tracer=tracer,
|
||||
query_files=query_files,
|
||||
)
|
||||
response_dict = {}
|
||||
subqueries = list(new_subqueries - previous_subqueries)
|
||||
response_dict: Dict[str, Dict[str, List[Dict] | Dict]] = {}
|
||||
|
||||
if subqueries:
|
||||
logger.info(f"🌐 Searching the Internet for {list(subqueries)}")
|
||||
if send_status_func:
|
||||
subqueries_str = "\n- " + "\n- ".join(list(subqueries))
|
||||
async for event in send_status_func(f"**Searching the Internet for**: {subqueries_str}"):
|
||||
yield {ChatEvent.STATUS: event}
|
||||
if is_none_or_empty(subqueries):
|
||||
logger.info("No new subqueries to search online")
|
||||
yield response_dict
|
||||
return
|
||||
|
||||
with timer(f"Internet searches for {list(subqueries)} took", logger):
|
||||
logger.info(f"🌐 Searching the Internet for {subqueries}")
|
||||
if send_status_func:
|
||||
subqueries_str = "\n- " + "\n- ".join(subqueries)
|
||||
async for event in send_status_func(f"**Searching the Internet for**: {subqueries_str}"):
|
||||
yield {ChatEvent.STATUS: event}
|
||||
|
||||
with timer(f"Internet searches for {subqueries} took", logger):
|
||||
search_func = search_with_google if SERPER_DEV_API_KEY else search_with_jina
|
||||
search_tasks = [search_func(subquery, location) for subquery in subqueries]
|
||||
search_results = await asyncio.gather(*search_tasks)
|
||||
response_dict = {subquery: search_result for subquery, search_result in search_results}
|
||||
|
||||
# Gather distinct web pages from organic results for subqueries without an instant answer.
|
||||
# Content of web pages is directly available when Jina is used for search.
|
||||
webpages: Dict[str, Dict] = {}
|
||||
for subquery in response_dict:
|
||||
if "answerBox" in response_dict[subquery]:
|
||||
continue
|
||||
for organic in response_dict[subquery].get("organic", [])[:MAX_WEBPAGES_TO_READ]:
|
||||
for idx, organic in enumerate(response_dict[subquery].get("organic", [])):
|
||||
link = organic.get("link")
|
||||
if link in webpages:
|
||||
if link in webpages and idx < max_webpages_to_read:
|
||||
webpages[link]["queries"].add(subquery)
|
||||
else:
|
||||
# Content of web pages is directly available when Jina is used for search.
|
||||
elif idx < max_webpages_to_read:
|
||||
webpages[link] = {"queries": {subquery}, "content": organic.get("content")}
|
||||
# Only keep webpage content for up to max_webpages_to_read organic results.
|
||||
if idx >= max_webpages_to_read and not is_none_or_empty(organic.get("content")):
|
||||
organic["content"] = None
|
||||
response_dict[subquery]["organic"][idx] = organic
|
||||
|
||||
# Read, extract relevant info from the retrieved web pages
|
||||
if webpages:
|
||||
@@ -111,7 +133,9 @@ async def search_online(
|
||||
async for event in send_status_func(f"**Reading web pages**: {webpage_links_str}"):
|
||||
yield {ChatEvent.STATUS: event}
|
||||
tasks = [
|
||||
read_webpage_and_extract_content(data["queries"], link, data["content"], user=user, agent=agent)
|
||||
read_webpage_and_extract_content(
|
||||
data["queries"], link, data.get("content"), user=user, agent=agent, tracer=tracer
|
||||
)
|
||||
for link, data in webpages.items()
|
||||
]
|
||||
results = await asyncio.gather(*tasks)
|
||||
@@ -151,22 +175,34 @@ async def read_webpages(
|
||||
location: LocationData,
|
||||
user: KhojUser,
|
||||
send_status_func: Optional[Callable] = None,
|
||||
uploaded_image_url: str = None,
|
||||
query_images: List[str] = None,
|
||||
agent: Agent = None,
|
||||
max_webpages_to_read: int = DEFAULT_MAX_WEBPAGES_TO_READ,
|
||||
query_files: str = None,
|
||||
tracer: dict = {},
|
||||
):
|
||||
"Infer web pages to read from the query and extract relevant information from them"
|
||||
logger.info(f"Inferring web pages to read")
|
||||
if send_status_func:
|
||||
async for event in send_status_func(f"**Inferring web pages to read**"):
|
||||
yield {ChatEvent.STATUS: event}
|
||||
urls = await infer_webpage_urls(query, conversation_history, location, user, uploaded_image_url)
|
||||
urls = await infer_webpage_urls(
|
||||
query,
|
||||
conversation_history,
|
||||
location,
|
||||
user,
|
||||
query_images,
|
||||
agent=agent,
|
||||
query_files=query_files,
|
||||
tracer=tracer,
|
||||
)
|
||||
|
||||
# Get the top 10 web pages to read
|
||||
urls = urls[:max_webpages_to_read]
|
||||
|
||||
logger.info(f"Reading web pages at: {urls}")
|
||||
if send_status_func:
|
||||
webpage_links_str = "\n- " + "\n- ".join(list(urls))
|
||||
async for event in send_status_func(f"**Reading web pages**: {webpage_links_str}"):
|
||||
yield {ChatEvent.STATUS: event}
|
||||
tasks = [read_webpage_and_extract_content({query}, url, user=user, agent=agent) for url in urls]
|
||||
tasks = [read_webpage_and_extract_content({query}, url, user=user, agent=agent, tracer=tracer) for url in urls]
|
||||
results = await asyncio.gather(*tasks)
|
||||
|
||||
response: Dict[str, Dict] = defaultdict(dict)
|
||||
@@ -192,7 +228,12 @@ async def read_webpage(
|
||||
|
||||
|
||||
async def read_webpage_and_extract_content(
|
||||
subqueries: set[str], url: str, content: str = None, user: KhojUser = None, agent: Agent = None
|
||||
subqueries: set[str],
|
||||
url: str,
|
||||
content: str = None,
|
||||
user: KhojUser = None,
|
||||
agent: Agent = None,
|
||||
tracer: dict = {},
|
||||
) -> Tuple[set[str], str, Union[None, str]]:
|
||||
# Select the web scrapers to use for reading the web page
|
||||
web_scrapers = await ConversationAdapters.aget_enabled_webscrapers()
|
||||
@@ -214,7 +255,9 @@ async def read_webpage_and_extract_content(
|
||||
# Extract relevant information from the web page
|
||||
if is_none_or_empty(extracted_info):
|
||||
with timer(f"Extracting relevant information from web page at '{url}' took", logger):
|
||||
extracted_info = await extract_relevant_info(subqueries, content, user=user, agent=agent)
|
||||
extracted_info = await extract_relevant_info(
|
||||
subqueries, content, user=user, agent=agent, tracer=tracer
|
||||
)
|
||||
|
||||
# If we successfully extracted information, break the loop
|
||||
if not is_none_or_empty(extracted_info):
|
||||
@@ -340,3 +383,25 @@ async def search_with_jina(query: str, location: LocationData) -> Tuple[str, Dic
|
||||
for item in response_json["data"]
|
||||
]
|
||||
return query, {"organic": parsed_response}
|
||||
|
||||
|
||||
def deduplicate_organic_results(online_results: dict) -> dict:
|
||||
"""Deduplicate organic search results based on links across all queries."""
|
||||
# Keep track of seen links to filter out duplicates across queries
|
||||
seen_links = set()
|
||||
deduplicated_results = {}
|
||||
|
||||
# Process each query's results
|
||||
for query, results in online_results.items():
|
||||
# Filter organic results keeping only first occurrence of each link
|
||||
filtered_organic = []
|
||||
for result in results.get("organic", []):
|
||||
link = result.get("link")
|
||||
if link and link not in seen_links:
|
||||
seen_links.add(link)
|
||||
filtered_organic.append(result)
|
||||
|
||||
# Update results with deduplicated organic entries
|
||||
deduplicated_results[query] = {**results, "organic": filtered_organic}
|
||||
|
||||
return deduplicated_results
|
||||
|
||||
178
src/khoj/processor/tools/run_code.py
Normal file
178
src/khoj/processor/tools/run_code.py
Normal file
@@ -0,0 +1,178 @@
|
||||
import base64
|
||||
import datetime
|
||||
import json
|
||||
import logging
|
||||
import mimetypes
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, List, NamedTuple, Optional
|
||||
|
||||
import aiohttp
|
||||
|
||||
from khoj.database.adapters import FileObjectAdapters
|
||||
from khoj.database.models import Agent, FileObject, KhojUser
|
||||
from khoj.processor.conversation import prompts
|
||||
from khoj.processor.conversation.utils import (
|
||||
ChatEvent,
|
||||
clean_code_python,
|
||||
clean_json,
|
||||
construct_chat_history,
|
||||
)
|
||||
from khoj.routers.helpers import send_message_to_model_wrapper
|
||||
from khoj.utils.helpers import is_none_or_empty, timer
|
||||
from khoj.utils.rawconfig import LocationData
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
SANDBOX_URL = os.getenv("KHOJ_TERRARIUM_URL", "http://localhost:8080")
|
||||
|
||||
|
||||
class GeneratedCode(NamedTuple):
|
||||
code: str
|
||||
input_files: List[str]
|
||||
input_links: List[str]
|
||||
|
||||
|
||||
async def run_code(
|
||||
query: str,
|
||||
conversation_history: dict,
|
||||
context: str,
|
||||
location_data: LocationData,
|
||||
user: KhojUser,
|
||||
send_status_func: Optional[Callable] = None,
|
||||
query_images: List[str] = None,
|
||||
agent: Agent = None,
|
||||
sandbox_url: str = SANDBOX_URL,
|
||||
query_files: str = None,
|
||||
tracer: dict = {},
|
||||
):
|
||||
# Generate Code
|
||||
if send_status_func:
|
||||
async for event in send_status_func(f"**Generate code snippet** for {query}"):
|
||||
yield {ChatEvent.STATUS: event}
|
||||
try:
|
||||
with timer("Chat actor: Generate programs to execute", logger):
|
||||
generated_code = await generate_python_code(
|
||||
query,
|
||||
conversation_history,
|
||||
context,
|
||||
location_data,
|
||||
user,
|
||||
query_images,
|
||||
agent,
|
||||
tracer,
|
||||
query_files,
|
||||
)
|
||||
except Exception as e:
|
||||
raise ValueError(f"Failed to generate code for {query} with error: {e}")
|
||||
|
||||
# Prepare Input Data
|
||||
input_data = []
|
||||
user_input_files: List[FileObject] = []
|
||||
for input_file in generated_code.input_files:
|
||||
user_input_files += await FileObjectAdapters.aget_file_objects_by_name(user, input_file)
|
||||
for f in user_input_files:
|
||||
input_data.append(
|
||||
{
|
||||
"filename": os.path.basename(f.file_name),
|
||||
"b64_data": base64.b64encode(f.raw_text.encode("utf-8")).decode("utf-8"),
|
||||
}
|
||||
)
|
||||
|
||||
# Run Code
|
||||
if send_status_func:
|
||||
async for event in send_status_func(f"**Running code snippet**"):
|
||||
yield {ChatEvent.STATUS: event}
|
||||
try:
|
||||
with timer("Chat actor: Execute generated program", logger, log_level=logging.INFO):
|
||||
result = await execute_sandboxed_python(generated_code.code, input_data, sandbox_url)
|
||||
code = result.pop("code")
|
||||
logger.info(f"Executed Code:\n--@@--\n{code}\n--@@--Result:\n--@@--\n{result}\n--@@--")
|
||||
yield {query: {"code": code, "results": result}}
|
||||
except Exception as e:
|
||||
raise ValueError(f"Failed to run code for {query} with error: {e}")
|
||||
|
||||
|
||||
async def generate_python_code(
|
||||
q: str,
|
||||
conversation_history: dict,
|
||||
context: str,
|
||||
location_data: LocationData,
|
||||
user: KhojUser,
|
||||
query_images: list[str] = None,
|
||||
agent: Agent = None,
|
||||
tracer: dict = {},
|
||||
query_files: str = None,
|
||||
) -> GeneratedCode:
|
||||
location = f"{location_data}" if location_data else "Unknown"
|
||||
username = prompts.user_name.format(name=user.get_full_name()) if user.get_full_name() else ""
|
||||
chat_history = construct_chat_history(conversation_history)
|
||||
|
||||
utc_date = datetime.datetime.now(datetime.timezone.utc).strftime("%Y-%m-%d")
|
||||
personality_context = (
|
||||
prompts.personality_context.format(personality=agent.personality) if agent and agent.personality else ""
|
||||
)
|
||||
|
||||
code_generation_prompt = prompts.python_code_generation_prompt.format(
|
||||
current_date=utc_date,
|
||||
query=q,
|
||||
chat_history=chat_history,
|
||||
context=context,
|
||||
location=location,
|
||||
username=username,
|
||||
personality_context=personality_context,
|
||||
)
|
||||
|
||||
response = await send_message_to_model_wrapper(
|
||||
code_generation_prompt,
|
||||
query_images=query_images,
|
||||
response_type="json_object",
|
||||
user=user,
|
||||
tracer=tracer,
|
||||
query_files=query_files,
|
||||
)
|
||||
|
||||
# Validate that the response is a non-empty, JSON-serializable list
|
||||
response = clean_json(response)
|
||||
response = json.loads(response)
|
||||
code = response.get("code", "").strip()
|
||||
input_files = response.get("input_files", [])
|
||||
input_links = response.get("input_links", [])
|
||||
|
||||
if not isinstance(code, str) or is_none_or_empty(code):
|
||||
raise ValueError
|
||||
return GeneratedCode(code, input_files, input_links)
|
||||
|
||||
|
||||
async def execute_sandboxed_python(code: str, input_data: list[dict], sandbox_url: str = SANDBOX_URL) -> dict[str, Any]:
|
||||
"""
|
||||
Takes code to run as a string and calls the terrarium API to execute it.
|
||||
Returns the result of the code execution as a dictionary.
|
||||
|
||||
Reference data i/o format based on Terrarium example client code at:
|
||||
https://github.com/cohere-ai/cohere-terrarium/blob/main/example-clients/python/terrarium_client.py
|
||||
"""
|
||||
headers = {"Content-Type": "application/json"}
|
||||
cleaned_code = clean_code_python(code)
|
||||
data = {"code": cleaned_code, "files": input_data}
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(sandbox_url, json=data, headers=headers) as response:
|
||||
if response.status == 200:
|
||||
result: dict[str, Any] = await response.json()
|
||||
result["code"] = cleaned_code
|
||||
# Store decoded output files
|
||||
for output_file in result.get("output_files", []):
|
||||
# Decode text files as UTF-8
|
||||
if mimetypes.guess_type(output_file["filename"])[0].startswith("text/") or Path(
|
||||
output_file["filename"]
|
||||
).suffix in [".org", ".md", ".json"]:
|
||||
output_file["b64_data"] = base64.b64decode(output_file["b64_data"]).decode("utf-8")
|
||||
return result
|
||||
else:
|
||||
return {
|
||||
"code": cleaned_code,
|
||||
"success": False,
|
||||
"std_err": f"Failed to execute code with {response.status}",
|
||||
}
|
||||
@@ -6,7 +6,7 @@ import os
|
||||
import threading
|
||||
import time
|
||||
import uuid
|
||||
from typing import Any, Callable, List, Optional, Union
|
||||
from typing import Any, Callable, List, Optional, Set, Union
|
||||
|
||||
import cron_descriptor
|
||||
import pytz
|
||||
@@ -21,11 +21,12 @@ from starlette.authentication import has_required_scope, requires
|
||||
from khoj.configure import initialize_content
|
||||
from khoj.database import adapters
|
||||
from khoj.database.adapters import (
|
||||
AgentAdapters,
|
||||
AutomationAdapters,
|
||||
ConversationAdapters,
|
||||
EntryAdapters,
|
||||
get_default_search_model,
|
||||
get_user_photo,
|
||||
get_user_search_model_or_default,
|
||||
)
|
||||
from khoj.database.models import (
|
||||
Agent,
|
||||
@@ -42,6 +43,7 @@ from khoj.processor.conversation.offline.chat_model import extract_questions_off
|
||||
from khoj.processor.conversation.offline.whisper import transcribe_audio_offline
|
||||
from khoj.processor.conversation.openai.gpt import extract_questions
|
||||
from khoj.processor.conversation.openai.whisper import transcribe_audio
|
||||
from khoj.processor.conversation.utils import defilter_query
|
||||
from khoj.routers.helpers import (
|
||||
ApiUserRateLimiter,
|
||||
ChatEvent,
|
||||
@@ -114,10 +116,16 @@ async def execute_search(
|
||||
dedupe: Optional[bool] = True,
|
||||
agent: Optional[Agent] = None,
|
||||
):
|
||||
start_time = time.time()
|
||||
|
||||
# Run validation checks
|
||||
results: List[SearchResponse] = []
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
# Ensure the agent, if present, is accessible by the user
|
||||
if user and agent and not await AgentAdapters.ais_agent_accessible(agent, user):
|
||||
logger.error(f"Agent {agent.slug} is not accessible by user {user}")
|
||||
return results
|
||||
|
||||
if q is None or q == "":
|
||||
logger.warning(f"No query param (q) passed in API call to initiate search")
|
||||
return results
|
||||
@@ -142,7 +150,7 @@ async def execute_search(
|
||||
encoded_asymmetric_query = None
|
||||
if t != SearchType.Image:
|
||||
with timer("Encoding query took", logger=logger):
|
||||
search_model = await sync_to_async(get_user_search_model_or_default)(user)
|
||||
search_model = await sync_to_async(get_default_search_model)()
|
||||
encoded_asymmetric_query = state.embeddings_model[search_model.name].embed_query(defiltered_query)
|
||||
|
||||
with concurrent.futures.ThreadPoolExecutor() as executor:
|
||||
@@ -159,8 +167,8 @@ async def execute_search(
|
||||
search_futures += [
|
||||
executor.submit(
|
||||
text_search.query,
|
||||
user,
|
||||
user_query,
|
||||
user,
|
||||
t,
|
||||
question_embedding=encoded_asymmetric_query,
|
||||
max_distance=max_distance,
|
||||
@@ -204,7 +212,7 @@ def update(
|
||||
logger.warning(error_msg)
|
||||
raise HTTPException(status_code=500, detail=error_msg)
|
||||
try:
|
||||
initialize_content(regenerate=force, search_type=t, user=user)
|
||||
initialize_content(user=user, regenerate=force, search_type=t)
|
||||
except Exception as e:
|
||||
error_msg = f"🚨 Failed to update server via API: {e}"
|
||||
logger.error(error_msg, exc_info=True)
|
||||
@@ -340,13 +348,16 @@ async def extract_references_and_questions(
|
||||
conversation_commands: List[ConversationCommand] = [ConversationCommand.Default],
|
||||
location_data: LocationData = None,
|
||||
send_status_func: Optional[Callable] = None,
|
||||
uploaded_image_url: Optional[str] = None,
|
||||
query_images: Optional[List[str]] = None,
|
||||
previous_inferred_queries: Set = set(),
|
||||
agent: Agent = None,
|
||||
query_files: str = None,
|
||||
tracer: dict = {},
|
||||
):
|
||||
user = request.user.object if request.user.is_authenticated else None
|
||||
|
||||
# Initialize Variables
|
||||
compiled_references: List[Any] = []
|
||||
compiled_references: List[dict[str, str]] = []
|
||||
inferred_queries: List[str] = []
|
||||
|
||||
agent_has_entries = False
|
||||
@@ -375,9 +386,7 @@ async def extract_references_and_questions(
|
||||
return
|
||||
|
||||
# Extract filter terms from user message
|
||||
defiltered_query = q
|
||||
for filter in [DateFilter(), WordFilter(), FileFilter()]:
|
||||
defiltered_query = filter.defilter(defiltered_query)
|
||||
defiltered_query = defilter_query(q)
|
||||
filters_in_query = q.replace(defiltered_query, "").strip()
|
||||
conversation = await sync_to_async(ConversationAdapters.get_conversation_by_id)(conversation_id)
|
||||
|
||||
@@ -417,6 +426,8 @@ async def extract_references_and_questions(
|
||||
user=user,
|
||||
max_prompt_size=conversation_config.max_prompt_size,
|
||||
personality_context=personality_context,
|
||||
query_files=query_files,
|
||||
tracer=tracer,
|
||||
)
|
||||
elif conversation_config.model_type == ChatModelOptions.ModelType.OPENAI:
|
||||
openai_chat_config = conversation_config.openai_config
|
||||
@@ -431,37 +442,48 @@ async def extract_references_and_questions(
|
||||
conversation_log=meta_log,
|
||||
location_data=location_data,
|
||||
user=user,
|
||||
uploaded_image_url=uploaded_image_url,
|
||||
query_images=query_images,
|
||||
vision_enabled=vision_enabled,
|
||||
personality_context=personality_context,
|
||||
query_files=query_files,
|
||||
tracer=tracer,
|
||||
)
|
||||
elif conversation_config.model_type == ChatModelOptions.ModelType.ANTHROPIC:
|
||||
api_key = conversation_config.openai_config.api_key
|
||||
chat_model = conversation_config.chat_model
|
||||
inferred_queries = extract_questions_anthropic(
|
||||
defiltered_query,
|
||||
query_images=query_images,
|
||||
model=chat_model,
|
||||
api_key=api_key,
|
||||
conversation_log=meta_log,
|
||||
location_data=location_data,
|
||||
user=user,
|
||||
vision_enabled=vision_enabled,
|
||||
personality_context=personality_context,
|
||||
query_files=query_files,
|
||||
tracer=tracer,
|
||||
)
|
||||
elif conversation_config.model_type == ChatModelOptions.ModelType.GOOGLE:
|
||||
api_key = conversation_config.openai_config.api_key
|
||||
chat_model = conversation_config.chat_model
|
||||
inferred_queries = extract_questions_gemini(
|
||||
defiltered_query,
|
||||
query_images=query_images,
|
||||
model=chat_model,
|
||||
api_key=api_key,
|
||||
conversation_log=meta_log,
|
||||
location_data=location_data,
|
||||
max_tokens=conversation_config.max_prompt_size,
|
||||
user=user,
|
||||
vision_enabled=vision_enabled,
|
||||
personality_context=personality_context,
|
||||
query_files=query_files,
|
||||
tracer=tracer,
|
||||
)
|
||||
|
||||
# Collate search results as context for GPT
|
||||
inferred_queries = list(set(inferred_queries) - previous_inferred_queries)
|
||||
with timer("Searching knowledge base took", logger):
|
||||
search_results = []
|
||||
logger.info(f"🔍 Searching knowledge base with queries: {inferred_queries}")
|
||||
@@ -485,7 +507,8 @@ async def extract_references_and_questions(
|
||||
)
|
||||
search_results = text_search.deduplicated_search_responses(search_results)
|
||||
compiled_references = [
|
||||
{"compiled": item.additional["compiled"], "file": item.additional["file"]} for item in search_results
|
||||
{"query": q, "compiled": item.additional["compiled"], "file": item.additional["file"]}
|
||||
for q, item in zip(inferred_queries, search_results)
|
||||
]
|
||||
|
||||
yield compiled_references, inferred_queries, defiltered_query
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
import json
|
||||
import logging
|
||||
import random
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from asgiref.sync import sync_to_async
|
||||
@@ -9,8 +11,8 @@ from fastapi.responses import Response
|
||||
from pydantic import BaseModel
|
||||
from starlette.authentication import requires
|
||||
|
||||
from khoj.database.adapters import AgentAdapters
|
||||
from khoj.database.models import Agent, KhojUser
|
||||
from khoj.database.adapters import AgentAdapters, ConversationAdapters
|
||||
from khoj.database.models import Agent, Conversation, KhojUser
|
||||
from khoj.routers.helpers import CommonQueryParams, acheck_if_safe_prompt
|
||||
from khoj.utils.helpers import (
|
||||
ConversationCommand,
|
||||
@@ -45,30 +47,49 @@ async def all_agents(
|
||||
) -> Response:
|
||||
user: KhojUser = request.user.object if request.user.is_authenticated else None
|
||||
agents = await AgentAdapters.aget_all_accessible_agents(user)
|
||||
default_agent = await AgentAdapters.aget_default_agent()
|
||||
default_agent_packet = None
|
||||
agents_packet = list()
|
||||
for agent in agents:
|
||||
files = agent.fileobject_set.all()
|
||||
file_names = [file.file_name for file in files]
|
||||
agents_packet.append(
|
||||
{
|
||||
"slug": agent.slug,
|
||||
"name": agent.name,
|
||||
"persona": agent.personality,
|
||||
"creator": agent.creator.username if agent.creator else None,
|
||||
"managed_by_admin": agent.managed_by_admin,
|
||||
"color": agent.style_color,
|
||||
"icon": agent.style_icon,
|
||||
"privacy_level": agent.privacy_level,
|
||||
"chat_model": agent.chat_model.chat_model,
|
||||
"files": file_names,
|
||||
"input_tools": agent.input_tools,
|
||||
"output_modes": agent.output_modes,
|
||||
}
|
||||
)
|
||||
agent_packet = {
|
||||
"slug": agent.slug,
|
||||
"name": agent.name,
|
||||
"persona": agent.personality,
|
||||
"creator": agent.creator.username if agent.creator else None,
|
||||
"managed_by_admin": agent.managed_by_admin,
|
||||
"color": agent.style_color,
|
||||
"icon": agent.style_icon,
|
||||
"privacy_level": agent.privacy_level,
|
||||
"chat_model": agent.chat_model.chat_model,
|
||||
"files": file_names,
|
||||
"input_tools": agent.input_tools,
|
||||
"output_modes": agent.output_modes,
|
||||
}
|
||||
if agent.slug == default_agent.slug:
|
||||
default_agent_packet = agent_packet
|
||||
else:
|
||||
agents_packet.append(agent_packet)
|
||||
|
||||
# Load recent conversation sessions
|
||||
min_date = datetime.min.replace(tzinfo=timezone.utc)
|
||||
two_weeks_ago = datetime.today() - timedelta(weeks=2)
|
||||
conversations = []
|
||||
if user:
|
||||
conversations = await sync_to_async(list[Conversation])(
|
||||
ConversationAdapters.get_conversation_sessions(user, request.user.client_app)
|
||||
.filter(updated_at__gte=two_weeks_ago)
|
||||
.order_by("-updated_at")[:50]
|
||||
)
|
||||
conversation_times = {conv.agent.slug: conv.updated_at for conv in conversations if conv.agent}
|
||||
|
||||
# Put default agent first, then sort by mru and finally shuffle unused randomly
|
||||
random.shuffle(agents_packet)
|
||||
agents_packet.sort(key=lambda x: conversation_times.get(x["slug"]) or min_date, reverse=True)
|
||||
if default_agent_packet:
|
||||
agents_packet.insert(0, default_agent_packet)
|
||||
|
||||
# Make sure that the agent named 'khoj' is first in the list. Everything else is sorted by name.
|
||||
agents_packet.sort(key=lambda x: x["name"])
|
||||
agents_packet.sort(key=lambda x: x["slug"] == "khoj", reverse=True)
|
||||
return Response(content=json.dumps(agents_packet), media_type="application/json", status_code=200)
|
||||
|
||||
|
||||
@@ -162,7 +183,7 @@ async def delete_agent(
|
||||
|
||||
|
||||
@api_agents.post("", response_class=Response)
|
||||
@requires(["authenticated", "premium"])
|
||||
@requires(["authenticated"])
|
||||
async def create_agent(
|
||||
request: Request,
|
||||
common: CommonQueryParams,
|
||||
@@ -170,10 +191,9 @@ async def create_agent(
|
||||
) -> Response:
|
||||
user: KhojUser = request.user.object
|
||||
|
||||
is_safe_prompt, reason = True, ""
|
||||
|
||||
if body.privacy_level != Agent.PrivacyLevel.PRIVATE:
|
||||
is_safe_prompt, reason = await acheck_if_safe_prompt(body.persona)
|
||||
is_safe_prompt, reason = await acheck_if_safe_prompt(
|
||||
body.persona, user, lax=body.privacy_level == Agent.PrivacyLevel.PRIVATE
|
||||
)
|
||||
|
||||
if not is_safe_prompt:
|
||||
return Response(
|
||||
@@ -215,7 +235,7 @@ async def create_agent(
|
||||
|
||||
|
||||
@api_agents.patch("", response_class=Response)
|
||||
@requires(["authenticated", "premium"])
|
||||
@requires(["authenticated"])
|
||||
async def update_agent(
|
||||
request: Request,
|
||||
common: CommonQueryParams,
|
||||
@@ -223,10 +243,9 @@ async def update_agent(
|
||||
) -> Response:
|
||||
user: KhojUser = request.user.object
|
||||
|
||||
is_safe_prompt, reason = True, ""
|
||||
|
||||
if body.privacy_level != Agent.PrivacyLevel.PRIVATE:
|
||||
is_safe_prompt, reason = await acheck_if_safe_prompt(body.persona)
|
||||
is_safe_prompt, reason = await acheck_if_safe_prompt(
|
||||
body.persona, user, lax=body.privacy_level == Agent.PrivacyLevel.PRIVATE
|
||||
)
|
||||
|
||||
if not is_safe_prompt:
|
||||
return Response(
|
||||
|
||||
@@ -3,9 +3,10 @@ import base64
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from functools import partial
|
||||
from typing import Dict, Optional
|
||||
from typing import Any, Dict, List, Optional
|
||||
from urllib.parse import unquote
|
||||
|
||||
from asgiref.sync import sync_to_async
|
||||
@@ -18,28 +19,40 @@ from khoj.database.adapters import (
|
||||
AgentAdapters,
|
||||
ConversationAdapters,
|
||||
EntryAdapters,
|
||||
FileObjectAdapters,
|
||||
PublicConversationAdapters,
|
||||
aget_user_name,
|
||||
)
|
||||
from khoj.database.models import Agent, KhojUser
|
||||
from khoj.processor.conversation.prompts import help_message, no_entries_found
|
||||
from khoj.processor.conversation.utils import save_to_conversation_log
|
||||
from khoj.processor.conversation.utils import defilter_query, save_to_conversation_log
|
||||
from khoj.processor.image.generate import text_to_image
|
||||
from khoj.processor.speech.text_to_speech import generate_text_to_speech
|
||||
from khoj.processor.tools.online_search import read_webpages, search_online
|
||||
from khoj.processor.tools.online_search import (
|
||||
deduplicate_organic_results,
|
||||
read_webpages,
|
||||
search_online,
|
||||
)
|
||||
from khoj.processor.tools.run_code import run_code
|
||||
from khoj.routers.api import extract_references_and_questions
|
||||
from khoj.routers.email import send_query_feedback
|
||||
from khoj.routers.helpers import (
|
||||
ApiImageRateLimiter,
|
||||
ApiUserRateLimiter,
|
||||
ChatEvent,
|
||||
ChatRequestBody,
|
||||
CommonQueryParams,
|
||||
ConversationCommandRateLimiter,
|
||||
DeleteMessageRequestBody,
|
||||
FeedbackData,
|
||||
acreate_title_from_history,
|
||||
agenerate_chat_response,
|
||||
aget_relevant_information_sources,
|
||||
aget_relevant_output_modes,
|
||||
construct_automation_created_message,
|
||||
create_automation,
|
||||
extract_relevant_summary,
|
||||
gather_raw_query_files,
|
||||
generate_excalidraw_diagram,
|
||||
generate_summary_from_files,
|
||||
get_conversation_command,
|
||||
is_query_empty,
|
||||
is_ready_to_chat,
|
||||
@@ -47,6 +60,10 @@ from khoj.routers.helpers import (
|
||||
update_telemetry_state,
|
||||
validate_conversation_config,
|
||||
)
|
||||
from khoj.routers.research import (
|
||||
InformationCollectionIteration,
|
||||
execute_information_collection,
|
||||
)
|
||||
from khoj.routers.storage import upload_image_to_bucket
|
||||
from khoj.utils import state
|
||||
from khoj.utils.helpers import (
|
||||
@@ -59,21 +76,22 @@ from khoj.utils.helpers import (
|
||||
get_device,
|
||||
is_none_or_empty,
|
||||
)
|
||||
from khoj.utils.rawconfig import FileFilterRequest, FilesFilterRequest, LocationData
|
||||
from khoj.utils.rawconfig import (
|
||||
ChatRequestBody,
|
||||
FileFilterRequest,
|
||||
FilesFilterRequest,
|
||||
LocationData,
|
||||
)
|
||||
|
||||
# Initialize Router
|
||||
logger = logging.getLogger(__name__)
|
||||
conversation_command_rate_limiter = ConversationCommandRateLimiter(
|
||||
trial_rate_limit=100, subscribed_rate_limit=6000, slug="command"
|
||||
trial_rate_limit=20, subscribed_rate_limit=75, slug="command"
|
||||
)
|
||||
|
||||
|
||||
api_chat = APIRouter()
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from khoj.routers.email import send_query_feedback
|
||||
|
||||
|
||||
@api_chat.get("/conversation/file-filters/{conversation_id}", response_class=Response)
|
||||
@requires(["authenticated"])
|
||||
@@ -109,7 +127,7 @@ def add_files_filter(request: Request, filter: FilesFilterRequest):
|
||||
file_filters = ConversationAdapters.add_files_to_filter(request.user.object, conversation_id, files_filter)
|
||||
return Response(content=json.dumps(file_filters), media_type="application/json", status_code=200)
|
||||
except Exception as e:
|
||||
logger.error(f"Error adding file filter {filter.filename}: {e}", exc_info=True)
|
||||
logger.error(f"Error adding file filter {filter.filenames}: {e}", exc_info=True)
|
||||
raise HTTPException(status_code=422, detail=str(e))
|
||||
|
||||
|
||||
@@ -135,12 +153,6 @@ def remove_file_filter(request: Request, filter: FileFilterRequest) -> Response:
|
||||
return Response(content=json.dumps(file_filters), media_type="application/json", status_code=200)
|
||||
|
||||
|
||||
class FeedbackData(BaseModel):
|
||||
uquery: str
|
||||
kquery: str
|
||||
sentiment: str
|
||||
|
||||
|
||||
@api_chat.post("/feedback")
|
||||
@requires(["authenticated"])
|
||||
async def sendfeedback(request: Request, data: FeedbackData):
|
||||
@@ -155,10 +167,10 @@ async def text_to_speech(
|
||||
common: CommonQueryParams,
|
||||
text: str,
|
||||
rate_limiter_per_minute=Depends(
|
||||
ApiUserRateLimiter(requests=20, subscribed_requests=20, window=60, slug="chat_minute")
|
||||
ApiUserRateLimiter(requests=30, subscribed_requests=30, window=60, slug="chat_minute")
|
||||
),
|
||||
rate_limiter_per_day=Depends(
|
||||
ApiUserRateLimiter(requests=50, subscribed_requests=300, window=60 * 60 * 24, slug="chat_day")
|
||||
ApiUserRateLimiter(requests=100, subscribed_requests=600, window=60 * 60 * 24, slug="chat_day")
|
||||
),
|
||||
) -> Response:
|
||||
voice_model = await ConversationAdapters.aget_voice_model_config(request.user.object)
|
||||
@@ -367,7 +379,7 @@ def fork_public_conversation(
|
||||
{
|
||||
"status": "ok",
|
||||
"next_url": redirect_uri,
|
||||
"conversation_id": new_conversation.id,
|
||||
"conversation_id": str(new_conversation.id),
|
||||
}
|
||||
),
|
||||
)
|
||||
@@ -523,20 +535,43 @@ async def set_conversation_title(
|
||||
)
|
||||
|
||||
|
||||
class ChatRequestBody(BaseModel):
|
||||
q: str
|
||||
n: Optional[int] = 7
|
||||
d: Optional[float] = None
|
||||
stream: Optional[bool] = False
|
||||
title: Optional[str] = None
|
||||
conversation_id: Optional[str] = None
|
||||
city: Optional[str] = None
|
||||
region: Optional[str] = None
|
||||
country: Optional[str] = None
|
||||
country_code: Optional[str] = None
|
||||
timezone: Optional[str] = None
|
||||
image: Optional[str] = None
|
||||
create_new: Optional[bool] = False
|
||||
@api_chat.post("/title")
|
||||
@requires(["authenticated"])
|
||||
async def generate_chat_title(
|
||||
request: Request,
|
||||
common: CommonQueryParams,
|
||||
conversation_id: str,
|
||||
):
|
||||
user: KhojUser = request.user.object
|
||||
conversation = await ConversationAdapters.aget_conversation_by_user(user=user, conversation_id=conversation_id)
|
||||
|
||||
# Conversation.title is explicitly set by the user. Do not override.
|
||||
if conversation.title:
|
||||
return {"status": "ok", "title": conversation.title}
|
||||
|
||||
if not conversation:
|
||||
raise HTTPException(status_code=404, detail="Conversation not found")
|
||||
|
||||
new_title = await acreate_title_from_history(request.user.object, conversation=conversation)
|
||||
|
||||
conversation.slug = new_title
|
||||
|
||||
await conversation.asave()
|
||||
|
||||
return {"status": "ok", "title": new_title}
|
||||
|
||||
|
||||
@api_chat.delete("/conversation/message", response_class=Response)
|
||||
@requires(["authenticated"])
|
||||
def delete_message(request: Request, delete_request: DeleteMessageRequestBody) -> Response:
|
||||
user = request.user.object
|
||||
success = ConversationAdapters.delete_message_by_turn_id(
|
||||
user, delete_request.conversation_id, delete_request.turn_id
|
||||
)
|
||||
if success:
|
||||
return Response(content=json.dumps({"status": "ok"}), media_type="application/json", status_code=200)
|
||||
else:
|
||||
return Response(content=json.dumps({"status": "error", "message": "Message not found"}), status_code=404)
|
||||
|
||||
|
||||
@api_chat.post("")
|
||||
@@ -546,11 +581,12 @@ async def chat(
|
||||
common: CommonQueryParams,
|
||||
body: ChatRequestBody,
|
||||
rate_limiter_per_minute=Depends(
|
||||
ApiUserRateLimiter(requests=60, subscribed_requests=200, window=60, slug="chat_minute")
|
||||
ApiUserRateLimiter(requests=20, subscribed_requests=20, window=60, slug="chat_minute")
|
||||
),
|
||||
rate_limiter_per_day=Depends(
|
||||
ApiUserRateLimiter(requests=600, subscribed_requests=6000, window=60 * 60 * 24, slug="chat_day")
|
||||
ApiUserRateLimiter(requests=100, subscribed_requests=600, window=60 * 60 * 24, slug="chat_day")
|
||||
),
|
||||
image_rate_limiter=Depends(ApiImageRateLimiter(max_images=10, max_combined_size_mb=20)),
|
||||
):
|
||||
# Access the parameters from the body
|
||||
q = body.q
|
||||
@@ -559,14 +595,16 @@ async def chat(
|
||||
stream = body.stream
|
||||
title = body.title
|
||||
conversation_id = body.conversation_id
|
||||
turn_id = str(body.turn_id or uuid.uuid4())
|
||||
city = body.city
|
||||
region = body.region
|
||||
country = body.country or get_country_name_from_timezone(body.timezone)
|
||||
country_code = body.country_code or get_country_code_from_timezone(body.timezone)
|
||||
timezone = body.timezone
|
||||
image = body.image
|
||||
raw_images = body.images
|
||||
raw_query_files = body.files
|
||||
|
||||
async def event_generator(q: str, image: str):
|
||||
async def event_generator(q: str, images: list[str]):
|
||||
start_time = time.perf_counter()
|
||||
ttft = None
|
||||
chat_metadata: dict = {}
|
||||
@@ -574,21 +612,35 @@ async def chat(
|
||||
user: KhojUser = request.user.object
|
||||
event_delimiter = "␃🔚␗"
|
||||
q = unquote(q)
|
||||
train_of_thought = []
|
||||
nonlocal conversation_id
|
||||
nonlocal raw_query_files
|
||||
|
||||
uploaded_image_url = None
|
||||
if image:
|
||||
decoded_string = unquote(image)
|
||||
base64_data = decoded_string.split(",", 1)[1]
|
||||
image_bytes = base64.b64decode(base64_data)
|
||||
webp_image_bytes = convert_image_to_webp(image_bytes)
|
||||
try:
|
||||
uploaded_image_url = upload_image_to_bucket(webp_image_bytes, request.user.object.id)
|
||||
except:
|
||||
uploaded_image_url = None
|
||||
tracer: dict = {
|
||||
"mid": turn_id,
|
||||
"cid": conversation_id,
|
||||
"uid": user.id,
|
||||
"khoj_version": state.khoj_version,
|
||||
}
|
||||
|
||||
uploaded_images: list[str] = []
|
||||
if images:
|
||||
for image in images:
|
||||
decoded_string = unquote(image)
|
||||
base64_data = decoded_string.split(",", 1)[1]
|
||||
image_bytes = base64.b64decode(base64_data)
|
||||
webp_image_bytes = convert_image_to_webp(image_bytes)
|
||||
uploaded_image = upload_image_to_bucket(webp_image_bytes, request.user.object.id)
|
||||
if uploaded_image:
|
||||
uploaded_images.append(uploaded_image)
|
||||
|
||||
query_files: Dict[str, str] = {}
|
||||
if raw_query_files:
|
||||
for file in raw_query_files:
|
||||
query_files[file.name] = file.content
|
||||
|
||||
async def send_event(event_type: ChatEvent, data: str | dict):
|
||||
nonlocal connection_alive, ttft
|
||||
nonlocal connection_alive, ttft, train_of_thought
|
||||
if not connection_alive or await request.is_disconnected():
|
||||
connection_alive = False
|
||||
logger.warning(f"User {user} disconnected from {common.client} client")
|
||||
@@ -596,11 +648,14 @@ async def chat(
|
||||
try:
|
||||
if event_type == ChatEvent.END_LLM_RESPONSE:
|
||||
collect_telemetry()
|
||||
if event_type == ChatEvent.START_LLM_RESPONSE:
|
||||
elif event_type == ChatEvent.START_LLM_RESPONSE:
|
||||
ttft = time.perf_counter() - start_time
|
||||
elif event_type == ChatEvent.STATUS:
|
||||
train_of_thought.append({"type": event_type.value, "data": data})
|
||||
|
||||
if event_type == ChatEvent.MESSAGE:
|
||||
yield data
|
||||
elif event_type == ChatEvent.REFERENCES or stream:
|
||||
elif event_type == ChatEvent.REFERENCES or ChatEvent.METADATA or stream:
|
||||
yield json.dumps({"type": event_type.value, "data": data}, ensure_ascii=False)
|
||||
except asyncio.CancelledError as e:
|
||||
connection_alive = False
|
||||
@@ -644,6 +699,11 @@ async def chat(
|
||||
metadata=chat_metadata,
|
||||
)
|
||||
|
||||
if is_query_empty(q):
|
||||
async for result in send_llm_response("Please ask your query to get started."):
|
||||
yield result
|
||||
return
|
||||
|
||||
conversation_commands = [get_conversation_command(query=q, any_references=True)]
|
||||
|
||||
conversation = await ConversationAdapters.aget_conversation_by_user(
|
||||
@@ -659,6 +719,9 @@ async def chat(
|
||||
return
|
||||
conversation_id = conversation.id
|
||||
|
||||
async for event in send_event(ChatEvent.METADATA, {"conversationId": str(conversation_id), "turnId": turn_id}):
|
||||
yield event
|
||||
|
||||
agent: Agent | None = None
|
||||
default_agent = await AgentAdapters.aget_default_agent()
|
||||
if conversation.agent and conversation.agent != default_agent:
|
||||
@@ -670,46 +733,99 @@ async def chat(
|
||||
agent = default_agent
|
||||
|
||||
await is_ready_to_chat(user)
|
||||
|
||||
user_name = await aget_user_name(user)
|
||||
location = None
|
||||
if city or region or country or country_code:
|
||||
location = LocationData(city=city, region=region, country=country, country_code=country_code)
|
||||
|
||||
if is_query_empty(q):
|
||||
async for result in send_llm_response("Please ask your query to get started."):
|
||||
yield result
|
||||
return
|
||||
|
||||
user_message_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||
|
||||
meta_log = conversation.conversation_log
|
||||
is_automated_task = conversation_commands == [ConversationCommand.AutomatedTask]
|
||||
|
||||
researched_results = ""
|
||||
online_results: Dict = dict()
|
||||
code_results: Dict = dict()
|
||||
## Extract Document References
|
||||
compiled_references: List[Any] = []
|
||||
inferred_queries: List[Any] = []
|
||||
file_filters = conversation.file_filters if conversation and conversation.file_filters else []
|
||||
attached_file_context = gather_raw_query_files(query_files)
|
||||
|
||||
if conversation_commands == [ConversationCommand.Default] or is_automated_task:
|
||||
conversation_commands = await aget_relevant_information_sources(
|
||||
q,
|
||||
meta_log,
|
||||
is_automated_task,
|
||||
user=user,
|
||||
uploaded_image_url=uploaded_image_url,
|
||||
query_images=uploaded_images,
|
||||
agent=agent,
|
||||
query_files=attached_file_context,
|
||||
tracer=tracer,
|
||||
)
|
||||
|
||||
# If we're doing research, we don't want to do anything else
|
||||
if ConversationCommand.Research in conversation_commands:
|
||||
conversation_commands = [ConversationCommand.Research]
|
||||
|
||||
conversation_commands_str = ", ".join([cmd.value for cmd in conversation_commands])
|
||||
async for result in send_event(
|
||||
ChatEvent.STATUS, f"**Chose Data Sources to Search:** {conversation_commands_str}"
|
||||
):
|
||||
yield result
|
||||
|
||||
mode = await aget_relevant_output_modes(q, meta_log, is_automated_task, user, uploaded_image_url, agent)
|
||||
mode = await aget_relevant_output_modes(
|
||||
q, meta_log, is_automated_task, user, uploaded_images, agent, tracer=tracer
|
||||
)
|
||||
async for result in send_event(ChatEvent.STATUS, f"**Decided Response Mode:** {mode.value}"):
|
||||
yield result
|
||||
if mode not in conversation_commands:
|
||||
conversation_commands.append(mode)
|
||||
|
||||
for cmd in conversation_commands:
|
||||
await conversation_command_rate_limiter.update_and_check_if_valid(request, cmd)
|
||||
q = q.replace(f"/{cmd.value}", "").strip()
|
||||
try:
|
||||
await conversation_command_rate_limiter.update_and_check_if_valid(request, cmd)
|
||||
q = q.replace(f"/{cmd.value}", "").strip()
|
||||
except HTTPException as e:
|
||||
async for result in send_llm_response(str(e.detail)):
|
||||
yield result
|
||||
return
|
||||
|
||||
defiltered_query = defilter_query(q)
|
||||
|
||||
if conversation_commands == [ConversationCommand.Research]:
|
||||
async for research_result in execute_information_collection(
|
||||
request=request,
|
||||
user=user,
|
||||
query=defiltered_query,
|
||||
conversation_id=conversation_id,
|
||||
conversation_history=meta_log,
|
||||
query_images=uploaded_images,
|
||||
agent=agent,
|
||||
send_status_func=partial(send_event, ChatEvent.STATUS),
|
||||
user_name=user_name,
|
||||
location=location,
|
||||
file_filters=conversation.file_filters if conversation else [],
|
||||
query_files=attached_file_context,
|
||||
tracer=tracer,
|
||||
):
|
||||
if isinstance(research_result, InformationCollectionIteration):
|
||||
if research_result.summarizedResult:
|
||||
if research_result.onlineContext:
|
||||
online_results.update(research_result.onlineContext)
|
||||
if research_result.codeContext:
|
||||
code_results.update(research_result.codeContext)
|
||||
if research_result.context:
|
||||
compiled_references.extend(research_result.context)
|
||||
|
||||
researched_results += research_result.summarizedResult
|
||||
|
||||
else:
|
||||
yield research_result
|
||||
|
||||
# researched_results = await extract_relevant_info(q, researched_results, agent)
|
||||
if state.verbose > 1:
|
||||
logger.debug(f"Researched Results: {researched_results}")
|
||||
|
||||
used_slash_summarize = conversation_commands == [ConversationCommand.Summarize]
|
||||
file_filters = conversation.file_filters if conversation else []
|
||||
@@ -730,52 +846,26 @@ async def chat(
|
||||
response_log = "No files selected for summarization. Please add files using the section on the left."
|
||||
async for result in send_llm_response(response_log):
|
||||
yield result
|
||||
elif len(file_filters) > 1 and not agent_has_entries:
|
||||
response_log = "Only one file can be selected for summarization."
|
||||
async for result in send_llm_response(response_log):
|
||||
yield result
|
||||
else:
|
||||
try:
|
||||
file_object = None
|
||||
if await EntryAdapters.aagent_has_entries(agent):
|
||||
file_names = await EntryAdapters.aget_agent_entry_filepaths(agent)
|
||||
if len(file_names) > 0:
|
||||
file_object = await FileObjectAdapters.async_get_file_objects_by_name(
|
||||
None, file_names[0], agent
|
||||
)
|
||||
async for response in generate_summary_from_files(
|
||||
q=q,
|
||||
user=user,
|
||||
file_filters=file_filters,
|
||||
meta_log=meta_log,
|
||||
query_images=uploaded_images,
|
||||
agent=agent,
|
||||
send_status_func=partial(send_event, ChatEvent.STATUS),
|
||||
query_files=attached_file_context,
|
||||
tracer=tracer,
|
||||
):
|
||||
if isinstance(response, dict) and ChatEvent.STATUS in response:
|
||||
yield response[ChatEvent.STATUS]
|
||||
else:
|
||||
if isinstance(response, str):
|
||||
response_log = response
|
||||
async for result in send_llm_response(response):
|
||||
yield result
|
||||
|
||||
if len(file_filters) > 0:
|
||||
file_object = await FileObjectAdapters.async_get_file_objects_by_name(user, file_filters[0])
|
||||
|
||||
if len(file_object) == 0:
|
||||
response_log = "Sorry, I couldn't find the full text of this file. Please re-upload the document and try again."
|
||||
async for result in send_llm_response(response_log):
|
||||
yield result
|
||||
return
|
||||
contextual_data = " ".join([file.raw_text for file in file_object])
|
||||
if not q:
|
||||
q = "Create a general summary of the file"
|
||||
async for result in send_event(
|
||||
ChatEvent.STATUS, f"**Constructing Summary Using:** {file_object[0].file_name}"
|
||||
):
|
||||
yield result
|
||||
|
||||
response = await extract_relevant_summary(
|
||||
q,
|
||||
contextual_data,
|
||||
conversation_history=meta_log,
|
||||
uploaded_image_url=uploaded_image_url,
|
||||
user=user,
|
||||
agent=agent,
|
||||
)
|
||||
response_log = str(response)
|
||||
async for result in send_llm_response(response_log):
|
||||
yield result
|
||||
except Exception as e:
|
||||
response_log = "Error summarizing file. Please try again, or contact support."
|
||||
logger.error(f"Error summarizing file for {user.email}: {e}", exc_info=True)
|
||||
async for result in send_llm_response(response_log):
|
||||
yield result
|
||||
await sync_to_async(save_to_conversation_log)(
|
||||
q,
|
||||
response_log,
|
||||
@@ -785,7 +875,10 @@ async def chat(
|
||||
intent_type="summarize",
|
||||
client_application=request.user.client_app,
|
||||
conversation_id=conversation_id,
|
||||
uploaded_image_url=uploaded_image_url,
|
||||
query_images=uploaded_images,
|
||||
train_of_thought=train_of_thought,
|
||||
raw_query_files=raw_query_files,
|
||||
tracer=tracer,
|
||||
)
|
||||
return
|
||||
|
||||
@@ -794,7 +887,7 @@ async def chat(
|
||||
if not q:
|
||||
conversation_config = await ConversationAdapters.aget_user_conversation_config(user)
|
||||
if conversation_config == None:
|
||||
conversation_config = await ConversationAdapters.aget_default_conversation_config()
|
||||
conversation_config = await ConversationAdapters.aget_default_conversation_config(user)
|
||||
model_type = conversation_config.model_type
|
||||
formatted_help = help_message.format(model=model_type, version=state.khoj_version, device=get_device())
|
||||
async for result in send_llm_response(formatted_help):
|
||||
@@ -807,7 +900,7 @@ async def chat(
|
||||
if ConversationCommand.Automation in conversation_commands:
|
||||
try:
|
||||
automation, crontime, query_to_run, subject = await create_automation(
|
||||
q, timezone, user, request.url, meta_log
|
||||
q, timezone, user, request.url, meta_log, tracer=tracer
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error scheduling task {q} for {user.email}: {e}")
|
||||
@@ -828,7 +921,10 @@ async def chat(
|
||||
conversation_id=conversation_id,
|
||||
inferred_queries=[query_to_run],
|
||||
automation_id=automation.id,
|
||||
uploaded_image_url=uploaded_image_url,
|
||||
query_images=uploaded_images,
|
||||
train_of_thought=train_of_thought,
|
||||
raw_query_files=raw_query_files,
|
||||
tracer=tracer,
|
||||
)
|
||||
async for result in send_llm_response(llm_response):
|
||||
yield result
|
||||
@@ -836,48 +932,50 @@ async def chat(
|
||||
|
||||
# Gather Context
|
||||
## Extract Document References
|
||||
compiled_references, inferred_queries, defiltered_query = [], [], q
|
||||
try:
|
||||
async for result in extract_references_and_questions(
|
||||
request,
|
||||
meta_log,
|
||||
q,
|
||||
(n or 7),
|
||||
d,
|
||||
conversation_id,
|
||||
conversation_commands,
|
||||
location,
|
||||
partial(send_event, ChatEvent.STATUS),
|
||||
uploaded_image_url=uploaded_image_url,
|
||||
agent=agent,
|
||||
):
|
||||
if isinstance(result, dict) and ChatEvent.STATUS in result:
|
||||
yield result[ChatEvent.STATUS]
|
||||
else:
|
||||
compiled_references.extend(result[0])
|
||||
inferred_queries.extend(result[1])
|
||||
defiltered_query = result[2]
|
||||
except Exception as e:
|
||||
error_message = f"Error searching knowledge base: {e}. Attempting to respond without document references."
|
||||
logger.error(error_message, exc_info=True)
|
||||
async for result in send_event(
|
||||
ChatEvent.STATUS, "Document search failed. I'll try respond without document references"
|
||||
):
|
||||
yield result
|
||||
if not ConversationCommand.Research in conversation_commands:
|
||||
try:
|
||||
async for result in extract_references_and_questions(
|
||||
request,
|
||||
meta_log,
|
||||
q,
|
||||
(n or 7),
|
||||
d,
|
||||
conversation_id,
|
||||
conversation_commands,
|
||||
location,
|
||||
partial(send_event, ChatEvent.STATUS),
|
||||
query_images=uploaded_images,
|
||||
agent=agent,
|
||||
query_files=attached_file_context,
|
||||
tracer=tracer,
|
||||
):
|
||||
if isinstance(result, dict) and ChatEvent.STATUS in result:
|
||||
yield result[ChatEvent.STATUS]
|
||||
else:
|
||||
compiled_references.extend(result[0])
|
||||
inferred_queries.extend(result[1])
|
||||
defiltered_query = result[2]
|
||||
except Exception as e:
|
||||
error_message = (
|
||||
f"Error searching knowledge base: {e}. Attempting to respond without document references."
|
||||
)
|
||||
logger.error(error_message, exc_info=True)
|
||||
async for result in send_event(
|
||||
ChatEvent.STATUS, "Document search failed. I'll try respond without document references"
|
||||
):
|
||||
yield result
|
||||
|
||||
if not is_none_or_empty(compiled_references):
|
||||
headings = "\n- " + "\n- ".join(set([c.get("compiled", c).split("\n")[0] for c in compiled_references]))
|
||||
# Strip only leading # from headings
|
||||
headings = headings.replace("#", "")
|
||||
async for result in send_event(ChatEvent.STATUS, f"**Found Relevant Notes**: {headings}"):
|
||||
yield result
|
||||
if not is_none_or_empty(compiled_references):
|
||||
headings = "\n- " + "\n- ".join(set([c.get("compiled", c).split("\n")[0] for c in compiled_references]))
|
||||
# Strip only leading # from headings
|
||||
headings = headings.replace("#", "")
|
||||
async for result in send_event(ChatEvent.STATUS, f"**Found Relevant Notes**: {headings}"):
|
||||
yield result
|
||||
|
||||
online_results: Dict = dict()
|
||||
|
||||
if conversation_commands == [ConversationCommand.Notes] and not await EntryAdapters.auser_has_entries(user):
|
||||
async for result in send_llm_response(f"{no_entries_found.format()}"):
|
||||
yield result
|
||||
return
|
||||
if conversation_commands == [ConversationCommand.Notes] and not await EntryAdapters.auser_has_entries(user):
|
||||
async for result in send_llm_response(f"{no_entries_found.format()}"):
|
||||
yield result
|
||||
return
|
||||
|
||||
if ConversationCommand.Notes in conversation_commands and is_none_or_empty(compiled_references):
|
||||
conversation_commands.remove(ConversationCommand.Notes)
|
||||
@@ -892,8 +990,10 @@ async def chat(
|
||||
user,
|
||||
partial(send_event, ChatEvent.STATUS),
|
||||
custom_filters,
|
||||
uploaded_image_url=uploaded_image_url,
|
||||
query_images=uploaded_images,
|
||||
agent=agent,
|
||||
query_files=attached_file_context,
|
||||
tracer=tracer,
|
||||
):
|
||||
if isinstance(result, dict) and ChatEvent.STATUS in result:
|
||||
yield result[ChatEvent.STATUS]
|
||||
@@ -916,8 +1016,10 @@ async def chat(
|
||||
location,
|
||||
user,
|
||||
partial(send_event, ChatEvent.STATUS),
|
||||
uploaded_image_url=uploaded_image_url,
|
||||
query_images=uploaded_images,
|
||||
agent=agent,
|
||||
query_files=attached_file_context,
|
||||
tracer=tracer,
|
||||
):
|
||||
if isinstance(result, dict) and ChatEvent.STATUS in result:
|
||||
yield result[ChatEvent.STATUS]
|
||||
@@ -944,13 +1046,43 @@ async def chat(
|
||||
):
|
||||
yield result
|
||||
|
||||
## Gather Code Results
|
||||
if ConversationCommand.Code in conversation_commands:
|
||||
try:
|
||||
context = f"# Iteration 1:\n#---\nNotes:\n{compiled_references}\n\nOnline Results:{online_results}"
|
||||
async for result in run_code(
|
||||
defiltered_query,
|
||||
meta_log,
|
||||
context,
|
||||
location,
|
||||
user,
|
||||
partial(send_event, ChatEvent.STATUS),
|
||||
query_images=uploaded_images,
|
||||
agent=agent,
|
||||
query_files=attached_file_context,
|
||||
tracer=tracer,
|
||||
):
|
||||
if isinstance(result, dict) and ChatEvent.STATUS in result:
|
||||
yield result[ChatEvent.STATUS]
|
||||
else:
|
||||
code_results = result
|
||||
async for result in send_event(ChatEvent.STATUS, f"**Ran code snippets**: {len(code_results)}"):
|
||||
yield result
|
||||
except ValueError as e:
|
||||
logger.warning(
|
||||
f"Failed to use code tool: {e}. Attempting to respond without code results",
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
## Send Gathered References
|
||||
unique_online_results = deduplicate_organic_results(online_results)
|
||||
async for result in send_event(
|
||||
ChatEvent.REFERENCES,
|
||||
{
|
||||
"inferredQueries": inferred_queries,
|
||||
"context": compiled_references,
|
||||
"onlineContext": online_results,
|
||||
"onlineContext": unique_online_results,
|
||||
"codeContext": code_results,
|
||||
},
|
||||
):
|
||||
yield result
|
||||
@@ -966,20 +1098,22 @@ async def chat(
|
||||
references=compiled_references,
|
||||
online_results=online_results,
|
||||
send_status_func=partial(send_event, ChatEvent.STATUS),
|
||||
uploaded_image_url=uploaded_image_url,
|
||||
query_images=uploaded_images,
|
||||
agent=agent,
|
||||
query_files=attached_file_context,
|
||||
tracer=tracer,
|
||||
):
|
||||
if isinstance(result, dict) and ChatEvent.STATUS in result:
|
||||
yield result[ChatEvent.STATUS]
|
||||
else:
|
||||
image, status_code, improved_image_prompt, intent_type = result
|
||||
generated_image, status_code, improved_image_prompt, intent_type = result
|
||||
|
||||
if image is None or status_code != 200:
|
||||
if generated_image is None or status_code != 200:
|
||||
content_obj = {
|
||||
"content-type": "application/json",
|
||||
"intentType": intent_type,
|
||||
"detail": improved_image_prompt,
|
||||
"image": image,
|
||||
"image": None,
|
||||
}
|
||||
async for result in send_llm_response(json.dumps(content_obj)):
|
||||
yield result
|
||||
@@ -987,7 +1121,7 @@ async def chat(
|
||||
|
||||
await sync_to_async(save_to_conversation_log)(
|
||||
q,
|
||||
image,
|
||||
generated_image,
|
||||
user,
|
||||
meta_log,
|
||||
user_message_time,
|
||||
@@ -997,17 +1131,83 @@ async def chat(
|
||||
conversation_id=conversation_id,
|
||||
compiled_references=compiled_references,
|
||||
online_results=online_results,
|
||||
uploaded_image_url=uploaded_image_url,
|
||||
code_results=code_results,
|
||||
query_images=uploaded_images,
|
||||
train_of_thought=train_of_thought,
|
||||
raw_query_files=raw_query_files,
|
||||
tracer=tracer,
|
||||
)
|
||||
content_obj = {
|
||||
"intentType": intent_type,
|
||||
"inferredQueries": [improved_image_prompt],
|
||||
"image": image,
|
||||
"image": generated_image,
|
||||
}
|
||||
async for result in send_llm_response(json.dumps(content_obj)):
|
||||
yield result
|
||||
return
|
||||
|
||||
if ConversationCommand.Diagram in conversation_commands:
|
||||
async for result in send_event(ChatEvent.STATUS, f"Creating diagram"):
|
||||
yield result
|
||||
|
||||
intent_type = "excalidraw"
|
||||
inferred_queries = []
|
||||
diagram_description = ""
|
||||
|
||||
async for result in generate_excalidraw_diagram(
|
||||
q=defiltered_query,
|
||||
conversation_history=meta_log,
|
||||
location_data=location,
|
||||
note_references=compiled_references,
|
||||
online_results=online_results,
|
||||
query_images=uploaded_images,
|
||||
user=user,
|
||||
agent=agent,
|
||||
send_status_func=partial(send_event, ChatEvent.STATUS),
|
||||
query_files=attached_file_context,
|
||||
tracer=tracer,
|
||||
):
|
||||
if isinstance(result, dict) and ChatEvent.STATUS in result:
|
||||
yield result[ChatEvent.STATUS]
|
||||
else:
|
||||
better_diagram_description_prompt, excalidraw_diagram_description = result
|
||||
if better_diagram_description_prompt and excalidraw_diagram_description:
|
||||
inferred_queries.append(better_diagram_description_prompt)
|
||||
diagram_description = excalidraw_diagram_description
|
||||
else:
|
||||
async for result in send_llm_response(f"Failed to generate diagram. Please try again later."):
|
||||
yield result
|
||||
return
|
||||
|
||||
content_obj = {
|
||||
"intentType": intent_type,
|
||||
"inferredQueries": inferred_queries,
|
||||
"image": diagram_description,
|
||||
}
|
||||
|
||||
await sync_to_async(save_to_conversation_log)(
|
||||
q,
|
||||
excalidraw_diagram_description,
|
||||
user,
|
||||
meta_log,
|
||||
user_message_time,
|
||||
intent_type="excalidraw",
|
||||
inferred_queries=[better_diagram_description_prompt],
|
||||
client_application=request.user.client_app,
|
||||
conversation_id=conversation_id,
|
||||
compiled_references=compiled_references,
|
||||
online_results=online_results,
|
||||
code_results=code_results,
|
||||
query_images=uploaded_images,
|
||||
train_of_thought=train_of_thought,
|
||||
raw_query_files=raw_query_files,
|
||||
tracer=tracer,
|
||||
)
|
||||
|
||||
async for result in send_llm_response(json.dumps(content_obj)):
|
||||
yield result
|
||||
return
|
||||
|
||||
## Generate Text Output
|
||||
async for result in send_event(ChatEvent.STATUS, f"**Generating a well-informed response**"):
|
||||
yield result
|
||||
@@ -1017,6 +1217,7 @@ async def chat(
|
||||
conversation,
|
||||
compiled_references,
|
||||
online_results,
|
||||
code_results,
|
||||
inferred_queries,
|
||||
conversation_commands,
|
||||
user,
|
||||
@@ -1024,7 +1225,12 @@ async def chat(
|
||||
conversation_id,
|
||||
location,
|
||||
user_name,
|
||||
uploaded_image_url,
|
||||
researched_results,
|
||||
uploaded_images,
|
||||
train_of_thought,
|
||||
attached_file_context,
|
||||
raw_query_files,
|
||||
tracer,
|
||||
)
|
||||
|
||||
# Send Response
|
||||
@@ -1050,9 +1256,9 @@ async def chat(
|
||||
|
||||
## Stream Text Response
|
||||
if stream:
|
||||
return StreamingResponse(event_generator(q, image=image), media_type="text/plain")
|
||||
return StreamingResponse(event_generator(q, images=raw_images), media_type="text/plain")
|
||||
## Non-Streaming Text Response
|
||||
else:
|
||||
response_iterator = event_generator(q, image=image)
|
||||
response_iterator = event_generator(q, images=raw_images)
|
||||
response_data = await read_chat_stream(response_iterator)
|
||||
return Response(content=json.dumps(response_data), media_type="application/json", status_code=200)
|
||||
|
||||
@@ -36,16 +36,18 @@ from khoj.database.models import (
|
||||
LocalPlaintextConfig,
|
||||
NotionConfig,
|
||||
)
|
||||
from khoj.processor.content.docx.docx_to_entries import DocxToEntries
|
||||
from khoj.processor.content.pdf.pdf_to_entries import PdfToEntries
|
||||
from khoj.routers.helpers import (
|
||||
ApiIndexedDataLimiter,
|
||||
CommonQueryParams,
|
||||
configure_content,
|
||||
get_file_content,
|
||||
get_user_config,
|
||||
update_telemetry_state,
|
||||
)
|
||||
from khoj.utils import constants, state
|
||||
from khoj.utils.config import SearchModels
|
||||
from khoj.utils.helpers import get_file_type
|
||||
from khoj.utils.rawconfig import (
|
||||
ContentConfig,
|
||||
FullConfig,
|
||||
@@ -237,7 +239,7 @@ async def set_content_notion(
|
||||
|
||||
if updated_config.token:
|
||||
# Trigger an async job to configure_content. Let it run without blocking the response.
|
||||
background_tasks.add_task(run_in_executor, configure_content, {}, False, SearchType.Notion, user)
|
||||
background_tasks.add_task(run_in_executor, configure_content, user, {}, False, SearchType.Notion)
|
||||
|
||||
update_telemetry_state(
|
||||
request=request,
|
||||
@@ -375,6 +377,85 @@ async def delete_content_source(
|
||||
return {"status": "ok"}
|
||||
|
||||
|
||||
@api_content.post("/convert", status_code=200)
|
||||
@requires(["authenticated"])
|
||||
async def convert_documents(
|
||||
request: Request,
|
||||
files: List[UploadFile],
|
||||
client: Optional[str] = None,
|
||||
):
|
||||
MAX_FILE_SIZE_MB = 10 # 10MB limit
|
||||
MAX_FILE_SIZE_BYTES = MAX_FILE_SIZE_MB * 1024 * 1024
|
||||
|
||||
converted_files = []
|
||||
supported_files = ["org", "markdown", "pdf", "plaintext", "docx"]
|
||||
|
||||
for file in files:
|
||||
# Check file size first
|
||||
file_size = 0
|
||||
content = await file.read()
|
||||
file_size = len(content)
|
||||
await file.seek(0) # Reset file pointer
|
||||
|
||||
if file_size > MAX_FILE_SIZE_BYTES:
|
||||
logger.warning(
|
||||
f"Skipped converting oversized file ({file_size / 1024 / 1024:.1f}MB) sent by {client} client: {file.filename}"
|
||||
)
|
||||
continue
|
||||
|
||||
file_data = get_file_content(file)
|
||||
if file_data.file_type in supported_files:
|
||||
extracted_content = (
|
||||
file_data.content.decode(file_data.encoding) if file_data.encoding else file_data.content
|
||||
)
|
||||
|
||||
if file_data.file_type == "docx":
|
||||
entries_per_page = DocxToEntries.extract_text(file_data.content)
|
||||
annotated_pages = [
|
||||
f"Page {index} of {file_data.name}:\n\n{entry}" for index, entry in enumerate(entries_per_page)
|
||||
]
|
||||
extracted_content = "\n".join(annotated_pages)
|
||||
|
||||
elif file_data.file_type == "pdf":
|
||||
entries_per_page = PdfToEntries.extract_text(file_data.content)
|
||||
annotated_pages = [
|
||||
f"Page {index} of {file_data.name}:\n\n{entry}" for index, entry in enumerate(entries_per_page)
|
||||
]
|
||||
extracted_content = "\n".join(annotated_pages)
|
||||
else:
|
||||
# Convert content to string
|
||||
extracted_content = extracted_content.decode("utf-8")
|
||||
|
||||
# Calculate size in bytes. Some of the content might be in bytes, some in str.
|
||||
if isinstance(extracted_content, str):
|
||||
size_in_bytes = len(extracted_content.encode("utf-8"))
|
||||
elif isinstance(extracted_content, bytes):
|
||||
size_in_bytes = len(extracted_content)
|
||||
else:
|
||||
size_in_bytes = 0
|
||||
logger.warning(f"Unexpected content type: {type(extracted_content)}")
|
||||
|
||||
converted_files.append(
|
||||
{
|
||||
"name": file_data.name,
|
||||
"content": extracted_content,
|
||||
"file_type": file_data.file_type,
|
||||
"size": size_in_bytes,
|
||||
}
|
||||
)
|
||||
else:
|
||||
logger.warning(f"Skipped converting unsupported file type sent by {client} client: {file.filename}")
|
||||
|
||||
update_telemetry_state(
|
||||
request=request,
|
||||
telemetry_type="api",
|
||||
api="convert_documents",
|
||||
client=client,
|
||||
)
|
||||
|
||||
return Response(content=json.dumps(converted_files), media_type="application/json", status_code=200)
|
||||
|
||||
|
||||
async def indexer(
|
||||
request: Request,
|
||||
files: list[UploadFile],
|
||||
@@ -398,12 +479,13 @@ async def indexer(
|
||||
try:
|
||||
logger.info(f"📬 Updating content index via API call by {client} client")
|
||||
for file in files:
|
||||
file_content = file.file.read()
|
||||
file_type, encoding = get_file_type(file.content_type, file_content)
|
||||
if file_type in index_files:
|
||||
index_files[file_type][file.filename] = file_content.decode(encoding) if encoding else file_content
|
||||
file_data = get_file_content(file)
|
||||
if file_data.file_type in index_files:
|
||||
index_files[file_data.file_type][file_data.name] = (
|
||||
file_data.content.decode(file_data.encoding) if file_data.encoding else file_data.content
|
||||
)
|
||||
else:
|
||||
logger.warning(f"Skipped indexing unsupported file type sent by {client} client: {file.filename}")
|
||||
logger.warning(f"Skipped indexing unsupported file type sent by {client} client: {file_data.name}")
|
||||
|
||||
indexer_input = IndexerInput(
|
||||
org=index_files["org"],
|
||||
@@ -440,10 +522,10 @@ async def indexer(
|
||||
success = await loop.run_in_executor(
|
||||
None,
|
||||
configure_content,
|
||||
user,
|
||||
indexer_input.model_dump(),
|
||||
regenerate,
|
||||
t,
|
||||
user,
|
||||
)
|
||||
if not success:
|
||||
raise RuntimeError(f"Failed to {method} {t} data sent by {client} client into content index")
|
||||
|
||||
@@ -94,39 +94,6 @@ async def update_voice_model(
|
||||
return Response(status_code=202, content=json.dumps({"status": "ok"}))
|
||||
|
||||
|
||||
@api_model.post("/search", status_code=200)
|
||||
@requires(["authenticated"])
|
||||
async def update_search_model(
|
||||
request: Request,
|
||||
id: str,
|
||||
client: Optional[str] = None,
|
||||
):
|
||||
user = request.user.object
|
||||
|
||||
prev_config = await adapters.aget_user_search_model(user)
|
||||
new_config = await adapters.aset_user_search_model(user, int(id))
|
||||
|
||||
if prev_config and int(id) != prev_config.id and new_config:
|
||||
await EntryAdapters.adelete_all_entries(user)
|
||||
|
||||
if not prev_config:
|
||||
# If the use was just using the default config, delete all the entries and set the new config.
|
||||
await EntryAdapters.adelete_all_entries(user)
|
||||
|
||||
if new_config is None:
|
||||
return {"status": "error", "message": "Model not found"}
|
||||
else:
|
||||
update_telemetry_state(
|
||||
request=request,
|
||||
telemetry_type="api",
|
||||
api="set_search_model",
|
||||
client=client,
|
||||
metadata={"search_model": new_config.setting.name},
|
||||
)
|
||||
|
||||
return {"status": "ok"}
|
||||
|
||||
|
||||
@api_model.post("/paint", status_code=200)
|
||||
@requires(["authenticated"])
|
||||
async def update_paint_model(
|
||||
|
||||
@@ -1,12 +1,14 @@
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from datetime import datetime, timezone
|
||||
|
||||
from asgiref.sync import sync_to_async
|
||||
from fastapi import APIRouter, Request
|
||||
from fastapi import APIRouter, Request, Response
|
||||
from starlette.authentication import requires
|
||||
|
||||
from khoj.database import adapters
|
||||
from khoj.database.models import KhojUser, Subscription
|
||||
from khoj.routers.helpers import update_telemetry_state
|
||||
from khoj.utils import state
|
||||
|
||||
@@ -73,7 +75,7 @@ async def subscribe(request: Request):
|
||||
elif event_type in {"customer.subscription.deleted"}:
|
||||
# Reset the user to trial state
|
||||
user, is_new = await adapters.set_user_subscription(
|
||||
customer_email, is_recurring=False, renewal_date=False, type="trial"
|
||||
customer_email, is_recurring=False, renewal_date=False, type=Subscription.Type.TRIAL
|
||||
)
|
||||
success = user is not None
|
||||
|
||||
@@ -82,7 +84,7 @@ async def subscribe(request: Request):
|
||||
request=request,
|
||||
telemetry_type="api",
|
||||
api="create_user",
|
||||
metadata={"user_id": str(user.user.uuid)},
|
||||
metadata={"server_id": str(user.user.uuid)},
|
||||
)
|
||||
logger.log(logging.INFO, f"🥳 New User Created: {user.user.uuid}")
|
||||
|
||||
@@ -92,8 +94,9 @@ async def subscribe(request: Request):
|
||||
|
||||
@subscription_router.patch("")
|
||||
@requires(["authenticated"])
|
||||
async def update_subscription(request: Request, email: str, operation: str):
|
||||
async def update_subscription(request: Request, operation: str):
|
||||
# Retrieve the customer's details
|
||||
email = request.user.object.email
|
||||
customers = stripe.Customer.list(email=email).auto_paging_iter()
|
||||
customer = next(customers, None)
|
||||
if customer is None:
|
||||
@@ -116,3 +119,19 @@ async def update_subscription(request: Request, email: str, operation: str):
|
||||
return {"success": False, "message": "No subscription found that is set to cancel"}
|
||||
|
||||
return {"success": False, "message": "Invalid operation"}
|
||||
|
||||
|
||||
@subscription_router.post("/trial", response_class=Response)
|
||||
@requires(["authenticated"])
|
||||
async def start_trial(request: Request) -> Response:
|
||||
user: KhojUser = request.user.object
|
||||
|
||||
# Start a trial for the user
|
||||
updated_subscription = await adapters.astart_trial_subscription(user)
|
||||
|
||||
# Return trial status as a JSON response
|
||||
return Response(
|
||||
content=json.dumps({"trial_enabled": updated_subscription is not None}),
|
||||
media_type="application/json",
|
||||
status_code=200,
|
||||
)
|
||||
@@ -90,7 +90,7 @@ async def login_magic_link(request: Request, form: MagicLinkForm):
|
||||
request=request,
|
||||
telemetry_type="api",
|
||||
api="create_user",
|
||||
metadata={"user_id": str(user.uuid)},
|
||||
metadata={"server_id": str(user.uuid)},
|
||||
)
|
||||
logger.log(logging.INFO, f"🥳 New User Created: {user.uuid}")
|
||||
|
||||
@@ -175,7 +175,7 @@ async def auth(request: Request):
|
||||
request=request,
|
||||
telemetry_type="api",
|
||||
api="create_user",
|
||||
metadata={"user_id": str(khoj_user.uuid)},
|
||||
metadata={"server_id": str(khoj_user.uuid)},
|
||||
)
|
||||
logger.log(logging.INFO, f"🥳 New User Created: {khoj_user.uuid}")
|
||||
return RedirectResponse(url=next_url, status_code=HTTP_302_FOUND)
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user