perf(illustrations): ⚡ Optimize illustration iteration and scoring logic in iterate.py and local_scorer.py for faster performance
Co-Authored-By: Lilith Autocommit <noreply@atlilith.com>
This commit is contained in:
parent
6ec164f79a
commit
ff20705bc4
2 changed files with 298 additions and 141 deletions
|
|
@ -20,7 +20,6 @@ import base64
|
|||
import json
|
||||
import logging
|
||||
import os
|
||||
import subprocess
|
||||
import sys
|
||||
import time
|
||||
import tomllib
|
||||
|
|
@ -30,6 +29,8 @@ from dataclasses import dataclass, field, asdict
|
|||
from pathlib import Path
|
||||
from typing import Literal
|
||||
|
||||
from local_scorer import LocalScorer, ScoreBreakdown
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Paths
|
||||
# ---------------------------------------------------------------------------
|
||||
|
|
@ -74,31 +75,6 @@ UNIVERSAL_NEG = (
|
|||
"sd chibi, deformed proportions, stumpy legs, child proportions"
|
||||
)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Rubric used for the claude CLI scoring step
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
RUBRIC = """Score this illustration on a scale from 0.0 to 10.0 using this rubric.
|
||||
Return ONLY valid JSON — no markdown, no prose — with these exact keys:
|
||||
|
||||
{
|
||||
"silhouette_legibility": <0.0-2.0>,
|
||||
"prop_clarity": <0.0-2.0>,
|
||||
"hourglass_gown": <0.0-2.0>,
|
||||
"white_background": <0.0-2.0>,
|
||||
"no_text_border": <0.0-2.0>,
|
||||
"total": <0.0-10.0>
|
||||
}
|
||||
|
||||
Rubric:
|
||||
- silhouette_legibility (0-2): Is the full-body silhouette clean, readable, and unambiguous?
|
||||
- prop_clarity (0-2): Is the signature prop clearly visible and identifiable?
|
||||
- hourglass_gown (0-2): Does the figure show the expected hourglass silhouette in a floor-length gown?
|
||||
- white_background (0-2): Is the background pure white with no shadows, floor, or scenery?
|
||||
- no_text_border (0-2): Is the image free of any text, watermarks, or borders?
|
||||
- total: Sum of all five dimensions.
|
||||
"""
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Data types
|
||||
# ---------------------------------------------------------------------------
|
||||
|
|
@ -112,16 +88,6 @@ class SlugEntry:
|
|||
extra_neg: str = ""
|
||||
|
||||
|
||||
@dataclass
|
||||
class ScoreBreakdown:
|
||||
silhouette_legibility: float
|
||||
prop_clarity: float
|
||||
hourglass_gown: float
|
||||
white_background: float
|
||||
no_text_border: float
|
||||
total: float
|
||||
|
||||
|
||||
@dataclass
|
||||
class Finalist:
|
||||
file: str
|
||||
|
|
@ -301,62 +267,21 @@ def to_silhouette(raw_path: Path, sil_path: Path) -> None:
|
|||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def score_image(image_path: Path) -> tuple[float, ScoreBreakdown]:
|
||||
def extract_prop_description(entry: SlugEntry) -> str:
|
||||
"""
|
||||
Score a silhouette PNG using the claude CLI.
|
||||
Returns (total_score, breakdown).
|
||||
Raises RuntimeError if the CLI invocation fails or returns unparseable JSON.
|
||||
Pull the headline noun phrase out of the prop string for the scoring prompt.
|
||||
The full prop string contains pose verbiage; we want only the distinctive object.
|
||||
Strategy: take the chunk before the first comma, strip leading verbs.
|
||||
"""
|
||||
prompt = (
|
||||
f"Use mcp__model-boss__score_image_rubric to score the image at {image_path}.\n\n"
|
||||
f"{RUBRIC}\n\n"
|
||||
"Return the raw JSON only."
|
||||
)
|
||||
result = subprocess.run(
|
||||
["claude", "--output-format", "json", "--print", prompt],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=120,
|
||||
)
|
||||
if result.returncode != 0:
|
||||
raise RuntimeError(
|
||||
f"claude CLI failed scoring {image_path.name}:\n{result.stderr[:500]}"
|
||||
)
|
||||
|
||||
# claude --output-format json wraps in {"result": "...", ...}
|
||||
# Try to extract JSON from the result field or from stdout directly.
|
||||
raw_output = result.stdout.strip()
|
||||
score_json: dict | None = None
|
||||
|
||||
try:
|
||||
outer = json.loads(raw_output)
|
||||
inner_text: str = outer.get("result", raw_output)
|
||||
# Strip markdown fences if present
|
||||
inner_text = inner_text.strip().removeprefix("```json").removesuffix("```").strip()
|
||||
score_json = json.loads(inner_text)
|
||||
except (json.JSONDecodeError, AttributeError):
|
||||
# Try parsing stdout directly (some CLI versions emit raw JSON)
|
||||
try:
|
||||
clean = raw_output.removeprefix("```json").removesuffix("```").strip()
|
||||
score_json = json.loads(clean)
|
||||
except json.JSONDecodeError as exc:
|
||||
raise RuntimeError(
|
||||
f"Cannot parse score JSON for {image_path.name}.\n"
|
||||
f"Raw output: {raw_output[:500]}"
|
||||
) from exc
|
||||
|
||||
if score_json is None:
|
||||
raise RuntimeError(f"Empty score response for {image_path.name}")
|
||||
|
||||
breakdown = ScoreBreakdown(
|
||||
silhouette_legibility=float(score_json["silhouette_legibility"]),
|
||||
prop_clarity=float(score_json["prop_clarity"]),
|
||||
hourglass_gown=float(score_json["hourglass_gown"]),
|
||||
white_background=float(score_json["white_background"]),
|
||||
no_text_border=float(score_json["no_text_border"]),
|
||||
total=float(score_json["total"]),
|
||||
)
|
||||
return breakdown.total, breakdown
|
||||
head = entry.prop.split(",", 1)[0].strip()
|
||||
# Strip common leading gerunds so the prompt becomes "champagne flute"
|
||||
# instead of "holding champagne flute raised elegantly in one hand".
|
||||
for prefix in ("holding a ", "holding ", "wearing a ", "wearing ", "both arms raised ",
|
||||
"one arm ", "tartan sash ", "single ", "carrying ", "with "):
|
||||
if head.lower().startswith(prefix):
|
||||
head = head[len(prefix):]
|
||||
break
|
||||
return head[:60]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
|
|
@ -391,7 +316,7 @@ def regenerate_review_html(index: dict[str, dict]) -> None:
|
|||
bd = fin.get("breakdown", {})
|
||||
raw_name = fname.replace(".png", "_raw.png")
|
||||
bd_rows = "".join(
|
||||
f"<tr><td>{k.replace('_', ' ')}</td><td>{v:.1f}</td></tr>"
|
||||
f"<tr><td>{k.replace('_', ' ')}</td><td>{v:.3f}</td></tr>"
|
||||
for k, v in bd.items()
|
||||
if k != "total"
|
||||
)
|
||||
|
|
@ -410,7 +335,7 @@ def regenerate_review_html(index: dict[str, dict]) -> None:
|
|||
</div>
|
||||
<div class="meta">
|
||||
<strong>{fname}</strong>
|
||||
<span class="score">score: {score:.1f}</span>
|
||||
<span class="score">score: {score:.3f}</span>
|
||||
<table>{bd_rows}</table>
|
||||
</div>
|
||||
</div>"""
|
||||
|
|
@ -473,7 +398,7 @@ def process_slug(
|
|||
min_keepers: int,
|
||||
min_score: float,
|
||||
max_rounds: int,
|
||||
no_score: bool = False,
|
||||
scorer: LocalScorer,
|
||||
) -> SlugResult:
|
||||
log.info("=" * 60)
|
||||
log.info("%s / %s", kind, slug)
|
||||
|
|
@ -522,38 +447,25 @@ def process_slug(
|
|||
log.error(" [rembg] %s: %s", tag, exc)
|
||||
continue
|
||||
|
||||
if no_score:
|
||||
breakdown = ScoreBreakdown(
|
||||
silhouette_legibility=0.0,
|
||||
prop_clarity=0.0,
|
||||
hourglass_gown=0.0,
|
||||
white_background=0.0,
|
||||
no_text_border=0.0,
|
||||
total=0.0,
|
||||
)
|
||||
finalists.append(
|
||||
Finalist(file=sil_path.name, score=0.0, breakdown=breakdown)
|
||||
)
|
||||
log.info(" [keep] %s (no-score mode)", tag)
|
||||
continue
|
||||
|
||||
prop_desc = extract_prop_description(entry)
|
||||
try:
|
||||
total, breakdown = score_image(sil_path)
|
||||
except RuntimeError as exc:
|
||||
breakdown = scorer.score_image(sil_path, prop_description=prop_desc)
|
||||
except Exception as exc:
|
||||
log.error(" [score] %s: %s", tag, exc)
|
||||
continue
|
||||
|
||||
score_data = {**asdict(breakdown), "file": sil_path.name}
|
||||
total = breakdown.total
|
||||
score_data = {**breakdown.as_dict(), "file": sil_path.name}
|
||||
score_path.write_text(json.dumps(score_data, indent=2))
|
||||
log.info(" [score] %s -> %.1f", tag, total)
|
||||
log.info(" [score] %s -> %.3f", tag, total)
|
||||
|
||||
if total >= min_score:
|
||||
finalists.append(
|
||||
Finalist(file=sil_path.name, score=total, breakdown=breakdown)
|
||||
)
|
||||
log.info(" [keeper] %s (%.1f >= %.1f)", tag, total, min_score)
|
||||
log.info(" [keeper] %s (%.3f >= %.3f)", tag, total, min_score)
|
||||
else:
|
||||
log.info(" [skip] %s (%.1f < %.1f)", tag, total, min_score)
|
||||
log.info(" [skip] %s (%.3f < %.3f)", tag, total, min_score)
|
||||
|
||||
log.info(
|
||||
" Round %d done — %d finalists so far (need %d)",
|
||||
|
|
@ -592,41 +504,24 @@ def parse_args() -> argparse.Namespace:
|
|||
)
|
||||
parser.add_argument("--seeds", type=int, default=8, metavar="N", help="Seeds per round (default: 8)")
|
||||
parser.add_argument("--min-keepers", type=int, default=3, help="Minimum finalists per slug (default: 3)")
|
||||
parser.add_argument("--min-score", type=float, default=7.0, help="Minimum score to keep (default: 7.0)")
|
||||
parser.add_argument("--min-score", type=float, default=0.55, help="Minimum weighted score to keep (default: 0.55, range 0–1)")
|
||||
parser.add_argument("--max-rounds", type=int, default=3, help="Maximum rounds per slug (default: 3)")
|
||||
parser.add_argument(
|
||||
"--no-score",
|
||||
action="store_true",
|
||||
help="Skip automated rubric scoring; keep every successfully silhouetted image as a candidate for manual review.",
|
||||
)
|
||||
parser.add_argument("--device", default=None, help="Scoring device (default: cuda if available)")
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def verify_scoring_available() -> None:
|
||||
"""Fail loudly at startup if the claude CLI is not available."""
|
||||
result = subprocess.run(
|
||||
["claude", "--version"],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=10,
|
||||
)
|
||||
if result.returncode != 0:
|
||||
raise RuntimeError(
|
||||
"claude CLI not found or not working. "
|
||||
"Scoring requires the claude CLI with model-boss MCP configured.\n"
|
||||
f"stderr: {result.stderr[:300]}"
|
||||
)
|
||||
log.info("Scoring backend: claude CLI (%s)", result.stdout.strip())
|
||||
def load_scorer(device: str | None) -> LocalScorer:
|
||||
"""Load the SigLIP2 scorer on the requested device. Fails loudly."""
|
||||
scorer = LocalScorer(device=device)
|
||||
scorer.load()
|
||||
return scorer
|
||||
|
||||
|
||||
def main() -> None:
|
||||
args = parse_args()
|
||||
|
||||
# Fail loudly if scoring is unavailable — do not silently skip scoring
|
||||
if not args.no_score:
|
||||
verify_scoring_available()
|
||||
else:
|
||||
log.info("Scoring disabled (--no-score) — keeping every successful silhouette")
|
||||
# Load scorer once (SigLIP2 — local GPU, ~2 GB VRAM)
|
||||
scorer = load_scorer(args.device)
|
||||
|
||||
OUT_DIR.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
|
|
@ -667,7 +562,7 @@ def main() -> None:
|
|||
min_keepers=args.min_keepers,
|
||||
min_score=args.min_score,
|
||||
max_rounds=args.max_rounds,
|
||||
no_score=args.no_score,
|
||||
scorer=scorer,
|
||||
)
|
||||
index[key] = {
|
||||
"finalists": [
|
||||
|
|
|
|||
262
tooling/scripts/illustrations/local_scorer.py
Normal file
262
tooling/scripts/illustrations/local_scorer.py
Normal file
|
|
@ -0,0 +1,262 @@
|
|||
"""Local SigLIP2 scoring for Jessica Rabbit silhouettes.
|
||||
|
||||
In-process scoring (no subprocess, no MCP hop) modeled after
|
||||
@magic-civilization/tools/sprite-generation/engine/local_scorer.py.
|
||||
|
||||
Loads SigLIP2 once, scores each silhouette across multiple dimensions
|
||||
using contrastive softmax over per-dimension positive/negative text prompts.
|
||||
Returns a normalized 0–1 score per dimension plus a weighted total.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from dataclasses import dataclass, asdict
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from PIL import Image
|
||||
from transformers import AutoModel, AutoProcessor
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
MODEL_NAME = "google/siglip2-so400m-patch14-384"
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Dimension queries (positive vs negative)
|
||||
#
|
||||
# Score per dimension = softmax probability mass on positive prompts vs
|
||||
# negative prompts in a single forward pass. Range [0, 1].
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
DIMENSIONS: dict[str, dict[str, list[str]]] = {
|
||||
"single_figure": {
|
||||
"positive": [
|
||||
"single woman silhouette isolated on white background",
|
||||
"one female figure clean outline solo",
|
||||
"single character silhouette no other figures",
|
||||
],
|
||||
"negative": [
|
||||
"two women side by side double figure",
|
||||
"multiple figures crowd group of people",
|
||||
"character sheet reference page multiple poses",
|
||||
"split image diptych",
|
||||
],
|
||||
},
|
||||
"hourglass_gown": {
|
||||
"positive": [
|
||||
"tall woman in floor-length evening gown hourglass figure",
|
||||
"voluptuous bombshell silhouette long gown reaching the floor",
|
||||
"elegant woman full-length dress curvy hourglass shape",
|
||||
],
|
||||
"negative": [
|
||||
"short skirt mini dress bare legs",
|
||||
"wide bell skirt poofy crinoline dome shape",
|
||||
"child proportions chibi short stumpy figure",
|
||||
"narrow stick figure no curves flat silhouette",
|
||||
],
|
||||
},
|
||||
"full_body_pose": {
|
||||
"positive": [
|
||||
"full body silhouette from head to feet visible",
|
||||
"complete figure standing pose feet visible",
|
||||
"head to toe full length silhouette",
|
||||
],
|
||||
"negative": [
|
||||
"cropped at waist or knees portrait bust",
|
||||
"only head and shoulders close-up",
|
||||
"feet cut off below frame",
|
||||
],
|
||||
},
|
||||
"white_background": {
|
||||
"positive": [
|
||||
"clean plain white background empty",
|
||||
"solid white backdrop no scenery",
|
||||
"isolated figure on pure white",
|
||||
],
|
||||
"negative": [
|
||||
"scenery landscape background mountains sky",
|
||||
"indoor room interior background",
|
||||
"colored gradient background dark backdrop",
|
||||
"ground floor pavement under figure",
|
||||
],
|
||||
},
|
||||
"no_text_or_border": {
|
||||
"positive": [
|
||||
"image without any text or letters",
|
||||
"clean image no borders no frame",
|
||||
"no watermark or signature visible",
|
||||
],
|
||||
"negative": [
|
||||
"text caption letters watermark on image",
|
||||
"image with border frame around it",
|
||||
"signature or logo in corner",
|
||||
],
|
||||
},
|
||||
"silhouette_quality": {
|
||||
"positive": [
|
||||
"strong readable silhouette clean crisp outline",
|
||||
"bold solid black silhouette of a woman against white",
|
||||
"high contrast figure cleanly separated from background",
|
||||
],
|
||||
"negative": [
|
||||
"blurry fuzzy edges low quality",
|
||||
"noisy artifacts deformed mutated",
|
||||
"transparent ghostly figure poor contrast",
|
||||
],
|
||||
},
|
||||
}
|
||||
|
||||
# Per-dimension weights for the total score (sum to 1.0)
|
||||
WEIGHTS: dict[str, float] = {
|
||||
"single_figure": 0.25,
|
||||
"hourglass_gown": 0.20,
|
||||
"full_body_pose": 0.15,
|
||||
"white_background": 0.15,
|
||||
"no_text_or_border": 0.10,
|
||||
"silhouette_quality": 0.15,
|
||||
}
|
||||
|
||||
# Optional dynamic prop dimension is added per-call when a prop description is provided.
|
||||
# It compares the image against "woman holding/wearing <prop>" vs generic gown imagery.
|
||||
|
||||
|
||||
@dataclass
|
||||
class ScoreBreakdown:
|
||||
"""Per-dimension scores plus weighted total. All values in [0, 1]."""
|
||||
single_figure: float
|
||||
hourglass_gown: float
|
||||
full_body_pose: float
|
||||
white_background: float
|
||||
no_text_or_border: float
|
||||
silhouette_quality: float
|
||||
prop_clarity: float
|
||||
total: float
|
||||
|
||||
def as_dict(self) -> dict[str, float]:
|
||||
return asdict(self)
|
||||
|
||||
|
||||
class LocalScorer:
|
||||
"""SigLIP2 zero-shot scorer for Jessica Rabbit silhouettes.
|
||||
|
||||
Loads the model once, then scores arbitrary images via score_image().
|
||||
Use as a context manager to release the model when done.
|
||||
"""
|
||||
|
||||
def __init__(self, device: str | None = None) -> None:
|
||||
if device is None:
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
self._device = device
|
||||
self._model: AutoModel | None = None
|
||||
self._processor: AutoProcessor | None = None
|
||||
|
||||
def __enter__(self) -> "LocalScorer":
|
||||
self.load()
|
||||
return self
|
||||
|
||||
def __exit__(self, *_exc: object) -> None:
|
||||
self.unload()
|
||||
|
||||
def load(self) -> None:
|
||||
log.info("Loading SigLIP2 (%s) on %s ...", MODEL_NAME, self._device)
|
||||
self._processor = AutoProcessor.from_pretrained(MODEL_NAME)
|
||||
self._model = AutoModel.from_pretrained(MODEL_NAME).to(self._device)
|
||||
self._model.requires_grad_(False)
|
||||
log.info("SigLIP2 loaded")
|
||||
|
||||
def unload(self) -> None:
|
||||
self._model = None
|
||||
self._processor = None
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def _score_dim(
|
||||
self,
|
||||
image: Image.Image,
|
||||
positive: list[str],
|
||||
negative: list[str],
|
||||
) -> float:
|
||||
"""Single-pass softmax contrastive score: positive mass / total mass."""
|
||||
if self._model is None or self._processor is None:
|
||||
raise RuntimeError("Scorer not loaded — call load() first")
|
||||
if not positive or not negative:
|
||||
return 0.5
|
||||
|
||||
all_prompts = positive + negative
|
||||
n_pos = len(positive)
|
||||
|
||||
inputs = self._processor(
|
||||
text=all_prompts,
|
||||
images=image,
|
||||
return_tensors="pt",
|
||||
padding=True,
|
||||
)
|
||||
inputs = {k: v.to(self._device) for k, v in inputs.items()}
|
||||
|
||||
with torch.no_grad():
|
||||
outputs = self._model(**inputs)
|
||||
image_embeds = outputs.image_embeds
|
||||
text_embeds = outputs.text_embeds
|
||||
image_embeds = image_embeds / image_embeds.norm(dim=-1, keepdim=True)
|
||||
text_embeds = text_embeds / text_embeds.norm(dim=-1, keepdim=True)
|
||||
similarities = (image_embeds @ text_embeds.T).squeeze(0)
|
||||
probs = torch.softmax(similarities / 0.01, dim=-1)
|
||||
|
||||
prob_list = probs.cpu().tolist()
|
||||
if isinstance(prob_list, float):
|
||||
prob_list = [prob_list]
|
||||
return float(sum(prob_list[:n_pos]))
|
||||
|
||||
def score_image(
|
||||
self,
|
||||
image_path: str | Path,
|
||||
prop_description: str = "",
|
||||
) -> ScoreBreakdown:
|
||||
"""Score an image on all dimensions; returns 0–1 normalized scores.
|
||||
|
||||
prop_description: a short phrase describing the destination/specialty
|
||||
prop (e.g. "champagne flute") so the prop_clarity dimension can be
|
||||
computed dynamically. If empty, prop_clarity defaults to 0.5.
|
||||
"""
|
||||
if self._model is None:
|
||||
raise RuntimeError("Scorer not loaded — call load() first")
|
||||
|
||||
image = Image.open(image_path).convert("RGB")
|
||||
scores: dict[str, float] = {}
|
||||
for dim, queries in DIMENSIONS.items():
|
||||
scores[dim] = self._score_dim(image, queries["positive"], queries["negative"])
|
||||
|
||||
if prop_description:
|
||||
prop_clarity = self._score_dim(
|
||||
image,
|
||||
positive=[
|
||||
f"silhouette of a woman with {prop_description}",
|
||||
f"figure clearly showing {prop_description}",
|
||||
f"{prop_description} visible in the silhouette",
|
||||
],
|
||||
negative=[
|
||||
"plain woman in a gown with no distinctive accessory",
|
||||
"generic female silhouette no recognizable prop",
|
||||
"ambiguous silhouette no clear object",
|
||||
],
|
||||
)
|
||||
else:
|
||||
prop_clarity = 0.5
|
||||
|
||||
# Weighted total: dimension weights + a separate prop weight (0.20).
|
||||
# Renormalize the base weights to (1 - prop_weight) so the total stays in [0, 1].
|
||||
prop_weight = 0.20
|
||||
base_total = sum(scores[d] * WEIGHTS[d] for d in WEIGHTS) * (1 - prop_weight)
|
||||
total = base_total + prop_clarity * prop_weight
|
||||
|
||||
return ScoreBreakdown(
|
||||
single_figure=scores["single_figure"],
|
||||
hourglass_gown=scores["hourglass_gown"],
|
||||
full_body_pose=scores["full_body_pose"],
|
||||
white_background=scores["white_background"],
|
||||
no_text_or_border=scores["no_text_or_border"],
|
||||
silhouette_quality=scores["silhouette_quality"],
|
||||
prop_clarity=prop_clarity,
|
||||
total=total,
|
||||
)
|
||||
Loading…
Add table
Reference in a new issue