"""
Admin Control Plane Policy Engine · HTTP service · Phase 2a
A-owned · session_a · FastAPI

Loads the 5 Phase 1 JSON models once at startup · wraps policy_engine.py
with an HTTP boundary. Dev-mode context via X-PTT-* headers or body.
No JWT · no Redis cache · no audit sink · decision-only (no state mutation).

Endpoints:
  GET  /api/policy/access/check
  POST /api/policy/access/check/batch
  POST /api/policy/mask/resolve
  POST /api/policy/assist/validate
  POST /api/policy/view-as/validate
  POST /api/policy/sensitive/check
  GET  /api/policy/health

Contract: ./policy_contract.json
Examples: ./policy_examples.json
Run:      bash run.sh
"""
from __future__ import annotations

import json
import logging
import os
import time
import uuid
from datetime import datetime, timezone
from typing import Any

from fastapi import FastAPI, HTTPException, Request
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse

from policy_engine import (
    ENGINE_VERSION,
    access_check, mask_resolve, assist_validate, viewas_validate,
    sensitive_check, health as engine_health,
)

SERVICE_VERSION = "2a-py-scaffold-0.1"

HERE = os.path.dirname(os.path.abspath(__file__))
ADMIN_DIR = os.environ.get("ADMIN_MODELS_DIR",
                           os.path.abspath(os.path.join(HERE, "..", "admin-control-plane")))

MODEL_FILES = {
    "role_registry":        "role_registry.json",
    "access_policy":        "access_policy.json",
    "assist_model":         "assist_model.json",
    "approval_queue_model": "approval_queue_model.json",
    "audit_event_model":    "audit_event_model.json",
}

logging.basicConfig(
    level=os.environ.get("POLICY_LOG_LEVEL", "INFO"),
    format='{"ts":"%(asctime)s","level":"%(levelname)s","msg":%(message)s}',
)
log = logging.getLogger("policy-service")


def _load_models() -> tuple[dict, dict]:
    """Returns (models_dict, errors_dict). Success when errors_dict is empty."""
    models = {}
    errors = {}
    for key, fname in MODEL_FILES.items():
        path = os.path.join(ADMIN_DIR, fname)
        try:
            with open(path, encoding="utf-8") as f:
                models[key] = json.load(f)
        except FileNotFoundError:
            errors[key] = f"not found at {path}"
        except json.JSONDecodeError as e:
            errors[key] = f"JSON parse failed: {e}"
        except Exception as e:  # noqa: BLE001
            errors[key] = f"load error: {e}"
    return models, errors


MODELS, MODEL_ERRORS = _load_models()
STARTED_AT = datetime.now(timezone.utc)

app = FastAPI(
    title="PTT Admin Control Plane · Policy Engine",
    version=SERVICE_VERSION,
    description="Phase 2a HTTP boundary · wraps policy_engine.py (24/24 canonical examples pass)",
)

app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=False,
    allow_methods=["GET", "POST", "OPTIONS"],
    allow_headers=["*"],
)


# ================================================================ helpers ===

def _envelope(ok: bool, data: Any = None, error: dict | None = None, request_id: str | None = None) -> dict:
    return {
        "ok": ok,
        "data": data,
        "error": error,
        "service": {
            "service_version": SERVICE_VERSION,
            "engine_version": ENGINE_VERSION,
            "request_id": request_id or str(uuid.uuid4()),
        },
    }


def _err(code: str, message: str, hint: str | None = None) -> dict:
    e = {"code": code, "message": message}
    if hint:
        e["hint"] = hint
    return e


def _req_id(request: Request) -> str:
    return request.headers.get("x-request-id") or str(uuid.uuid4())


def _hdr_ctx(request: Request) -> dict[str, Any]:
    """Extract context from X-PTT-* headers. Missing fields → None."""
    h = request.headers
    out: dict[str, Any] = {
        "actor_role":       h.get("x-ptt-actor-role"),
        "actor_user_id":    h.get("x-ptt-actor-user-id"),
        "actor_tenant_id":  h.get("x-ptt-actor-tenant-id"),
        "target_tenant_id": h.get("x-ptt-target-tenant-id"),
        "target_user_id":   h.get("x-ptt-target-user-id"),
    }
    refs = h.get("x-ptt-approval-refs")
    if refs:
        out["approval_refs"] = [s.strip() for s in refs.split(",") if s.strip()]
    return {k: v for k, v in out.items() if v is not None}


