#!/usr/bin/env python

##########################################################################      
# residence_buffered.py -- a lipid-protein binding lifetime calculator   #
# Copyright 2019 Manuel N. Melo                                          #       
# m.n.melo@itqb.unl.pt                                                   #       
#                                                                        #       
# This program is free software: you can redistribute it and/or modify   #       
# it under the terms of the GNU General Public License as published by   #       
# the Free Software Foundation, either version 3 of the License, or      #       
# (at your option) any later version.                                    #       
#                                                                        #       
# This program is distributed in the hope that it will be useful,        #       
# but WITHOUT ANY WARRANTY; without even the implied warranty of         #       
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the          #       
# GNU General Public License for more details.                           #       
#                                                                        #       
# You should have received a copy of the GNU General Public License      #       
# along with this program.  If not, see <https://www.gnu.org/licenses/>. #       
##########################################################################  

import sys
import cPickle
import mdreader
import numpy as np
import MDAnalysis
import math
import scikits.bootstrap as bootstrap
from scipy.stats import binned_statistic

def prepare_system(args):
    syst = mdreader.MDreader(args)
    syst.add_argument("-cut", dest='cutoff', metavar='CUTOFF', type=float, help="real\tThe contact cutoff, in nm.", default=0.65)
    syst.add_argument("-buff", dest='buffer', metavar='BUFFER', type=float, help="real\tThe contact buffer width, in nm.", default=0.2)
    syst.add_argument("-sel", metavar='SELECTION', help="Selection string that defines the beads to get residence times from.")
    syst.add_argument("-smooth", metavar='SMOOTHFRAMES', type=int, help="int \tNumber of frames to smooth over (central frame is set to 1 if average is > 0.5.", default=5)
    syst.add_argument("-bs", action='store_true', help="bool\tWhether to count residences from the lipid's perspective (the default) or the binding-site's.")
    syst.setargs(o="residences.xvg")

    syst.opts.cutoff *= 10
    syst.opts.cutoff2 = syst.opts.cutoff + syst.opts.buffer * 10
    syst.opts.buffered = bool(syst.opts.buffer)

    #derive the npy fname
    fname_parts = syst.opts.outfile.split('.')
    if len(fname_parts) > 1:
        fname_parts = fname_parts[:-1]
    syst.opts.npy_name = '.'.join(fname_parts) + ".npy"

    return syst

