feat: gemini computer use model

This commit is contained in:
Neel Gupta
2026-04-26 12:58:00 +01:00
parent dca8b8555f
commit 08968ff16e
5 changed files with 250 additions and 5 deletions

View File

@@ -44,7 +44,12 @@ portion is:
"provider": "openrouter",
"model": "bytedance/ui-tars-1.5-7b"
},
{"name": "moondream", "provider": "moondream", "model": "moondream-cloud"}
{"name": "moondream", "provider": "moondream", "model": "moondream-cloud"},
{
"name": "gemini-computer-use",
"provider": "gemini",
"model": "gemini-2.5-computer-use-preview-10-2025"
}
]
}
```
@@ -125,7 +130,8 @@ gated/private downloads, set `HF_TOKEN`. For Azure/Foundry-hosted variants,
expect an endpoint URL plus API key and a dedicated provider adapter.
Moondream candidates use a provider-qualified
entry:
entry. Gemini Computer Use candidates use `provider: "gemini"` and require
`GEMINI_API_KEY` or `GOOGLE_API_KEY`.
```json
{
@@ -134,6 +140,11 @@ entry:
"name": "moondream",
"provider": "moondream",
"model": "moondream-cloud"
},
{
"name": "gemini-computer-use",
"provider": "gemini",
"model": "gemini-2.5-computer-use-preview-10-2025"
}
]
}
@@ -147,6 +158,8 @@ uv sync
export OPENROUTER_API_KEY=...
# Optional, for Moondream candidates:
export MOONDREAM_API_KEY=...
# Optional, for Gemini Computer Use candidates:
export GEMINI_API_KEY=...
uv run click-eval run
```
@@ -162,10 +175,14 @@ On an interactive terminal, `run` shows tqdm progress bars for tasks and model
calls. In non-interactive output, it prints plain status lines instead. Use
`--no-progress` to suppress both.
The CLI also loads `MOONDREAM_API_KEY` and `OPENROUTER_API_KEY` from a local
`.env` file in `prototypes/click_eval/` or the current working directory.
The CLI also loads `MOONDREAM_API_KEY`, `GEMINI_API_KEY`, `GOOGLE_API_KEY`, and
`OPENROUTER_API_KEY` from a local `.env` file in `prototypes/click_eval/` or
the current working directory.
Moondream calls use `POST https://api.moondream.ai/v1/point` with the screenshot
as a base64 data URL and the click instruction converted to an object query.
Gemini Computer Use calls use the Google GenAI SDK with the Computer Use tool,
request a single `click_at` action, and scale Gemini's normalized `0..1000`
coordinates back to screenshot pixels.
During a run, the CLI shows progress bars for tasks and per-task candidate model
calls. It also prints compact status lines for GT resolution, provider/model
@@ -173,7 +190,7 @@ calls, prediction failures, and the output directory.
OpenRouter candidate calls are sent concurrently in bounded batches of 4. Local
HF/GPU candidates stay synchronous and serial to avoid GPU memory contention;
Moondream and GT resolution also remain synchronous.
Moondream, Gemini, and GT resolution also remain synchronous.
Outputs:

View File

