lilith-platform.live/tooling/scripts/illustrations/iterate.py
2026-04-08 09:41:11 -07:00

602 lines
21 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python3
"""
Illustration iteration pipeline — Jessica Rabbit concept.
Quinn is always herself: tall, blonde, floor-length gown, hourglass.
Destination = one glamorous tourist prop.
Specialty = tasteful silhouette abstraction for a service category.
Usage:
iterate.py [--slug SLUG] [--kind destination|specialty]
[--seeds N] [--min-keepers 3] [--min-score 7.0] [--max-rounds 3]
Default: iterate all slugs in both kinds.
"""
from __future__ import annotations
import argparse
import base64
import json
import logging
import os
import subprocess
import sys
import time
import tomllib
import urllib.error
import urllib.request
from dataclasses import dataclass, field, asdict
from pathlib import Path
from typing import Literal
from local_scorer import LocalScorer, ScoreBreakdown
# ---------------------------------------------------------------------------
# Paths
# ---------------------------------------------------------------------------
SCRIPT_DIR = Path(__file__).parent
PROMPTS_DIR = SCRIPT_DIR / "prompts"
OUT_DIR = SCRIPT_DIR / "out"
API_BASE = "http://localhost:8210/api/v1"
# ---------------------------------------------------------------------------
# Prompt constants (verbatim from v6_pipeline)
# ---------------------------------------------------------------------------
BASE = (
"1girl, solo, "
"tall voluptuous woman, long wavy blonde hair past shoulders, "
"floor-length strapless evening gown, hourglass figure, long elegant legs, "
"bombshell, femme fatale, glamorous,"
)
COMPOSITION = (
"white background, simple background, full body, floating, "
"no floor, no ground, no shadow,"
)
STYLE = "painted illustration, clean crisp edges, strong readable silhouette outline, masterpiece, best quality"
UNIVERSAL_NEG = (
"2girls, 3girls, multiple girls, multiple persons, multiple characters, "
"couple, crowd, group, "
"split image, side by side, reference sheet, multiple poses, character sheet, "
"text, watermark, blurry, low quality, deformed, mutated, disfigured, "
"shadow, drop shadow, cast shadow, "
"floor, ground, grass, pavement, base, pedestal, "
"black background, dark background, colored background, gradient background, "
"scenery background, outdoor scene, indoor scene, "
"landscape, sky, horizon, clouds, border, frame, "
"photograph, photorealistic, 3D render, CGI, hyperrealism, "
"short skirt, mini skirt, mini dress, short dress, short hemline, "
"wide skirt, bell skirt, dome skirt, crinoline, poofy skirt, "
"short, petite, chibi, loli, young girl, teenage, child, "
"sd chibi, deformed proportions, stumpy legs, child proportions"
)
# ---------------------------------------------------------------------------
# Data types
# ---------------------------------------------------------------------------
Kind = Literal["destination", "specialty"]
@dataclass
class SlugEntry:
prop: str
extra_neg: str = ""
@dataclass
class Finalist:
file: str
score: float
breakdown: ScoreBreakdown
@dataclass
class SlugResult:
finalists: list[Finalist] = field(default_factory=list)
rounds: int = 0
# ---------------------------------------------------------------------------
# Logging
# ---------------------------------------------------------------------------
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s [%(levelname)s] %(message)s",
datefmt="%H:%M:%S",
)
log = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# TOML loading
# ---------------------------------------------------------------------------
def load_slugs(kind: Kind) -> dict[str, SlugEntry]:
toml_name = "destinations.toml" if kind == "destination" else "specialties.toml"
toml_path = PROMPTS_DIR / toml_name
if not toml_path.exists():
raise FileNotFoundError(f"TOML not found: {toml_path}")
with open(toml_path, "rb") as f:
raw: dict[str, dict[str, str]] = tomllib.load(f)
return {slug: SlugEntry(**data) for slug, data in raw.items()}
# ---------------------------------------------------------------------------
# Prompt building
# ---------------------------------------------------------------------------
def build_prompt(entry: SlugEntry) -> tuple[str, str]:
pos = f"{BASE} {entry.prop} {COMPOSITION} {STYLE}"
neg = UNIVERSAL_NEG
if entry.extra_neg:
neg = f"{neg}, {entry.extra_neg}"
return pos.strip(), neg.strip()
# ---------------------------------------------------------------------------
# model-boss diffusion API
# ---------------------------------------------------------------------------
def _api_request(method: str, path: str, body: dict | None = None) -> dict | list:
url = f"{API_BASE}{path}"
data = json.dumps(body).encode() if body is not None else None
req = urllib.request.Request(
url,
data=data,
method=method,
headers={"Content-Type": "application/json"} if data else {},
)
try:
with urllib.request.urlopen(req, timeout=30) as resp:
return json.loads(resp.read())
except urllib.error.HTTPError as e:
body_text = e.read().decode(errors="replace")[:500]
raise RuntimeError(f"HTTP {e.code} {method} {url}: {body_text}") from e
except urllib.error.URLError as e:
raise RuntimeError(f"Cannot reach {url}: {e.reason}") from e
def submit_job(prompt: str, neg: str, seed: int) -> str:
result = _api_request(
"POST",
"/diffusion/jobs",
{
"model": "animagine-xl-4.0-opt",
"prompt": prompt,
"negative_prompt": neg,
"width": 768,
"height": 1152,
"steps": 35,
"guidance_scale": 6.5,
"seed": seed,
},
)
assert isinstance(result, dict)
job_id: str | None = result.get("jobId") or result.get("job_id")
if not job_id:
raise RuntimeError(f"No jobId in submit response: {result}")
return job_id
def poll_until_done(job_id: str, timeout: int = 600) -> None:
"""Block until the job is completed. Raises RuntimeError on failure/timeout."""
deadline = time.time() + timeout
while time.time() < deadline:
result = _api_request("GET", f"/diffusion/jobs/{job_id}")
assert isinstance(result, dict)
status = result.get("status", "")
if status == "completed":
return
if status == "failed":
raise RuntimeError(f"Job {job_id} failed: {result.get('error', 'unknown')}")
time.sleep(15)
raise RuntimeError(f"Job {job_id} timed out after {timeout}s")
def retrieve_image(job_id: str, out_path: Path) -> None:
result = _api_request("GET", f"/diffusion/jobs/{job_id}/result")
assert isinstance(result, dict)
images = result.get("images")
if not images:
raise RuntimeError(f"No images in result for job {job_id}: {result}")
img_bytes = base64.b64decode(images[0])
out_path.write_bytes(img_bytes)
log.info(" saved raw -> %s (%d KB)", out_path.name, len(img_bytes) // 1024)
# ---------------------------------------------------------------------------
# rembg silhouette step (verbatim logic from v6_pipeline, subprocess for isolation)
# ---------------------------------------------------------------------------
_REMBG_SCRIPT = """
from rembg import remove, new_session
from PIL import Image
from io import BytesIO
import numpy as np, sys
raw_path = sys.argv[1]
sil_path = sys.argv[2]
session = new_session('silueta')
with open(raw_path, 'rb') as f:
raw = f.read()
rgba = Image.open(BytesIO(remove(raw, session=session))).convert('RGBA')
arr = np.array(rgba)
mask = arr[:, :, 3] > 128
try:
from scipy.ndimage import binary_fill_holes, binary_opening
mask = binary_opening(mask, iterations=2)
mask = binary_fill_holes(mask)
except ImportError:
pass
l = np.where(mask, 0, 255).astype(np.uint8)
a = np.full(l.shape, 255, dtype=np.uint8)
Image.merge('LA', [Image.fromarray(l, 'L'), Image.fromarray(a, 'L')]).save(sil_path)
print('OK')
"""
def to_silhouette(raw_path: Path, sil_path: Path) -> None:
result = subprocess.run(
["python3", "-c", _REMBG_SCRIPT, str(raw_path), str(sil_path)],
capture_output=True,
text=True,
timeout=120,
)
if result.returncode != 0:
raise RuntimeError(f"rembg failed for {raw_path.name}:\n{result.stderr[:500]}")
log.info(" silhouette -> %s", sil_path.name)
# ---------------------------------------------------------------------------
# Scoring via claude CLI + mcp__model-boss__score_image_rubric
#
# NOTE: model-boss has no HTTP scoring endpoint — the /api/v1/* routes all
# return the SPA shell (confirmed at pipeline build time). Scoring is
# therefore performed by invoking the `claude` CLI as a subprocess and asking
# it to use the MCP tool. The claude binary must be on PATH and the
# model-boss MCP server must be configured in ~/.claude.json.
# ---------------------------------------------------------------------------
def extract_prop_description(entry: SlugEntry) -> str:
"""
Pull the headline noun phrase out of the prop string for the scoring prompt.
The full prop string contains pose verbiage; we want only the distinctive object.
Strategy: take the chunk before the first comma, strip leading verbs.
"""
head = entry.prop.split(",", 1)[0].strip()
# Strip common leading gerunds so the prompt becomes "champagne flute"
# instead of "holding champagne flute raised elegantly in one hand".
for prefix in ("holding a ", "holding ", "wearing a ", "wearing ", "both arms raised ",
"one arm ", "tartan sash ", "single ", "carrying ", "with "):
if head.lower().startswith(prefix):
head = head[len(prefix):]
break
return head[:60]
# ---------------------------------------------------------------------------
# Index + review HTML
# ---------------------------------------------------------------------------
def load_index() -> dict[str, dict]:
index_path = OUT_DIR / "_index.json"
if index_path.exists():
return json.loads(index_path.read_text())
return {}
def save_index(index: dict[str, dict]) -> None:
index_path = OUT_DIR / "_index.json"
index_path.write_text(json.dumps(index, indent=2))
def regenerate_review_html(index: dict[str, dict]) -> None:
sections: list[str] = []
for key, data in sorted(index.items()):
finalists: list[dict] = sorted(
data.get("finalists", []), key=lambda f: f["score"], reverse=True
)
rounds = data.get("rounds", 0)
cards: list[str] = []
for fin in finalists:
fname = fin["file"]
score = fin["score"]
bd = fin.get("breakdown", {})
raw_name = fname.replace(".png", "_raw.png")
bd_rows = "".join(
f"<tr><td>{k.replace('_', ' ')}</td><td>{v:.3f}</td></tr>"
for k, v in bd.items()
if k != "total"
)
cards.append(
f"""
<div class="card">
<div class="images">
<figure>
<img src="{fname}" alt="silhouette" loading="lazy">
<figcaption>silhouette</figcaption>
</figure>
<figure>
<img src="{raw_name}" alt="raw" loading="lazy">
<figcaption>raw</figcaption>
</figure>
</div>
<div class="meta">
<strong>{fname}</strong>
<span class="score">score: {score:.3f}</span>
<table>{bd_rows}</table>
</div>
</div>"""
)
cards_html = "\n".join(cards) if cards else "<p>No finalists.</p>"
sections.append(
f""" <section>
<h2>{key} <small>({rounds} round{'s' if rounds != 1 else ''})</small></h2>
<div class="cards">{cards_html}
</div>
</section>"""
)
body = "\n".join(sections) if sections else "<p>No results yet.</p>"
html = f"""<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<title>Illustration Review</title>
<style>
* {{ box-sizing: border-box; margin: 0; padding: 0; }}
body {{ font-family: system-ui, sans-serif; background: #111; color: #eee; padding: 2rem; }}
h1 {{ margin-bottom: 2rem; font-size: 1.8rem; }}
h2 {{ font-size: 1.2rem; margin-bottom: 1rem; border-bottom: 1px solid #333; padding-bottom: .5rem; }}
small {{ color: #888; font-weight: normal; }}
section {{ margin-bottom: 3rem; }}
.cards {{ display: flex; flex-wrap: wrap; gap: 1.5rem; }}
.card {{ background: #1e1e1e; border-radius: 8px; padding: 1rem; width: 420px; }}
.images {{ display: flex; gap: 1rem; margin-bottom: .75rem; }}
.images figure {{ flex: 1; text-align: center; }}
.images img {{ width: 100%; border-radius: 4px; display: block; }}
figcaption {{ font-size: .7rem; color: #888; margin-top: .25rem; }}
.meta strong {{ display: block; font-size: .85rem; word-break: break-all; }}
.score {{ display: block; font-size: 1.1rem; font-weight: bold; color: #7cf; margin: .25rem 0; }}
table {{ font-size: .75rem; width: 100%; border-collapse: collapse; margin-top: .5rem; }}
td {{ padding: .15rem .25rem; border-bottom: 1px solid #2a2a2a; }}
td:last-child {{ text-align: right; color: #7cf; }}
</style>
</head>
<body>
<h1>Illustration Review</h1>
{body}
</body>
</html>"""
review_path = OUT_DIR / "_review.html"
review_path.write_text(html)
log.info("Review HTML updated -> %s", review_path)
# ---------------------------------------------------------------------------
# Per-slug iteration
# ---------------------------------------------------------------------------
def process_slug(
slug: str,
kind: Kind,
entry: SlugEntry,
seeds_count: int,
min_keepers: int,
min_score: float,
max_rounds: int,
scorer: LocalScorer,
) -> SlugResult:
log.info("=" * 60)
log.info("%s / %s", kind, slug)
log.info("=" * 60)
prompt, neg = build_prompt(entry)
candidates: list[Finalist] = []
round_num = 0
# Use a deterministic seed sequence; each round shifts the window.
base_seeds = list(range(1000, 1000 + seeds_count * max_rounds))
while round_num < max_rounds and len(candidates) < min_keepers:
round_num += 1
round_seeds = base_seeds[(round_num - 1) * seeds_count : round_num * seeds_count]
log.info(" Round %d — submitting %d jobs", round_num, len(round_seeds))
# Submit all jobs for this round
jobs: list[tuple[int, int, str]] = [] # (seed_index, seed, job_id)
for vi, seed in enumerate(round_seeds, start=1):
job_id = submit_job(prompt, neg, seed)
jobs.append((vi, seed, job_id))
log.info(" [submit] v%d seed=%d -> %s", vi, seed, job_id)
# Poll, retrieve, rembg, score
for vi, seed, job_id in jobs:
tag = f"{kind}_{slug}_r{round_num}v{vi}"
raw_path = OUT_DIR / f"{tag}_raw.png"
sil_path = OUT_DIR / f"{tag}.png"
score_path = OUT_DIR / f"{tag}.score.json"
try:
poll_until_done(job_id)
except RuntimeError as exc:
log.error(" [poll] %s: %s", tag, exc)
continue
try:
retrieve_image(job_id, raw_path)
except RuntimeError as exc:
log.error(" [retrieve] %s: %s", tag, exc)
continue
try:
to_silhouette(raw_path, sil_path)
except RuntimeError as exc:
log.error(" [rembg] %s: %s", tag, exc)
continue
prop_desc = extract_prop_description(entry)
try:
# Score the RAW image (not silhouette) — SigLIP2 needs semantic
# content (color, texture, shape) for accurate classification.
# The silhouette is the final asset; the raw is the quality signal.
breakdown = scorer.score_image(raw_path, prop_description=prop_desc)
except Exception as exc:
log.error(" [score] %s: %s", tag, exc)
continue
total = breakdown.total
score_data = {**breakdown.as_dict(), "file": sil_path.name}
score_path.write_text(json.dumps(score_data, indent=2))
log.info(" [score] %s -> %.3f", tag, total)
candidates.append(
Finalist(file=sil_path.name, score=total, breakdown=breakdown)
)
log.info(" [candidate] %s score=%.3f", tag, total)
log.info(
" Round %d done — %d candidates so far (need %d)",
round_num,
len(candidates),
min_keepers,
)
# Keep the top min_keepers by score (ranking, not hard threshold).
# All candidates go into the review HTML; the top ones are flagged as finalists.
candidates.sort(key=lambda f: f.score, reverse=True)
finalists = candidates[:min_keepers]
if len(finalists) < min_keepers:
log.warning(
" %s/%s: only %d/%d candidates after %d rounds",
kind,
slug,
len(finalists),
min_keepers,
max_rounds,
)
else:
log.info(
" %s/%s: top %d of %d candidates (score range %.3f%.3f)",
kind,
slug,
min_keepers,
len(candidates),
finalists[-1].score,
finalists[0].score,
)
return SlugResult(finalists=finalists, rounds=round_num)
# ---------------------------------------------------------------------------
# Entry point
# ---------------------------------------------------------------------------
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(
description="Illustration iteration pipeline — Jessica Rabbit concept"
)
parser.add_argument("--slug", help="Single slug to process (default: all)")
parser.add_argument(
"--kind",
choices=["destination", "specialty"],
help="Limit to this kind (default: both)",
)
parser.add_argument("--seeds", type=int, default=8, metavar="N", help="Seeds per round (default: 8)")
parser.add_argument("--min-keepers", type=int, default=3, help="Minimum finalists per slug (default: 3)")
parser.add_argument("--min-score", type=float, default=0.55, help="Minimum weighted score to keep (default: 0.55, range 01)")
parser.add_argument("--max-rounds", type=int, default=3, help="Maximum rounds per slug (default: 3)")
parser.add_argument("--device", default=None, help="Scoring device (default: cuda if available)")
return parser.parse_args()
def load_scorer(device: str | None) -> LocalScorer:
"""Load the SigLIP2 scorer on the requested device. Fails loudly."""
scorer = LocalScorer(device=device)
scorer.load()
return scorer
def main() -> None:
args = parse_args()
# Load scorer once (SigLIP2 — local GPU, ~2 GB VRAM)
scorer = load_scorer(args.device)
OUT_DIR.mkdir(parents=True, exist_ok=True)
kinds: list[Kind] = (
[args.kind] if args.kind else ["destination", "specialty"] # type: ignore[list-item]
)
# Build work list
work: list[tuple[str, Kind, SlugEntry]] = []
for kind in kinds:
slugs_map = load_slugs(kind)
if args.slug:
if args.slug not in slugs_map:
log.error("Slug %r not found in %s TOMLs", args.slug, kind)
sys.exit(1)
work.append((args.slug, kind, slugs_map[args.slug]))
else:
work.extend((slug, kind, entry) for slug, entry in slugs_map.items())
log.info(
"Pipeline start — %d slug(s), seeds=%d, min_keepers=%d, min_score=%.1f, max_rounds=%d",
len(work),
args.seeds,
args.min_keepers,
args.min_score,
args.max_rounds,
)
index = load_index()
for slug, kind, entry in work:
key = f"{kind}/{slug}"
result = process_slug(
slug=slug,
kind=kind,
entry=entry,
seeds_count=args.seeds,
min_keepers=args.min_keepers,
min_score=args.min_score,
max_rounds=args.max_rounds,
scorer=scorer,
)
index[key] = {
"finalists": [
{
"file": f.file,
"score": f.score,
"breakdown": asdict(f.breakdown),
}
for f in result.finalists
],
"rounds": result.rounds,
}
save_index(index)
regenerate_review_html(index)
log.info("Index + review HTML updated after %s/%s", kind, slug)
log.info("ALL DONE — outputs in %s", OUT_DIR)
if __name__ == "__main__":
main()