#!/usr/bin/env python3
"""
🦷⟐ GHOST_DECODE v50 — Optical State Persistence Decoder
=========================================================

Decodes a Ghostprint v50 Visible Barcode image (PNG/JPG) by reading RGB
macro-pixels from the Horizon ring, reconstructing a byte stream,
zlib-decompressing it, and printing the recovered payload.

Designed to decode images generated by the v50 generator:
  - IMG_SIZE ~ 2048
  - CENTER at image center
  - DATA_START_RADIUS ~ 610
  - BLOCK_SIZE = 4 (macro-pixels)
  - RING STEP = BLOCK_SIZE + 1 (gap buffer)
  - ANGLE STEP = atan(BLOCK_SIZE / radius)

Works best on PNG. JPG may work if blocks are big enough and compression
isn't brutal.
"""

import argparse
import math
import sys
import zlib
from dataclasses import dataclass
from typing import Tuple, Optional

import numpy as np
from PIL import Image


@dataclass
class DecodeConfig:
    block_size: int = 4
    gap: int = 1
    start_radius: Optional[int] = None  # if None, auto-detect
    max_radius: Optional[int] = None    # if None, use (min_dim/2 - 8)
    angle_start: float = 0.0
    background_threshold: float = 18.0  # color distance to detect first block
    sample_stat: str = "median"         # "median" or "mean"
    debug: bool = False


def _rgb_distance(a: np.ndarray, b: np.ndarray) -> float:
    # Euclidean distance in RGB space
    d = a.astype(np.float32) - b.astype(np.float32)
    return float(np.sqrt(np.sum(d * d)))