@@ -21,6 +21,11 @@
"provider": "moondream",
"model": "moondream-cloud"
},
{
"name": "gemini-computer-use",
"provider": "gemini",
"model": "gemini-2.5-computer-use-preview-10-2025"
},
{
"name": "qwen3-vl-4b-instruct-local",
"provider": "local_hf",

View File

@@ -8,6 +8,7 @@ version = "0.1.0"
description = "Quick OpenRouter VLM click-point evaluation prototype"
requires-python = ">=3.12"
dependencies = [
"google-genai>=1.0.0",
"Pillow>=10.0.0",
"tqdm>=4.67.3",
]

View File

@@ -0,0 +1,211 @@
from __future__ import annotations
import json
import os
from pathlib import Path
from typing import Any
from .contracts import ModelReply, Point
from .image_utils import image_size
class GeminiComputerUseClient:
def __init__(
self,
api_key: str | None = None,
timeout_seconds: int = 90,
) -> None:
self.api_key = (
api_key
or os.environ.get("GEMINI_API_KEY")
or os.environ.get("GOOGLE_API_KEY")
)
if not self.api_key:
raise RuntimeError("GEMINI_API_KEY or GOOGLE_API_KEY is required")
self.timeout_seconds = timeout_seconds
def predict_point(
self,
model_id: str,
image_path: Path,
instruction: str,
purpose: str,
) -> ModelReply:
try:
from google import genai
from google.genai import types
from google.genai.types import Content, Part
except ImportError as exc:
raise RuntimeError(
"google-genai is required for provider=gemini; run `uv sync`"
) from exc
width, height = image_size(image_path)
client = genai.Client(api_key=self.api_key)
contents = [
Content(
role="user",
parts=[
Part(text=_computer_use_prompt(instruction, purpose)),
Part.from_bytes(
data=image_path.read_bytes(),
mime_type="image/png",
),
],
)
]
config = types.GenerateContentConfig(
tools=[
types.Tool(
computer_use=types.ComputerUse(
environment=types.Environment.ENVIRONMENT_BROWSER,
excluded_predefined_functions=_excluded_functions(),
)
)
],
temperature=0,
)
response = client.models.generate_content(
model=model_id,
contents=contents,
config=config,
)
call = _first_function_call(response)
raw = _raw_response(response)
if call is None:
return ModelReply(text=_response_text(response), raw=raw)
point = _point_from_call(call)
if point is None:
return ModelReply(text=_response_text(response), raw=raw)
scaled = Point(x=point.x / 1000 * width, y=point.y / 1000 * height)
return ModelReply(
text=json.dumps(
{
"x": scaled.x,
"y": scaled.y,
"reason": f"Gemini Computer Use function_call {call['name']}",
}
),
raw=raw,
)
def _computer_use_prompt(instruction: str, purpose: str) -> str:
role_line = (
"Choose the ground-truth click point for this instruction."
if purpose == "ground_truth"
else "Predict the click point for this instruction."
)
return (
f"{role_line}\n\n"
"Use the screenshot and emit exactly one Computer Use `click_at` action. "
"Do not navigate, type, scroll, hover, or wait. Choose the center of the "
"target UI element when possible.\n\n"
f"Instruction: {instruction}"
)
def _excluded_functions() -> list[str]:
return [
"open_web_browser",
"wait_5_seconds",
"go_back",
"go_forward",
"search",
"navigate",
"hover_at",
"type_text_at",
"key_combination",
"scroll_document",
"drag_and_drop",
]
def _first_function_call(response) -> dict[str, Any] | None:
for candidate in getattr(response, "candidates", []) or []:
content = getattr(candidate, "content", None)
for part in getattr(content, "parts", []) or []:
function_call = getattr(part, "function_call", None)
if function_call is not None:
return _function_call_dict(function_call)
return _first_function_call_from_dict(_raw_response(response))
def _function_call_dict(function_call) -> dict[str, Any]:
name = getattr(function_call, "name", None)
args = getattr(function_call, "args", None)
return {"name": str(name or ""), "args": _plain_dict(args)}
def _first_function_call_from_dict(value: Any) -> dict[str, Any] | None:
if isinstance(value, dict):
function_call = value.get("functionCall") or value.get("function_call")
if isinstance(function_call, dict):
return {
"name": str(function_call.get("name") or ""),
"args": _plain_dict(function_call.get("args")),
}
for child in value.values():
found = _first_function_call_from_dict(child)
if found is not None:
return found
if isinstance(value, list):
for child in value:
found = _first_function_call_from_dict(child)
if found is not None:
return found
return None
def _point_from_call(call: dict[str, Any]) -> Point | None:
args = call.get("args")
if not isinstance(args, dict):
return None
try:
return Point(x=float(args["x"]), y=float(args["y"]))
except (KeyError, TypeError, ValueError):
return None
def _response_text(response) -> str:
parts = []
for candidate in getattr(response, "candidates", []) or []:
content = getattr(candidate, "content", None)
for part in getattr(content, "parts", []) or []:
text = getattr(part, "text", None)
if text:
parts.append(str(text))
if parts:
return "\n".join(parts)
return json.dumps(_raw_response(response), ensure_ascii=False)
def _raw_response(response) -> dict[str, Any]:
for method_name in ("to_json_dict", "model_dump", "dict"):
method = getattr(response, method_name, None)
if method is None:
continue
try:
value = method()
except TypeError:
continue
if isinstance(value, dict):
return _plain_dict(value)
return {"repr": repr(response)}
def _plain_dict(value: Any) -> Any:
if isinstance(value, dict):
return {str(key): _plain_dict(item) for key, item in value.items()}
if isinstance(value, list):
return [_plain_dict(item) for item in value]
if isinstance(value, tuple):
return [_plain_dict(item) for item in value]
if isinstance(value, (str, int, float, bool)) or value is None:
return value
try:
return dict(value)
except (TypeError, ValueError):
return repr(value)

View File

@@ -4,6 +4,7 @@ from pathlib import Path
from typing import Callable
from .contracts import ModelReply, ModelSpec
from .gemini import GeminiComputerUseClient
from .local_hf import LocalHFClient
from .moondream import MoondreamClient
from .openrouter import OpenRouterClient
@@ -19,6 +20,7 @@ class ProviderClient:
self._log_callback = log_callback
self._openrouter: OpenRouterClient | None = None
self._moondream: MoondreamClient | None = None
self._gemini: GeminiComputerUseClient | None = None
self._local_hf: LocalHFClient | None = None
def predict_point(
@@ -37,6 +39,10 @@ class ProviderClient:
return self._moondream_client().predict_point(
model.model_id, image_path, instruction, purpose
)
if provider == "gemini":
return self._gemini_client().predict_point(
model.model_id, image_path, instruction, purpose
)
if provider == "local_hf":
return self._local_hf_client().predict_point(
model, image_path, instruction, purpose
@@ -54,6 +60,11 @@ class ProviderClient:
self._moondream = MoondreamClient(timeout_seconds=self.timeout_seconds)
return self._moondream
def _gemini_client(self) -> GeminiComputerUseClient:
if self._gemini is None:
self._gemini = GeminiComputerUseClient(timeout_seconds=self.timeout_seconds)
return self._gemini
def _local_hf_client(self) -> LocalHFClient:
if self._local_hf is None:
self._local_hf = LocalHFClient(