/*
This file is part of FixBox, version 1.3.

Copyright (c) 2021-2022, Instituto de Tecnologia Quimica e Biologica,
Universidade Nova de Lisboa, Portugal.

FixBox is free software: you can redistribute it and/or modify it
under the terms of the GNU General Public License as published by the
Free Software Foundation, either version 2 of the License, or (at your
option) any later version.

FixBox is distributed in the hope that it will be useful, but
WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
General Public License for more details.

You should have received a copy of the GNU General Public License
along with FixBox.  If not, see <http://www.gnu.org/licenses/>.

For further details and info check the  file.

You can get FixBox at www.itqb.unl.pt/simulation
*/



#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <errno.h>
#include <stdarg.h>
#include <math.h>

#define VERSION "1.3"
#define MAXSTR 200
#define BIG_LENGTH 1e10
#define TINY_LENGTH 1e-10
#define TINY_VOLUME 1e-10

typedef struct {
  int anumb, rnumb, mol ;
  float r[3], s[3] ;
  char aname[6+1], rname[6+1], vels[24+1] ;
} atom_t ;

typedef struct {
  int nmols, *mol ;
  char name[MAXSTR] ;
} group_t ;

typedef struct {
  int firstatom, lastatom ;
} mol_t ;

typedef struct {
  float dRMI2, dS[3] ;
} dist_t ;

typedef struct {
  int mol, dN[3], assembled ;
} obj_t ;


/* Function prototypes */
void parse_arguments(int argc, char **argv) ;
void read_frame(void) ;
void read_moldef(void) ;
void write_frame(void) ;
int gnumb(char groupname[]) ;
void r_to_s(void) ;
void s_to_r(void) ;
void compute_objdist(void) ;
void assemble_objs(void) ;
void center(void) ;
void apply_PBCs(void) ;
int nint(float x) ;
void message(char mtype, char *format, ...) ;

/* Global variables */
int natoms, ngroups, nmols, nstages, *nobjs, nobjs_tot, gcent[3],
  gPBC[3], nframe = 0, nClines = 0, nPlines = 0 ;
float B[3][3], iB[3][3], boxC[3], halfbox2 ;
char title[MAXSTR+1], boxline[MAXSTR+1], centEW[3] ;
static const char empty_gname[] = "None" ;
atom_t *atom ;
group_t *group ;
mol_t *mol ;
dist_t **dist ;
obj_t *obj ;
FILE *fp_gro, *fp_moldef ;


int main(int argc, char **argv)
{
  parse_arguments(argc, argv) ;

  while (fgets(title, sizeof title, fp_gro) != NULL)
  {
    nframe++ ;
    read_frame() ;
    if (nframe == 1) read_moldef() ;
    r_to_s() ;
    compute_objdist() ;
    assemble_objs() ;
    center() ;
    apply_PBCs() ;
    s_to_r() ;
    write_frame() ;
  }
  fclose(fp_gro) ;

  /* free stuff... */

  return 0 ;
}


/* Reads arguments */
void parse_arguments(int argc, char **argv)
{
  char *gro_file, *moldef_file ;
  
  if (argc != 3) message('U', "Wrong number of arguments.\n") ;

  gro_file = calloc(strlen(argv[1])+1, sizeof(char)) ;
  strcpy(gro_file, argv[1]) ;
  if ((fp_gro = fopen(gro_file, "r")) == NULL)
    message('E', "Opening \"%s\": %s\n", gro_file, strerror(errno));

  moldef_file = calloc(strlen(argv[2])+1, sizeof(char)) ;
  strcpy(moldef_file, argv[2]) ;
  if ((fp_moldef = fopen(moldef_file, "r")) == NULL)
    message('E', "Opening \"%s\": %s\n", moldef_file, strerror(errno));
}


