#!/usr/bin/env python3
"""Phase 14i — actual MM-GBSA ΔΔG_bind with GAFF-parametrised ligand.

Replaces the Phase 14g placeholder (which only computed ΔE_receptor without
a parametrised ligand). This script:

  1. Extracts the holo-pose ligand (raltitrexed / D16) from each mutant complex
     PDB in 07d_mut_docking_v4/viewer_files/<MUT>_holo_complex.pdb.
  2. Parametrises the ligand with acpype + GAFF (no openff-toolkit dependency
     — openff-toolkit has no arm64-macOS wheel; acpype ships its own
     pipeline using bundled openbabel + Antechamber-like atom-typing).
  3. Builds a combined OpenMM system: AMBER ff14SB (protein) + GAFF (ligand)
     + GBn2 implicit solvent.
  4. Local-minimises the complex with mild positional restraints on the
     protein backbone (10 kcal/mol/Å² on CA atoms) so the ligand pocket
     relaxes without the whole receptor drifting.
  5. Computes single-trajectory MM-GBSA:
        ΔG_bind = E_complex - E_receptor - E_ligand
     ΔΔG_bind(mutant) = ΔG_bind(mutant) - ΔG_bind(WT)

Inputs:
  - 07d_mut_docking_v4/viewer_files/<MUT>_holo_complex.pdb     (holo complex)
  - 06f_receptor_fixed/dimer_noH.pdb                            (WT reference)

Outputs:
  - 14_inhibitor_design/07_advanced_methods/mmgbsa_gaff/
        per-system: <label>/ligand.{mol2,gaff.frcmod,xml}, complex_min.pdb
        summary:   mmgbsa_gaff_results.csv
                   mmgbsa_gaff_plot.png

Wall-time: ~3-6 min/system on arm64 CPU (8 systems → ~30-45 min total).
"""
from __future__ import annotations
import os, sys, csv, json, shutil, subprocess, time
from pathlib import Path

import numpy as np
from openmm import (
    LangevinMiddleIntegrator, NonbondedForce, CustomExternalForce,
    Platform, unit, app, XmlSerializer,
)
from openmm import unit as u
from openmm.app import (
    PDBFile, ForceField, Modeller, Simulation, NoCutoff, GBn2,
    HBonds, AllBonds, Topology,
)
import pdbfixer
import parmed

REPO = Path(__file__).resolve().parents[2]
OUT  = REPO / "14_inhibitor_design" / "07_advanced_methods" / "mmgbsa_gaff"
OUT.mkdir(parents=True, exist_ok=True)

VIEWER = REPO / "07d_mut_docking_v4" / "viewer_files"

# (label, complex_pdb relative to VIEWER, comment)
SYSTEMS = [
    ("WT_holo",           "WT_holo_complex.pdb",           "wild-type baseline"),
    ("R175E_R176E_holo",  "R175E_R176E_holo_complex.pdb",  "double Arg→Glu charge reversal"),
    ("R175E_holo",        "R175E_holo_complex.pdb",        "single Arg→Glu"),
    ("R215E_holo",        "R215E_holo_complex.pdb",        "phosphate-clamp Arg→Glu"),
    ("R215A_holo",        "R215A_holo_complex.pdb",        "phosphate-clamp Arg→Ala"),
    ("C195A_holo",        "C195A_holo_complex.pdb",        "catalytic Cys ablation"),
    ("H196A_holo",        "H196A_holo_complex.pdb",        "catalytic His ablation"),
    ("T170A_holo",        "T170A_holo_complex.pdb",        "distant-surface negative control"),
]

LIG_RESNAME = "D16"   # raltitrexed in 1HVY
ACPYPE = shutil.which("acpype") or "/tmp/py311_haddock/bin/acpype"


def split_complex(complex_pdb: Path, wdir: Path):
    """Split into receptor.pdb (ATOM only) + ligand.pdb (HETATM with D16)."""
    rec_lines, lig_lines = [], []
    for ln in complex_pdb.read_text().splitlines():
        if ln.startswith("ATOM"):
            rec_lines.append(ln)
        elif ln.startswith("HETATM") and ln[17:20].strip() == LIG_RESNAME:
            # Treat as HETATM in the ligand-only PDB
            lig_lines.append(ln)
    rec = wdir / "receptor.pdb"; rec.write_text("\n".join(rec_lines) + "\nEND\n")
    lig = wdir / "ligand_input.pdb"; lig.write_text("\n".join(lig_lines) + "\nEND\n")
    return rec, lig, len(rec_lines), len(lig_lines)