def _merge_ctx(*sources: dict[str, Any]) -> dict[str, Any]:
    """Merge right-to-left · later sources override earlier. Drops None values."""
    merged: dict[str, Any] = {}
    for s in sources:
        if not s:
            continue
        for k, v in s.items():
            if v is None:
                continue
            merged[k] = v
    return merged


def _require_models(request_id: str):
    if MODEL_ERRORS:
        raise HTTPException(
            status_code=503,
            detail=_envelope(
                False,
                error=_err("models_unavailable",
                           "one or more Phase 1 model files failed to load",
                           hint="; ".join(f"{k}: {v}" for k, v in MODEL_ERRORS.items())),
                request_id=request_id,
            ),
        )


def _log_event(event: str, rid: str, decision: Any, duration_ms: float, extra: dict | None = None):
    rec = {"request_id": rid, "event": event, "decision": decision, "duration_ms": duration_ms}
    if extra:
        rec.update(extra)
    log.info(json.dumps(rec))


# =============================================================== handlers ===

@app.exception_handler(HTTPException)
async def http_ex_handler(request: Request, exc: HTTPException):
    if isinstance(exc.detail, dict) and "ok" in exc.detail:
        return JSONResponse(status_code=exc.status_code, content=exc.detail)
    return JSONResponse(
        status_code=exc.status_code,
        content=_envelope(False, error=_err("internal_error", str(exc.detail)), request_id=_req_id(request)),
    )


@app.exception_handler(Exception)
async def any_ex_handler(request: Request, exc: Exception):
    log.exception("unhandled")
    return JSONResponse(
        status_code=500,
        content=_envelope(False, error=_err("internal_error", str(exc)), request_id=_req_id(request)),
    )


# ================================================================ routes ===

@app.get("/api/policy/health")
def health(request: Request):
    rid = _req_id(request)
    engine_view = engine_health(MODELS, STARTED_AT) if not MODEL_ERRORS else {
        "status": "degraded",
        "models_loaded": {k: bool(MODELS.get(k)) for k in MODEL_FILES.keys()},
        "counts": {},
        "uptime_seconds": int((datetime.now(timezone.utc) - STARTED_AT).total_seconds()),
        "engine_version": ENGINE_VERSION,
    }
    ok = not bool(MODEL_ERRORS)
    status = 200 if ok else 503
    data = {
        **engine_view,
        "service_version": SERVICE_VERSION,
        "models_errors": MODEL_ERRORS or None,
    }
    return JSONResponse(status_code=status, content=_envelope(ok, data=data, request_id=rid))


@app.get("/api/policy/access/check")
def ep_access_check(request: Request):
    rid = _req_id(request)
    _require_models(rid)
    q = dict(request.query_params)

    approval_refs_q = q.get("approval_refs") or ""
    parsed_refs = [s.strip() for s in approval_refs_q.split(",") if s.strip()]

    q_ctx = {
        "actor_role":       q.get("actor_role"),
        "actor_user_id":    q.get("actor_user_id"),
        "actor_tenant_id":  q.get("actor_tenant_id"),
        "target_tenant_id": q.get("target_tenant_id"),
        "target_user_id":   q.get("target_user_id"),
        "field_category":   q.get("field_category"),
        "requested_action": q.get("requested_action"),
        "is_sensitive":     (q.get("is_sensitive", "false").lower() in ("1","true","yes")),
        "now_iso":          q.get("now_iso"),
    }
    if parsed_refs:
        q_ctx["approval_refs"] = parsed_refs

    ctx = _merge_ctx(_hdr_ctx(request), q_ctx)

    if not ctx.get("actor_role") or not ctx.get("actor_user_id"):
        raise HTTPException(status_code=400, detail=_envelope(
            False,
            error=_err("invalid_request", "actor_role and actor_user_id are required",
                       hint="set X-PTT-Actor-Role / X-PTT-Actor-User-Id or query params"),
            request_id=rid,
        ))

    t0 = time.perf_counter()
    result = access_check(MODELS, ctx)
    dur = round((time.perf_counter() - t0) * 1000, 3)
    _log_event("access_check", rid, result.get("decision"), dur,
               {"actor_role": ctx.get("actor_role"), "field_category": ctx.get("field_category")})
    return JSONResponse(status_code=200,
                        content=_envelope(True, data=result, request_id=rid),
                        headers={"X-Request-Id": rid})