void read_frame(void)
{
  int a, k, j ;
  float detB ;
  char line[MAXSTR+1], rnumb[6], rname[6], aname[6], anumb[6],
    ax[9], ay[9], az[9] ;

  /* Read number of atoms */
  fgets(line, sizeof line, fp_gro) ;
  sscanf(line, "%d", &natoms) ;
  atom = calloc(natoms+1, sizeof(atom_t)) ;

  /* Read atoms */
  for (a = 1 ; a <= natoms ; a++)
  {
    fgets(line, sizeof line, fp_gro) ;
    sscanf(line, "%5[^\n]%5[^\n]%5[^\n]%5[^\n]%8[^\n]%8[^\n]%8[^\n]%24[^\n]",
	   rnumb, rname, aname, anumb, ax, ay, az, atom[a].vels) ;
    sscanf(rnumb, "%d", &(atom[a].rnumb)) ;
    sscanf(rname, "%s", atom[a].rname) ;
    sscanf(aname, "%s", atom[a].aname) ;
    sscanf(anumb, "%d", &(atom[a].anumb)) ;
    sscanf(ax, "%f", &(atom[a].r[0])) ;
    sscanf(ay, "%f", &(atom[a].r[1])) ;
    sscanf(az, "%f", &(atom[a].r[2])) ;
  }
  fgets(boxline, sizeof boxline, fp_gro) ;
  /* Read box matrix B (free format space-separated reals, with 3 or 9
     reals): Order = a(x) b(y) c(z) a(y) a(z) b(x) b(z) c(x) c(y) If the
     last 6 fields do not exist, they should be set to zero.  The usual
     actual format in GROMACS seems to be " %9.5" for each. */
  sscanf(boxline, "%f %f %f %f %f %f %f %f %f",
	 &(B[0][0]), &(B[1][1]), &(B[2][2]), &(B[1][0]), &(B[2][0]),
	 &(B[0][1]), &(B[2][1]), &(B[0][2]), &(B[1][2])) ;

  /* Compute inverse of B using the determinant+cofactors method: */
  detB = B[0][0] * (B[1][1] * B[2][2] - B[1][2] * B[2][1])
         - B[0][1] * (B[1][0] * B[2][2] - B[1][2] * B[2][0])
         + B[0][2] * (B[1][0] * B[2][1] - B[1][1] * B[2][0]) ;
  if (fabs(detB) < TINY_VOLUME)
    message('E', "Tiny box in frame %d. Box info may be missing...\n", nframe) ;
  for (k = 0 ; k < 3 ; k++)
    for (j = 0 ; j < 3 ; j++)
      iB[k][j] = ( B[(j+1)%3][(k+1)%3] * B[(j+2)%3][(k+2)%3]
	           - B[(j+1)%3][(k+2)%3] * B[(j+2)%3][(k+1)%3] ) / detB ;

  /* Compute box center, assuming origin at the "lowest" box corner */
  for (k = 0 ; k < 3 ; k++)
    boxC[k] = (B[k][0] + B[k][1] + B[k][2]) / 2 ;

}


