diff --git a/prototypes/click_eval/README.md b/prototypes/click_eval/README.md index 34dea636..8c72c572 100644 --- a/prototypes/click_eval/README.md +++ b/prototypes/click_eval/README.md @@ -44,7 +44,12 @@ portion is: "provider": "openrouter", "model": "bytedance/ui-tars-1.5-7b" }, - {"name": "moondream", "provider": "moondream", "model": "moondream-cloud"} + {"name": "moondream", "provider": "moondream", "model": "moondream-cloud"}, + { + "name": "gemini-computer-use", + "provider": "gemini", + "model": "gemini-2.5-computer-use-preview-10-2025" + } ] } ``` @@ -125,7 +130,8 @@ gated/private downloads, set `HF_TOKEN`. For Azure/Foundry-hosted variants, expect an endpoint URL plus API key and a dedicated provider adapter. Moondream candidates use a provider-qualified -entry: +entry. Gemini Computer Use candidates use `provider: "gemini"` and require +`GEMINI_API_KEY` or `GOOGLE_API_KEY`. ```json { @@ -134,6 +140,11 @@ entry: "name": "moondream", "provider": "moondream", "model": "moondream-cloud" + }, + { + "name": "gemini-computer-use", + "provider": "gemini", + "model": "gemini-2.5-computer-use-preview-10-2025" } ] } @@ -147,6 +158,8 @@ uv sync export OPENROUTER_API_KEY=... # Optional, for Moondream candidates: export MOONDREAM_API_KEY=... +# Optional, for Gemini Computer Use candidates: +export GEMINI_API_KEY=... uv run click-eval run ``` @@ -162,10 +175,14 @@ On an interactive terminal, `run` shows tqdm progress bars for tasks and model calls. In non-interactive output, it prints plain status lines instead. Use `--no-progress` to suppress both. -The CLI also loads `MOONDREAM_API_KEY` and `OPENROUTER_API_KEY` from a local -`.env` file in `prototypes/click_eval/` or the current working directory. +The CLI also loads `MOONDREAM_API_KEY`, `GEMINI_API_KEY`, `GOOGLE_API_KEY`, and +`OPENROUTER_API_KEY` from a local `.env` file in `prototypes/click_eval/` or +the current working directory. Moondream calls use `POST https://api.moondream.ai/v1/point` with the screenshot as a base64 data URL and the click instruction converted to an object query. +Gemini Computer Use calls use the Google GenAI SDK with the Computer Use tool, +request a single `click_at` action, and scale Gemini's normalized `0..1000` +coordinates back to screenshot pixels. During a run, the CLI shows progress bars for tasks and per-task candidate model calls. It also prints compact status lines for GT resolution, provider/model @@ -173,7 +190,7 @@ calls, prediction failures, and the output directory. OpenRouter candidate calls are sent concurrently in bounded batches of 4. Local HF/GPU candidates stay synchronous and serial to avoid GPU memory contention; -Moondream and GT resolution also remain synchronous. +Moondream, Gemini, and GT resolution also remain synchronous. Outputs: diff --git a/prototypes/click_eval/examples/models.json b/prototypes/click_eval/examples/models.json index c924e32b..4db6d16a 100644 --- a/prototypes/click_eval/examples/models.json +++ b/prototypes/click_eval/examples/models.json @@ -21,6 +21,11 @@ "provider": "moondream", "model": "moondream-cloud" }, + { + "name": "gemini-computer-use", + "provider": "gemini", + "model": "gemini-2.5-computer-use-preview-10-2025" + }, { "name": "qwen3-vl-4b-instruct-local", "provider": "local_hf", diff --git a/prototypes/click_eval/pyproject.toml b/prototypes/click_eval/pyproject.toml index 22c40f05..06c3d0f4 100644 --- a/prototypes/click_eval/pyproject.toml +++ b/prototypes/click_eval/pyproject.toml @@ -8,6 +8,7 @@ version = "0.1.0" description = "Quick OpenRouter VLM click-point evaluation prototype" requires-python = ">=3.12" dependencies = [ + "google-genai>=1.0.0", "Pillow>=10.0.0", "tqdm>=4.67.3", ] diff --git a/prototypes/click_eval/src/click_eval/gemini.py b/prototypes/click_eval/src/click_eval/gemini.py new file mode 100644 index 00000000..b746cb68 --- /dev/null +++ b/prototypes/click_eval/src/click_eval/gemini.py @@ -0,0 +1,211 @@ +from __future__ import annotations + +import json +import os +from pathlib import Path +from typing import Any + +from .contracts import ModelReply, Point +from .image_utils import image_size + + +class GeminiComputerUseClient: + def __init__( + self, + api_key: str | None = None, + timeout_seconds: int = 90, + ) -> None: + self.api_key = ( + api_key + or os.environ.get("GEMINI_API_KEY") + or os.environ.get("GOOGLE_API_KEY") + ) + if not self.api_key: + raise RuntimeError("GEMINI_API_KEY or GOOGLE_API_KEY is required") + self.timeout_seconds = timeout_seconds + + def predict_point( + self, + model_id: str, + image_path: Path, + instruction: str, + purpose: str, + ) -> ModelReply: + try: + from google import genai + from google.genai import types + from google.genai.types import Content, Part + except ImportError as exc: + raise RuntimeError( + "google-genai is required for provider=gemini; run `uv sync`" + ) from exc + + width, height = image_size(image_path) + client = genai.Client(api_key=self.api_key) + contents = [ + Content( + role="user", + parts=[ + Part(text=_computer_use_prompt(instruction, purpose)), + Part.from_bytes( + data=image_path.read_bytes(), + mime_type="image/png", + ), + ], + ) + ] + config = types.GenerateContentConfig( + tools=[ + types.Tool( + computer_use=types.ComputerUse( + environment=types.Environment.ENVIRONMENT_BROWSER, + excluded_predefined_functions=_excluded_functions(), + ) + ) + ], + temperature=0, + ) + response = client.models.generate_content( + model=model_id, + contents=contents, + config=config, + ) + call = _first_function_call(response) + raw = _raw_response(response) + if call is None: + return ModelReply(text=_response_text(response), raw=raw) + + point = _point_from_call(call) + if point is None: + return ModelReply(text=_response_text(response), raw=raw) + + scaled = Point(x=point.x / 1000 * width, y=point.y / 1000 * height) + return ModelReply( + text=json.dumps( + { + "x": scaled.x, + "y": scaled.y, + "reason": f"Gemini Computer Use function_call {call['name']}", + } + ), + raw=raw, + ) + + +def _computer_use_prompt(instruction: str, purpose: str) -> str: + role_line = ( + "Choose the ground-truth click point for this instruction." + if purpose == "ground_truth" + else "Predict the click point for this instruction." + ) + return ( + f"{role_line}\n\n" + "Use the screenshot and emit exactly one Computer Use `click_at` action. " + "Do not navigate, type, scroll, hover, or wait. Choose the center of the " + "target UI element when possible.\n\n" + f"Instruction: {instruction}" + ) + + +def _excluded_functions() -> list[str]: + return [ + "open_web_browser", + "wait_5_seconds", + "go_back", + "go_forward", + "search", + "navigate", + "hover_at", + "type_text_at", + "key_combination", + "scroll_document", + "drag_and_drop", + ] + + +def _first_function_call(response) -> dict[str, Any] | None: + for candidate in getattr(response, "candidates", []) or []: + content = getattr(candidate, "content", None) + for part in getattr(content, "parts", []) or []: + function_call = getattr(part, "function_call", None) + if function_call is not None: + return _function_call_dict(function_call) + return _first_function_call_from_dict(_raw_response(response)) + + +def _function_call_dict(function_call) -> dict[str, Any]: + name = getattr(function_call, "name", None) + args = getattr(function_call, "args", None) + return {"name": str(name or ""), "args": _plain_dict(args)} + + +def _first_function_call_from_dict(value: Any) -> dict[str, Any] | None: + if isinstance(value, dict): + function_call = value.get("functionCall") or value.get("function_call") + if isinstance(function_call, dict): + return { + "name": str(function_call.get("name") or ""), + "args": _plain_dict(function_call.get("args")), + } + for child in value.values(): + found = _first_function_call_from_dict(child) + if found is not None: + return found + if isinstance(value, list): + for child in value: + found = _first_function_call_from_dict(child) + if found is not None: + return found + return None + + +def _point_from_call(call: dict[str, Any]) -> Point | None: + args = call.get("args") + if not isinstance(args, dict): + return None + try: + return Point(x=float(args["x"]), y=float(args["y"])) + except (KeyError, TypeError, ValueError): + return None + + +def _response_text(response) -> str: + parts = [] + for candidate in getattr(response, "candidates", []) or []: + content = getattr(candidate, "content", None) + for part in getattr(content, "parts", []) or []: + text = getattr(part, "text", None) + if text: + parts.append(str(text)) + if parts: + return "\n".join(parts) + return json.dumps(_raw_response(response), ensure_ascii=False) + + +def _raw_response(response) -> dict[str, Any]: + for method_name in ("to_json_dict", "model_dump", "dict"): + method = getattr(response, method_name, None) + if method is None: + continue + try: + value = method() + except TypeError: + continue + if isinstance(value, dict): + return _plain_dict(value) + return {"repr": repr(response)} + + +def _plain_dict(value: Any) -> Any: + if isinstance(value, dict): + return {str(key): _plain_dict(item) for key, item in value.items()} + if isinstance(value, list): + return [_plain_dict(item) for item in value] + if isinstance(value, tuple): + return [_plain_dict(item) for item in value] + if isinstance(value, (str, int, float, bool)) or value is None: + return value + try: + return dict(value) + except (TypeError, ValueError): + return repr(value) diff --git a/prototypes/click_eval/src/click_eval/providers.py b/prototypes/click_eval/src/click_eval/providers.py index 5e4eb022..432f279d 100644 --- a/prototypes/click_eval/src/click_eval/providers.py +++ b/prototypes/click_eval/src/click_eval/providers.py @@ -4,6 +4,7 @@ from pathlib import Path from typing import Callable from .contracts import ModelReply, ModelSpec +from .gemini import GeminiComputerUseClient from .local_hf import LocalHFClient from .moondream import MoondreamClient from .openrouter import OpenRouterClient @@ -19,6 +20,7 @@ class ProviderClient: self._log_callback = log_callback self._openrouter: OpenRouterClient | None = None self._moondream: MoondreamClient | None = None + self._gemini: GeminiComputerUseClient | None = None self._local_hf: LocalHFClient | None = None def predict_point( @@ -37,6 +39,10 @@ class ProviderClient: return self._moondream_client().predict_point( model.model_id, image_path, instruction, purpose ) + if provider == "gemini": + return self._gemini_client().predict_point( + model.model_id, image_path, instruction, purpose + ) if provider == "local_hf": return self._local_hf_client().predict_point( model, image_path, instruction, purpose @@ -54,6 +60,11 @@ class ProviderClient: self._moondream = MoondreamClient(timeout_seconds=self.timeout_seconds) return self._moondream + def _gemini_client(self) -> GeminiComputerUseClient: + if self._gemini is None: + self._gemini = GeminiComputerUseClient(timeout_seconds=self.timeout_seconds) + return self._gemini + def _local_hf_client(self) -> LocalHFClient: if self._local_hf is None: self._local_hf = LocalHFClient(