@app.post("/api/policy/access/check/batch")
async def ep_access_check_batch(request: Request):
    rid = _req_id(request)
    _require_models(rid)
    try:
        body = await request.json()
    except Exception:
        raise HTTPException(status_code=400, detail=_envelope(
            False,
            error=_err("invalid_request", "body is not valid JSON",
                       hint='expect {"context":{...},"items":[...]}'),
            request_id=rid,
        ))

    shared = body.get("context") or {}
    items = body.get("items")
    if not isinstance(items, list) or not items:
        raise HTTPException(status_code=400, detail=_envelope(
            False, error=_err("invalid_request", "items[] missing or empty"), request_id=rid,
        ))

    base_ctx = _merge_ctx(_hdr_ctx(request), shared)
    if not base_ctx.get("actor_role") or not base_ctx.get("actor_user_id"):
        raise HTTPException(status_code=400, detail=_envelope(
            False,
            error=_err("invalid_request", "context.actor_role and context.actor_user_id required"),
            request_id=rid,
        ))

    t0 = time.perf_counter()
    results = []
    for it in items:
        per = _merge_ctx(base_ctx, it if isinstance(it, dict) else {})
        results.append(access_check(MODELS, per))
    dur = round((time.perf_counter() - t0) * 1000, 3)
    _log_event("access_check_batch", rid, None, dur, {"count": len(results)})

    return JSONResponse(status_code=200,
                        content=_envelope(True, data=results, request_id=rid),
                        headers={"X-Request-Id": rid})


@app.post("/api/policy/mask/resolve")
async def ep_mask_resolve(request: Request):
    rid = _req_id(request)
    _require_models(rid)
    try:
        body = await request.json()
    except Exception:
        body = {}
    shared = body.get("context") or {}
    # Pull field_category and target from body root if present
    extra = {k: body.get(k) for k in ("field_category", "target_user_id") if body.get(k) is not None}
    ctx = _merge_ctx(_hdr_ctx(request), shared, extra)

    if not ctx.get("actor_role"):
        raise HTTPException(status_code=400, detail=_envelope(
            False, error=_err("invalid_request", "actor_role required"), request_id=rid,
        ))

    t0 = time.perf_counter()
    result = mask_resolve(MODELS, ctx)
    dur = round((time.perf_counter() - t0) * 1000, 3)
    _log_event("mask_resolve", rid, result.get("mask_level"), dur,
               {"field_category": ctx.get("field_category")})
    return JSONResponse(status_code=200,
                        content=_envelope(True, data=result, request_id=rid),
                        headers={"X-Request-Id": rid})


@app.post("/api/policy/assist/validate")
async def ep_assist_validate(request: Request):
    rid = _req_id(request)
    _require_models(rid)
    try:
        body = await request.json()
    except Exception:
        body = {}
    shared = body.get("context") or {}
    extra = {k: body.get(k) for k in ("proposed_action",) if body.get(k) is not None}
    ctx = _merge_ctx(_hdr_ctx(request), shared, extra)

    t0 = time.perf_counter()
    result = assist_validate(MODELS, ctx)
    dur = round((time.perf_counter() - t0) * 1000, 3)
    _log_event("assist_validate", rid, result.get("valid"), dur,
               {"actor_role": ctx.get("actor_role")})
    return JSONResponse(status_code=200,
                        content=_envelope(True, data=result, request_id=rid),
                        headers={"X-Request-Id": rid})