void read_moldef(void)
{
  /* Try to simplify/reduce the number of auxiliary variables... */
  int k, g, a, m, gm, prevrnumb, i, subg, st, gst, notot, o ;
  char line[MAXSTR+1], auxname[MAXSTR], auxstr[3][MAXSTR] ;

  /* 1st pass: count groups and assembling stages */
  ngroups = nstages = 0 ;
  while(fgets(line, sizeof line, fp_moldef) != NULL)
  {
    if (line[0] == 'G') ngroups++ ;
    else if (line[0] == 'A') nstages++ ;
  }
  group = calloc(ngroups+1, sizeof(group_t)) ; /* indices = {0 to ngroups} */
  nobjs = calloc(nstages+1, sizeof(int)) ;     /* indices = {0 to nstages} */
  /* Define empty group */
  strcpy(group[0].name, empty_gname) ;
  group[0].nmols = 0 ;
  
  /* 2nd pass: read group names, count total and per-group molecules */
  rewind(fp_moldef) ;
  g = 0 ;
  nmols = 0 ;
  prevrnumb = 0 ;
  st = 0 ;
  nobjs_tot = 0 ;
  while(fgets(line, sizeof line, fp_moldef) != NULL)
  {
    if (line[0] == 'G')
    {
      sscanf(line, "G %s", group[++g].name) ;
      if (strcmp(group[g].name, group[0].name) == 0)
	message('E', "Name '%s' not allowed for user-defined groups.\n",
		group[0].name) ;;
    }
    else if (line[0] == 'a')
    {
      if (g == 0)
	message('E', "Entry 'a' found before a group was defined.\n") ;
      group[g].nmols++ ;
      nmols++ ;
    }
    else if (line[0] == 'n')
    {
      if (g == 0)
	message('E', "Entry 'n' found before a group was defined.\n") ;
      sscanf(line, "n %s", auxname) ;
      for (a = 1 ; a <= natoms ; a++)
      {
	if (strcmp(auxname, atom[a].rname) == 0)
	{
	  if (atom[a].rnumb != prevrnumb)
	  {
	    group[g].nmols++ ;
	    nmols++ ;
	  }
	}
	prevrnumb = atom[a].rnumb ;
      }
    }
    else if (line[0] == 'g')
    {
      if (g == 0)
	message('E', "Entry 'g' found before a group was defined.\n") ;
      sscanf(line, "g %s", auxname) ;
      group[g].nmols += group[gnumb(auxname)].nmols ;
    }
    else if (line[0] == 'A')
    {
      sscanf(line, "A %s", auxname) ;
      nobjs[++st] = nobjs_tot += group[gnumb(auxname)].nmols ;
    }
  }
  mol = calloc(nmols+1, sizeof(mol_t)) ;
  for (g = 1 ; g <= ngroups ; g++)
    group[g].mol = calloc(group[g].nmols+1, sizeof(int)) ;
  obj = calloc(nobjs_tot+1, sizeof(obj_t)) ;
  dist = calloc(nobjs_tot+1, sizeof(dist_t *)) ;
  for (o = 1 ; o <= nobjs_tot ; o++)
    dist[o] = calloc(nobjs_tot+1, sizeof(dist_t)) ;

  /* 3rd pass: assign molecules to groups and atoms to molecules. */
  rewind(fp_moldef) ;
  g = m = gm = 0 ;
  prevrnumb = 0 ;
  notot = 0 ;
  while(fgets(line, sizeof line, fp_moldef) != NULL)
  {
    if (line[0] == 'G')
    {
      g++ ;
      gm = 0 ;
    }
    else if (line[0] == 'a')
    {
      group[g].mol[++gm] = ++m ;
      sscanf(line, "a %d %d", &(mol[m].firstatom), &(mol[m].lastatom)) ;
      for (a = mol[m].firstatom ; a <= mol[m].lastatom ; a++)
      {
	if (atom[a].mol == 0) atom[a].mol = m ;
	else message('E', "Atom index %d already assigned to molecule.\n", a) ;
      }
    }
    else if (line[0] == 'n')
    {
      sscanf(line, "n %s", auxname) ;
      for (a = 1 ; a <= natoms ; a++)
      {
	if (strcmp(auxname, atom[a].rname) == 0)
	{
	  if (atom[a].rnumb != prevrnumb) group[g].mol[++gm] = ++m ;
	  if (mol[m].firstatom == 0) mol[m].firstatom = a ;
	  mol[m].lastatom = a ;
	  if (atom[a].mol == 0) atom[a].mol = m ;
	  else message('E', "Atom %d already assigned to molecule.\n",
		       atom[a].anumb) ;
	}
	prevrnumb = atom[a].rnumb ;
      }
    }
    else if (line[0] == 'g')
    {
      sscanf(line, "g %s", auxname) ;
      subg = gnumb(auxname) ;
      for (i = 1 ; i <= group[subg].nmols ; i++)
	group[g].mol[++gm] = group[subg].mol[i] ;
    }
    else if (line[0] == 'A')
    {
      sscanf(line, "A %s", auxname) ;
      gst = gnumb(auxname) ;
      for (i = 1 ; i <= group[gst].nmols ; i++)
	obj[++notot].mol = group[gst].mol[i] ;
    }
    else if (line[0] == 'C')
    {
      nClines++ ;
      sscanf(line, "C %s %s %s %c %c %c", auxstr[0], auxstr[1], auxstr[2],
	     &(centEW[0]), &(centEW[1]), &(centEW[2])) ;
      for (k = 0 ; k < 3 ; k++)
      {
	gcent[k] = gnumb(auxstr[k]) ;
	if (centEW[k] != 'E' && centEW[k] != 'W')
	  message('E', "Give either 'E' or 'W' as message codes in C-line.\n") ;
      }
    }
    else if (line[0] == 'P')
    {
      nPlines++ ;
      sscanf(line, "P %s %s %s", auxstr[0], auxstr[1], auxstr[2]) ;
      for (k = 0 ; k < 3 ; k++)	gPBC[k] = gnumb(auxstr[k]) ;
    }
  }

  /* Check if everything was properly read (useful for debugging) */
  if (0)
  {
    fprintf(stderr, "%d %d %d\n", ngroups, nmols, natoms) ;
    for (g = 1 ; g <= ngroups ; g++)
    {
      fprintf(stderr, "%s  %d\n", group[g].name, group[g].nmols) ;
      for (i = 1 ; i <= group[g].nmols ; i++)
      {
	m = group[g].mol[i] ;
	fprintf(stderr, "  %6d  %6d\n", mol[m].firstatom, mol[m].lastatom) ;
      }
    }
    for (st = 1 ; st <= nstages ; st++)
    {
      fprintf(stderr, "O  %d\n", st) ;
      for (o = nobjs[st-1]+1 ; o <= nobjs[st] ; o++)
	fprintf(stderr, "O      %d\n", obj[o].mol) ;
    }
    fprintf(stderr, "C %d %d %d\n", gcent[0], gcent[1], gcent[2]) ;
    fprintf(stderr, "P %d %d %d\n", gPBC[0], gPBC[1], gPBC[2]) ;
  }

  /* Some final checks for errors or warnings. */
  for (g = 1 ; g <= ngroups ; g++)
    if (group[g].nmols == 0)
      message('E', "Group '%s' is empty.\n", group[g].name) ;
  for (a = 1 ; a <= natoms ; a++)
    if (atom[a].mol == 0)
      message('E', "Atom %d not assigned to a molecule.\n", a) ;
  if (nstages == 0) message('E', "Give at least one A-lines.\n") ;
  if (nClines != 1) message('E', "Give one (and only one) C-line.\n") ;
  if (nPlines != 1) message('E', "Give one (and only one) P-line.\n") ;

  fclose(fp_moldef) ;
}


