#!/usr/bin/env python
"""Render the three sub-case avatars for the Submoves (2022) study.

Each avatar is generated directly from the case input data so the picture
faithfully represents the system that PyRETIS samples:

    internal -> the 1D double-well potential V(x) = a x^4 - b (x - c)^2
    hole     -> the dodecane thin-film carbon density (the H2O/film "hole")
    ruru     -> the di-ruthenium aqua complex (proton-transfer redox)

Usage
-----
    python make_submoves_avatars.py [SOURCE_DIR] [OUT_DIR]

SOURCE_DIR defaults to the directory holding internal/, hole/ and ruru/;
OUT_DIR defaults to ./avatars.
"""
import os
import sys

import numpy as np
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
from matplotlib.patches import FancyArrowPatch

PX = 400          # final avatar size in pixels (square)
DPI = 100
FIGSIZE = (PX / DPI, PX / DPI)

BLUE = '#1f9bff'
ORANGE = '#ffa500'


def _save_square(fig, path):
    """Save a matplotlib figure as a tight square PNG."""
    fig.savefig(path, dpi=DPI, transparent=False,
                facecolor='white', bbox_inches=None, pad_inches=0)
    plt.close(fig)


def _pad_to_square(raw_path, out_path):
    """Centre a tight (ASE) render on a clean square white canvas."""
    raw = plt.imread(raw_path)
    fig, ax = plt.subplots(figsize=FIGSIZE, dpi=DPI)
    fig.subplots_adjust(left=0, right=1, top=1, bottom=0)
    ax.set_facecolor('white')
    ax.imshow(raw)
    ax.axis('off')
    ax.set_aspect('equal')
    _save_square(fig, out_path)
    os.remove(raw_path)


# --------------------------------------------------------------------------
def avatar_internal(source_dir, out_dir):
    """1D double-well potential with the TIS interfaces marked."""
    a, b, c = 1.0, 2.0, 0.0
    x = np.linspace(-1.55, 1.55, 600)
    v = a * x**4 - b * (x - c)**2

    fig, ax = plt.subplots(figsize=FIGSIZE, dpi=DPI)
    fig.subplots_adjust(left=0.02, right=0.98, top=0.98, bottom=0.02)

    ax.fill_between(x, v, v.max(), color='#eef4fb', zorder=1)
    ax.plot(x, v, color='#1b2a3a', lw=3.5, zorder=3)

    # the two metastable states at the minima x = +/- 1
    ax.scatter([-1.0], [a - b], s=320, color=BLUE, zorder=4,
               edgecolor='white', linewidth=2)
    ax.scatter([1.0], [a - b], s=320, color=ORANGE, zorder=4,
               edgecolor='white', linewidth=2)

    # a reactive path going over the barrier
    arrow = FancyArrowPatch((-1.0, a - b + 0.18), (1.0, a - b + 0.18),
                            connectionstyle='arc3,rad=-0.55',
                            arrowstyle='-|>', mutation_scale=22,
                            lw=2.5, color='#d6334c', zorder=5)
    ax.add_patch(arrow)

    ax.set_xlim(-1.6, 1.6)
    ax.set_ylim(v.min() - 0.25, v.max() + 0.15)
    ax.axis('off')
    _save_square(fig, os.path.join(out_dir, 'submoves-internal-400x400.png'))


# --------------------------------------------------------------------------
def _read_g96_positions(path):
    """Return the POSITIONRED block of a GROMACS .g96 file as an (N, 3)."""
    coords = []
    inside = False
    with open(path, 'r') as fileh:
        for line in fileh:
            tag = line.strip()
            if tag == 'POSITIONRED':
                inside = True
                continue
            if inside:
                if tag == 'END':
                    break
                parts = line.split()
                coords.append([float(parts[0]), float(parts[1]),
                               float(parts[2])])
    return np.array(coords)


def avatar_hole(source_dir, out_dir):
    """Tilted render of the dodecane thin film whose rupture is sampled."""
    from ase import Atoms

    g96 = os.path.join(source_dir, 'hole', 'gromacs_input', 'conf.g96')
    pos = _read_g96_positions(g96)

    # topol.top ordering: 4 x 5984 TIP4P waters (4 sites) then 1100 dodecanes
    n_water = 4 * 5984 * 4
    dod = pos[n_water:].reshape(1100, 38, 3)
    # carbon offsets within a DOD molecule (1->0 indexed) from dodecane.itp
    carbon = [0, 1, 5, 8, 11, 14, 17, 20, 23, 26, 29, 32]
    cpos = dod[:, carbon, :].reshape(-1, 3)
    cpos = np.mod(cpos, [15.0, 15.0, 1.0e9])   # wrap the film into the box

    atoms = Atoms('C%d' % len(cpos), positions=cpos)
    atoms.center()

    # colour the carbons by height so the slab reads with depth
    zcoord = cpos[:, 2]
    znorm = (zcoord - zcoord.min()) / (np.ptp(zcoord) + 1.0e-9)
    cols = plt.get_cmap('YlGnBu')(0.25 + 0.6 * znorm)[:, :3]

    from ase.io import write
    tmp = os.path.join(out_dir, '_hole_raw.png')
    write(tmp, atoms, format='png', rotation='65x,0y,0z',
          colors=cols, radii=0.9, scale=12, show_unit_cell=0)
    _pad_to_square(tmp, os.path.join(out_dir, 'submoves-hole-400x400.png'))


# --------------------------------------------------------------------------
def avatar_ruru(source_dir, out_dir):
    """Ball-and-stick render of the di-ruthenium aqua complex."""
    from ase.io import read, write
    from ase.data import covalent_radii
    from ase.data.colors import jmol_colors

    xyz = os.path.join(source_dir, 'ruru', 'cp2k_input', 'initial.xyz')
    atoms = read(xyz)
    atoms.center()

    radii = np.array([covalent_radii[z] for z in atoms.numbers]) * 0.85
    tmp = os.path.join(out_dir, '_ruru_raw.png')
    write(tmp, atoms, format='png', rotation='75x,15y,0z',
          radii=radii, colors=jmol_colors[atoms.numbers],
          scale=28, show_unit_cell=0)
    _pad_to_square(tmp, os.path.join(out_dir, 'submoves-ruru-400x400.png'))


def main():
    source_dir = sys.argv[1] if len(sys.argv) > 1 else '.'
    out_dir = sys.argv[2] if len(sys.argv) > 2 else 'avatars'
    os.makedirs(out_dir, exist_ok=True)
    avatar_internal(source_dir, out_dir)
    avatar_hole(source_dir, out_dir)
    avatar_ruru(source_dir, out_dir)
    print('wrote avatars to', os.path.abspath(out_dir))


if __name__ == '__main__':
    main()