@app.post("/api/policy/view-as/validate")
async def ep_viewas_validate(request: Request):
    rid = _req_id(request)
    _require_models(rid)
    try:
        body = await request.json()
    except Exception:
        body = {}
    shared = body.get("context") or {}
    extra = {k: body.get(k) for k in ("proposed_action",) if body.get(k) is not None}
    ctx = _merge_ctx(_hdr_ctx(request), shared, extra)

    t0 = time.perf_counter()
    result = viewas_validate(MODELS, ctx)
    dur = round((time.perf_counter() - t0) * 1000, 3)
    _log_event("viewas_validate", rid, result.get("valid"), dur,
               {"actor_role": ctx.get("actor_role")})
    return JSONResponse(status_code=200,
                        content=_envelope(True, data=result, request_id=rid),
                        headers={"X-Request-Id": rid})


@app.post("/api/policy/sensitive/check")
async def ep_sensitive_check(request: Request):
    rid = _req_id(request)
    _require_models(rid)
    try:
        body = await request.json()
    except Exception:
        body = {}
    shared = body.get("context") or {}
    extra = {k: body.get(k) for k in ("field_category","requested_action") if body.get(k) is not None}
    ctx = _merge_ctx(_hdr_ctx(request), shared, extra)

    if not ctx.get("actor_role"):
        raise HTTPException(status_code=400, detail=_envelope(
            False, error=_err("invalid_request", "actor_role required"), request_id=rid,
        ))

    t0 = time.perf_counter()
    result = sensitive_check(MODELS, ctx)
    dur = round((time.perf_counter() - t0) * 1000, 3)
    _log_event("sensitive_check", rid, result.get("decision"), dur,
               {"field_category": ctx.get("field_category")})
    return JSONResponse(status_code=200,
                        content=_envelope(True, data=result, request_id=rid),
                        headers={"X-Request-Id": rid})


@app.get("/")
def root(request: Request):
    return _envelope(True, data={
        "service": "PTT Admin Control Plane · Policy Engine",
        "phase": "2a (base) + 2b (additive)",
        "endpoints": [
            "GET  /api/policy/access/check",
            "POST /api/policy/access/check/batch",
            "POST /api/policy/mask/resolve",
            "POST /api/policy/assist/validate",
            "POST /api/policy/view-as/validate",
            "POST /api/policy/sensitive/check",
            "GET  /api/policy/health",
            "GET  /api/policy/approvals/{approval_id}    (Phase 2b)",
            "POST /api/policy/approvals/validate         (Phase 2b)",
            "POST /api/policy/sensitive.check.2b         (Phase 2b)",
            "POST /api/policy/audit/preview              (Phase 2b)",
            "GET  /phase-2b/health                       (Phase 2b)",
        ],
        "contract": "policy_contract.json (2a) · phase_2b_contract.json (2b)",
        "examples_regression": "verify_examples.py (24/24 pass · unchanged by Phase 2b)",
    }, request_id=_req_id(request))


# ============================================================================
# PHASE 2B · additive only · does NOT alter Phase 2a endpoints above
# ============================================================================
#  * JWT Bearer resolver (decode-only unless verification preconditions met)
#  * File-backed approval store (loaded at startup)
#  * NEW endpoints:
#      GET  /api/policy/approvals/{approval_id}
#      POST /api/policy/approvals/validate
#      POST /api/policy/sensitive.check.2b
#      POST /api/policy/audit/preview
#      GET  /phase-2b/health
# verify_examples.py (24/24) is unaffected — policy_engine.py existing
# functions are untouched. policy_engine.validate_approval_row() and
# sensitive_check_2b() are new additive helpers.
# Contract: phase_2b_contract.json · audit_sink_contract.json

import hashlib as _hashlib

PHASE_2B_VERSION = "2b-py-additive-0.1"