void write_frame(void)
{
  int a ;

  printf("%s", title) ;  /* already '\n'-terminated */
  printf("%d\n", natoms) ;
  for (a = 1 ; a <= natoms ; a++)
  {
    printf("%5d%-5s%5s%5d%8.3f%8.3f%8.3f%s\n",
	   atom[a].rnumb, atom[a].rname, atom[a].aname, atom[a].anumb,
	   atom[a].r[0], atom[a].r[1], atom[a].r[2], atom[a].vels) ;
  }
  printf("%s", boxline) ;  /* already '\n'-terminated */
}


/* Return the number of a group from its name. */
int gnumb(char groupname[])
{
  int g, gg = -1 ;

  for (g = 0 ; g <= ngroups ; g++)
    if (strcmp(groupname, group[g].name) == 0) gg = g ;
  if (gg == -1)
    message('E', "Group '%s' not defined.\n", groupname) ;
  return gg ;
}


/* Map from physical to scaled coordinates. */
void r_to_s(void)
{
  int a, k, j ;

  for (a = 1 ; a <= natoms ; a++)
    for (k = 0 ; k < 3 ; k++)
    {
      atom[a].s[k] = 0 ;
      for (j = 0 ; j < 3 ; j++)
	atom[a].s[k] += iB[k][j]*(atom[a].r[j]-boxC[j]) ;
    }
}