def calc_residences(syst):
    # Must see two out of these three groups to be in contact.
    bs1 = [58, 59, 60]
    bs2 = [73, 74, 75]
    bs3 = [81, 82]

    bs_sels = [syst.select_atoms(*["resnum {}".format(resn) for resn in bs]) for bs in (bs1, bs2, bs3)]  

    beads = syst.select_atoms(syst.opts.sel)
    nbeads = len(np.intersect1d(beads.indices, beads.residues[0].atoms.indices))
    nlips = len(beads)/nbeads

    def get_ctx():
        ctx = []
        ctx2 = []
        for res in bs_sels:
            dst = MDAnalysis.lib.distances.distance_array(res.positions, beads.positions)
            #print(dst.shape, nlips, nbeads)
            dst = dst.reshape(-1, nlips, nbeads)
            cut = dst <= syst.opts.cutoff
            ctx.append(np.any(cut, axis=(0,2)))
            if syst.opts.buffered:
                cut2 = dst <= syst.opts.cutoff2
                ctx2.append(np.any(cut2, axis=(0,2)))
        ctx_count = np.sum(ctx, axis=0)
        if syst.opts.buffered:
            ctx_count2 = np.sum(ctx2, axis=0)
            if syst.opts.bs:
                return [np.any(ctx_count >= 2)], [np.any(ctx_count2 >= 2)]
            else:
                return ctx_count >= 2, ctx_count2 >= 2 
        else:
            if syst.opts.bs:
                return [np.any(ctx_count >= 2)]
            else:
                return ctx_count >= 2

    result = np.array(syst.do_in_parallel(get_ctx))

    try:
        nframes = [rdr.n_frames for rdr in syst.trajectory.readers]
    except AttributeError:
        nframes = [len(syst.trajectory)]
        multitraj = False
    else:
        multitraj = True

    if syst.opts.buffered:
        if multitraj:
            result = [liptraj for segment in np.split(result, np.cumsum(nframes)[:-1], axis=0) for liptraj in filter_buffer(segment).T]
        else:
            result = [liptraj for liptraj in filter_buffer(result).T]
    else:
        if multitraj:
            result = [liptraj for segment in np.split(result, np.cumsum(nframes)[:-1], axis=0) for liptraj in segment.T]
        else:
            result = [liptraj for liptraj in result.T]

    #cPickle.dump(result, open("raw_residences.obj","w"))

    # filtering: most out of N consecutive
    conv = np.ones(syst.opts.smooth, dtype=np.float)/syst.opts.smooth
    result = [np.convolve(arr, conv, mode='same')>.5 for arr in result]

    # getting times and trimming edges
    residences = [get_times(arr) for arr in result]
    residences = np.concatenate(residences)
    if not len(residences):
        print("No contacts found.")
        sys.exit()

    np.save(syst.opts.npy_name, residences)

    residences_frame, residences_time = residences, residences*syst.trajectory.dt/1.e6
    
    #weighted by itself, to reflect the mass center of the "scaled_residences" histogram.
    avg_time = np.average(residences_time, weights=residences_time) 
    sys.stderr.write("Bootstrapping...")
    ci = bootstrap.ci(residences_time, statfunction=weighted_avg,
                      output='errorbar', n_samples=10000)[:,0]
    xvg_header = " {}\n Sel: {}\n Number of events: {}\n Residence: {} -{} +{}".format(syst.info_header(), syst.opts.sel, len(residences), avg_time, ci[0], ci[1])

    #maxt = residences.max() * syst.trajectory.dt
    #logmaxt = math.ceil(math.log(maxt, 10))
    logmint = math.log(syst.trajectory.dt*1.e-6, 10) - 0.01
    bins = np.logspace(logmint, 1, num=20)
    xs = np.sqrt(bins[:-1]*bins[1:])
    bwdths = np.diff(bins)
    if not len(residences_time):
        hist = scaledhist = np.zeros_like(xs)
    else:
        hist = np.histogram(residences_time, bins)[0]
        scaledhist = binned_statistic(residences_time, residences_time, statistic='sum', bins=bins)[0]
    #hist = hist.astype(np.float)/hist.sum()
    #print len(bins), len(xs)
    #print bins
    #print xs
    np.savetxt(syst.opts.outfile, np.column_stack((xs, hist)), header=xvg_header)
    # The scaled version
    total_traj_time = len(syst) * syst.trajectory.dt/1.e6
    np.savetxt("scaled_"+syst.opts.outfile, np.column_stack((xs, scaledhist/total_traj_time)), header=xvg_header)
    np.savetxt("scaled_norm_"+syst.opts.outfile, np.column_stack((xs, scaledhist/bwdths)), header=xvg_header)

    hist_lin = np.bincount(residences_frame)[1:]
    x_lin = np.arange(1, len(hist_lin)+1) * syst.trajectory.dt

    np.savetxt("lin_"+syst.opts.outfile, np.column_stack((x_lin, hist_lin)), header=xvg_header)

def get_times(data):
    diff = np.diff(data)
    events = diff.nonzero()[0]
    # ignores if it starts bound
    if data[0]:
        events = events[1:]
    # ignores last residents
    return np.diff(events)[::2]

def weighted_avg(data, weights=None):
    return np.average(data, weights=data)

def filter_buffer(data):
    inner = data[:,0]
    outer = data[:,1]

    res = inner*outer
    while True:
        res_shift = np.roll(res, 1, axis=0)
        res_shift[0] = False
        newres = np.where(inner != outer, res_shift, res)
        if np.array_equal(newres, res):
            return res
        else:
            res = newres

if __name__ == "__main__":
    syst = prepare_system(sys.argv[1:])
    calc_residences(syst)