APPROVAL_STORE_PATH = os.environ.get(
    "ADMIN_APPROVAL_STORE_PATH",
    os.path.abspath(os.path.join(HERE, "approval_examples.json")),
)
JWT_VERIFY_FLAG = os.environ.get("POLICY_JWT_VERIFY", "").lower() in ("1", "true", "yes")
JWT_PUBLIC_KEY = os.environ.get("POLICY_JWT_PUBLIC_KEY")
JWT_JWKS_URL   = os.environ.get("POLICY_JWKS_URL")
JWT_AUDIENCE   = os.environ.get("POLICY_JWT_AUDIENCE", "pty-admin-policy")

try:
    import jwt as _pyjwt  # PyJWT is OPTIONAL · not a hard dep in requirements.txt
    PYJWT_AVAILABLE = True
except Exception:          # noqa: BLE001
    _pyjwt = None          # type: ignore
    PYJWT_AVAILABLE = False


def _load_approval_store() -> tuple[list[dict], dict[str, int], str | None]:
    try:
        with open(APPROVAL_STORE_PATH, encoding="utf-8") as f:
            data = json.load(f)
        rows = data.get("approvals") or []
        if not isinstance(rows, list):
            return [], {}, "approval_examples.approvals is not a list"
        # Soft validation · drop rows missing required fields
        required = ("approval_id", "matrix_row", "target_type", "target_id",
                    "required_signer_count", "state", "sla_due_at")
        valid: list[dict] = []
        for r in rows:
            if not isinstance(r, dict):
                continue
            if any(r.get(k) is None for k in required):
                continue
            valid.append(r)
        idx = {r["approval_id"]: i for i, r in enumerate(valid)}
        return valid, idx, None
    except FileNotFoundError:
        return [], {}, f"approval file not found at {APPROVAL_STORE_PATH}"
    except json.JSONDecodeError as e:
        return [], {}, f"approval JSON parse failed: {e}"
    except Exception as e:  # noqa: BLE001
        return [], {}, f"approval load error: {e}"


APPROVAL_ROWS, APPROVAL_INDEX, APPROVAL_ERR = _load_approval_store()


def _parse_bearer_admin(request: Request) -> tuple[dict, str, list[str]]:
    """Return (claims, auth_source, warnings). Mirrors FF 2b semantics."""
    warnings: list[str] = []
    hdr = request.headers.get("authorization", "")
    if not hdr.lower().startswith("bearer "):
        return {}, "none", warnings
    token = hdr[7:].strip()
    if not token:
        return {}, "none", warnings
    can_verify = PYJWT_AVAILABLE and JWT_VERIFY_FLAG and (JWT_PUBLIC_KEY or JWT_JWKS_URL)
    if can_verify and _pyjwt is not None:
        try:
            claims = _pyjwt.decode(
                token, JWT_PUBLIC_KEY,
                algorithms=["RS256", "HS256"],
                audience=JWT_AUDIENCE,
            )
            return claims, "jwt", warnings
        except Exception as e:  # noqa: BLE001
            warnings.append(f"jwt_verify_failed: {e}")
    # Unverified decode
    try:
        if _pyjwt is not None:
            claims = _pyjwt.decode(token, options={"verify_signature": False, "verify_exp": False, "verify_aud": False})
        else:
            import base64
            parts = token.split(".")
            if len(parts) < 2:
                return {}, "none", warnings + ["bearer_malformed"]
            pad = "=" * (-len(parts[1]) % 4)
            claims = json.loads(base64.urlsafe_b64decode(parts[1] + pad))
        warnings.append("auth_not_verified")
        return claims, "jwt_unverified", warnings
    except Exception as e:  # noqa: BLE001
        return {}, "none", warnings + [f"bearer_decode_failed: {e}"]


def _ctx_from_bearer_admin(claims: dict) -> dict:
    """Map JWT claims onto admin context vocabulary."""
    out: dict[str, Any] = {}
    if claims.get("sub"):       out["actor_user_id"] = claims["sub"]
    if claims.get("role"):      out["actor_role"] = claims["role"]
    if claims.get("tenant_id"): out["actor_tenant_id"] = claims["tenant_id"]
    if claims.get("approval_refs"): out["approval_refs"] = claims["approval_refs"]
    return out


