#!/usr/bin/env python3
"""Single matplotlib panel summarising the per-residue Plasmodium vs human
cavity-18 substitutions and what each substitution does to a putative
ligand–protein contact.

For each of the 21 Pf↔Hs cavity-18 substitutions:
  - WT residue (human) + replacement (Pf)
  - Δ hydropathy (Kyte–Doolittle)
  - Δ side-chain volume (Å³, Zamyatnin)
  - "Hot" if the position is also a contact residue in the indazole or
    ibuprofen cavity-18 pose (from all_interactions.csv)
"""
import csv, json
from pathlib import Path
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
import matplotlib.patches as mp
import numpy as np

REPO = Path(__file__).resolve().parents[2]
MUT_JSON = REPO / "14_inhibitor_design" / "04_allosteric" / "cavity18_evidence" / "downloads" / "cavity18_mutations_per_taxon.json"
INTERACT = REPO / "14_inhibitor_design" / "04_allosteric" / "poses" / "all_interactions.csv"
OUT      = REPO / "14_inhibitor_design" / "presentation" / "figures" / "pf_specificity_panel.png"

# Kyte-Doolittle hydropathy index
KD = {"A":1.8,"C":2.5,"D":-3.5,"E":-3.5,"F":2.8,"G":-0.4,"H":-3.2,"I":4.5,
      "K":-3.9,"L":3.8,"M":1.9,"N":-3.5,"P":-1.6,"Q":-3.5,"R":-4.5,"S":-0.8,
      "T":-0.7,"V":4.2,"W":-0.9,"Y":-1.3}
# Zamyatnin side-chain volume (Å³)
VOL = {"A":88.6,"C":108.5,"D":111.1,"E":138.4,"F":189.9,"G":60.1,"H":153.2,
       "I":166.7,"K":168.6,"L":166.7,"M":162.9,"N":114.1,"P":112.7,"Q":143.8,
       "R":173.4,"S":89.0,"T":116.1,"V":140.0,"W":227.8,"Y":193.6}

muts = json.loads(MUT_JSON.read_text())["Plasmodium_falciparum"]
rows = []
for m in muts:
    wt, pos, new = m[0], int(m[1:-1]), m[-1]
    rows.append({
        "pos": pos, "wt": wt, "new": new,
        "d_kd":  KD[new]  - KD[wt],
        "d_vol": VOL[new] - VOL[wt],
    })

# Contact-residue flag — which positions ALSO touch indazole or ibuprofen
# in the cavity-18 docking?
contact_pos = set()
if INTERACT.exists():
    for r in csv.DictReader(INTERACT.open()):
        if r.get("cavity") and "18" in r["cavity"]:
            try:
                contact_pos.add(int(r["resid"]))
            except Exception: pass

# Sort by position
rows.sort(key=lambda r: r["pos"])
labels = [f"{r['wt']}{r['pos']}{r['new']}" for r in rows]
n = len(rows)
xs = np.arange(n)
hot_mask = np.array([r["pos"] in contact_pos for r in rows])

# Two-row figure: top = Δ hydropathy bars, bottom = Δ side-chain volume bars
fig, axes = plt.subplots(2, 1, figsize=(15, 6.2), sharex=True,
                         gridspec_kw={"hspace": 0.15, "left": 0.07,
                                      "right": 0.985, "top": 0.91,
                                      "bottom": 0.13})

# Top: Δ hydropathy
ax = axes[0]
colours_kd = ["#C84427" if r["d_kd"] > 1.5 else
              "#2563EB" if r["d_kd"] < -1.5 else
              "#8A8470" for r in rows]
ax.bar(xs, [r["d_kd"] for r in rows], color=colours_kd, edgecolor="#1D1F24",
       linewidth=0.5)
ax.set_ylabel("Δ Hydropathy\n(Pf − Hs)", fontsize=11)
ax.axhline(0, color="#3A3628", lw=0.5)
ax.set_facecolor("#FDFCF7")
ax.grid(True, axis="y", color="#D9D4C2", linewidth=0.5, alpha=0.7)
for s in ("top","right"): ax.spines[s].set_visible(False)
ax.set_title("Plasmodium falciparum ↔ Homo sapiens — cavity-18 residue substitutions  (21 positions)",
             fontsize=13.5, fontweight="bold", color="#1D1F24")
# Mark hot positions
for i, hot in enumerate(hot_mask):
    if hot:
        ax.plot(i, ax.get_ylim()[1] - 0.4, marker="v", color="#C84427", markersize=8)

# Bottom: Δ side-chain volume
ax = axes[1]
colours_vol = ["#C84427" if r["d_vol"] > 30 else
               "#2563EB" if r["d_vol"] < -30 else
               "#8A8470" for r in rows]
ax.bar(xs, [r["d_vol"] for r in rows], color=colours_vol, edgecolor="#1D1F24",
       linewidth=0.5)
ax.set_ylabel("Δ Side-chain vol.\n(Å³, Pf − Hs)", fontsize=11)
ax.axhline(0, color="#3A3628", lw=0.5)
ax.set_facecolor("#FDFCF7")
ax.grid(True, axis="y", color="#D9D4C2", linewidth=0.5, alpha=0.7)
for s in ("top","right"): ax.spines[s].set_visible(False)
ax.set_xticks(xs)
ax.set_xticklabels(labels, rotation=45, ha="right", fontsize=9.5)
ax.set_xlabel("Mutation  (human residue → Plasmodium residue, at cavity-18 positions)",
              fontsize=10.5, labelpad=8, color="#3A3628")

# Legend / interpretation
handles = [
    mp.Patch(facecolor="#C84427", edgecolor="#1D1F24",
             label="Substantial change (Pf more hydrophobic / larger)"),
    mp.Patch(facecolor="#2563EB", edgecolor="#1D1F24",
             label="Substantial change (Pf more polar / smaller)"),
    mp.Patch(facecolor="#8A8470", edgecolor="#1D1F24",
             label="Mild change"),
]
if hot_mask.any():
    handles.append(plt.Line2D([0], [0], marker="v", color="w",
                              markerfacecolor="#C84427", markersize=10,
                              label="▼ position also touches an indazole/ibuprofen pose"))
axes[0].legend(handles=handles, loc="upper right", frameon=False, fontsize=9)

plt.savefig(OUT, dpi=150, bbox_inches="tight", facecolor="#F5F3EC")
print(f"→ {OUT}  ({OUT.stat().st_size/1024:.0f} KB)")