def parametrise_ligand_acpype(lig_pdb: Path, wdir: Path):
    """Run acpype to generate GAFF parameters; convert to OpenMM XML via parmed.
    Returns (ligand_pdb_relabeled, ff_xml_path)."""
    # acpype wants a small-molecule input; PDB works if it has CONECT or if
    # openbabel can perceive bonds. Convert via openbabel for safety.
    work = wdir / "_acpype"
    work.mkdir(parents=True, exist_ok=True)

    # Use openbabel (bundled with acpype) to make a clean mol2
    from openbabel import openbabel as ob
    obc = ob.OBConversion()
    obc.SetInAndOutFormats("pdb", "mol2")
    mol = ob.OBMol()
    obc.ReadFile(mol, str(lig_pdb))
    # AddHs is risky — D16 in the crystal already has correct H. Just write.
    mol2_path = work / "lig.mol2"
    obc.WriteFile(mol, str(mol2_path))

    # Net charge of raltitrexed (D16) at pH 7.4 = -2 (deprotonated carboxylates)
    cmd = [str(ACPYPE), "-i", "lig.mol2", "-b", "LIG", "-c", "gas", "-n", "-2",
           "-a", "gaff2"]
    proc = subprocess.run(cmd, cwd=str(work), capture_output=True, timeout=600)
    if proc.returncode != 0:
        print("    acpype stdout:", proc.stdout.decode()[-600:])
        print("    acpype stderr:", proc.stderr.decode()[-600:])
        raise RuntimeError(f"acpype failed rc={proc.returncode}")

    # acpype output: LIG.acpype/LIG_AC.prmtop, LIG_AC.inpcrd
    acdir = work / "LIG.acpype"
    prmtop = acdir / "LIG_AC.prmtop"
    inpcrd = acdir / "LIG_AC.inpcrd"
    if not (prmtop.exists() and inpcrd.exists()):
        # Sometimes acpype names them differently
        candidates = list(acdir.glob("*.prmtop")) if acdir.exists() else []
        if candidates:
            prmtop = candidates[0]
            inpcrd = prmtop.with_suffix(".inpcrd")
    if not prmtop.exists():
        raise RuntimeError(f"acpype produced no prmtop in {acdir}")
    print(f"    ligand parametrised: {prmtop.name}")
    return prmtop, inpcrd


