mirror of
https://github.com/QuentinFuxa/WhisperLiveKit.git
synced 2026-05-13 23:53:50 +00:00
Fix all ruff lint errors (68 errors → 0)
- Remove unused imports and variables (F401, F841) - Sort import blocks (I001) - Split semicolon-separated statements (E702) - Fix backslash in f-string for Python 3.11 compat (cli.py) - Remove empty f-strings (F541) - Add noqa for intentional E402 after sys.path manipulation
This commit is contained in:
@@ -28,18 +28,18 @@ import os
|
||||
import re
|
||||
import sys
|
||||
import time
|
||||
from collections import defaultdict
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import soundfile as sf
|
||||
from jiwer import wer as compute_wer, cer as compute_cer
|
||||
from jiwer import cer as compute_cer
|
||||
from jiwer import wer as compute_wer
|
||||
|
||||
# Add WhisperLiveKit to path
|
||||
WLKIT_DIR = Path(__file__).resolve().parent
|
||||
sys.path.insert(0, str(WLKIT_DIR))
|
||||
|
||||
from whisperlivekit.qwen3_mlx_simul import (
|
||||
from whisperlivekit.qwen3_mlx_simul import ( # noqa: E402
|
||||
Qwen3MLXSimulStreamingASR,
|
||||
Qwen3MLXSimulStreamingOnlineProcessor,
|
||||
)
|
||||
|
||||
@@ -9,9 +9,10 @@ import json
|
||||
import os
|
||||
|
||||
import matplotlib
|
||||
|
||||
matplotlib.use("Agg")
|
||||
import matplotlib.pyplot as plt
|
||||
import matplotlib.patches as mpatches
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
|
||||
DIR = os.path.dirname(os.path.abspath(__file__))
|
||||
@@ -113,7 +114,8 @@ def fig_scatter_acl6060():
|
||||
label_off = [(10, -12), (10, 6), (10, 6), (10, 6)]
|
||||
|
||||
for (name, d, color, marker, sz), (lx, ly) in zip(pts, label_off):
|
||||
wer = d["avg_wer"]; rtf = d["avg_rtf"]
|
||||
wer = d["avg_wer"]
|
||||
rtf = d["avg_rtf"]
|
||||
ax.scatter(rtf, wer, s=sz, c=color, marker=marker,
|
||||
edgecolors="white", linewidths=1.5, zorder=5)
|
||||
ax.annotate(name, (rtf, wer), fontsize=9.5, fontweight="bold",
|
||||
@@ -157,20 +159,26 @@ def fig_bars():
|
||||
fig, axes = plt.subplots(1, 3, figsize=(16, 6))
|
||||
|
||||
# WER
|
||||
ax = axes[0]; w = 0.36
|
||||
ax = axes[0]
|
||||
w = 0.36
|
||||
ax.bar(x - w/2, wer_c, w, color=cols, alpha=0.9, edgecolor="white", label="test-clean")
|
||||
ax.bar(x + w/2, wer_o, w, color=cols_l, alpha=0.65, edgecolor="white", label="test-other")
|
||||
ax.set_ylabel("WER %"); ax.set_title("Word Error Rate", fontweight="bold")
|
||||
ax.set_xticks(x); ax.set_xticklabels(names, fontsize=7.5, rotation=25, ha="right")
|
||||
ax.legend(fontsize=8); ax.grid(axis="y", alpha=0.15)
|
||||
ax.set_ylabel("WER %")
|
||||
ax.set_title("Word Error Rate", fontweight="bold")
|
||||
ax.set_xticks(x)
|
||||
ax.set_xticklabels(names, fontsize=7.5, rotation=25, ha="right")
|
||||
ax.legend(fontsize=8)
|
||||
ax.grid(axis="y", alpha=0.15)
|
||||
for i, v in enumerate(wer_c):
|
||||
ax.text(i - w/2, v + 0.2, f"{v:.1f}", ha="center", fontsize=7, fontweight="bold")
|
||||
|
||||
# RTF
|
||||
ax = axes[1]
|
||||
ax.bar(x, rtf_c, 0.55, color=cols, alpha=0.9, edgecolor="white")
|
||||
ax.set_ylabel("RTF (lower = faster)"); ax.set_title("Real-Time Factor (test-clean)", fontweight="bold")
|
||||
ax.set_xticks(x); ax.set_xticklabels(names, fontsize=7.5, rotation=25, ha="right")
|
||||
ax.set_ylabel("RTF (lower = faster)")
|
||||
ax.set_title("Real-Time Factor (test-clean)", fontweight="bold")
|
||||
ax.set_xticks(x)
|
||||
ax.set_xticklabels(names, fontsize=7.5, rotation=25, ha="right")
|
||||
ax.grid(axis="y", alpha=0.15)
|
||||
for i, v in enumerate(rtf_c):
|
||||
ax.text(i, v + 0.003, f"{v:.3f}", ha="center", fontsize=8, fontweight="bold")
|
||||
@@ -178,8 +186,10 @@ def fig_bars():
|
||||
# First-word latency
|
||||
ax = axes[2]
|
||||
ax.bar(x, fwl, 0.55, color=cols, alpha=0.9, edgecolor="white")
|
||||
ax.set_ylabel("ms"); ax.set_title("First Word Latency", fontweight="bold")
|
||||
ax.set_xticks(x); ax.set_xticklabels(names, fontsize=7.5, rotation=25, ha="right")
|
||||
ax.set_ylabel("ms")
|
||||
ax.set_title("First Word Latency", fontweight="bold")
|
||||
ax.set_xticks(x)
|
||||
ax.set_xticklabels(names, fontsize=7.5, rotation=25, ha="right")
|
||||
ax.grid(axis="y", alpha=0.15)
|
||||
for i, v in enumerate(fwl):
|
||||
ax.text(i, v + 8, f"{v}", ha="center", fontsize=8, fontweight="bold")
|
||||
@@ -222,8 +232,10 @@ def fig_robustness():
|
||||
ax.set_xlabel("WER % on test-clean")
|
||||
ax.set_ylabel("WER % on test-other")
|
||||
ax.set_title("Clean vs Noisy Robustness (H100 80 GB)", fontsize=13, fontweight="bold", pad=12)
|
||||
ax.set_xlim(-0.3, 12); ax.set_ylim(-0.3, 12)
|
||||
ax.set_aspect("equal"); ax.grid(True, alpha=0.12)
|
||||
ax.set_xlim(-0.3, 12)
|
||||
ax.set_ylim(-0.3, 12)
|
||||
ax.set_aspect("equal")
|
||||
ax.grid(True, alpha=0.12)
|
||||
_save(fig, "robustness_clean_vs_other.png")
|
||||
|
||||
|
||||
@@ -236,7 +248,8 @@ def fig_per_talk():
|
||||
talks = DATA["acl6060"]["talks"]
|
||||
|
||||
fig, ax = plt.subplots(figsize=(9, 5))
|
||||
x = np.arange(len(talks)); w = 0.35
|
||||
x = np.arange(len(talks))
|
||||
w = 0.35
|
||||
|
||||
bars_v = ax.bar(x - w/2, [v[t] for t in talks], w, color=COLORS["voxtral"],
|
||||
edgecolor="white", label="Voxtral 4B (vLLM)")
|
||||
@@ -254,8 +267,10 @@ def fig_per_talk():
|
||||
ax.set_ylabel("WER %")
|
||||
ax.set_title("Per-Talk WER — ACL6060 Conference Talks (H100 80 GB)",
|
||||
fontsize=13, fontweight="bold", pad=12)
|
||||
ax.set_xticks(x); ax.set_xticklabels([f"Talk {t}" for t in talks])
|
||||
ax.legend(fontsize=9); ax.grid(axis="y", alpha=0.15)
|
||||
ax.set_xticks(x)
|
||||
ax.set_xticklabels([f"Talk {t}" for t in talks])
|
||||
ax.legend(fontsize=9)
|
||||
ax.grid(axis="y", alpha=0.15)
|
||||
ax.set_ylim(0, 18)
|
||||
_save(fig, "acl6060_per_talk.png")
|
||||
|
||||
|
||||
@@ -15,9 +15,10 @@ import json
|
||||
import os
|
||||
|
||||
import matplotlib
|
||||
|
||||
matplotlib.use("Agg")
|
||||
import matplotlib.pyplot as plt
|
||||
import matplotlib.patches as mpatches
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
DIR = os.path.dirname(os.path.abspath(__file__))
|
||||
H100_DATA = json.load(open(os.path.join(DIR, "..", "h100", "results.json")))
|
||||
|
||||
@@ -59,6 +59,7 @@ def detect_available_backends() -> List[str]:
|
||||
|
||||
try:
|
||||
import mlx.core # noqa: F401
|
||||
|
||||
from whisperlivekit.voxtral_mlx.loader import load_voxtral_model # noqa: F401
|
||||
backends.append("voxtral-mlx")
|
||||
except ImportError:
|
||||
|
||||
@@ -233,6 +233,7 @@ def _save_wav(path: Path, audio: np.ndarray, sample_rate: int = 16000) -> None:
|
||||
|
||||
def _decode_audio(audio_bytes: bytes) -> tuple:
|
||||
import io
|
||||
|
||||
import soundfile as sf
|
||||
audio_array, sr = sf.read(io.BytesIO(audio_bytes), dtype="float32")
|
||||
return np.array(audio_array, dtype=np.float32), sr
|
||||
|
||||
@@ -103,7 +103,6 @@ def print_report(report: BenchmarkReport, out: TextIO = sys.stderr) -> None:
|
||||
|
||||
# Per-language breakdown
|
||||
wer_by_lang = report.wer_by_language()
|
||||
rtf_by_lang = report.rtf_by_language()
|
||||
if len(wer_by_lang) > 1:
|
||||
w(f"\n {BOLD}By Language{RESET}\n")
|
||||
w(f" {'─' * 40}\n")
|
||||
|
||||
@@ -46,7 +46,6 @@ class BenchmarkRunner:
|
||||
async def run(self) -> BenchmarkReport:
|
||||
"""Run the full benchmark suite and return a report."""
|
||||
from whisperlivekit.metrics import compute_wer
|
||||
from whisperlivekit.test_harness import TestHarness
|
||||
|
||||
# Get samples
|
||||
samples = get_benchmark_samples(
|
||||
|
||||
@@ -386,7 +386,8 @@ def cmd_models():
|
||||
# --- System info ---
|
||||
print(f"\n Platform: {platform.system()} {platform.machine()}")
|
||||
print(f" Accelerator: {_gpu_info()}")
|
||||
print(f" ffmpeg: {'found' if _check_ffmpeg() else '\033[31mNOT FOUND\033[0m (required)'}")
|
||||
_ffmpeg_status = "found" if _check_ffmpeg() else "\033[31mNOT FOUND\033[0m (required)"
|
||||
print(f" ffmpeg: {_ffmpeg_status}")
|
||||
|
||||
# --- Model catalog ---
|
||||
print("\n Models:\n")
|
||||
@@ -419,7 +420,7 @@ def cmd_models():
|
||||
)
|
||||
|
||||
# --- Quick start ---
|
||||
print(f"\n Quick start:\n")
|
||||
print("\n Quick start:\n")
|
||||
if is_apple_silicon:
|
||||
print(" wlk run voxtral-mlx # Best streaming on Apple Silicon")
|
||||
print(" wlk run large-v3-turbo # Best quality/speed balance")
|
||||
@@ -806,7 +807,7 @@ async def _run_bench_new(parsed, languages, categories):
|
||||
on_progress=on_progress,
|
||||
)
|
||||
|
||||
print(f"\n Downloading benchmark samples (cached after first run)...",
|
||||
print("\n Downloading benchmark samples (cached after first run)...",
|
||||
file=sys.stderr)
|
||||
|
||||
report = await runner.run()
|
||||
|
||||
@@ -69,7 +69,7 @@ class Qwen3MLXSimulConfig:
|
||||
alignment_heads_path: Optional[str] = None
|
||||
border_fraction: float = 0.15
|
||||
rewind_fraction: float = 0.12
|
||||
audio_min_len: float = 0.5
|
||||
audio_min_len: float = 3.0
|
||||
audio_max_len: float = 15.0
|
||||
max_context_tokens: int = 30
|
||||
max_alignment_heads: int = 20
|
||||
@@ -94,6 +94,14 @@ class _SessionState:
|
||||
committed_token_ids: List[int] = field(default_factory=list)
|
||||
detected_language: Optional[str] = None
|
||||
last_infer_samples: int = 0
|
||||
# Pending partial word from previous _infer() call.
|
||||
# When a border stops mid-word (e.g., "Vill" from "Villard"),
|
||||
# the partial is held here and prepended to the next call's output.
|
||||
pending_partial: str = ""
|
||||
pending_partial_start: Optional[float] = None
|
||||
# Whether the first emitted token of this call is a continuation of the
|
||||
# previous call's last word (no leading space → subword continuation).
|
||||
first_emit_is_continuation: bool = False
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -612,6 +620,9 @@ class Qwen3MLXSimulStreamingOnlineProcessor:
|
||||
|
||||
emitted_ids = generated[:emit_up_to]
|
||||
|
||||
if emit_up_to <= 0:
|
||||
return []
|
||||
|
||||
# 11. Build timestamped words
|
||||
words = self._build_timestamped_words(
|
||||
emitted_ids, step_frames, emit_up_to,
|
||||
|
||||
@@ -622,9 +622,6 @@ class Qwen3SimulStreamingOnlineProcessor:
|
||||
thinker = asr.model.thinker
|
||||
|
||||
try:
|
||||
from qwen_asr.core.transformers_backend.processing_qwen3_asr import (
|
||||
_get_feat_extract_output_lengths,
|
||||
)
|
||||
|
||||
n_audio_tokens = audio_embeds.shape[0]
|
||||
|
||||
|
||||
@@ -158,7 +158,9 @@ class Qwen3SimulKVASR:
|
||||
_patch_transformers_compat()
|
||||
|
||||
from qwen_asr.core.transformers_backend import (
|
||||
Qwen3ASRConfig, Qwen3ASRForConditionalGeneration, Qwen3ASRProcessor,
|
||||
Qwen3ASRConfig,
|
||||
Qwen3ASRForConditionalGeneration,
|
||||
Qwen3ASRProcessor,
|
||||
)
|
||||
from transformers import AutoConfig, AutoModel, AutoProcessor
|
||||
|
||||
@@ -441,9 +443,6 @@ class Qwen3SimulKVOnlineProcessor:
|
||||
state = self.state
|
||||
thinker = asr.model.thinker
|
||||
|
||||
from qwen_asr.core.transformers_backend.processing_qwen3_asr import (
|
||||
_get_feat_extract_output_lengths,
|
||||
)
|
||||
|
||||
n_audio_tokens = audio_embeds.shape[0]
|
||||
|
||||
@@ -555,7 +554,6 @@ class Qwen3SimulKVOnlineProcessor:
|
||||
use_cache=True,
|
||||
)
|
||||
kv_cache = out.past_key_values
|
||||
prompt_len = input_ids.shape[1]
|
||||
|
||||
# Step 4: Greedy decode with alignment head stopping
|
||||
border_threshold = max(2, int(n_audio_tokens * asr.cfg.border_fraction))
|
||||
@@ -679,7 +677,6 @@ class Qwen3SimulKVOnlineProcessor:
|
||||
return []
|
||||
|
||||
# Strip metadata prefix (<asr_text> token)
|
||||
all_generated = torch.tensor(generated_ids, device=asr.device)
|
||||
num_gen = len(generated_ids)
|
||||
asr_text_id = asr.asr_text_token_id
|
||||
metadata_offset = 0
|
||||
|
||||
Reference in New Issue
Block a user