#!/usr/bin/env python

##########################################################################      
# get_occs.py -- a ligand occupancy mapper in 3D space                   #
# Copyright 2019 Manuel N. Melo                                          #       
# m.n.melo@itqb.unl.pt                                                   #       
#                                                                        #       
# This program is free software: you can redistribute it and/or modify   #       
# it under the terms of the GNU General Public License as published by   #       
# the Free Software Foundation, either version 3 of the License, or      #       
# (at your option) any later version.                                    #       
#                                                                        #       
# This program is distributed in the hope that it will be useful,        #       
# but WITHOUT ANY WARRANTY; without even the implied warranty of         #       
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the          #       
# GNU General Public License for more details.                           #       
#                                                                        #       
# You should have received a copy of the GNU General Public License      #       
# along with this program.  If not, see <https://www.gnu.org/licenses/>. #       
##########################################################################  

import mdreader
import numpy as np

def round_grid_boundary(vals, bw, is_min=True):
    if is_min:
        return (vals/bw).astype(np.int)*bw
    else:
        return ((vals/bw).astype(np.int) + 1)*bw

def write_dx(fname, data, origin, delta):
    nbins = data.shape
    data = data.flatten()
    origin = origin + delta/2
    fastprint_data, leftover_data = np.split(data, [len(data)/3*3])
    fastprint_data = fastprint_data.reshape(-1,3)
    with open(fname, 'w') as DX:
        DX.write("object 1 class gridpositions counts {} {} {}\n".format(*nbins))
        DX.write("origin {} {} {}\n".format(*origin))
        DX.write("delta {} 0 0\n".format(delta))
        DX.write("delta 0 {} 0\n".format(delta))
        DX.write("delta 0 0 {}\n".format(delta))
        DX.write("object 2 class gridconnections counts {} {} {}\n".format(*nbins))
        DX.write("object 3 class array type double rank 0 items {} data follows\n".format(np.product(nbins)))
        np.savetxt(DX, fastprint_data, fmt="%g")
        np.savetxt(DX, leftover_data, fmt="%g")
        DX.write('\nobject "occupancy (resname {})" class field\n'.format(resname))

def sanitize(string):
    return string.replace(' ','_')

def sum_unique(arrs):
    unique = []
    for arr in arrs:
        for other in unique:
            if other is arr:
                break
        else:
            unique.append(arr)
    return np.sum(unique, axis=0)

syst = mdreader.MDreader()
syst.add_argument("-res", nargs='+', type=str, help="str \tThe selections (space separated) to calculate contacts against protein residues.", default=["name AM2"])
syst.add_argument("-gw", metavar='WIDTH', type=float, help="real\tThe grid width, in nm.", default=0.5)
syst.setargs(o="_occupancy.dx")

syst.opts.gw *= 10
grid_size = syst.dimensions[:3]/2
# Allowing extra 20% for pressure wobbling
grid_min = round_grid_boundary(-grid_size*1.2, syst.opts.gw)
grid_max = round_grid_boundary(grid_size*1.2, syst.opts.gw, is_min=False)
grid_nbins = np.rint((grid_max - grid_min)/syst.opts.gw).astype(np.int)
range_edges = np.column_stack((grid_min, grid_max))

analysis_residues = [(sanitize(res), syst.select_atoms(res), np.zeros(grid_nbins, dtype=np.int64))
                      for res in syst.opts.res]
prot = syst.select_atoms("protein")
prot_cog = prot.center_of_geometry()

def get_histogram():
    ret = []
    prot_cog = prot.center_of_geometry()
    for name, resgrp, resarr in analysis_residues:
        cdx = resgrp.positions - prot_cog
        hist, edges = np.histogramdd(cdx, bins=grid_nbins, range=range_edges)
        resarr[:] = resarr + hist.astype(np.bool)
        ret.append(resarr)
    return ret

results = [sum_unique(res) for res in zip(*syst.do_in_parallel(get_histogram))]

for (resname, resgrp, resarr), summed_res in zip(analysis_residues, results):
    hist = summed_res.astype(np.float64)
    hist /= len(syst)
    write_dx(resname + syst.opts.outfile, hist, prot_cog+grid_min, syst.opts.gw)