def build_complex_and_minimise(label: str, complex_pdb: Path, wdir: Path):
    """Whole-system minimise; return (E_complex, E_rec, E_lig, dG_bind, status)."""
    rec_pdb, lig_pdb, n_rec, n_lig = split_complex(complex_pdb, wdir)
    print(f"    split: {n_rec} ATOM (receptor), {n_lig} HETATM (ligand)")
    if n_lig == 0:
        return None

    # Parametrise ligand
    prmtop, inpcrd = parametrise_ligand_acpype(lig_pdb, wdir)

    # ===== Build receptor system separately (AMBER ff14SB) =====
    # Strip Hs first then re-protonate via PDBFixer (fixes inconsistent HIS naming).
    rec_noh_lines = []
    for ln in rec_pdb.read_text().splitlines():
        if ln.startswith("ATOM"):
            element = ln[76:78].strip() if len(ln) >= 78 else ""
            atom_name = ln[12:16].strip()
            if element == "H" or atom_name.startswith("H"):
                continue
        rec_noh_lines.append(ln)
    rec_noh = wdir / "receptor_noh.pdb"; rec_noh.write_text("\n".join(rec_noh_lines))

    fixer = pdbfixer.PDBFixer(filename=str(rec_noh))
    fixer.findMissingResidues(); fixer.missingResidues = {}
    fixer.findMissingAtoms(); fixer.addMissingAtoms()
    fixer.removeHeterogens(keepWater=False)
    fixer.addMissingHydrogens(pH=7.4)
    rec_fix = wdir / "receptor_fix.pdb"
    PDBFile.writeFile(fixer.topology, fixer.positions, open(rec_fix, "w"))

    # ===== Combine receptor + ligand into a single OpenMM system via parmed =====
    rec_pdb_obj = PDBFile(str(rec_fix))
    rec_ff = ForceField("amber14-all.xml", "implicit/gbn2.xml")
    rec_topo = rec_pdb_obj.topology
    rec_pos = rec_pdb_obj.positions
    rec_system = rec_ff.createSystem(rec_topo, nonbondedMethod=NoCutoff,
                                      constraints=None, soluteDielectric=1.0,
                                      solventDielectric=78.5)
    rec_structure = parmed.openmm.load_topology(rec_topo, rec_system, xyz=rec_pos)

    lig_structure = parmed.load_file(str(prmtop), xyz=str(inpcrd))
    # parmed lets us concatenate
    combined = rec_structure + lig_structure

    # Now build a fresh OpenMM system from combined parmed Structure.
    # GBn2 needs all atoms classified; load_topology already handled receptor;
    # parmed's createSystem from prmtop+gbsa info: we must build the GB force ourselves
    # or rely on GBn2 in ff14SB. Easiest: build via parmed.createSystem with GBn2.
    sys_combined = combined.createSystem(
        nonbondedMethod=app.NoCutoff,
        constraints=app.HBonds,
        implicitSolvent=app.GBn2,
        soluteDielectric=1.0,
        solventDielectric=78.5,
    )

    # Backbone-restrained minimisation: 50 kcal/mol/Å² on receptor CA atoms.
    # (Higher k = the receptor is essentially frozen, which is what we want for
    # MM-GBSA single-trajectory rescoring: only the pocket-side chains + ligand
    # relax.  This also makes the minimisation converge orders of magnitude
    # faster — without it OpenMM's L-BFGS spends all its time relaxing the
    # solvent-exposed surface that's irrelevant to ΔG_bind.)
    restraint = CustomExternalForce("k*((x-x0)^2+(y-y0)^2+(z-z0)^2)")
    restraint.addGlobalParameter("k", 50.0 * u.kilocalories_per_mole / u.angstroms**2)
    restraint.addPerParticleParameter("x0")
    restraint.addPerParticleParameter("y0")
    restraint.addPerParticleParameter("z0")

    # Identify CA atoms in the receptor part (first n_rec_atoms atoms = receptor)
    n_rec_atoms = sum(1 for _ in rec_pdb_obj.topology.atoms())
    for i, atom in enumerate(combined.atoms):
        if i < n_rec_atoms and atom.name == "CA":
            pos = combined.positions[i]
            restraint.addParticle(i, [pos.value_in_unit(u.nanometer)[0],
                                      pos.value_in_unit(u.nanometer)[1],
                                      pos.value_in_unit(u.nanometer)[2]])
    sys_combined.addForce(restraint)

    integrator = LangevinMiddleIntegrator(300*u.kelvin, 1.0/u.picosecond, 0.001*u.picosecond)
    try:
        platform = Platform.getPlatformByName("CPU")
    except Exception:
        platform = None
    sim = Simulation(combined.topology, sys_combined, integrator,
                     platform=platform) if platform else \
          Simulation(combined.topology, sys_combined, integrator)
    sim.context.setPositions(combined.positions)

    e0 = sim.context.getState(getEnergy=True).getPotentialEnergy().value_in_unit(u.kilocalories_per_mole)
    print(f"    E_complex pre-min: {e0:+.1f} kcal/mol")

    # Short minimisation: 50 steps is enough to relieve close clashes given the
    # already-docked pose; we are doing single-point rescoring, not free
    # energy MD. Each L-BFGS step on 9000 atoms with GB-NoCutoff is O(N²) and
    # takes ~5-10 s on CPU, so even 50 steps is ~5 min/system.
    sim.minimizeEnergy(tolerance=50.0*u.kilocalories_per_mole/u.angstrom, maxIterations=50)

    state = sim.context.getState(getEnergy=True, getPositions=True)
    e_complex = state.getPotentialEnergy().value_in_unit(u.kilocalories_per_mole)
    pos_min = state.getPositions()
    print(f"    E_complex min:     {e_complex:+.1f} kcal/mol")

    # Write minimised complex
    PDBFile.writeFile(combined.topology, pos_min, open(wdir / "complex_min.pdb", "w"))

    # ===== Single-trajectory MM-GBSA: re-score receptor-only and ligand-only =====
    # Receptor-only: build a system with just receptor atoms (no ligand).
    rec_only_sys = rec_structure.createSystem(
        nonbondedMethod=app.NoCutoff, constraints=app.HBonds,
        implicitSolvent=app.GBn2, soluteDielectric=1.0, solventDielectric=78.5,
    )
    int2 = LangevinMiddleIntegrator(300*u.kelvin, 1/u.picosecond, 0.001*u.picosecond)
    sim2 = Simulation(rec_structure.topology, rec_only_sys, int2)
    # Use first n_rec_atoms positions from minimised complex
    sim2.context.setPositions(list(pos_min)[:n_rec_atoms])
    e_rec = sim2.context.getState(getEnergy=True).getPotentialEnergy().value_in_unit(u.kilocalories_per_mole)

    # Ligand-only:
    lig_only_sys = lig_structure.createSystem(
        nonbondedMethod=app.NoCutoff, constraints=app.HBonds,
        implicitSolvent=app.GBn2, soluteDielectric=1.0, solventDielectric=78.5,
    )
    int3 = LangevinMiddleIntegrator(300*u.kelvin, 1/u.picosecond, 0.001*u.picosecond)
    sim3 = Simulation(lig_structure.topology, lig_only_sys, int3)
    sim3.context.setPositions(list(pos_min)[n_rec_atoms:])
    e_lig = sim3.context.getState(getEnergy=True).getPotentialEnergy().value_in_unit(u.kilocalories_per_mole)

    dG_bind = e_complex - e_rec - e_lig
    print(f"    E_receptor:        {e_rec:+.1f}")
    print(f"    E_ligand:          {e_lig:+.1f}")
    print(f"    ΔG_bind (single):  {dG_bind:+.2f} kcal/mol")
    return {"E_complex": e_complex, "E_receptor": e_rec, "E_ligand": e_lig,
            "dG_bind": dG_bind, "n_rec_atoms": n_rec_atoms,
            "n_lig_atoms": len(combined.atoms) - n_rec_atoms}