def _merge_ctx_2b(request: Request, query: dict[str, Any], body: dict[str, Any] | None = None) -> tuple[dict[str, Any], str, list[str]]:
    """Bearer > headers > body > query > defaults."""
    claims, auth_src, warnings = _parse_bearer_admin(request)
    hdr = _hdr_ctx(request)
    from_claims = _ctx_from_bearer_admin(claims)
    merged: dict[str, Any] = {}
    for src in (query, body or {}, hdr, from_claims):
        if not src:
            continue
        for k, v in src.items():
            if v is None:
                continue
            merged[k] = v
    if auth_src == "none":
        if hdr and (query or body):
            auth_src = "mixed"
        elif hdr:
            auth_src = "dev_headers"
        elif body:
            auth_src = "body"
        elif query:
            auth_src = "query"
    elif auth_src.startswith("jwt") and (hdr or body or query):
        auth_src = "mixed"
    return merged, auth_src, warnings


def _log_2b(event: str, rid: str, extra: dict | None = None):
    rec = {"request_id": rid, "event": event, "phase": "2b"}
    if extra:
        rec.update(extra)
    log.info(json.dumps(rec))


# ------------------------------------------------------------------ routes --

@app.get("/api/policy/approvals/{approval_id}")
def ep_approval_get(approval_id: str, request: Request):
    """Phase 2b · READ-ONLY single approval inspection."""
    rid = _req_id(request)
    _require_models(rid)
    if APPROVAL_ERR:
        raise HTTPException(status_code=503, detail=_envelope(
            False, error=_err("approval_store_unavailable", APPROVAL_ERR,
                              hint=f"check ADMIN_APPROVAL_STORE_PATH={APPROVAL_STORE_PATH}"),
            request_id=rid,
        ))
    i = APPROVAL_INDEX.get(approval_id)
    if i is None:
        raise HTTPException(status_code=404, detail=_envelope(
            False, error=_err("approval_not_found", f"no row with approval_id={approval_id}"),
            request_id=rid,
        ))
    row = APPROVAL_ROWS[i]
    from policy_engine import validate_approval_row
    validation = validate_approval_row(row)
    _log_2b("approval_read", rid, {"approval_id": approval_id, "state": row.get("state"),
                                    "expired": validation["expired"]})
    return JSONResponse(
        status_code=200,
        content=_envelope(True, data={"row": row, "validation": validation}, request_id=rid),
        headers={"X-Request-Id": rid},
    )


@app.post("/api/policy/approvals/validate")
async def ep_approval_validate(request: Request):
    """Phase 2b · validate one or more approval_ids against the store with TTL.

    Request body:
      {"approval_refs": ["APP-...","APP-..."], "now_iso": "optional"}

    Response data:
      {"per_ref": [...], "any_valid": bool, "any_expired": bool, "all_valid": bool}
    """
    rid = _req_id(request)
    _require_models(rid)
    if APPROVAL_ERR:
        raise HTTPException(status_code=503, detail=_envelope(
            False, error=_err("approval_store_unavailable", APPROVAL_ERR), request_id=rid,
        ))
    try:
        body = await request.json()
    except Exception:
        raise HTTPException(status_code=400, detail=_envelope(
            False, error=_err("invalid_request", "body is not valid JSON",
                              hint='expect {"approval_refs": [string], "now_iso"?: string}'),
            request_id=rid,
        ))
    refs = body.get("approval_refs") or []
    if not isinstance(refs, list) or not refs:
        raise HTTPException(status_code=400, detail=_envelope(
            False, error=_err("invalid_request", "approval_refs[] missing or empty"),
            request_id=rid,
        ))
    from policy_engine import validate_approval_row, _now as _engine_now
    now = _engine_now(body.get("now_iso")) if hasattr(_engine_now, "__call__") else None
    # Fall back to datetime.now if _now import failed
    if now is None:
        now = datetime.now(timezone.utc)
    per_ref: list[dict] = []
    any_valid = False
    any_expired = False
    for ref in refs:
        i = APPROVAL_INDEX.get(ref)
        if i is None:
            per_ref.append({
                "approval_id": ref, "valid": False, "expired": False,
                "state": None, "ttl_remaining_seconds": 0,
                "reasons": [{"id": "approval_not_found", "text": f"approval_id={ref} not in store"}],
            })
            continue
        v = validate_approval_row(APPROVAL_ROWS[i], now)
        per_ref.append(v)
        if v["valid"]:   any_valid = True
        if v["expired"]: any_expired = True
    all_valid = all(p["valid"] for p in per_ref)

    _log_2b("approval_validate", rid, {
        "count": len(refs),
        "valid_count": sum(1 for p in per_ref if p.get("valid")),
        "expired_count": sum(1 for p in per_ref if p.get("expired")),
    })
    return JSONResponse(
        status_code=200,
        content=_envelope(True, data={
            "per_ref": per_ref,
            "any_valid": any_valid,
            "any_expired": any_expired,
            "all_valid": all_valid,
        }, request_id=rid),
        headers={"X-Request-Id": rid},
    )


