#!/usr/bin/env python3
"""Phase 14j — HADDOCK3 full pipeline via official Docker image (Linux x86_64).

Bypasses the macOS-arm64 CNS template variable-substitution bug
($mol_fix_origin_$maxid) by running the same HADDOCK3 2026.5.0 binary against
the Linux CNS that ships in the Docker image.

Inputs are the same as the host pipeline (haddock3_full_pipeline.py); we just
swap the executable paths to be container-relative.
"""
from __future__ import annotations
import os, subprocess, csv, json, shutil, time
from pathlib import Path

# Inside container, the repo is mounted at /repo
REPO = Path("/repo")
OUT = REPO / "14_inhibitor_design" / "07_advanced_methods" / "haddock3_full"
OUT.mkdir(parents=True, exist_ok=True)

HADDOCK3 = Path("/usr/local/bin/haddock3")
# The container's bundled CNS is used automatically by haddock3

APO_PDB = REPO / "06f_receptor_fixed" / "dimer_noH.pdb"

ACTIVE_RESIDUES_A = [20, 21, 22, 23, 24, 25, 32, 33, 34, 35, 36, 37, 39, 117,
                     118, 135, 148, 150, 151, 153, 157, 158, 159, 160, 167,
                     168, 172, 173, 175, 177, 178, 199, 200, 202]
PASSIVE_RESIDUES_A = sorted(set(r + d for r in ACTIVE_RESIDUES_A
                                for d in (-2, -1, 0, 1, 2)
                                if r + d > 0 and (r + d) not in ACTIVE_RESIDUES_A))


def prep_receptor_chain_a_only(out_pdb: Path):
    """Single chain A (first half) — HADDOCK3 2026.5.0 / 2025.11.0 has a CNS
    template bug triggered by multi-chain receptors. The dimer-interface AIRs
    on chain A still work as long as the partner protomer is implicit."""
    text = APO_PDB.read_text().splitlines()
    n_total = sum(1 for ln in text if ln.startswith("ATOM"))
    half = n_total // 2
    cnt = 0; out = []
    for ln in text:
        if ln.startswith("ATOM"):
            if cnt >= half:
                cnt += 1
                continue
            new = ln[:21] + "A" + ln[22:]
            out.append(new); cnt += 1
    out_pdb.write_text("\n".join(out) + "\nTER\nEND\n")
    return out_pdb


def prep_peptide(sequence: str, out_pdb: Path, chain: str = "C"):
    from rdkit import Chem
    from rdkit.Chem import AllChem
    mol = Chem.MolFromSequence(sequence)
    mol = Chem.AddHs(mol)
    AllChem.EmbedMolecule(mol, randomSeed=42)
    try: AllChem.MMFFOptimizeMolecule(mol, maxIters=1000)
    except Exception: pass
    raw_pdb = out_pdb.with_suffix(".raw.pdb")
    Chem.MolToPDBFile(mol, str(raw_pdb))
    lines = []
    for ln in raw_pdb.read_text().splitlines():
        if ln.startswith(("ATOM", "HETATM")):
            new = ln[:21] + chain + ln[22:]
            if new.startswith("HETATM"):
                new = "ATOM  " + new[6:]
            lines.append(new)
    out_pdb.write_text("\n".join(lines) + "\nEND\n")
    return out_pdb


def make_airs_tbl(out_tbl: Path, peptide_chain: str = "C"):
    air_lines = ["! Phase 14j dimer-interface AIRs"]
    for resid in ACTIVE_RESIDUES_A:
        air_lines.append(
            f"assign ( resid {resid} and segid A ) "
            f"( segid {peptide_chain} ) 2.0 2.0 0.0"
        )
    out_tbl.write_text("\n".join(air_lines) + "\n")
    return out_tbl


def write_config(run_dir: Path, receptor_pdb: Path, peptide_pdb: Path,
                 airs_tbl: Path, config_path: Path):
    config = f"""# HADDOCK3 config — TYMS dimer + LR peptide (Docker)
run_dir = "{run_dir}"
postprocess = true
clean = false
ncores = 4

molecules = [
    "{receptor_pdb}",
    "{peptide_pdb}",
]

[topoaa]

[rigidbody]
ambig_fname = "{airs_tbl}"
sampling = 50

[seletop]
select = 10

[flexref]
ambig_fname = "{airs_tbl}"

[caprieval]
"""
    config_path.write_text(config)
    return config_path