def estimate_background(img_np: np.ndarray) -> np.ndarray:
    """
    Estimate background color by sampling corners and taking median.
    """
    h, w, _ = img_np.shape
    # sample small patches from 4 corners
    patches = []
    s = max(8, min(h, w) // 64)  # adaptive small patch
    patches.append(img_np[0:s, 0:s, :].reshape(-1, 3))
    patches.append(img_np[0:s, w - s:w, :].reshape(-1, 3))
    patches.append(img_np[h - s:h, 0:s, :].reshape(-1, 3))
    patches.append(img_np[h - s:h, w - s:w, :].reshape(-1, 3))
    all_px = np.concatenate(patches, axis=0)
    bg = np.median(all_px, axis=0)
    return bg.astype(np.uint8)


def autodetect_start_radius(
    img_np: np.ndarray,
    center: Tuple[int, int],
    bg: np.ndarray,
    cfg: DecodeConfig,
) -> int:
    """
    Scan outward along +x from the center to find where the horizon data blocks begin.
    Looks for first sustained deviation from background.
    """
    h, w, _ = img_np.shape
    cx, cy = center
    r_min = 400
    r_max = int(min(h, w) / 2) - 8

    # We look for a run of "non-background" points, not a single noisy pixel.
    run_needed = 6
    run = 0
    found = None
    for r in range(r_min, r_max):
        x = cx + r
        y = cy
        if x < 0 or x >= w:
            break
        px = img_np[y, x, :]
        dist = _rgb_distance(px, bg)
        if dist >= cfg.background_threshold:
            run += 1
            if run >= run_needed:
                found = r - (run_needed - 1)
                break
        else:
            run = 0

    if found is None:
        # Fall back to the known default if detection fails.
        return 610

    return int(found)


def block_color(
    img_np: np.ndarray, x: float, y: float, cfg: DecodeConfig
) -> Tuple[int, int, int]:
    """
    Sample the color of a BLOCK_SIZE x BLOCK_SIZE macro-pixel centered at (x, y).
    Uses median/mean over the block for JPEG robustness.
    """
    h, w, _ = img_np.shape
    bs = cfg.block_size
    half = bs // 2
    xi = int(round(x))
    yi = int(round(y))
    x0 = max(0, xi - half)
    y0 = max(0, yi - half)
    x1 = min(w, x0 + bs)
    y1 = min(h, y0 + bs)
    block = img_np[y0:y1, x0:x1, :].reshape(-1, 3)
    if block.size == 0:
        return (0, 0, 0)
    if cfg.sample_stat == "mean":
        c = np.mean(block.astype(np.float32), axis=0)
    else:
        c = np.median(block, axis=0)
    c = np.clip(np.round(c), 0, 255).astype(np.uint8)
    return (int(c[0]), int(c[1]), int(c[2]))


def read_horizon_bytes(
    img_np: np.ndarray,
    center: Tuple[int, int],
    cfg: DecodeConfig,
) -> bytes:
    """
    Walk the same spiral used by the generator, sampling macro-pixels into RGB tuples,
    then flattening into a byte stream.
    """
    h, w, _ = img_np.shape
    cx, cy = center

    if cfg.start_radius is None:
        bg = estimate_background(img_np)
        cfg.start_radius = autodetect_start_radius(img_np, (cx, cy), bg, cfg)
        if cfg.debug:
            print(f"[debug] autodetected start_radius={cfg.start_radius}", file=sys.stderr)

    if cfg.max_radius is None:
        cfg.max_radius = int(min(h, w) / 2) - 8

    radius = float(cfg.start_radius)
    angle = float(cfg.angle_start)
    step_out = cfg.block_size + cfg.gap

    colors = []
    rings = 0
    samples = 0

    while radius < cfg.max_radius:
        # Keep arc length roughly constant as generator did:
        angle_step = math.atan(cfg.block_size / radius) if radius > 0 else 0.1

        # Walk one full ring at this radius
        while angle < 2 * math.pi:
            x = cx + radius * math.cos(angle)
            y = cy + radius * math.sin(angle)
            colors.append(block_color(img_np, x, y, cfg))
            samples += 1
            angle += angle_step

        # Next ring
        rings += 1
        angle = 0.0
        radius += step_out

        # Safety valve (avoid runaway on unexpected images)
        if samples > 2_500_000:
            break

    if cfg.debug:
        print(f"[debug] rings={rings}, samples={samples}", file=sys.stderr)

    # Flatten RGB triplets -> bytes
    out = bytearray()
    for (r, g, b) in colors:
        out.append(r)
        out.append(g)
        out.append(b)
    return bytes(out)


def zlib_decompress_until_eof(data: bytes) -> bytes:
    """
    Decompress a zlib stream from the front of data.
    Ignores trailing junk (which will exist if we oversampled beyond payload).
    """
    d = zlib.decompressobj()
    try:
        out = d.decompress(data)
        if not d.eof:
            out += d.flush()
        if not d.eof:
            raise ValueError("zlib stream did not reach EOF; need more or correct alignment.")
        return out
    except zlib.error as e:
        raise ValueError(f"zlib decode failed: {e}") from e


def decode_framed_payload(data: bytes) -> bytes:
    """
    Decode a framed payload: 4-byte big-endian length prefix + zlib data.
    This is the format from ghost_barcode_v50_aligned.py.
    """
    if len(data) < 4:
        raise ValueError("Data too short for length prefix")
    
    length = int.from_bytes(data[:4], 'big')
    if length > len(data) - 4:
        raise ValueError(f"Payload claims {length} bytes but only {len(data) - 4} available")
    
    compressed = data[4:4 + length]
    return zlib.decompress(compressed)


def decode_ghostprint(image_path: str, debug: bool = False) -> Optional[str]:
    """
    Convenience function to decode a ghostprint image and return the payload as string.
    Returns None if decoding fails.
    """
    cfg = DecodeConfig(debug=debug)
    
    try:
        img = Image.open(image_path).convert("RGB")
        img_np = np.array(img)
        h, w, _ = img_np.shape
        cx, cy = w // 2, h // 2
        
        if debug:
            print(f"[debug] image size={w}x{h}, center=({cx},{cy})", file=sys.stderr)
        
        raw = read_horizon_bytes(img_np, (cx, cy), cfg)
        
        # Try framed format first, fall back to raw zlib
        try:
            payload_bytes = decode_framed_payload(raw)
        except:
            payload_bytes = zlib_decompress_until_eof(raw)
        
        return payload_bytes.decode("utf-8")
    except Exception as e:
        if debug:
            print(f"[debug] decode failed: {e}", file=sys.stderr)
        return None


def main():
    ap = argparse.ArgumentParser(description="Decode Ghostprint v50 visible horizon payload.")
    ap.add_argument("image", help="Input image: ghostprint_v50_visible.png or .jpg")
    ap.add_argument("--block-size", type=int, default=4,
                    help="Macro-pixel block size (default: 4)")
    ap.add_argument("--gap", type=int, default=1,
                    help="Gap between rings (default: 1)")
    ap.add_argument("--start-radius", type=int, default=None,
                    help="Start radius for horizon (default: auto-detect)")
    ap.add_argument("--max-radius", type=int, default=None,
                    help="Max radius to sample (default: half-min-dim-8)")
    ap.add_argument("--stat", choices=["median", "mean"], default="median",
                    help="Block sampling statistic (default: median)")
    ap.add_argument("--bg-threshold", type=float, default=18.0,
                    help="Auto-detect threshold vs background (default: 18)")
    ap.add_argument("--debug", action="store_true",
                    help="Print debug info to stderr")
    args = ap.parse_args()

    cfg = DecodeConfig(
        block_size=args.block_size,
        gap=args.gap,
        start_radius=args.start_radius,
        max_radius=args.max_radius,
        sample_stat=args.stat,
        background_threshold=args.bg_threshold,
        debug=args.debug,
    )

    img = Image.open(args.image).convert("RGB")
    img_np = np.array(img)
    h, w, _ = img_np.shape
    cx, cy = w // 2, h // 2

    if cfg.debug:
        print(f"[debug] image size={w}x{h}, center=({cx},{cy})", file=sys.stderr)

    raw = read_horizon_bytes(img_np, (cx, cy), cfg)
    
    # Try framed format first (4-byte length prefix), fall back to raw zlib
    try:
        payload_bytes = decode_framed_payload(raw)
        if cfg.debug:
            print("[debug] decoded using framed format (length prefix)", file=sys.stderr)
    except Exception as e:
        if cfg.debug:
            print(f"[debug] framed decode failed ({e}), trying raw zlib", file=sys.stderr)
        payload_bytes = zlib_decompress_until_eof(raw)

    # Try UTF-8 first; if it fails, print bytes repr.
    try:
        text = payload_bytes.decode("utf-8")
        print(text)
    except UnicodeDecodeError:
        sys.stdout.buffer.write(payload_bytes)


if __name__ == "__main__":
    main()