/* Map from scaled to physical coordinates. */
void s_to_r(void)
{
  int a, k, j ;

  for (a = 1 ; a <= natoms ; a++)
    for (k = 0 ; k < 3 ; k++)
    {
      atom[a].r[k] = boxC[k] ;
      for (j = 0 ; j < 3 ; j++)
	atom[a].r[k] += B[k][j]*atom[a].s[j] ;
    }
}


/* Compute inter-molecular distances as the one between their closest atoms. */
void compute_objdist(void)
{
  int o1, o2, m1, m2, a1, a2, k, j ;
  float d2_min, drMI2, ds[3] = {0,0,0}, dsMI[3] = {0,0,0}, drMI[3] = {0,0,0} ;

  for (o1 = 1 ; o1 < nobjs_tot ; o1++)
  for (o2 = o1+1 ; o2 <= nobjs_tot ; o2++)
  {
    d2_min = BIG_LENGTH * BIG_LENGTH ;
    m1 = obj[o1].mol ;
    m2 = obj[o2].mol ;
    for (a1 = mol[m1].firstatom ; a1 <= mol[m1].lastatom ; a1++)
    for (a2 = mol[m2].firstatom ; a2 <= mol[m2].lastatom ; a2++)
    {
      drMI2 = 0 ;
      for (k = 0 ; k < 3 ; k++)
      {
	ds[k] = atom[a2].s[k] - atom[a1].s[k] ;
	dsMI[k] = ds[k] - nint(ds[k]) ;
	drMI[k] = 0 ;
	for (j = 0 ; j < 3 ; j++) drMI[k] += B[k][j] * dsMI[j] ;
	drMI2 += drMI[k] * drMI[k] ;
      }
      if (drMI2 < d2_min)
      {
	dist[o1][o2].dRMI2 = dist[o2][o1].dRMI2 = d2_min = drMI2 ;
	for (k = 0 ; k < 3 ; k++)
	{
	  dist[o1][o2].dS[k] = ds[k] ;
	  dist[o2][o1].dS[k] = -ds[k] ;
	}
      }
    }
  }
}


void assemble_objs(void)
{
  int st, not_empty, o1, o2, o1_min = -1, o2_min = -1, a, k, m ;
  float d2_min ;

  for (o1 = 1 ; o1 <= nobjs_tot ; o1++) obj[o1].assembled = 0 ;

  /* Use first object as nucleation point (doesn't matter which one is
     used, since it always gives the same assembling, but translated). */
  obj[1].assembled = 1 ;
  for (k = 0 ; k < 3 ; k++) obj[1].dN[k] = 0 ;

  /* Cluster objects by sequential assembling stages. */
  for (st = 1 ; st <= nstages ; st++)
  {
    do
    {
      d2_min = BIG_LENGTH * BIG_LENGTH ;
      not_empty = 0 ;
      for (o1 = nobjs[st-1]+1 ; o1 <= nobjs[st] ; o1++)
      {
	if (obj[o1].assembled == 0)
	{
	  not_empty = 1 ;
	  for (o2 = 1 ; o2 <= nobjs[st] ; o2++)
	  {
	    /* I tested switching the two conditions, but got no speed
	       difference. But may be system-dependent... */
	    if (obj[o2].assembled == 1)
	    {
	      if (dist[o1][o2].dRMI2 < d2_min)
	      {
		d2_min = dist[o1][o2].dRMI2 ;
		o1_min = o1 ;
		o2_min = o2 ;
	      }
	    }
	  }
	}
      }
      if (o1_min != -1 && o2_min != -1)
      {
	obj[o1_min].assembled = 1 ;
	for (k = 0 ; k < 3 ; k++)
	  /* Since o2 is already assembled, its dN was already assigned. */
	  obj[o1_min].dN[k] = obj[o2_min].dN[k] + nint(dist[o1_min][o2_min].dS[k]) ;
      }
    }
    while (not_empty) ;
  }

  /* Apply box displacements from assembling. */
  for (o1 = 1 ; o1 <= nobjs_tot ; o1++)
  {
    m = obj[o1].mol ;
    for (a = mol[m].firstatom ; a <= mol[m].lastatom ; a++)
      for (k = 0 ; k < 3 ; k++) atom[a].s[k] += obj[o1].dN[k] ;
  }
}