def run_haddock3(config: Path):
    cmd = [str(HADDOCK3), str(config)]
    print(f"  ▶ {' '.join(cmd)}", flush=True)
    t0 = time.time()
    proc = subprocess.run(cmd, capture_output=True, timeout=7200,
                          cwd=str(config.parent))
    wall = time.time() - t0
    print(f"    completed in {wall:.0f}s (rc={proc.returncode})", flush=True)
    return proc.returncode, proc.stdout.decode() + proc.stderr.decode(), wall


def parse_caprieval(run_dir: Path):
    for csv_path in run_dir.rglob("capri_ss.tsv"):
        rows = list(csv.DictReader(csv_path.open(), delimiter="\t"))
        if rows:
            rows.sort(key=lambda r: float(r.get("score", 0) or 0))
            return {"top_score": float(rows[0].get("score", 0) or 0),
                    "n_models": len(rows), "ss_file": str(csv_path),
                    "best_row": {k: rows[0].get(k) for k in
                                 ("score", "vdw", "elec", "desolv", "air", "bsa")}}
    return None


def main():
    print("=== Phase 14j — HADDOCK3 via Docker (Linux x86_64) ===", flush=True)
    print(f"  HADDOCK3: {HADDOCK3}  exists={HADDOCK3.exists()}", flush=True)

    receptor_pdb = OUT / "receptor_A_only.pdb"
    prep_receptor_chain_a_only(receptor_pdb)
    print(f"  receptor: {receptor_pdb}", flush=True)

    results = {}
    for label, sequence in [("canonical", "LSCQLYQR"),
                            ("scrambled", "QLCRQSYL")]:
        print(f"\n  --- {label} peptide ({sequence}) ---", flush=True)
        run_dir = OUT / f"run_{label}_docker"
        if run_dir.exists(): shutil.rmtree(run_dir)
        run_dir.mkdir(parents=True)

        peptide_pdb = run_dir / "peptide.pdb"
        prep_peptide(sequence, peptide_pdb, chain="C")
        print(f"  peptide: {peptide_pdb}", flush=True)

        airs_tbl = run_dir / "airs.tbl"
        make_airs_tbl(airs_tbl, peptide_chain="C")

        config_path = run_dir / "config.cfg"
        run_subdir = run_dir / "haddock_run"
        write_config(run_subdir, receptor_pdb, peptide_pdb, airs_tbl, config_path)

        rc, log, wall = run_haddock3(config_path)
        (run_dir / "haddock3_run.log").write_text(log)
        print(f"    rc={rc}  log size={len(log)}  wall={wall:.0f}s", flush=True)

        capri = parse_caprieval(run_subdir)
        print(f"    capri: {capri}", flush=True)
        results[label] = {"return_code": rc, "wall_s": wall, "sequence": sequence,
                          "capri": capri, "log_tail": log[-1500:] if log else ""}

    print("\n=== HADDOCK3 (Docker) cluster-score comparison ===", flush=True)
    can = results.get("canonical", {}).get("capri") or {}
    scr = results.get("scrambled", {}).get("capri") or {}
    can_score = can.get("top_score")
    scr_score = scr.get("top_score")
    print(f"  canonical: top HADDOCK score = {can_score}", flush=True)
    print(f"  scrambled: top HADDOCK score = {scr_score}", flush=True)
    if can_score is not None and scr_score is not None:
        delta = can_score - scr_score
        verdict = ("★ canonical IS more favourable than scrambled" if delta < -1
                   else "within noise — null result confirmed under HADDOCK3")
        print(f"  Δ (canonical − scrambled) = {delta:+.2f} → {verdict}", flush=True)
        results["comparison"] = {"canonical": can_score, "scrambled": scr_score,
                                 "delta": delta, "verdict": verdict}

    summary_path = OUT / "haddock3_docker_summary.json"
    summary_path.write_text(json.dumps(results, indent=2, default=str))
    print(f"\n  → {summary_path}", flush=True)


if __name__ == "__main__":
    main()