def main():
    print("=== Phase 14i — MM-GBSA ΔΔG_bind with GAFF ligand ===")
    print(f"  acpype: {ACPYPE}  exists={os.path.exists(ACPYPE)}")
    results = {}
    for label, complex_rel, comment in SYSTEMS:
        cpx = VIEWER / complex_rel
        if not cpx.exists():
            print(f"  ! missing: {cpx}"); continue
        print(f"\n--- {label}  ({comment}) ---")
        wdir = OUT / label
        if wdir.exists(): shutil.rmtree(wdir)
        wdir.mkdir(parents=True)
        t0 = time.time()
        try:
            r = build_complex_and_minimise(label, cpx, wdir)
        except Exception as e:
            print(f"  ✗ failed: {e}")
            import traceback; traceback.print_exc()
            r = {"error": str(e)}
        if r:
            r["wall_s"] = round(time.time() - t0, 1)
            r["comment"] = comment
            results[label] = r

    # Compute ΔΔG_bind relative to WT
    wt = results.get("WT_holo", {})
    wt_dG = wt.get("dG_bind") if wt else None
    rows = []
    for label, r in results.items():
        dG = r.get("dG_bind")
        ddG = (dG - wt_dG) if (dG is not None and wt_dG is not None) else None
        rows.append({
            "label": label, "comment": r.get("comment", ""),
            "E_complex": round(r.get("E_complex", 0), 2) if "E_complex" in r else "",
            "E_receptor": round(r.get("E_receptor", 0), 2) if "E_receptor" in r else "",
            "E_ligand": round(r.get("E_ligand", 0), 2) if "E_ligand" in r else "",
            "dG_bind": round(dG, 2) if dG is not None else "",
            "ddG_vs_WT": round(ddG, 2) if ddG is not None else "",
            "wall_s": r.get("wall_s", "")})

    csv_path = OUT / "mmgbsa_gaff_results.csv"
    if rows:
        with csv_path.open("w", newline="") as f:
            w = csv.DictWriter(f, fieldnames=list(rows[0].keys()))
            w.writeheader(); w.writerows(rows)
        print(f"\n  → {csv_path}")
    json_path = OUT / "mmgbsa_gaff_summary.json"
    json_path.write_text(json.dumps(results, indent=2, default=str))
    print(f"  → {json_path}")

    print("\n=== ΔΔG_bind summary (relative to WT) ===")
    for r in rows:
        print(f"  {r['label']:<22s}  ΔG_bind={r['dG_bind']:>9}  ΔΔG={r['ddG_vs_WT']:>8}  {r['comment']}")


if __name__ == "__main__":
    main()