void center(void)
{
  int k, g, i, m, a ;
  float smin[3], smax[3], smid ;
  const char* ord[] = { "1st", "2nd", "3rd" } ;
  
  for (k = 0 ; k < 3 ; k++)
  {
    smin[k] = +BIG_LENGTH ;
    smax[k] = -BIG_LENGTH ;
    g = gcent[k] ;
    for (i = 1 ; i <= group[g].nmols ; i++)
    {
      m = group[g].mol[i] ;
      for (a = mol[m].firstatom ; a <= mol[m].lastatom ; a++)
      {
	if (atom[a].s[k] < smin[k]) smin[k] = atom[a].s[k] ;
	if (atom[a].s[k] > smax[k]) smax[k] = atom[a].s[k] ;
      }
    }
    smid = 0.5 * (smin[k] + smax[k]) ;
    for (a = 1 ; a <= natoms ; a++) atom[a].s[k] -= smid ;
    /* Check if the box range was exceeded by the centering group and, if
       it was, give an error or warning depending on centEW[k]. */
    if (smax[k] - smin[k] > 1)
    {
      message(centEW[k],
	      "Group '%s' exceeds range of %s box vector in frame %d.\n",
	      group[gcent[k]].name, ord[k], nframe) ;
    }
  }
}


void apply_PBCs(void)
{
  int k, g, i, m, a, dnCOM ;
  float sCOM ;

  for (k = 0 ; k < 3 ; k++)
  {
    g = gPBC[k] ;
    for (i = 1 ; i <= group[g].nmols ; i++)
    {
      m = group[g].mol[i] ;
      /* Compute k-coordinate of m's "COM" (actually, geometric center). */
      sCOM = 0 ;
      for (a = mol[m].firstatom ; a <= mol[m].lastatom ; a++)
	sCOM += atom[a].s[k] ;
      sCOM /= (mol[m].lastatom - mol[m].firstatom + 1) ;
      /* Move m's atoms to inside the reference box. */
      dnCOM = nint(sCOM) ;
      for (a = mol[m].firstatom ; a <= mol[m].lastatom ; a++)
	atom[a].s[k] -= dnCOM ;
    }
  }
}


/* Fortran-like named. Useful property: nint(-x) = -nint(x). */
int nint(float x)
{
  return floor(x + 0.5) ;
}


void message(char mtype, char *format, ...)
{
  va_list args ;
  const char cmd[] = "fixbox" ;
  const char usage[] = "Usage: %s GRO_FILE MOLDEF_FILE > FIXED_GRO_FILE\n"
    "FixBox version " VERSION "\n" ;

  va_start(args, format) ;
  if (mtype != 'W' && mtype != 'E' && mtype != 'U')
    message('E', "Wrong use of message() function.\n") ;
  if (mtype == 'W') fprintf(stderr, "%s: WARNING: ", cmd) ;
  else fprintf(stderr, "%s: ERROR: ", cmd) ;
  vfprintf(stderr, format, args) ;
  va_end(args) ;
  if (mtype == 'U') fprintf(stderr, usage, cmd) ;
  if (mtype != 'W') exit(1) ;
}

