# -*- coding: utf-8 -*-
# Copyright (c) 2015, PyRETIS Development Team.
# Distributed under the LGPLv3 License. See LICENSE for more info.
"""This file defines the order parameter used for the hydrate example."""
import logging
import os
import subprocess
import numpy as np
import mdtraj as md
from pyretis.orderparameter import OrderParameter
logger = logging.getLogger(__name__)  # pylint: disable=C0103
logger.addHandler(logging.NullHandler())

def pbc(lst, pt):
    for i in range(pt):
        if lst[-1] == 0:
            lst = np.roll(lst, +1)
    return lst

def cnt0(lst, cnt_max=0):
    cnt = 0
    for i in lst:
        if i == 0:
             cnt += 1
        else:
            if cnt_max < cnt:
                cnt_max = cnt
            cnt = 0
    return cnt_max

def cnt0seq(lst, ngrid=10):
    n_seq = 0
    for i in range( len(lst) - 2):
        if lst[i] == lst[i+1] == lst[i+2] == 0:
            n_seq += 1
    return n_seq


class H20Hole(OrderParameter):
    """H20Ccontinuous(OrderParameter).

    This class counts the size of H2O holes

    Attributes
    ----------
    name : string
        A human readable name for the order parameter
    """

    def __init__(self):
        """Initialize the order parameter.

        """
        super().__init__(description='H2O-hole')
        trj=md.load('gromacs_input/conf.gro')  # <---- shortcut, make sure the file is there
        self.idx_o = trj.top.select("symbol == O")
        self.idx_c = trj.top.select("symbol == C")


    def calculate(self, system):
        """Calculate the order parameter.

        Here, the order parameter is just the minimal value of 
        local H2O density in the thin film (2D density).

        Parameters
        ----------
        system : object like :py:class:`.System`
            This object is used for the actual calculation, typically
            only `system.particles.pos` and/or `system.particles.vel`
            will be used. In some cases `system.forcefield` can also be
            used to include specific energies for the order parameter.

        Returns
        -------
        out : float
            The order parameter list.
        """

        gridpoints = 85 

        pos = system.particles.pos[self.idx_c, :]
        box = [15.00000, 15.00000,  5.23032] 

        minpos, maxpos = [], []
        natoms = len(pos[:, 0])
        for i in range(0, 3):
            minpos.append(0.)
            maxpos.append(box[i])

        d2hyst = np.zeros([gridpoints, gridpoints])
        for i in pos:
            neg_check = np.where(i<0)[0]
            if len(neg_check) != 0:
                i[neg_check] += np.array(box)[neg_check]
            d2grid = np.array(i/maxpos*gridpoints)
            d2grid = list(map(int, d2grid))
            for k in [0, 1]:
                #d2grid[k] -= gridpoints
                if d2grid[k] >= gridpoints:
                    d2grid[k] -= gridpoints
                if d2grid[k] < 0:
                    d2grid[k] += gridpoints
            d2hyst[d2grid[0], d2grid[1]] += 1

        cnt_x = 0
        cnt_ngb = 0
        cnt_hole = [0, 0]
        X_xy = []
        for i, (x, y) in enumerate(zip(d2hyst, d2hyst.T)):
            #print('')
            # pbc
            x = pbc(x, gridpoints)
            y = pbc(y, gridpoints)
            cnt_hole[0] = cnt0(x, cnt_hole[0])
            cnt_hole[1] = cnt0(y, cnt_hole[1])

            cnt_ngb += cnt0seq(x, gridpoints) + cnt0seq(y, gridpoints)
    
            for j in range(len(x)):
                # PBC
                ix = [j, j+1, j+2]
                iy = [i, i+1, i+2]
                for k in [1, 2]:
                    if ix[k] >= gridpoints:
                        ix[k] -= gridpoints
                    if iy[k] >= gridpoints:
                        iy[k] -= gridpoints
                #print_d = 'o '

                if d2hyst[ix[0], iy[1]] == d2hyst[ix[1], iy[1]] == d2hyst[ix[2], iy[1]] == 0:
                    if d2hyst[ix[1], iy[0]] == d2hyst[ix[1], iy[1]] == d2hyst[ix[1], iy[2]]:
                        cnt_x += 1
                        #print_d = 'X ' 
                        X_xy.append(ix[1])
                        X_xy.append(iy[1])
                #print(print_d, end = '')

        orderp = [cnt_x, *cnt_hole, cnt_ngb, *X_xy]
        #orderp = [cnt_x, *cnt_hole, cnt_ngb]
        print(orderp, system.particles.pos[0])
        return orderp

        