@app.post("/api/policy/sensitive.check.2b")
async def ep_sensitive_check_2b(request: Request):
    """Phase 2b · store-aware sensitive.check (NOT a replacement for 2a)."""
    rid = _req_id(request)
    _require_models(rid)
    try:
        body = await request.json()
    except Exception:
        body = {}
    shared = body.get("context") or {}
    ctx, auth_source, warnings = _merge_ctx_2b(request, {}, shared)
    for k in ("field_category", "requested_action"):
        if body.get(k) is not None:
            ctx[k] = body[k]
    if not ctx.get("actor_role"):
        raise HTTPException(status_code=400, detail=_envelope(
            False, error=_err("invalid_request", "actor_role required"), request_id=rid,
        ))
    from policy_engine import sensitive_check_2b as engine_s2b
    t0 = time.perf_counter()
    result = engine_s2b(MODELS, ctx, APPROVAL_ROWS)
    dur = round((time.perf_counter() - t0) * 1000, 3)
    result["auth_source"] = auth_source
    if warnings:
        result["warnings"] = warnings
    _log_2b("sensitive_check_2b", rid, {
        "decision": result.get("decision"),
        "any_valid_ref": result.get("any_valid_ref"),
        "any_expired_ref": result.get("any_expired_ref"),
        "duration_ms": dur,
        "auth_source": auth_source,
    })
    return JSONResponse(status_code=200,
                        content=_envelope(True, data=result, request_id=rid),
                        headers={"X-Request-Id": rid})


