#!/usr/bin/env python

##########################################################################
# get_contacts_bootst.py -- a protein-lipid contact calculator           #        #
# Copyright 2019 Manuel N. Melo                                          #       
# m.n.melo@itqb.unl.pt                                                   #       
#                                                                        #       
# This program is free software: you can redistribute it and/or modify   #       
# it under the terms of the GNU General Public License as published by   #       
# the Free Software Foundation, either version 3 of the License, or      #       
# (at your option) any later version.                                    #       
#                                                                        #       
# This program is distributed in the hope that it will be useful,        #       
# but WITHOUT ANY WARRANTY; without even the implied warranty of         #       
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the          #       
# GNU General Public License for more details.                           #       
#                                                                        #       
# You should have received a copy of the GNU General Public License      #       
# along with this program.  If not, see <https://www.gnu.org/licenses/>. #       
##########################################################################

import MDAnalysis
import mdreader
import numpy as np
import scikits.bootstrap as bootstrap
import sys
import cPickle

syst = mdreader.MDreader()
#syst.add_argument("-o", metavar='SUFFIX', type=str, help="str \tThe result filename suffix. It will be appended to each analyzed residue name.", default="_contacts.xvg")
syst.add_argument("-res", nargs='+', type=str, help="str \tThe selections (space separated) to calculate contacts against protein residues.", default=["name AM1 AM2"])
syst.add_argument("-cut", dest='cutoff', metavar='CUTOFF', type=float, help="real\tThe contact cutoff, in nm.", default=0.7)
syst.add_argument("-bst", metavar='TIME', type=float, help="real\tThe bootstrap block time, in ps.", default=5000000)
syst.setargs(o="contacts.xvg")

syst.opts.cutoff *= 10
protein = syst.select_atoms("protein")
protres = protein.residues
nres = len(protres)
analysis_residues = [syst.select_atoms(res) for res in syst.opts.res]
nanalysis = len(analysis_residues)
res_split = np.diff(protein.resnums).nonzero()[0] + 1 

def get_contacts():
    ctx_num = np.empty((nanalysis, len(protein)), dtype=np.bool)
    for res, arr in zip(analysis_residues, ctx_num):
        arr[:] = np.any(MDAnalysis.lib.distances.distance_array(protein.positions, res.positions) < syst.opts.cutoff, axis=1)
    result = []
    for res_arr in np.split(ctx_num, res_split, axis=1):
        if res_arr.shape[1] > 1:
            result.append(np.any(res_arr, axis=1))
        else:
            result.append(res_arr[:,0])
    return np.column_stack(result)

def loop_inplace_sum(arrlist):
    # performance comparison in http://stackoverflow.com/questions/20640396/quickly-summing-numpy-arrays-element-wise
    # assumes len(arrlist) > 0
    sm = arrlist[0].copy().astype(np.float)
    for a in arrlist[1:]:
        sm += a
    return sm

#result = syst.do_in_parallel(get_contacts, parallel=False)
result = syst.do_in_parallel(get_contacts)
#try:
#    result = cPickle.load(open('res.obj'))
#except IOError:
#    result = syst.do_in_parallel(get_contacts)
#    cPickle.dump(result, open('res.obj', 'w'))

nframes_per_block = int(syst.opts.bst/syst.trajectory.dt)
nblocks = len(syst)/nframes_per_block
block_result = []
print("Blocking...")
for n in xrange(nblocks):
    block = result[n*nframes_per_block:(n+1)*nframes_per_block]
    block_avg = loop_inplace_sum(block)
    block_result.append(block_avg/nframes_per_block)
block_result = np.array(block_result) # shape: nblocks, nanalysis, nres
print(block_result.shape)
#print(result[0])
#print res_split, len(result), result[0].shape
avg = loop_inplace_sum(result)
avg /= len(syst)

print("Bootstrapping...")
cis = np.empty((2, nanalysis, nres))
for analysis in xrange(nanalysis):
    for res in xrange(nres):
        data = block_result[:, analysis, res]
        #sys.stderr.write("Res: {}, Data: {}\n".format(res, data))
        try:
            cis[:, analysis, res] = bootstrap.ci(data, output='errorbar', n_samples=2000)[:,0]
        except IndexError:
            cis[:, analysis, res] = 0., 0.

np.savetxt(syst.opts.outfile, np.column_stack((np.arange(1, nres+1), avg.T, cis[0].T, cis[1].T)),
        fmt=["%d"] + nanalysis*3*["%.4f"],
        header=" {}\n Res# ".format(syst.info_header()) + " ".join(syst.opts.res))

