mirror of
https://github.com/khoj-ai/khoj.git
synced 2026-05-14 05:51:43 +00:00
Compare commits
55 Commits
master
...
create-dat
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
f3045ab436 | ||
|
|
899e9a6c23 | ||
|
|
a35002e93b | ||
|
|
3db3d7bbb4 | ||
|
|
f870a2f7a2 | ||
|
|
0281d8b715 | ||
|
|
947ce8f029 | ||
|
|
9815bbc8a1 | ||
|
|
c2180627b4 | ||
|
|
095184906e | ||
|
|
06a900069f | ||
|
|
61a2956de2 | ||
|
|
7a3794040a | ||
|
|
2654b36a56 | ||
|
|
022adbe63c | ||
|
|
6ffd93f16e | ||
|
|
f917ee7c05 | ||
|
|
af3c16ec43 | ||
|
|
28f28fc942 | ||
|
|
9d20b365e3 | ||
|
|
90df9de0fe | ||
|
|
7c1cd3d5e9 | ||
|
|
4359b83568 | ||
|
|
adcf9314fd | ||
|
|
79f3f61b70 | ||
|
|
72aceaba38 | ||
|
|
7b2bffbcd7 | ||
|
|
776ece8d22 | ||
|
|
784389b172 | ||
|
|
af7c0e9445 | ||
|
|
6d859135a1 | ||
|
|
47fded4ac5 | ||
|
|
da5357cdaf | ||
|
|
ac3f979919 | ||
|
|
b70a4ff92e | ||
|
|
6257d1bb62 | ||
|
|
ef794a78d5 | ||
|
|
9f8307b88a | ||
|
|
e5d5153fc4 | ||
|
|
11a254863b | ||
|
|
8b2798cbe0 | ||
|
|
4059d10470 | ||
|
|
509ae2f9f6 | ||
|
|
3509b50f92 | ||
|
|
33d870c907 | ||
|
|
7f001768c1 | ||
|
|
5c90db6fec | ||
|
|
fed510a32a | ||
|
|
1077dd5acd | ||
|
|
224eb228cf | ||
|
|
4560a64f6e | ||
|
|
1815927f6c | ||
|
|
e19e84e2eb | ||
|
|
818239c3c7 | ||
|
|
d999edcbd9 |
9
.github/workflows/run_evals.yml
vendored
9
.github/workflows/run_evals.yml
vendored
@@ -33,6 +33,12 @@ on:
|
||||
default: 200
|
||||
type: number
|
||||
|
||||
hf_repo_name:
|
||||
description: 'HuggingFace data tracer repo name to output the results'
|
||||
required: false
|
||||
default: 'khoj-ai/datatracer-frames'
|
||||
type: string
|
||||
|
||||
jobs:
|
||||
eval:
|
||||
runs-on: ubuntu-latest
|
||||
@@ -131,6 +137,9 @@ jobs:
|
||||
# Run evals
|
||||
python tests/evals/eval.py -d ${{ matrix.dataset }}
|
||||
|
||||
# Run data tracer script to upload filtered results to hf repo passed as arg
|
||||
python scripts/create_dataset_from_eval.py --repo_name ${{ inputs.hf_repo_name }} --eval_path ${{ matrix.dataset }}_evaluation_results_*.csv --datatrace_path datatracer.csv
|
||||
|
||||
- name: Upload Results
|
||||
if: always() # Upload results even if tests fail
|
||||
uses: actions/upload-artifact@v4
|
||||
|
||||
@@ -14,6 +14,11 @@ services:
|
||||
retries: 5
|
||||
sandbox:
|
||||
image: ghcr.io/khoj-ai/terrarium:latest
|
||||
healthcheck:
|
||||
test: ["CMD-SHELL", "curl -f http://localhost:8080/health"]
|
||||
interval: 30s
|
||||
timeout: 10s
|
||||
retries: 2
|
||||
search:
|
||||
image: docker.io/searxng/searxng:latest
|
||||
volumes:
|
||||
@@ -25,16 +30,17 @@ services:
|
||||
database:
|
||||
condition: service_healthy
|
||||
# Use the following line to use the latest version of khoj. Otherwise, it will build from source. Set this to ghcr.io/khoj-ai/khoj-cloud:latest if you want to use the prod image.
|
||||
image: ghcr.io/khoj-ai/khoj:latest
|
||||
# image: ghcr.io/khoj-ai/khoj:latest
|
||||
# Uncomment the following line to build from source. This will take a few minutes. Comment the next two lines out if you want to use the official image.
|
||||
# build:
|
||||
# context: .
|
||||
build:
|
||||
context: .
|
||||
dockerfile: prod.Dockerfile
|
||||
ports:
|
||||
# If changing the local port (left hand side), no other changes required.
|
||||
# If changing the remote port (right hand side),
|
||||
# change the port in the args in the build section,
|
||||
# as well as the port in the command section to match
|
||||
- "42110:42110"
|
||||
- "42110"
|
||||
extra_hosts:
|
||||
- "host.docker.internal:host-gateway"
|
||||
working_dir: /app
|
||||
@@ -42,6 +48,7 @@ services:
|
||||
- khoj_config:/root/.khoj/
|
||||
- khoj_models:/root/.cache/torch/sentence_transformers
|
||||
- khoj_models:/root/.cache/huggingface
|
||||
- khoj_tmp:/app/tmp/
|
||||
# Use 0.0.0.0 to explicitly set the host ip for the service on the container. https://pythonspeed.com/articles/docker-connection-refused/
|
||||
environment:
|
||||
- POSTGRES_DB=postgres
|
||||
@@ -51,6 +58,8 @@ services:
|
||||
- POSTGRES_PORT=5432
|
||||
- KHOJ_DJANGO_SECRET_KEY=secret
|
||||
- KHOJ_DEBUG=False
|
||||
- DATATRACE_PATH=tmp/thought_steps.jsonl
|
||||
- FULL_THOUGHTS_PATH=tmp/full_thoughts.jsonl
|
||||
- KHOJ_ADMIN_EMAIL=username@example.com
|
||||
- KHOJ_ADMIN_PASSWORD=password
|
||||
# Default URL of Terrarium, the Python sandbox used by Khoj to run code. Its container is specified above
|
||||
@@ -92,10 +101,17 @@ services:
|
||||
# Read more at https://docs.khoj.dev/miscellaneous/telemetry
|
||||
# - KHOJ_TELEMETRY_DISABLE=True
|
||||
# Comment out this line when you're using the official ghcr.io/khoj-ai/khoj-cloud:latest prod image.
|
||||
command: --host="0.0.0.0" --port=42110 -vv --anonymous-mode --non-interactive
|
||||
|
||||
nginx:
|
||||
image: nginx:latest
|
||||
volumes:
|
||||
- ./nginx.conf:/etc/nginx/nginx.conf
|
||||
depends_on:
|
||||
- server
|
||||
ports:
|
||||
- '42110:42110'
|
||||
volumes:
|
||||
khoj_config:
|
||||
khoj_db:
|
||||
khoj_models:
|
||||
khoj_search:
|
||||
khoj_tmp:
|
||||
|
||||
17
nginx.conf
Normal file
17
nginx.conf
Normal file
@@ -0,0 +1,17 @@
|
||||
user nginx;
|
||||
|
||||
events {
|
||||
worker_connections 1000;
|
||||
}
|
||||
http {
|
||||
server {
|
||||
listen 42110;
|
||||
location / {
|
||||
proxy_pass http://server:42110;
|
||||
proxy_read_timeout 1800;
|
||||
proxy_connect_timeout 1800;
|
||||
proxy_send_timeout 1800;
|
||||
send_timeout 1800;
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -125,6 +125,7 @@ dev = [
|
||||
"gitpython ~= 3.1.43",
|
||||
"datasets",
|
||||
"pandas",
|
||||
"math-verify",
|
||||
]
|
||||
|
||||
[tool.hatch.version]
|
||||
|
||||
233
scripts/create_dataset_from_eval.py
Normal file
233
scripts/create_dataset_from_eval.py
Normal file
@@ -0,0 +1,233 @@
|
||||
import argparse
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import tempfile
|
||||
|
||||
import pandas as pd
|
||||
from datasets import Dataset, DatasetDict, load_dataset
|
||||
from dotenv import load_dotenv
|
||||
from huggingface_hub import HfApi
|
||||
|
||||
load_dotenv()
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
|
||||
SYSTEM_PROMPT = """
|
||||
Your role as an assistant involves thoroughly exploring questions through a systematic long thinking process before providing the final precise and accurate solutions. This requires engaging in a comprehensive cycle of analysis, summarizing, exploration, reassessment, reflection, backtracing, and iteration to develop well-considered thinking process. Please structure your response into two main sections: Thought and Solution. In the Thought section, detail your reasoning process using the specified format: <|begin_of_thought|> {thought with steps separated with '\\n\\n'} <|end_of_thought|> Each step should include detailed considerations such as analisying questions, summarizing relevant findings, brainstorming new ideas, verifying the accuracy of the current steps, refining any errors, and revisiting previous steps.In the Solution section, based on various attempts, explorations, and reflections from the Thought section, systematically present the final solution that you deem correct. The solution should remain a logical, accurate, concise expression style and detail necessary step needed to reach the conclusion, formatted as follows: <|begin_of_solution|> {final formatted, precise, and clear solution} <|end_of_solution|>
|
||||
|
||||
Make sure to use the specific LaTeX math mode delimiters for your response. LaTex math mode specific delimiters as following
|
||||
- inline math mode : \\( and \\)
|
||||
- display math mode: insert linebreak after opening $$, \\[ and before closing $$, \\]
|
||||
|
||||
Now, try to solve the following question through the above guidelines"""
|
||||
|
||||
|
||||
def load_dataset_from_jsonl(dataset_path: str, replace_system_prompt: bool = False) -> pd.DataFrame:
|
||||
"""Load data trace from JSONL into pandas Dataframe."""
|
||||
if not os.path.exists(dataset_path):
|
||||
return Dataset.from_dict({"system": "", "conversations": []})
|
||||
|
||||
try:
|
||||
# Read JSONL line by line to catch errors
|
||||
data = []
|
||||
with open(dataset_path, "r") as f:
|
||||
for i, line in enumerate(f, 1):
|
||||
try:
|
||||
loaded_data = json.loads(line.strip())
|
||||
if "system" in loaded_data and replace_system_prompt:
|
||||
loaded_data["system"] = SYSTEM_PROMPT
|
||||
data.append(loaded_data)
|
||||
except json.JSONDecodeError as e:
|
||||
logger.debug(f"Error on line {i}: {e}")
|
||||
logger.debug(f"Problematic line: {line.strip()}")
|
||||
continue
|
||||
|
||||
# Convert to DataFrame
|
||||
if not data:
|
||||
return Dataset.from_dict({"system": [], "conversations": []})
|
||||
return pd.DataFrame(data)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing dataset: {e}")
|
||||
return Dataset.from_dict({"system": [], "conversations": []})
|
||||
|
||||
|
||||
def load_eval_results_from_csv(eval_results_paths: str | list[str]) -> pd.DataFrame:
|
||||
"""
|
||||
Load evaluation results from one or more CSV files into a single Pandas DataFrame.
|
||||
|
||||
Args:
|
||||
eval_results_paths: Single path string or list of paths to CSV files
|
||||
|
||||
Returns:
|
||||
Combined DataFrame with all evaluation results
|
||||
"""
|
||||
# Convert single path to list
|
||||
if isinstance(eval_results_paths, str):
|
||||
eval_results_paths = [eval_results_paths]
|
||||
|
||||
# Initialize empty DataFrame
|
||||
combined_df = pd.DataFrame()
|
||||
|
||||
# Load and concatenate each CSV file
|
||||
for path in eval_results_paths:
|
||||
if os.path.exists(path):
|
||||
df = pd.read_csv(path)
|
||||
combined_df = pd.concat([combined_df, df], ignore_index=True)
|
||||
|
||||
return combined_df
|
||||
|
||||
|
||||
def get_good_dataset_rows(datatrace: pd.DataFrame, eval: pd.DataFrame) -> pd.DataFrame:
|
||||
"""
|
||||
Filter dataset rows to keep only those with successful evaluations.
|
||||
|
||||
Args:
|
||||
dataset: Dataframe with 'system' and 'conversations' columns
|
||||
eval: Dataframe with 'prompt', 'agent_response' and 'success' columns
|
||||
|
||||
Returns:
|
||||
Filtered data trace rows corresponding to the successful evaluations
|
||||
"""
|
||||
|
||||
# Filter successful evaluations
|
||||
successful_evals = eval[eval["evaluation_decision"] == 1.0]
|
||||
|
||||
# Create matching columns in datatrace
|
||||
datatrace["prompt"] = datatrace["conversations"].apply(lambda x: x[-2]["value"] if len(x) >= 2 else None)
|
||||
|
||||
# Merge datatrace with successful evaluations
|
||||
good_rows = pd.merge(datatrace, successful_evals[["prompt"]], on=["prompt"], how="inner")
|
||||
|
||||
# Drop temporary columns and return
|
||||
return good_rows.drop(columns=["prompt"])
|
||||
|
||||
|
||||
def deduplicate_rows(good_rows: pd.DataFrame, parent_dataset: DatasetDict) -> pd.DataFrame:
|
||||
"""
|
||||
Deduplicate the good_rows and parent_dataset. Within the `conversation` column, the first item in the list is the user prompt. The duplicate rows would have matching `value` fields in the user prompt.
|
||||
|
||||
Args:
|
||||
good_rows: Dataframe with 'system' and 'conversations' columns
|
||||
parent_dataset: Dataframe with 'system' and 'conversations' columns
|
||||
|
||||
Returns:
|
||||
Deduplicated rows
|
||||
"""
|
||||
# Convert the good_rows and parent_dataset to JSON
|
||||
user_prompts = set(good_rows["conversations"].apply(lambda x: x[0]["value"]))
|
||||
|
||||
logger.info(f"Found {len(good_rows)} rows in good_rows")
|
||||
|
||||
logger.info(f"Found {len(user_prompts)} unique user prompts in good_rows")
|
||||
|
||||
# Convert parent dataset to DataFrame and filter
|
||||
parent_df = parent_dataset.data["train"].to_pandas()
|
||||
|
||||
logger.info(f"Found {len(parent_df)} rows in parent dataset")
|
||||
parent_filtered = parent_df[~parent_df["conversations"].apply(lambda x: x[0]["value"]).isin(user_prompts)]
|
||||
|
||||
logger.info(f"Found {len(parent_filtered)} rows in parent dataset after filtering")
|
||||
|
||||
final_data = pd.concat([parent_filtered, good_rows], ignore_index=True)
|
||||
|
||||
logger.info(f"Found {len(final_data)} rows in final dataset after combining")
|
||||
|
||||
# Combine filtered parent data with good rows
|
||||
return final_data
|
||||
|
||||
|
||||
def main():
|
||||
# Set up argument parser
|
||||
parser = argparse.ArgumentParser(description="Create filtered dataset from evaluation results")
|
||||
datatrace_path = os.getenv("DATATRACE_PATH")
|
||||
fullthoughts_path = os.getenv("FULLTHOUGHTS_PATH")
|
||||
eval_paths = os.getenv("EVAL_PATH").split(",") if os.getenv("EVAL_PATH") else None
|
||||
output_path = os.getenv("OUTPUT_PATH")
|
||||
repo_name = os.getenv("REPO_NAME")
|
||||
thoughts_repo_name = os.getenv("THOUGHTS_REPO_NAME")
|
||||
|
||||
# parser.add_argument("--datatrace_path", type=str, required=True, help="Path to datatrace CSV")
|
||||
# parser.add_argument("--eval_path", type=str, required=True, help="Path to evaluation results CSV")
|
||||
# parser.add_argument("--output_path", type=str, required=True, help="Path to save filtered dataset")
|
||||
# parser.add_argument("--repo_name", type=str, required=True, help="HuggingFace repo in user/dataset format")
|
||||
# args = parser.parse_args()
|
||||
|
||||
hf_token = os.getenv("HF_TOKEN")
|
||||
|
||||
try:
|
||||
# Load data
|
||||
datatrace_df = load_dataset_from_jsonl(datatrace_path)
|
||||
eval_df = load_eval_results_from_csv(eval_paths)
|
||||
|
||||
# Get filtered rows
|
||||
good_rows = get_good_dataset_rows(datatrace_df, eval_df)
|
||||
|
||||
# Load in the parent dataset, stratos
|
||||
parent_dataset = load_dataset("bespokelabs/Bespoke-Stratos-17k", split=None)
|
||||
logger.info(f"Loaded parent dataset with {len(parent_dataset)} rows")
|
||||
|
||||
# load full thoughts dataset
|
||||
fullthoughts_df = load_dataset_from_jsonl(fullthoughts_path, replace_system_prompt=True)
|
||||
logger.info(f"Loaded fullthoughts dataset with {len(fullthoughts_df)} rows")
|
||||
|
||||
# Convert to pandas and save as JSON
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
json_path = os.path.join(tmp_dir, "data.json")
|
||||
# Dedupe the good_rows and parent_dataset.
|
||||
deduplicated_rows = deduplicate_rows(good_rows, parent_dataset)
|
||||
deduplicated_rows.to_json(json_path, orient="records", indent=2)
|
||||
|
||||
if repo_name and hf_token:
|
||||
# Initialize HF API
|
||||
api = HfApi(token=hf_token)
|
||||
|
||||
# Upload JSON file
|
||||
api.upload_file(
|
||||
path_or_fileobj=json_path,
|
||||
path_in_repo="data.json",
|
||||
repo_id=repo_name,
|
||||
repo_type="dataset",
|
||||
commit_message="Upload filtered dataset as JSON",
|
||||
)
|
||||
logging.info(f"Pushed JSON dataset with {len(good_rows)} rows to {repo_name}")
|
||||
elif output_path:
|
||||
# Save locally if no HF repo specified
|
||||
output_json = os.path.join(output_path, "data.json")
|
||||
good_rows.to_json(output_json, orient="records", indent=2)
|
||||
logging.info(f"Saved JSON dataset with {len(good_rows)} rows to {output_json}")
|
||||
|
||||
if fullthoughts_path:
|
||||
# Load fullthoughts dataset
|
||||
good_rows = get_good_dataset_rows(fullthoughts_df, eval_df)
|
||||
deduplicated_rows = deduplicate_rows(good_rows, parent_dataset)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
json_path = os.path.join(tmp_dir, "data.json")
|
||||
deduplicated_rows.to_json(json_path, orient="records", indent=2)
|
||||
|
||||
if repo_name and hf_token:
|
||||
api = HfApi(token=hf_token)
|
||||
api.upload_file(
|
||||
path_or_fileobj=json_path,
|
||||
path_in_repo="data.json",
|
||||
repo_id=thoughts_repo_name,
|
||||
repo_type="dataset",
|
||||
commit_message="Upload filtered dataset as JSON",
|
||||
)
|
||||
logging.info(f"Pushed JSON dataset with {len(fullthoughts_df)} rows to {repo_name}")
|
||||
elif output_path:
|
||||
output_json = os.path.join(output_path, "data.json")
|
||||
fullthoughts_df.to_json(output_json, orient="records", indent=2)
|
||||
logging.info(f"Saved JSON dataset with {len(fullthoughts_df)} rows to {output_json}")
|
||||
|
||||
except Exception as e:
|
||||
logging.error(f"Error processing dataset: {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -32,7 +32,8 @@ DEBUG = in_debug_mode()
|
||||
|
||||
# All Subdomains of KHOJ_DOMAIN are trusted
|
||||
KHOJ_DOMAIN = os.getenv("KHOJ_DOMAIN", "khoj.dev")
|
||||
ALLOWED_HOSTS = [f".{KHOJ_DOMAIN}", "localhost", "127.0.0.1", "[::1]", f"{KHOJ_DOMAIN}"]
|
||||
KHOJ_ALLOWED_DOMAIN = os.getenv("KHOJ_ALLOWED_DOMAIN", KHOJ_DOMAIN)
|
||||
ALLOWED_HOSTS = [f".{KHOJ_ALLOWED_DOMAIN}", "localhost", "127.0.0.1", "[::1]", f"{KHOJ_ALLOWED_DOMAIN}"]
|
||||
|
||||
CSRF_TRUSTED_ORIGINS = [
|
||||
f"https://*.{KHOJ_DOMAIN}",
|
||||
@@ -45,7 +46,7 @@ CSRF_TRUSTED_ORIGINS = [
|
||||
DISABLE_HTTPS = is_env_var_true("KHOJ_NO_HTTPS")
|
||||
|
||||
COOKIE_SAMESITE = "None"
|
||||
if DEBUG or os.getenv("KHOJ_DOMAIN") == None:
|
||||
if DEBUG and os.getenv("KHOJ_DOMAIN") == None:
|
||||
SESSION_COOKIE_DOMAIN = "localhost"
|
||||
CSRF_COOKIE_DOMAIN = "localhost"
|
||||
else:
|
||||
|
||||
@@ -65,15 +65,6 @@ def extract_questions_anthropic(
|
||||
last_new_year = current_new_year.replace(year=today.year - 1)
|
||||
|
||||
system_prompt = prompts.extract_questions_anthropic_system_prompt.format(
|
||||
current_date=today.strftime("%Y-%m-%d"),
|
||||
day_of_week=today.strftime("%A"),
|
||||
current_month=today.strftime("%Y-%m"),
|
||||
last_new_year=last_new_year.strftime("%Y"),
|
||||
last_new_year_date=last_new_year.strftime("%Y-%m-%d"),
|
||||
current_new_year_date=current_new_year.strftime("%Y-%m-%d"),
|
||||
yesterday_date=(today - timedelta(days=1)).strftime("%Y-%m-%d"),
|
||||
location=location,
|
||||
username=username,
|
||||
personality_context=personality_context,
|
||||
)
|
||||
|
||||
@@ -176,18 +167,7 @@ def converse_anthropic(
|
||||
day_of_week=current_date.strftime("%A"),
|
||||
)
|
||||
else:
|
||||
system_prompt = prompts.personality.format(
|
||||
current_date=current_date.strftime("%Y-%m-%d"),
|
||||
day_of_week=current_date.strftime("%A"),
|
||||
)
|
||||
|
||||
if location_data:
|
||||
location_prompt = prompts.user_location.format(location=f"{location_data}")
|
||||
system_prompt = f"{system_prompt}\n{location_prompt}"
|
||||
|
||||
if user_name:
|
||||
user_name_prompt = prompts.user_name.format(name=user_name)
|
||||
system_prompt = f"{system_prompt}\n{user_name_prompt}"
|
||||
system_prompt = prompts.personality.format(current_date=current_date.strftime("%Y-%m-%d"))
|
||||
|
||||
# Get Conversation Primer appropriate to Conversation Type
|
||||
if conversation_commands == [ConversationCommand.Notes] and is_none_or_empty(references):
|
||||
|
||||
@@ -67,14 +67,6 @@ def extract_questions_gemini(
|
||||
|
||||
system_prompt = prompts.extract_questions_anthropic_system_prompt.format(
|
||||
current_date=today.strftime("%Y-%m-%d"),
|
||||
day_of_week=today.strftime("%A"),
|
||||
current_month=today.strftime("%Y-%m"),
|
||||
last_new_year=last_new_year.strftime("%Y"),
|
||||
last_new_year_date=last_new_year.strftime("%Y-%m-%d"),
|
||||
current_new_year_date=current_new_year.strftime("%Y-%m-%d"),
|
||||
yesterday_date=(today - timedelta(days=1)).strftime("%Y-%m-%d"),
|
||||
location=location,
|
||||
username=username,
|
||||
personality_context=personality_context,
|
||||
)
|
||||
|
||||
@@ -186,10 +178,7 @@ def converse_gemini(
|
||||
day_of_week=current_date.strftime("%A"),
|
||||
)
|
||||
else:
|
||||
system_prompt = prompts.personality.format(
|
||||
current_date=current_date.strftime("%Y-%m-%d"),
|
||||
day_of_week=current_date.strftime("%A"),
|
||||
)
|
||||
system_prompt = prompts.personality.format(current_date=current_date.strftime("%Y-%m-%d"))
|
||||
|
||||
system_prompt += f"{system_prompt}\n\n{prompts.gemini_verbose_language_personality}"
|
||||
if location_data:
|
||||
|
||||
@@ -22,11 +22,13 @@ from tenacity import (
|
||||
from khoj.processor.conversation.utils import (
|
||||
ThreadedGenerator,
|
||||
commit_conversation_trace,
|
||||
commit_dataset_trace,
|
||||
get_image_from_url,
|
||||
)
|
||||
from khoj.utils import state
|
||||
from khoj.utils.helpers import (
|
||||
get_chat_usage_metrics,
|
||||
is_datatrace_enabled,
|
||||
is_none_or_empty,
|
||||
is_promptrace_enabled,
|
||||
)
|
||||
@@ -90,6 +92,8 @@ def gemini_completion_with_backoff(
|
||||
tracer["temperature"] = temperature
|
||||
if is_promptrace_enabled():
|
||||
commit_conversation_trace(messages, response_text, tracer)
|
||||
if is_datatrace_enabled(tracer):
|
||||
commit_dataset_trace(messages, response_text)
|
||||
|
||||
return response_text
|
||||
|
||||
|
||||
@@ -45,13 +45,10 @@ def extract_questions(
|
||||
"""
|
||||
Infer search queries to retrieve relevant notes to answer user query
|
||||
"""
|
||||
location = f"{location_data}" if location_data else "Unknown"
|
||||
username = prompts.user_name.format(name=user.get_full_name()) if user and user.get_full_name() else ""
|
||||
|
||||
# Extract Past User Message and Inferred Questions from Conversation Log
|
||||
chat_history = "".join(
|
||||
[
|
||||
f'Q: {chat["intent"]["query"]}\nKhoj: {{"queries": {chat["intent"].get("inferred-queries") or list([chat["intent"]["query"]])}}}\nA: {chat["message"]}\n\n'
|
||||
f'Q: {chat["intent"]["query"]}\S: {{"queries": {chat["intent"].get("inferred-queries") or list([chat["intent"]["query"]])}}}\nA: {chat["message"]}\n\n'
|
||||
for chat in conversation_log.get("chat", [])[-4:]
|
||||
if chat["by"] == "khoj" and "to-image" not in chat["intent"].get("type")
|
||||
]
|
||||
@@ -60,23 +57,13 @@ def extract_questions(
|
||||
# Get dates relative to today for prompt creation
|
||||
today = datetime.today()
|
||||
current_new_year = today.replace(month=1, day=1)
|
||||
last_new_year = current_new_year.replace(year=today.year - 1)
|
||||
temperature = 0.7
|
||||
|
||||
prompt = prompts.extract_questions.format(
|
||||
current_date=today.strftime("%Y-%m-%d"),
|
||||
day_of_week=today.strftime("%A"),
|
||||
current_month=today.strftime("%Y-%m"),
|
||||
last_new_year=last_new_year.strftime("%Y"),
|
||||
last_new_year_date=last_new_year.strftime("%Y-%m-%d"),
|
||||
current_new_year_date=current_new_year.strftime("%Y-%m-%d"),
|
||||
bob_tom_age_difference={current_new_year.year - 1984 - 30},
|
||||
bob_age={current_new_year.year - 1984},
|
||||
chat_history=chat_history,
|
||||
text=text,
|
||||
yesterday_date=(today - timedelta(days=1)).strftime("%Y-%m-%d"),
|
||||
location=location,
|
||||
username=username,
|
||||
personality_context=personality_context,
|
||||
)
|
||||
|
||||
@@ -178,17 +165,8 @@ def converse_openai(
|
||||
else:
|
||||
system_prompt = prompts.personality.format(
|
||||
current_date=current_date.strftime("%Y-%m-%d"),
|
||||
day_of_week=current_date.strftime("%A"),
|
||||
)
|
||||
|
||||
if location_data:
|
||||
location_prompt = prompts.user_location.format(location=f"{location_data}")
|
||||
system_prompt = f"{system_prompt}\n{location_prompt}"
|
||||
|
||||
if user_name:
|
||||
user_name_prompt = prompts.user_name.format(name=user_name)
|
||||
system_prompt = f"{system_prompt}\n{user_name_prompt}"
|
||||
|
||||
# Get Conversation Primer appropriate to Conversation Type
|
||||
if conversation_commands == [ConversationCommand.Notes] and is_none_or_empty(references):
|
||||
completion_func(chat_response=prompts.no_notes_found.format())
|
||||
|
||||
@@ -18,10 +18,12 @@ from tenacity import (
|
||||
from khoj.processor.conversation.utils import (
|
||||
ThreadedGenerator,
|
||||
commit_conversation_trace,
|
||||
commit_dataset_trace,
|
||||
)
|
||||
from khoj.utils.helpers import (
|
||||
get_chat_usage_metrics,
|
||||
get_openai_client,
|
||||
is_datatrace_enabled,
|
||||
is_promptrace_enabled,
|
||||
)
|
||||
|
||||
@@ -81,6 +83,7 @@ def completion_with_backoff(
|
||||
stream=stream,
|
||||
temperature=temperature,
|
||||
timeout=20,
|
||||
max_tokens=3000,
|
||||
**model_kwargs,
|
||||
)
|
||||
|
||||
@@ -112,6 +115,8 @@ def completion_with_backoff(
|
||||
tracer["temperature"] = temperature
|
||||
if is_promptrace_enabled():
|
||||
commit_conversation_trace(messages, aggregated_response, tracer)
|
||||
if is_datatrace_enabled(tracer):
|
||||
commit_dataset_trace(messages, aggregated_response)
|
||||
|
||||
return aggregated_response
|
||||
|
||||
|
||||
@@ -4,7 +4,7 @@ from langchain.prompts import PromptTemplate
|
||||
## --
|
||||
personality = PromptTemplate.from_template(
|
||||
"""
|
||||
You are Khoj, a smart, inquisitive and helpful personal assistant.
|
||||
You are a smart, inquisitive and helpful personal assistant.
|
||||
Use your general knowledge and past conversation with the user as context to inform your responses.
|
||||
You were created by Khoj Inc. with the following capabilities:
|
||||
|
||||
@@ -18,7 +18,8 @@ You were created by Khoj Inc. with the following capabilities:
|
||||
- Provide inline references to quotes from the user's notes or any web pages you refer to in your responses in markdown format. For example, "The farmer had ten sheep. [1](https://example.com)". *ALWAYS CITE YOUR SOURCES AND PROVIDE REFERENCES*. Add them inline to directly support your claim.
|
||||
|
||||
Note: More information about you, the company or Khoj apps can be found at https://khoj.dev.
|
||||
Today is {day_of_week}, {current_date} in UTC.
|
||||
|
||||
Current Date: {current_date} in UTC.
|
||||
""".strip()
|
||||
)
|
||||
|
||||
@@ -554,68 +555,53 @@ Q: {query}
|
||||
|
||||
extract_questions = PromptTemplate.from_template(
|
||||
"""
|
||||
You are Khoj, an extremely smart and helpful document search assistant with only the ability to retrieve information from the user's notes and documents.
|
||||
You are an extremely smart and helpful document search assistant with only the ability to retrieve information from the user's notes and documents.
|
||||
Construct search queries to retrieve relevant information to answer the user's question.
|
||||
- You will be provided example and actual past user questions(Q), search queries(Khoj) and answers(A) for context.
|
||||
- You will be provided example and actual past user questions (Q), search queries (S), and answers (A) for context.
|
||||
- Add as much context from the previous questions and answers as required into your search queries.
|
||||
- Break your search down into multiple search queries from a diverse set of lenses to retrieve all related documents.
|
||||
- Add date filters to your search queries from questions and answers when required to retrieve the relevant information.
|
||||
- When asked a meta, vague or random questions, search for a variety of broad topics to answer the user's question.
|
||||
{personality_context}
|
||||
What searches will you perform to answer the user's question? Respond with search queries as list of strings in a JSON object.
|
||||
Current Date: {day_of_week}, {current_date}
|
||||
User's Location: {location}
|
||||
{username}
|
||||
|
||||
Examples
|
||||
---
|
||||
Q: How was my trip to Cambodia?
|
||||
Khoj: {{"queries": ["How was my trip to Cambodia?", "Angkor Wat temple visit", "Flight to Phnom Penh", "Expenses in Cambodia", "Stay in Cambodia"]}}
|
||||
S: {{"queries": ["How was my trip to Cambodia?", "Angkor Wat temple visit", "Flight to Phnom Penh", "Expenses in Cambodia", "Stay in Cambodia"]}}
|
||||
A: The trip was amazing. You went to the Angkor Wat temple and it was beautiful.
|
||||
|
||||
Q: Who did i visit that temple with?
|
||||
Khoj: {{"queries": ["Who did I visit the Angkor Wat Temple in Cambodia with?"]}}
|
||||
S: {{"queries": ["Who did I visit the Angkor Wat Temple in Cambodia with?"]}}
|
||||
A: You visited the Angkor Wat Temple in Cambodia with Pablo, Namita and Xi.
|
||||
|
||||
Q: What national parks did I go to last year?
|
||||
Khoj: {{"queries": ["National park I visited in {last_new_year} dt>='{last_new_year_date}' dt<'{current_new_year_date}'"]}}
|
||||
A: You visited the Grand Canyon and Yellowstone National Park in {last_new_year}.
|
||||
|
||||
Q: How can you help me?
|
||||
Khoj: {{"queries": ["Social relationships", "Physical and mental health", "Education and career", "Personal life goals and habits"]}}
|
||||
S: {{"queries": ["Social relationships", "Physical and mental health", "Education and career", "Personal life goals and habits"]}}
|
||||
A: I can help you live healthier and happier across work and personal life
|
||||
|
||||
Q: How many tennis balls fit in the back of a 2002 Honda Civic?
|
||||
Khoj: {{"queries": ["What is the size of a tennis ball?", "What is the trunk size of a 2002 Honda Civic?"]}}
|
||||
S: {{"queries": ["What is the size of a tennis ball?", "What is the trunk size of a 2002 Honda Civic?"]}}
|
||||
A: 1085 tennis balls will fit in the trunk of a Honda Civic
|
||||
|
||||
Q: Share some random, interesting experiences from this month
|
||||
Khoj: {{"queries": ["Exciting travel adventures from {current_month}", "Fun social events dt>='{current_month}-01' dt<'{current_date}'", "Intense emotional experiences in {current_month}"]}}
|
||||
A: You had a great time at the local beach with your friends, attended a music concert and had a deep conversation with your friend, Khalid.
|
||||
|
||||
Q: Is Bob older than Tom?
|
||||
Khoj: {{"queries": ["When was Bob born?", "What is Tom's age?"]}}
|
||||
S: {{"queries": ["When was Bob born?", "What is Tom's age?"]}}
|
||||
A: Yes, Bob is older than Tom. As Bob was born on 1984-01-01 and Tom is 30 years old.
|
||||
|
||||
Q: What is their age difference?
|
||||
Khoj: {{"queries": ["What is Bob's age?", "What is Tom's age?"]}}
|
||||
S: {{"queries": ["What is Bob's age?", "What is Tom's age?"]}}
|
||||
A: Bob is {bob_tom_age_difference} years older than Tom. As Bob is {bob_age} years old and Tom is 30 years old.
|
||||
|
||||
Q: Who all did I meet here yesterday?
|
||||
Khoj: {{"queries": ["Met in {location} on {yesterday_date} dt>='{yesterday_date}' dt<'{current_date}'"]}}
|
||||
A: Yesterday's note mentions your visit to your local beach with Ram and Shyam.
|
||||
|
||||
Actual
|
||||
---
|
||||
{chat_history}
|
||||
Q: {text}
|
||||
Khoj:
|
||||
S:
|
||||
""".strip()
|
||||
)
|
||||
|
||||
extract_questions_anthropic_system_prompt = PromptTemplate.from_template(
|
||||
"""
|
||||
You are Khoj, an extremely smart and helpful document search assistant with only the ability to retrieve information from the user's notes.
|
||||
You are an extremely smart and helpful document search assistant with only the ability to retrieve information from the user's notes.
|
||||
Construct search queries to retrieve relevant information to answer the user's question.
|
||||
- You will be provided past questions(User), search queries(Assistant) and answers(A) for context.
|
||||
- Add as much context from the previous questions and answers as required into your search queries.
|
||||
@@ -625,32 +611,15 @@ Construct search queries to retrieve relevant information to answer the user's q
|
||||
{personality_context}
|
||||
What searches will you perform to answer the users question? Respond with a JSON object with the key "queries" mapping to a list of searches you would perform on the user's knowledge base. Just return the queries and nothing else.
|
||||
|
||||
Current Date: {day_of_week}, {current_date}
|
||||
User's Location: {location}
|
||||
{username}
|
||||
|
||||
Here are some examples of how you can construct search queries to answer the user's question:
|
||||
|
||||
User: How was my trip to Cambodia?
|
||||
Assistant: {{"queries": ["How was my trip to Cambodia?", "Angkor Wat temple visit", "Flight to Phnom Penh", "Expenses in Cambodia", "Stay in Cambodia"]}}
|
||||
A: The trip was amazing. You went to the Angkor Wat temple and it was beautiful.
|
||||
|
||||
User: What national parks did I go to last year?
|
||||
Assistant: {{"queries": ["National park I visited in {last_new_year} dt>='{last_new_year_date}' dt<'{current_new_year_date}'"]}}
|
||||
A: You visited the Grand Canyon and Yellowstone National Park in {last_new_year}.
|
||||
|
||||
User: How can you help me?
|
||||
Assistant: {{"queries": ["Social relationships", "Physical and mental health", "Education and career", "Personal life goals and habits"]}}
|
||||
A: I can help you live healthier and happier across work and personal life
|
||||
|
||||
User: Who all did I meet here yesterday?
|
||||
Assistant: {{"queries": ["Met in {location} on {yesterday_date} dt>='{yesterday_date}' dt<'{current_date}'"]}}
|
||||
A: Yesterday's note mentions your visit to your local beach with Ram and Shyam.
|
||||
|
||||
User: Share some random, interesting experiences from this month
|
||||
Assistant: {{"queries": ["Exciting travel adventures from {current_month}", "Fun social events dt>='{current_month}-01' dt<'{current_date}'", "Intense emotional experiences in {current_month}"]}}
|
||||
A: You had a great time at the local beach with your friends, attended a music concert and had a deep conversation with your friend, Khalid.
|
||||
|
||||
""".strip()
|
||||
)
|
||||
|
||||
@@ -727,7 +696,7 @@ Here's some additional context about you:
|
||||
|
||||
plan_function_execution = PromptTemplate.from_template(
|
||||
"""
|
||||
You are Khoj, a smart, creative and methodical researcher. Use the provided tool AIs to investigate information to answer query.
|
||||
You are a smart, creative and methodical researcher. Use the provided tool AIs to investigate information to answer query.
|
||||
Create a multi-step plan and intelligently iterate on the plan based on the retrieved information to find the requested information.
|
||||
{personality_context}
|
||||
|
||||
@@ -740,26 +709,6 @@ Create a multi-step plan and intelligently iterate on the plan based on the retr
|
||||
- You are allowed upto {max_iterations} iterations to use the help of the provided tool AIs to answer the user's question.
|
||||
- Stop when you have the required information by returning a JSON object with an empty "tool" field. E.g., {{scratchpad: "I have all I need", tool: "", query: ""}}
|
||||
|
||||
# Examples
|
||||
Assuming you can search the user's notes and the internet.
|
||||
- When the user asks for the population of their hometown
|
||||
1. Try look up their hometown in their notes. Ask the note search AI to search for their birth certificate, childhood memories, school, resume etc.
|
||||
2. If not found in their notes, try infer their hometown from their online social media profiles. Ask the online search AI to look for {username}'s biography, school, resume on linkedin, facebook, website etc.
|
||||
3. Only then try find the latest population of their hometown by reading official websites with the help of the online search and web page reading AI.
|
||||
- When the user asks for their computer's specs
|
||||
1. Try find their computer model in their notes.
|
||||
2. Now find webpages with their computer model's spec online.
|
||||
3. Ask the webpage tool AI to extract the required information from the relevant webpages.
|
||||
- When the user asks what clothes to carry for their upcoming trip
|
||||
1. Find the itinerary of their upcoming trip in their notes.
|
||||
2. Next find the weather forecast at the destination online.
|
||||
3. Then find if they mentioned what clothes they own in their notes.
|
||||
|
||||
# Background Context
|
||||
- Current Date: {day_of_week}, {current_date}
|
||||
- User Location: {location}
|
||||
- User Name: {username}
|
||||
|
||||
# Available Tool AIs
|
||||
Which of the tool AIs listed below would you use to answer the user's question? You **only** have access to the following tool AIs:
|
||||
|
||||
@@ -788,7 +737,7 @@ previous_iteration = PromptTemplate.from_template(
|
||||
|
||||
pick_relevant_tools = PromptTemplate.from_template(
|
||||
"""
|
||||
You are Khoj, an extremely smart and helpful search assistant.
|
||||
You are an extremely smart and helpful search assistant.
|
||||
{personality_context}
|
||||
- You have access to a variety of data sources to help you answer the user's question.
|
||||
- You can use any subset of data sources listed below to collect more relevant information.
|
||||
@@ -858,7 +807,7 @@ Khoj:
|
||||
|
||||
infer_webpages_to_read = PromptTemplate.from_template(
|
||||
"""
|
||||
You are Khoj, an advanced web page reading assistant. You are to construct **up to three, valid** webpage urls to read before answering the user's question.
|
||||
You are an advanced web page reading assistant. You are to construct **up to three, valid** webpage urls to read before answering the user's question.
|
||||
- You will receive the conversation history as context.
|
||||
- Add as much context from the previous questions and answers as required to construct the webpage urls.
|
||||
- Use multiple web page urls if required to retrieve the relevant information.
|
||||
@@ -903,7 +852,7 @@ Khoj:
|
||||
|
||||
online_search_conversation_subqueries = PromptTemplate.from_template(
|
||||
"""
|
||||
You are Khoj, an advanced web search assistant. You are tasked with constructing **up to three** google search queries to answer the user's question.
|
||||
You are an advanced web search assistant. You are tasked with constructing **up to three** google search queries to answer the user's question.
|
||||
- You will receive the actual chat history as context.
|
||||
- Add as much context from the chat history as required into your search queries.
|
||||
- Break messages into multiple search queries when required to retrieve the relevant information.
|
||||
@@ -976,12 +925,11 @@ Khoj:
|
||||
# --
|
||||
python_code_generation_prompt = PromptTemplate.from_template(
|
||||
"""
|
||||
You are Khoj, an advanced python programmer. You are tasked with constructing a python program to best answer the user query.
|
||||
You are an advanced python programmer. You are tasked with constructing a python program to best answer the user query.
|
||||
- The python program will run in a pyodide python sandbox with no network access.
|
||||
- You can write programs to run complex calculations, analyze data, create charts, generate documents to meticulously answer the query.
|
||||
- The sandbox has access to the standard library, matplotlib, panda, numpy, scipy, bs4 and sympy packages. The requests, torch, catboost, tensorflow and tkinter packages are not available.
|
||||
- List known file paths to required user documents in "input_files" and known links to required documents from the web in the "input_links" field.
|
||||
- The python program should be self-contained. It can only read data generated by the program itself and from provided input_files, input_links by their basename (i.e filename excluding file path).
|
||||
- The sandbox has access to only the standard library, matplotlib, panda, numpy, scipy, bs4 and sympy packages. The requests, torch, catboost, tensorflow and tkinter packages are not available.
|
||||
- The python program should be self-contained. It can only read data generated by the program itself.
|
||||
- Do not try display images or plots in the code directly. The code should save the image or plot to a file instead.
|
||||
- Write any document, charts etc. to be shared with the user to file. These files can be seen by the user.
|
||||
- Use as much context from the previous questions and answers as required to generate your code.
|
||||
@@ -992,25 +940,100 @@ Current Date: {current_date}
|
||||
User's Location: {location}
|
||||
{username}
|
||||
|
||||
The response JSON schema is of the form {{"code": "<python_code>", "input_files": ["file_path_1", "file_path_2"], "input_links": ["link_1", "link_2"]}}
|
||||
Examples:
|
||||
The response should contain python code wrapped in markdown code blocks (starting with```python and ending with ```)
|
||||
Example 1:
|
||||
---
|
||||
{{
|
||||
"code": "# Input values\\nprincipal = 43235\\nrate = 5.24\\nyears = 5\\n\\n# Convert rate to decimal\\nrate_decimal = rate / 100\\n\\n# Calculate final amount\\nfinal_amount = principal * (1 + rate_decimal) ** years\\n\\n# Calculate interest earned\\ninterest_earned = final_amount - principal\\n\\n# Print results with formatting\\nprint(f"Interest Earned: ${{interest_earned:,.2f}}")\\nprint(f"Final Amount: ${{final_amount:,.2f}}")"
|
||||
}}
|
||||
Q: Calculate the interest earned and final amount for a principal of $43,235 invested at a rate of 5.24 percent for 5 years.
|
||||
A: Ok, to calculate the interest earned and final amount, we can use the formula for compound interest: $T = P(1 + r/n)^{{nt}}$,
|
||||
where T: total amount, P: principal, r: interest rate, n: number of times interest is compounded per year, and t: time in years.
|
||||
|
||||
{{
|
||||
"code": "import re\\n\\n# Read org file\\nfile_path = 'tasks.org'\\nwith open(file_path, 'r') as f:\\n content = f.read()\\n\\n# Get today's date in YYYY-MM-DD format\\ntoday = datetime.now().strftime('%Y-%m-%d')\\npattern = r'\*+\s+.*\\n.*SCHEDULED:\s+<' + today + r'.*>'\\n\\n# Find all matches using multiline mode\\nmatches = re.findall(pattern, content, re.MULTILINE)\\ncount = len(matches)\\n\\n# Display count\\nprint(f'Count of scheduled tasks for today: {{count}}')",
|
||||
"input_files": ["/home/linux/tasks.org"]
|
||||
}}
|
||||
Let's write the Python program to calculate this.
|
||||
|
||||
{{
|
||||
"code": "import pandas as pd\\nimport matplotlib.pyplot as plt\\n\\n# Load the CSV file\\ndf = pd.read_csv('world_population_by_year.csv')\\n\\n# Plot the data\\nplt.figure(figsize=(10, 6))\\nplt.plot(df['Year'], df['Population'], marker='o')\\n\\n# Add titles and labels\\nplt.title('Population by Year')\\nplt.xlabel('Year')\\nplt.ylabel('Population')\\n\\n# Save the plot to a file\\nplt.savefig('population_by_year_plot.png')",
|
||||
"input_links": ["https://population.un.org/world_population_by_year.csv"]
|
||||
}}
|
||||
```python
|
||||
# Input values
|
||||
principal = 43235
|
||||
rate = 5.24
|
||||
years = 5
|
||||
|
||||
Now it's your turn to construct a python program to answer the user's question. Provide the code, required input files and input links in a JSON object. Do not say anything else.
|
||||
Context:
|
||||
# Convert rate to decimal
|
||||
rate_decimal = rate / 100
|
||||
|
||||
# Calculate final amount
|
||||
final_amount = principal * (1 + rate_decimal) ** years
|
||||
|
||||
# Calculate interest earned
|
||||
interest_earned = final_amount - principal
|
||||
|
||||
# Print results with formatting
|
||||
print(f"Interest Earned: ${{interest_earned:,.2f}}")
|
||||
print(f"Final Amount: ${{final_amount:,.2f}}")
|
||||
```
|
||||
|
||||
Example 2:
|
||||
---
|
||||
Q: Simplify first, then evaluate: $-7x+2(x^{{2}}-1)-(2x^{{2}}-x+3)$, where $x=1$.
|
||||
A: Certainly! Let's break down the problem step-by-step and utilize Python with SymPy to simplify and evaluate the expression.
|
||||
|
||||
1. **Expression Simplification:**
|
||||
We start with the expression \\(-7x + 2(x^2 - 1) - (2x^2 - x + 3)\\).
|
||||
|
||||
2. **Substitute \\(x=1\\) into the simplified expression:**
|
||||
Once simplified, we will substitute \\(x=1\\) into the expression to find its value.
|
||||
|
||||
Let's implement this in Python using SymPy (as it is listed as an available in the sanbox):
|
||||
|
||||
```python
|
||||
import sympy as sp
|
||||
|
||||
# Define the variable
|
||||
x = sp.symbols('x')
|
||||
|
||||
# Define the expression
|
||||
expression = -7*x + 2*(x**2 - 1) - (2*x**2 - x + 3)
|
||||
|
||||
# Simplify the expression
|
||||
simplified_expression = sp.simplify(expression)
|
||||
|
||||
# Substitute x = 1 into the simplified expression
|
||||
evaluated_expression = simplified_expression.subs(x, 1)
|
||||
|
||||
# Print the simplified expression and its evaluated value
|
||||
print(\"Simplified Expression:\", simplified_expression)
|
||||
print(\"Evaluated Expression at x=1:\", evaluated_expression)
|
||||
```
|
||||
|
||||
Example 3:
|
||||
---
|
||||
Q: Plot the world ppulation growth over the years, given this year, world population world tuples: [(2000, 6), (2001, 7), (2002, 8), (2003, 9), (2004, 10)].
|
||||
A: Absolutely! We can utilize the Pandas and Matplotlib libraries (as both are available in the sandbox) to create the world population growth plot.
|
||||
```python
|
||||
import pandas as pd
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
# Create a DataFrame of world population from the provided data
|
||||
data = {{
|
||||
'Year': [2000, 2001, 2002, 2003, 2004],
|
||||
'Population': [6, 7, 8, 9, 10]
|
||||
}}
|
||||
df = pd.DataFrame(data)
|
||||
|
||||
# Plot the data
|
||||
plt.figure(figsize=(10, 6))
|
||||
plt.plot(df['Year'], df['Population'], marker='o')
|
||||
|
||||
# Add titles and labels
|
||||
plt.title('Population by Year')
|
||||
plt.xlabel('Year')
|
||||
plt.ylabel('Population')
|
||||
|
||||
# Save the plot to a file
|
||||
plt.savefig('population_by_year_plot.png')
|
||||
```
|
||||
|
||||
Now it's your turn to construct a python program to answer the user's query using the provided context and coversation provided below.
|
||||
Ensure the python code to execute is wrapped in a markdown code block.
|
||||
|
||||
"Context:
|
||||
---
|
||||
{context}
|
||||
|
||||
@@ -1018,8 +1041,9 @@ Chat History:
|
||||
---
|
||||
{chat_history}
|
||||
|
||||
User: {query}
|
||||
Khoj:
|
||||
User Query:
|
||||
---
|
||||
{query}
|
||||
""".strip()
|
||||
)
|
||||
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import base64
|
||||
import csv
|
||||
import json
|
||||
import logging
|
||||
import math
|
||||
@@ -308,6 +309,39 @@ def save_to_conversation_log(
|
||||
user_message=q,
|
||||
)
|
||||
|
||||
EXCLUDED_PHRASES = [
|
||||
"**Generating a well-informed response**",
|
||||
"**Searching the Internet for**",
|
||||
"**Running code snippet**",
|
||||
"**Ran code snippets**",
|
||||
]
|
||||
|
||||
trains_of_thought = ""
|
||||
for t in train_of_thought:
|
||||
tot_contains_excluded_phrase = any(phrase in t["data"] for phrase in EXCLUDED_PHRASES)
|
||||
if t["type"] == "status" and not tot_contains_excluded_phrase:
|
||||
trains_of_thought += t["data"] + "\n\n"
|
||||
|
||||
formatted_assistant_response = f"<begin_of_thought>\n\n{trains_of_thought}<end_of_thought>\n\n<begin_of_solution>\n\n{chat_response}<end_of_solution>"
|
||||
|
||||
system_prompt = prompts.personality.format(
|
||||
current_date=datetime.now().strftime("%Y-%m-%d"),
|
||||
)
|
||||
|
||||
session_messages = [
|
||||
ChatMessage(content=q, role="user"),
|
||||
ChatMessage(content=system_prompt, role="system"),
|
||||
]
|
||||
|
||||
FULL_THOUGHTS_PATH = os.getenv("FULL_THOUGHTS_PATH", None)
|
||||
|
||||
if FULL_THOUGHTS_PATH:
|
||||
commit_dataset_trace(
|
||||
session_messages,
|
||||
formatted_assistant_response,
|
||||
FULL_THOUGHTS_PATH,
|
||||
)
|
||||
|
||||
if is_promptrace_enabled():
|
||||
merge_message_into_conversation_trace(q, chat_response, tracer)
|
||||
|
||||
@@ -871,3 +905,63 @@ def messages_to_print(messages: list[ChatMessage], max_length: int = 70) -> str:
|
||||
return str(content)
|
||||
|
||||
return "\n".join([f"{json.dumps(safe_serialize(message.content))[:max_length]}..." for message in messages])
|
||||
|
||||
|
||||
def commit_dataset_trace(
|
||||
session: list[ChatMessage],
|
||||
response: str | list[dict],
|
||||
dataset_path: str = None,
|
||||
) -> str:
|
||||
"""Save data trace of conversation step in JSONL format."""
|
||||
dataset_path = dataset_path if not is_none_or_empty(dataset_path) else os.getenv("DATATRACE_PATH")
|
||||
if not dataset_path:
|
||||
return None
|
||||
|
||||
# Ensure .jsonl extension
|
||||
dataset_path = dataset_path if dataset_path.endswith(".jsonl") else f"{dataset_path}.jsonl"
|
||||
|
||||
# Create directory if needed
|
||||
os.makedirs(os.path.dirname(dataset_path), exist_ok=True)
|
||||
|
||||
# Format the new record
|
||||
session = [
|
||||
ChatMessage(
|
||||
content="\n\n".join(message.content) if isinstance(message.content, list) else message.content,
|
||||
role=message.role,
|
||||
)
|
||||
for message in session
|
||||
]
|
||||
|
||||
system_message = "\n\n".join([message.content for message in session if message.role == "system"])
|
||||
if is_none_or_empty(system_message):
|
||||
system_message = session.pop(0).content
|
||||
|
||||
session.append(ChatMessage(content=response, role="assistant"))
|
||||
formatted_session = [
|
||||
{"from": message.role, "value": message.content} for message in session if message.role in ["user", "assistant"]
|
||||
]
|
||||
|
||||
new_row = {
|
||||
"system": system_message,
|
||||
"conversations": formatted_session,
|
||||
}
|
||||
|
||||
# Append single record atomically
|
||||
temp_path = f"{dataset_path}.tmp"
|
||||
with open(temp_path, "a", encoding="utf-8") as f:
|
||||
json.dump(new_row, f, ensure_ascii=False)
|
||||
f.write("\n")
|
||||
|
||||
if os.path.exists(dataset_path):
|
||||
with open(dataset_path, "a", encoding="utf-8") as main_file:
|
||||
with open(temp_path, "r", encoding="utf-8") as temp_file:
|
||||
main_file.write(temp_file.read())
|
||||
else:
|
||||
os.replace(temp_path, dataset_path)
|
||||
|
||||
try:
|
||||
os.remove(temp_path)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to remove temporary file {temp_path}: {e}")
|
||||
|
||||
return dataset_path
|
||||
|
||||
@@ -1,8 +1,11 @@
|
||||
import asyncio
|
||||
import base64
|
||||
import datetime
|
||||
import logging
|
||||
import mimetypes
|
||||
import os
|
||||
import re
|
||||
from functools import wraps
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, List, NamedTuple, Optional
|
||||
|
||||
@@ -127,23 +130,47 @@ async def generate_python_code(
|
||||
response = await send_message_to_model_wrapper(
|
||||
code_generation_prompt,
|
||||
query_images=query_images,
|
||||
response_type="json_object",
|
||||
user=user,
|
||||
tracer=tracer,
|
||||
query_files=query_files,
|
||||
)
|
||||
|
||||
# Validate that the response is a non-empty, JSON-serializable list
|
||||
response = load_complex_json(response)
|
||||
code = response.get("code", "").strip()
|
||||
input_files = response.get("input_files", [])
|
||||
input_links = response.get("input_links", [])
|
||||
# Extract python code wrapped in markdown code blocs from the response
|
||||
code_blocks = re.findall(r"```(?:python)?\n(.*?)\n```", response, re.DOTALL)
|
||||
|
||||
if not code_blocks:
|
||||
raise ValueError("No Python code blocks found in response")
|
||||
|
||||
# Join multiple code blocks with newlines and strip any leading/trailing whitespace
|
||||
code = "\n".join(code_blocks).strip()
|
||||
|
||||
if not isinstance(code, str) or is_none_or_empty(code):
|
||||
raise ValueError
|
||||
return GeneratedCode(code, input_files, input_links)
|
||||
return GeneratedCode(code, [], [])
|
||||
|
||||
|
||||
def async_retry_with_backoff(retries=3, backoff_in_seconds=1):
|
||||
def decorator(func):
|
||||
@wraps(func)
|
||||
async def wrapper(*args, **kwargs):
|
||||
retry_count = 0
|
||||
while retry_count < retries:
|
||||
try:
|
||||
return await func(*args, **kwargs)
|
||||
except (aiohttp.ClientError, asyncio.TimeoutError) as e:
|
||||
retry_count += 1
|
||||
if retry_count == retries:
|
||||
raise e
|
||||
wait_time = backoff_in_seconds * (2 ** (retry_count - 1)) # exponential backoff
|
||||
await asyncio.sleep(wait_time)
|
||||
return await func(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
@async_retry_with_backoff(retries=3, backoff_in_seconds=1)
|
||||
async def execute_sandboxed_python(code: str, input_data: list[dict], sandbox_url: str = SANDBOX_URL) -> dict[str, Any]:
|
||||
"""
|
||||
Takes code to run as a string and calls the terrarium API to execute it.
|
||||
@@ -157,7 +184,7 @@ async def execute_sandboxed_python(code: str, input_data: list[dict], sandbox_ur
|
||||
data = {"code": cleaned_code, "files": input_data}
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(sandbox_url, json=data, headers=headers) as response:
|
||||
async with session.post(sandbox_url, json=data, headers=headers, timeout=30) as response:
|
||||
if response.status == 200:
|
||||
result: dict[str, Any] = await response.json()
|
||||
result["code"] = cleaned_code
|
||||
|
||||
@@ -357,8 +357,12 @@ async def aget_data_sources_and_output_format(
|
||||
source_options_str = ""
|
||||
|
||||
agent_sources = agent.input_tools if agent else []
|
||||
user_has_entries = await EntryAdapters.auser_has_entries(user)
|
||||
|
||||
for source, description in tool_descriptions_for_llm.items():
|
||||
# Skip showing Notes tool as an option if user has no entries
|
||||
if source == ConversationCommand.Notes and not user_has_entries:
|
||||
continue
|
||||
source_options[source.value] = description
|
||||
if len(agent_sources) == 0 or source.value in agent_sources:
|
||||
source_options_str += f'- "{source.value}": "{description}"\n'
|
||||
|
||||
@@ -5,6 +5,7 @@ from typing import Callable, Dict, List, Optional
|
||||
import yaml
|
||||
from fastapi import Request
|
||||
|
||||
from khoj.database.adapters import EntryAdapters
|
||||
from khoj.database.models import Agent, KhojUser
|
||||
from khoj.processor.conversation import prompts
|
||||
from khoj.processor.conversation.utils import (
|
||||
@@ -54,13 +55,17 @@ async def apick_next_tool(
|
||||
tool_options = dict()
|
||||
tool_options_str = ""
|
||||
agent_tools = agent.input_tools if agent else []
|
||||
user_has_entries = await EntryAdapters.auser_has_entries(user)
|
||||
for tool, description in function_calling_description_for_llm.items():
|
||||
# Skip showing Notes tool as an option if user has no entries
|
||||
if tool == ConversationCommand.Notes and not user_has_entries:
|
||||
continue
|
||||
tool_options[tool.value] = description
|
||||
if len(agent_tools) == 0 or tool.value in agent_tools:
|
||||
tool_options_str += f'- "{tool.value}": "{description}"\n'
|
||||
|
||||
# Construct chat history with user and iteration history with researcher agent for context
|
||||
chat_history = construct_chat_history(conversation_history, agent_name=agent.name if agent else "Khoj")
|
||||
chat_history = construct_chat_history(conversation_history)
|
||||
previous_iterations_history = construct_iteration_history(previous_iterations, prompts.previous_iteration)
|
||||
|
||||
if query_images:
|
||||
@@ -76,15 +81,12 @@ async def apick_next_tool(
|
||||
tools=tool_options_str,
|
||||
chat_history=chat_history,
|
||||
personality_context=personality_context,
|
||||
current_date=today.strftime("%Y-%m-%d"),
|
||||
day_of_week=today.strftime("%A"),
|
||||
username=user_name or "Unknown",
|
||||
location=location_data,
|
||||
previous_iterations=previous_iterations_history,
|
||||
max_iterations=max_iterations,
|
||||
)
|
||||
|
||||
try:
|
||||
tracer["save_to_dataset"] = True
|
||||
with timer("Chat actor: Infer information sources to refer", logger):
|
||||
response = await send_message_to_model_wrapper(
|
||||
query=query,
|
||||
@@ -95,6 +97,7 @@ async def apick_next_tool(
|
||||
query_files=query_files,
|
||||
tracer=tracer,
|
||||
)
|
||||
tracer["save_to_dataset"] = False
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to infer information sources to refer: {e}", exc_info=True)
|
||||
yield InformationCollectionIteration(
|
||||
|
||||
@@ -460,6 +460,12 @@ def is_promptrace_enabled():
|
||||
return not is_none_or_empty(os.getenv("PROMPTRACE_DIR"))
|
||||
|
||||
|
||||
def is_datatrace_enabled(tracer: dict):
|
||||
"""Check if Khoj is running with data tracing enabled.
|
||||
Set DATATRACE_PATH environment variable to prompt tracing path to enable it."""
|
||||
return tracer.get("save_to_dataset", False) and not is_none_or_empty(os.getenv("DATATRACE_PATH"))
|
||||
|
||||
|
||||
def is_valid_url(url: str) -> bool:
|
||||
"""Check if a string is a valid URL"""
|
||||
try:
|
||||
|
||||
@@ -1,13 +1,13 @@
|
||||
import argparse
|
||||
import base64
|
||||
import concurrent.futures
|
||||
import hashlib
|
||||
import json
|
||||
import logging
|
||||
import multiprocessing
|
||||
import os
|
||||
import pickle
|
||||
import re
|
||||
import time
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from functools import partial
|
||||
from io import StringIO
|
||||
@@ -19,6 +19,9 @@ import pandas as pd
|
||||
import requests
|
||||
import yaml
|
||||
from datasets import Dataset, load_dataset
|
||||
from dotenv import load_dotenv
|
||||
from latex2sympy2_extended.latex2sympy2 import NormalizationConfig
|
||||
from math_verify import LatexExtractionConfig, parse, verify
|
||||
from tqdm import tqdm
|
||||
|
||||
from khoj.utils.helpers import (
|
||||
@@ -32,6 +35,9 @@ from khoj.utils.helpers import (
|
||||
logging.basicConfig(level=logging.INFO, format="%(message)s")
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Load environment variables
|
||||
load_dotenv()
|
||||
|
||||
# Configuration
|
||||
KHOJ_URL = os.getenv("KHOJ_URL", "http://localhost:42110")
|
||||
KHOJ_CHAT_API_URL = f"{KHOJ_URL}/api/chat"
|
||||
@@ -48,6 +54,9 @@ BATCH_SIZE = int(
|
||||
) # Examples to evaluate in each batch
|
||||
SLEEP_SECONDS = 3 if KHOJ_MODE == "general" else 1 # Sleep between API calls to avoid rate limiting
|
||||
|
||||
# Add at module level
|
||||
DF_LOCK = Lock()
|
||||
|
||||
|
||||
class Counter:
|
||||
"""Thread-safe counter for tracking metrics"""
|
||||
@@ -204,6 +213,77 @@ def load_frames_dataset():
|
||||
return None
|
||||
|
||||
|
||||
def load_skythought_dataset(output_file: str, regrade_mode=False):
|
||||
"""
|
||||
Load the skythought training dataset from HuggingFace
|
||||
|
||||
The Sky T-1 dataset is comprised of 17000+ training samples.
|
||||
|
||||
### Data Fields
|
||||
- system: The system prompt used for the agent
|
||||
- conversations:
|
||||
- [0]: from: user, value: user prompt
|
||||
- [1]: from: assistant, value: assistant response
|
||||
|
||||
## Assistant response format
|
||||
The assistant response is formatted as follows:
|
||||
<|begin_of_thought|> <thought> <|end_of_thought|> <|begin_of_solution|> <answer> <|end_of_solution|>
|
||||
"""
|
||||
|
||||
try:
|
||||
# Load already processed indices from output_file
|
||||
processed_indices = set()
|
||||
if os.path.exists(output_file):
|
||||
df_processed = pd.read_csv(output_file, usecols=["id"])
|
||||
processed_indices = set(df_processed["id"].tolist())
|
||||
logger.info(f"Loaded {len(processed_indices)} previously processed samples.")
|
||||
|
||||
dataset = load_dataset("bespokelabs/Bespoke-Stratos-17k")
|
||||
# Assign unique id based on original index before shuffling
|
||||
dataset = dataset.map(lambda example, idx: {"id": idx}, with_indices=True)
|
||||
|
||||
# Use test split for evaluation. Sample and shuffle dataset if configured
|
||||
dataset = dataset.shuffle() if RANDOMIZE else dataset
|
||||
logger.info(f"Total samples in dataset: {dataset['train'].num_rows}")
|
||||
|
||||
formatted_data = []
|
||||
for d in dataset["train"]:
|
||||
if len(formatted_data) + len(processed_indices) >= int(SAMPLE_SIZE) and not regrade_mode:
|
||||
logger.info(f"Processing remaining {len(formatted_data)} of {SAMPLE_SIZE} samples.")
|
||||
# Exit loop if sample size is reached
|
||||
break
|
||||
|
||||
# Check if current_index is already processed
|
||||
idx = d["id"]
|
||||
if idx in processed_indices and not regrade_mode:
|
||||
continue
|
||||
|
||||
# Extract the answer from the assistant response
|
||||
user_prompt = d["conversations"][0]["value"]
|
||||
assistant_response = d["conversations"][1]["value"]
|
||||
match = re.search(
|
||||
r"<\|begin_of_solution\|>\s*([\s\S]*?)(?=\s*<\|end_of_solution\|>|\Z)", assistant_response, re.DOTALL
|
||||
)
|
||||
answer = match.group(1).strip() if match else None
|
||||
formatted_data.append(
|
||||
{
|
||||
"id": idx,
|
||||
"Prompt": user_prompt,
|
||||
"Answer": answer or assistant_response,
|
||||
"reasoning_types": "unknown",
|
||||
}
|
||||
)
|
||||
|
||||
dataset = Dataset.from_list(formatted_data)
|
||||
if regrade_mode:
|
||||
return dataset
|
||||
return dataset[: int(SAMPLE_SIZE)] if SAMPLE_SIZE else dataset
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading dataset: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def load_simpleqa_dataset():
|
||||
"""
|
||||
Load the OpenAI SimpleQA benchmark dataset from their public bucket.
|
||||
@@ -378,7 +458,7 @@ def calculate_precision_recall(numerator: int, denominator: int) -> float:
|
||||
return numerator / denominator
|
||||
|
||||
|
||||
def calculate_fi(precision: float, recall: float) -> float:
|
||||
def calculate_f1(precision: float, recall: float) -> float:
|
||||
"""Calculate F1 score from precision and recall"""
|
||||
return 2 * (precision * recall) / (precision + recall) if precision + recall > 0 else 0.0
|
||||
|
||||
@@ -401,7 +481,7 @@ def evaluate_response_for_ir(
|
||||
articles = get_articles_by_prompt_id(ground_truth)
|
||||
precision = calculate_precision_recall(count_of_correct_articles_used_by_agent, len(unique_file_refs))
|
||||
recall = calculate_precision_recall(count_of_correct_articles_used_by_agent, len(articles))
|
||||
f1 = calculate_fi(precision, recall)
|
||||
f1 = calculate_f1(precision, recall)
|
||||
|
||||
explanation = (
|
||||
f"Information Retrieval F1 Score: {f1:.2%} Recall: {recall:.2%}, Precision: {precision:.2%}.\n"
|
||||
@@ -424,6 +504,55 @@ def evaluate_response_for_ir(
|
||||
return None, f"Evaluation failed: {str(e)}", 0.0
|
||||
|
||||
|
||||
def is_picklable(obj):
|
||||
"""Test if an object can be pickled"""
|
||||
try:
|
||||
pickle.dumps(obj)
|
||||
return True
|
||||
except:
|
||||
return False
|
||||
|
||||
|
||||
def evaluate_response_for_latex_match_with_llm_fallback(
|
||||
query: str, agent_response: str, ground_truth: str, agent_references: dict = {}
|
||||
) -> tuple[bool | None, str, float]:
|
||||
"""Evaluate Khoj response against benchmark ground truth using string matching"""
|
||||
# Initialize variables
|
||||
cost = 0.0
|
||||
extraction_config = [
|
||||
LatexExtractionConfig(
|
||||
normalization_config=NormalizationConfig(
|
||||
basic_latex=True, boxed=True, equations=True, units=False, malformed_operators=False, nits=False
|
||||
)
|
||||
)
|
||||
]
|
||||
|
||||
try:
|
||||
with multiprocessing.Pool(1) as pool:
|
||||
# Extract answer from agent response - pass function and args separately
|
||||
extracted_answer = pool.apply(parse, args=(agent_response, extraction_config))
|
||||
extracted_ground_truth = pool.apply(parse, args=(ground_truth, extraction_config))
|
||||
|
||||
# Check if extracted answer matches extracted ground truth
|
||||
if is_picklable(extracted_answer) and is_picklable(extracted_ground_truth):
|
||||
decision = pool.apply(verify, args=(extracted_ground_truth, extracted_answer))
|
||||
else:
|
||||
decision = None
|
||||
|
||||
# Fallback to LLM for decision if parse & match fails
|
||||
if not decision:
|
||||
decision, explanation, cost = evaluate_response_with_gemini(
|
||||
query, agent_response, ground_truth, agent_references
|
||||
)
|
||||
explanation = f"Agent response {'matches' if decision else 'does not match'} ground truth {ground_truth}"
|
||||
|
||||
# Return decision, explanation and cost in structured form
|
||||
return float(decision), explanation, cost
|
||||
except Exception as e:
|
||||
logger.error(f"Error in evaluation: {e}")
|
||||
return None, f"Evaluation failed: {str(e)}", cost
|
||||
|
||||
|
||||
def evaluate_response_with_mcq_match(
|
||||
query: str, agent_response: str, ground_truth: str, agent_references: dict = {}
|
||||
) -> tuple[bool | None, str, float]:
|
||||
@@ -452,7 +581,7 @@ def evaluate_response_with_gemini(
|
||||
evaluation_prompt = f"""
|
||||
Compare the following agent response with the ground truth answer.
|
||||
Determine if the agent response contains the key information from the ground truth.
|
||||
Focus on factual correctness rather than exact wording.
|
||||
Focus on factual correctness rather than exact wording. The responses may be in Latex format.
|
||||
|
||||
Query: {query}
|
||||
Agent Response: {agent_response}
|
||||
@@ -498,17 +627,95 @@ def evaluate_response_with_gemini(
|
||||
return None, f"Evaluation failed: {str(e)}", 0.0
|
||||
|
||||
|
||||
def process_batch(batch, batch_start, results, dataset_length, response_evaluator):
|
||||
def process_batch_regrade(parent_data, batch, dataset_length, response_evaluator, output_file):
|
||||
global running_cost
|
||||
for idx, (prompt, answer, reasoning_type) in enumerate(batch):
|
||||
current_index = batch_start + idx
|
||||
logger.info(f"Processing example: {current_index}/{dataset_length}")
|
||||
|
||||
# Convert parent_data to DataFrame at start
|
||||
parent_df = pd.DataFrame(parent_data)
|
||||
|
||||
for current_index, prompt, agent_response, evaluation_explanation in batch:
|
||||
logger.info(f"Processing example: {current_index+1}/{dataset_length}")
|
||||
|
||||
# Get latest data
|
||||
existing_results = pd.read_csv(output_file)
|
||||
|
||||
# Match existing row by prompt
|
||||
parent_row = parent_df[parent_df["Prompt"] == prompt]
|
||||
answer = parent_row["Answer"].values[0] if not parent_row.empty else None
|
||||
|
||||
existing_row = existing_results[existing_results["prompt"] == prompt]
|
||||
agent_usage = existing_row["usage"].values[0] if not existing_row.empty else {}
|
||||
agent_references = {}
|
||||
|
||||
# Evaluate response
|
||||
if is_none_or_empty(agent_response):
|
||||
decision = None
|
||||
explanation = "Agent response is empty. This maybe due to a service error."
|
||||
continue # Do not store results. Allows including this eval row in next resumable eval run
|
||||
else:
|
||||
if evaluation_explanation is not None and "Evaluation failed" not in evaluation_explanation:
|
||||
explanation = evaluation_explanation
|
||||
decision = existing_row["evaluation_decision"].values[0]
|
||||
eval_cost = 0.0
|
||||
else:
|
||||
decision, explanation, eval_cost = response_evaluator(prompt, agent_response, answer, agent_references)
|
||||
|
||||
# Store results
|
||||
# Thread-safe DataFrame modification
|
||||
with DF_LOCK:
|
||||
# Update existing row with new evaluation results
|
||||
existing_results.loc[existing_results["prompt"] == prompt, "agent_response"] = agent_response
|
||||
existing_results.loc[existing_results["prompt"] == prompt, "evaluation_decision"] = decision
|
||||
existing_results.loc[existing_results["prompt"] == prompt, "evaluation_explanation"] = explanation
|
||||
existing_results.to_csv(output_file, index=False)
|
||||
|
||||
# logger.info(f"Results: new_row: {results.to_dict()}")
|
||||
|
||||
# Update running cost
|
||||
try:
|
||||
query_cost = float(agent_usage.get("cost", 0.0))
|
||||
running_cost.add(query_cost + eval_cost)
|
||||
except Exception as e:
|
||||
logger.error(f"Error updating running cost: {e}")
|
||||
|
||||
# Update running accuracy
|
||||
running_accuracy = 0.0
|
||||
if decision is not None:
|
||||
running_true_count.add(decision)
|
||||
running_total_count.add(1)
|
||||
running_accuracy = running_true_count.get() / running_total_count.get()
|
||||
|
||||
## Log results
|
||||
decision_color = {True: "green", None: "blue", False: "red"}[decision > 0.5]
|
||||
colored_decision = color_text(str(decision), decision_color)
|
||||
result_to_print = f"""
|
||||
---------
|
||||
Decision: {colored_decision}
|
||||
Accuracy: {running_accuracy:.2%}
|
||||
Question: {prompt}
|
||||
Expected Answer: {answer}
|
||||
Agent Answer: {agent_response}
|
||||
Explanation: {explanation}
|
||||
---------
|
||||
"""
|
||||
logger.info(dedent(result_to_print).lstrip())
|
||||
|
||||
# Sleep between API calls to avoid rate limiting
|
||||
time.sleep(SLEEP_SECONDS)
|
||||
|
||||
|
||||
def process_batch(batch, dataset_length, response_evaluator, output_file, regrade_mode=False):
|
||||
global running_cost
|
||||
for current_index, prompt, answer, reasoning_type in batch:
|
||||
logger.info(f"Processing example: {current_index+1}/{dataset_length}")
|
||||
|
||||
# Trigger research mode if enabled
|
||||
prompt = f"/{KHOJ_MODE} {prompt}" if KHOJ_MODE and not prompt.startswith(f"/{KHOJ_MODE}") else prompt
|
||||
prompt_to_evaluate = (
|
||||
f"/{KHOJ_MODE} {prompt}" if KHOJ_MODE and not prompt.startswith(f"/{KHOJ_MODE}") else prompt
|
||||
)
|
||||
|
||||
# Get agent response
|
||||
response = get_agent_response(prompt)
|
||||
response = get_agent_response(prompt_to_evaluate)
|
||||
agent_response = response["response"]
|
||||
agent_usage = response["usage"]
|
||||
agent_references = response["references"]
|
||||
@@ -517,23 +724,29 @@ def process_batch(batch, batch_start, results, dataset_length, response_evaluato
|
||||
if is_none_or_empty(agent_response):
|
||||
decision = None
|
||||
explanation = "Agent response is empty. This maybe due to a service error."
|
||||
continue # Do not store results. Allows including this eval row in next resumable eval run
|
||||
else:
|
||||
decision, explanation, eval_cost = response_evaluator(prompt, agent_response, answer, agent_references)
|
||||
|
||||
# Store results
|
||||
results.append(
|
||||
{
|
||||
"index": current_index,
|
||||
"prompt": prompt,
|
||||
"ground_truth": answer,
|
||||
"agent_response": agent_response,
|
||||
"evaluation_decision": decision,
|
||||
"evaluation_explanation": explanation,
|
||||
"reasoning_type": reasoning_type,
|
||||
"usage": agent_usage,
|
||||
"references": agent_references,
|
||||
}
|
||||
)
|
||||
# Thread-safe DataFrame modification
|
||||
with DF_LOCK:
|
||||
pd.DataFrame(
|
||||
[
|
||||
{
|
||||
"id": current_index,
|
||||
"prompt": prompt,
|
||||
"ground_truth": answer,
|
||||
"agent_response": agent_response,
|
||||
"evaluation_decision": decision,
|
||||
"evaluation_explanation": explanation,
|
||||
"reasoning_type": reasoning_type,
|
||||
"usage": agent_usage,
|
||||
}
|
||||
]
|
||||
).to_csv(output_file, mode="a", header=False, index=False)
|
||||
|
||||
# logger.info(f"Results: new_row: {results.to_dict()}")
|
||||
|
||||
# Update running cost
|
||||
query_cost = float(agent_usage.get("cost", 0.0))
|
||||
@@ -592,10 +805,17 @@ def parse_args():
|
||||
parser.add_argument(
|
||||
"--dataset",
|
||||
"-d",
|
||||
default="frames",
|
||||
choices=["frames", "frames_ir", "simpleqa", "gpqa", "math500"],
|
||||
default="skythought",
|
||||
choices=["frames", "frames_ir", "simpleqa", "gpqa", "math500", "skythought"],
|
||||
help="Dataset to use for evaluation (default: frames)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--regrade",
|
||||
"-r",
|
||||
action="store_true",
|
||||
help="Regrade existing results in output file",
|
||||
default=False,
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
@@ -603,6 +823,22 @@ def main():
|
||||
# Initialize variables
|
||||
args = parse_args()
|
||||
dataset = None
|
||||
output_file = args.output or f"{args.dataset}_evaluation_results_{datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}.csv"
|
||||
if not os.path.exists(output_file):
|
||||
pd.DataFrame(
|
||||
columns=[
|
||||
"id",
|
||||
"prompt",
|
||||
"ground_truth",
|
||||
"agent_response",
|
||||
"evaluation_decision",
|
||||
"evaluation_explanation",
|
||||
"reasoning_type",
|
||||
"usage",
|
||||
]
|
||||
).to_csv(output_file, index=False)
|
||||
|
||||
regrade_mode = args.regrade
|
||||
|
||||
# Load dataset
|
||||
with timer(f"Loaded {args.dataset} dataset in", logger, log_level=logging.INFO):
|
||||
@@ -614,48 +850,72 @@ def main():
|
||||
dataset = load_gpqa_dataset()
|
||||
elif args.dataset == "math500":
|
||||
dataset = load_math500_dataset()
|
||||
elif args.dataset == "skythought":
|
||||
dataset = load_skythought_dataset(output_file, regrade_mode)
|
||||
elif args.dataset == "frames_ir":
|
||||
indexed = index_frames_kb()
|
||||
if indexed:
|
||||
dataset = load_frames_dataset()
|
||||
# Rename the index field, 'Unnamed: 0' to 'Answer' for IR evaluation
|
||||
dataset["Answer"] = dataset["Unnamed: 0"]
|
||||
if dataset is None:
|
||||
if dataset is None or len(dataset) == 0:
|
||||
return
|
||||
|
||||
# Initialize variables
|
||||
results = []
|
||||
dataset_length = len(dataset["Prompt"])
|
||||
if args.dataset == "gpqa":
|
||||
response_evaluator = evaluate_response_with_mcq_match
|
||||
elif args.dataset == "math500":
|
||||
response_evaluator = partial(
|
||||
evaluate_response_with_gemini, eval_model=os.getenv("GEMINI_EVAL_MODEL", "gemini-1.5-flash-002")
|
||||
)
|
||||
response_evaluator = evaluate_response_for_latex_match_with_llm_fallback
|
||||
elif args.dataset == "frames_ir":
|
||||
response_evaluator = evaluate_response_for_ir
|
||||
elif args.dataset == "skythought":
|
||||
response_evaluator = evaluate_response_for_latex_match_with_llm_fallback
|
||||
else:
|
||||
response_evaluator = evaluate_response_with_gemini
|
||||
|
||||
# Process examples in batches
|
||||
with concurrent.futures.ThreadPoolExecutor() as executor:
|
||||
futures = []
|
||||
for i in range(0, dataset_length, BATCH_SIZE):
|
||||
batch_start = i
|
||||
batch = zip(
|
||||
dataset["Prompt"][i : i + BATCH_SIZE],
|
||||
dataset["Answer"][i : i + BATCH_SIZE],
|
||||
dataset["reasoning_types"][i : i + BATCH_SIZE],
|
||||
)
|
||||
futures.append(
|
||||
executor.submit(process_batch, batch, batch_start, results, dataset_length, response_evaluator)
|
||||
)
|
||||
if regrade_mode:
|
||||
existing_data = pd.read_csv(output_file)
|
||||
existing_data_length = len(existing_data)
|
||||
parallel_size = max(existing_data_length // BATCH_SIZE, 4)
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=parallel_size) as executor:
|
||||
futures = []
|
||||
for i in range(0, existing_data_length, BATCH_SIZE):
|
||||
batch_start, batch_end = i, min(i + BATCH_SIZE, existing_data_length)
|
||||
batch = zip(
|
||||
existing_data["id"][batch_start:batch_end],
|
||||
existing_data["prompt"][batch_start:batch_end],
|
||||
existing_data["agent_response"][batch_start:batch_end],
|
||||
existing_data["evaluation_explanation"][batch_start:batch_end],
|
||||
)
|
||||
futures.append(
|
||||
executor.submit(
|
||||
process_batch_regrade, dataset, batch, existing_data_length, response_evaluator, output_file
|
||||
)
|
||||
)
|
||||
|
||||
# Wait for all futures to complete
|
||||
concurrent.futures.wait(futures)
|
||||
# Wait for all futures to complete
|
||||
concurrent.futures.wait(futures)
|
||||
else:
|
||||
# Process examples in batches
|
||||
parallel_size = max(dataset_length // BATCH_SIZE, 4)
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=parallel_size) as executor:
|
||||
futures = []
|
||||
for i in range(0, dataset_length, BATCH_SIZE):
|
||||
batch_start, batch_end = i, min(i + BATCH_SIZE, dataset_length)
|
||||
batch = zip(
|
||||
dataset["id"][batch_start:batch_end],
|
||||
dataset["Prompt"][batch_start:batch_end],
|
||||
dataset["Answer"][batch_start:batch_end],
|
||||
dataset["reasoning_types"][batch_start:batch_end],
|
||||
)
|
||||
futures.append(executor.submit(process_batch, batch, dataset_length, response_evaluator, output_file))
|
||||
|
||||
# Wait for all futures to complete
|
||||
concurrent.futures.wait(futures)
|
||||
|
||||
# Calculate metrics
|
||||
df = pd.DataFrame(results)
|
||||
df = pd.read_csv(output_file)
|
||||
eval_df = df.dropna(subset=["evaluation_decision"]) # Exclude rows with missing evaluation decision
|
||||
accuracy = (eval_df["evaluation_decision"]).mean()
|
||||
|
||||
@@ -683,6 +943,10 @@ def main():
|
||||
|
||||
# Save raw results to file
|
||||
output_file = args.output or f"{args.dataset}_evaluation_results_{datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}.csv"
|
||||
|
||||
if regrade_mode:
|
||||
output_file = output_file.replace(".csv", "_regraded.csv")
|
||||
|
||||
df.to_csv(output_file, index=False)
|
||||
logger.info(f"Results saved to {summary_file}, {output_file}")
|
||||
|
||||
|
||||
Reference in New Issue
Block a user