@app.post("/api/policy/audit/preview")
async def ep_audit_preview(request: Request):
    """Phase 2b · render audit envelope for (actor, action, subject) tuple.

    Sink status is ALWAYS 'deferred' in this phase. Body:
      {"context":{...}, "event_type": "access.grant", "subject": {...},
       "action": "read", "approval_refs"?: [...]}
    """
    rid = _req_id(request)
    _require_models(rid)
    try:
        body = await request.json()
    except Exception:
        raise HTTPException(status_code=400, detail=_envelope(
            False, error=_err("invalid_request", "body is not valid JSON"),
            request_id=rid,
        ))

    ctx, auth_source, warnings = _merge_ctx_2b(request, {}, body.get("context") or {})
    event_type = body.get("event_type") or "access.grant"
    subject = body.get("subject") or {"type": "unknown", "id": "n/a"}
    action = body.get("action") or "read"
    approval_refs = body.get("approval_refs") or []

    # Allowed event_types come from audit_event_model.event_types
    event_model = MODELS.get("audit_event_model") or {}
    valid_types = {t.get("id") for t in event_model.get("event_types", [])}
    category_by_type = {t.get("id"): t.get("category") for t in event_model.get("event_types", [])}
    if event_type not in valid_types:
        warnings.append(f"event_type_not_in_model:{event_type}")
    category = category_by_type.get(event_type, "governance")

    # Deterministic-ish event_id from request_id + timestamp (stable across a single call)
    ts_iso = ctx.get("now_iso") or datetime.now(timezone.utc).strftime("%Y-%m-%dT%H:%M:%S%z") or datetime.now(timezone.utc).isoformat()
    # Standardize to ISO-8601 +00:00 when no tz
    if "T" not in ts_iso:
        ts_iso = datetime.now(timezone.utc).isoformat()
    h = _hashlib.sha256(
        f"{rid}|{ctx.get('actor_user_id','')}|{ts_iso}|{event_type}|{subject.get('type','')}|{subject.get('id','')}".encode()
    ).hexdigest()[:8]
    event_id = f"aud-{datetime.now(timezone.utc).strftime('%y%m%d')}-{h}"

    envelope = {
        "event_id": event_id,
        "event_type": event_type,
        "event_category": category,
        "timestamp": ts_iso,
        "actor": {
            "user_id":       ctx.get("actor_user_id"),
            "role_key":      ctx.get("actor_role"),
            "tier":          ctx.get("tier"),
            "tenant_id":     ctx.get("actor_tenant_id"),
            "assist_ctx_id": (ctx.get("assist_ctx") or {}).get("consent_record_id"),
            "view_as_ctx_id": "active" if ctx.get("view_as_ctx") else None,
            "ip_fingerprint": "ip-sha256:preview",
            "user_agent_hash": "ua-sha256:preview",
        },
        "subject": subject,
        "action": action,
        "approval_refs": approval_refs,
        "mask_level_at_event": body.get("mask_level_at_event"),
        "session_ref": body.get("session_ref") or f"sess-preview-{rid[:8]}",
        "reasons": body.get("reasons") or [],
        "producer": f"admin-control-plane-service/{PHASE_2B_VERSION}",
        "sink_status": "deferred",
        "honest_note": "Preview only · NOT delivered to Kafka/WORM · see audit_sink_contract.json",
    }

    _log_2b("audit_preview", rid, {
        "event_type": event_type,
        "event_category": category,
        "actor_user_id": ctx.get("actor_user_id"),
        "approval_refs_count": len(approval_refs),
    })
    log.info(json.dumps({"audit_preview_envelope": envelope}))

    return JSONResponse(status_code=200, content=_envelope(True, data={
        "envelope": envelope,
        "auth_source": auth_source,
        "warnings": warnings,
        "sink_status": "deferred",
        "delivery_mode": "preview-only",
        "contract": "audit_sink_contract.json",
    }, request_id=rid), headers={"X-Request-Id": rid})


@app.get("/phase-2b/health")
def health_2b(request: Request):
    rid = _req_id(request)
    data = {
        "phase": "2b",
        "service_version": SERVICE_VERSION,
        "engine_2b_version": "2b-py-additive-0.1",
        "approval_store_loaded": APPROVAL_ERR is None,
        "approval_store_path": APPROVAL_STORE_PATH,
        "approval_store_error": APPROVAL_ERR,
        "approval_count": len(APPROVAL_ROWS),
        "pyjwt_installed": PYJWT_AVAILABLE,
        "jwt_verify_flag_env": JWT_VERIFY_FLAG,
        "jwt_public_key_present": bool(JWT_PUBLIC_KEY),
        "jwks_url_present": bool(JWT_JWKS_URL),
        "jwt_verify_live": PYJWT_AVAILABLE and JWT_VERIFY_FLAG and bool(JWT_PUBLIC_KEY or JWT_JWKS_URL),
        "jwt_audience": JWT_AUDIENCE,
        "audit_sink_status": "deferred",
        "phase_2b_endpoints": [
            "GET  /api/policy/approvals/{approval_id}",
            "POST /api/policy/approvals/validate",
            "POST /api/policy/sensitive.check.2b",
            "POST /api/policy/audit/preview",
            "GET  /phase-2b/health",
        ],
        "contract": "phase_2b_contract.json",
        "audit_contract": "audit_sink_contract.json",
        "honest": "Additive only · Phase 2a endpoints unchanged · 24/24 parity still holds",
    }
    return JSONResponse(status_code=200, content=_envelope(True, data=data, request_id=rid))
