#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <math.h>
#include <stdarg.h>
#include <ctype.h>
#include "pcma.h"
/*#include "new.h"  JP */

/*
 *   Prototypes
 */
static void create_tree(treeptr ptree, treeptr parent);
static void create_node(treeptr pptr, treeptr parent);
static treeptr insert_node(treeptr pptr);
static void skip_space(FILE *fd);
static treeptr avail(void);
static void set_info(treeptr p, treeptr parent, sint pleaf, char *pname, float pdist);
static treeptr reroot(treeptr ptree, sint nseqs);
static treeptr insert_root(treeptr p, float diff);
static float calc_root_mean(treeptr root, float *maxdist);
static float calc_mean(treeptr nptr, float *maxdist, sint nseqs);
static void order_nodes(void);
static sint calc_weight(sint leaf);
/*
static void group_seqs(treeptr p, sint *next_groups, sint nseqs);
static void mark_group1(treeptr p, sint *groups, sint n);
static void mark_group2(treeptr p, sint *groups, sint n);
static void save_set(sint n, sint *groups);
*/
static void clear_tree_nodes(treeptr p);

/* JP */
static streeptr savail(void);
static void copy_tree(treeptr t1, streeptr t2);
static void copy_content(treeptr t1, streeptr t2);
static void group_seqs(streeptr p, sint *next_groups, sint nseqs);
static void mark_group1(streeptr p, sint *groups, sint n);
static void mark_group2(streeptr p, sint *groups, sint n);
static void save_set(sint n, sint *groups, streeptr p);
extern int am2num_c(int c);


/*
 *   Global variables
 */
extern Boolean interactive;
extern Boolean distance_tree;
extern Boolean usemenu;
extern sint debug;
extern double **tmat;
extern sint **sets;
extern sint nsets;
extern char **names;
extern sint *seq_weight;
extern Boolean no_weights;
extern char **seq_array;
sint *seqlen_array;
extern char *amino_acid_codes;

char ch;
FILE *fd;
treeptr *lptr;
treeptr *olptr;
treeptr *nptr;
treeptr *ptrs;
sint nnodes = 0;
sint ntotal = 0;
Boolean rooted_tree = TRUE;
static treeptr seq_tree,root;
static sint *groups, numseq;
static sint nseqshere;

/* JP */
streeptr sroot;
streeptr *grp_ancestor; /* ancestor nodes for the groups */
streeptr *solptr;
streeptr *groupptr;
sint ngroups;
extern sint gap_pos1, gap_pos2;

void calc_seq_weights(sint first_seq, sint last_seq, sint *sweight)
{
  sint   i, nseqs;
  sint   temp, sum, *weight;


/*
  If there are more than three sequences....
*/
  nseqs = last_seq-first_seq; nseqshere = nseqs;
   if ((nseqs >= 2) && (distance_tree == TRUE) && (no_weights == FALSE))
     {
/*
  Calculate sequence weights based on Phylip tree.
*/
      weight = (sint *)ckalloc((last_seq+1) * sizeof(sint));

      for (i=first_seq; i<last_seq; i++)
           weight[i] = calc_weight(i);

/*
  Normalise the weights, such that the sum of the weights = INT_SCALE_FACTOR
*/

         sum = 0;
         for (i=first_seq; i<last_seq; i++)
            sum += weight[i];

         if (sum == 0)
          {
            for (i=first_seq; i<last_seq; i++)
               weight[i] = 1;
            sum = i;
          }

         for (i=first_seq; i<last_seq; i++)
           {
              sweight[i] = (weight[i] * INT_SCALE_FACTOR) / sum;
              if (sweight[i] < 1) sweight[i] = 1;
           }

       weight=ckfree((void *)weight);

     }

   else
     {
/*
  Otherwise, use identity weights.
*/
        temp = INT_SCALE_FACTOR / nseqs;
        for (i=first_seq; i<last_seq; i++)
           sweight[i] = temp;
     }

}

void create_sets(sint first_seq, sint last_seq)
{
  sint   i, j, nseqs;

  nsets = 0;
  nseqs = last_seq-first_seq;
  /*fprintf(stdout, "\nfirst: %d; last: %d\n", first_seq, last_seq);*/

  /* JP: generate the tree with sequences */
  copy_tree(root, sroot);
  /* for(i=0;i<nseqshere;i++) fprintf(stdout, "%s\n", solptr[i]->name[1]); */
  grp_ancestor = ckalloc((nseqs+1)*sizeof(streeptr *));



  if (nseqs >= 2)
     {
/*
  If there are more than three sequences....
*/

       groups = (sint *)ckalloc((nseqs+1) * sizeof(sint));
       group_seqs(sroot, groups, nseqs);
       groups=ckfree((void *)groups);

     }

   else
     {
       groups = (sint *)ckalloc((nseqs+1) * sizeof(sint));
       for (i=0;i<nseqs-1;i++)
         {
           for (j=0;j<nseqs;j++)
              if (j<=i) groups[j] = 1;
              else if (j==i+1) groups[j] = 2;
              else groups[j] = 0;
           save_set(nseqs, groups, sroot);
         }
       groups=ckfree((void *)groups);
     }

}

sint read_tree(char *treefile, sint first_seq, sint last_seq)
{

  char c;
  char name1[MAXNAMES+1], name2[MAXNAMES+1];
  sint i, j, k;
  Boolean found;

  numseq = 0;
  nnodes = 0;
  ntotal = 0;
  rooted_tree = TRUE;

#ifdef VMS
  if ((fd = fopen(treefile,"r","rat=cr","rfm=var")) == NULL)
#else
  if ((fd = fopen(treefile, "r")) == NULL)
#endif
    {
      error("cannot open %s", treefile);
      return((sint)0);
    }

  skip_space(fd);
  ch = (char)getc(fd);
  if (ch != '(')
    {
      error("Wrong format in tree file %s", treefile);
      return((sint)0);
    }
  rewind(fd);

  distance_tree = TRUE;

/*
  Allocate memory for tree
*/
  nptr = (treeptr *)ckalloc(3*(last_seq-first_seq+1) * sizeof(treeptr));
  ptrs = (treeptr *)ckalloc(3*(last_seq-first_seq+1) * sizeof(treeptr));
  lptr = (treeptr *)ckalloc((last_seq-first_seq+1) * sizeof(treeptr));
  olptr = (treeptr *)ckalloc((last_seq+1) * sizeof(treeptr));
  solptr = (streeptr *)ckalloc((last_seq+1) * sizeof(streeptr));

  seq_tree = avail();
  set_info(seq_tree, NULL, 0, "", 0.0);

  create_tree(seq_tree,NULL);
  fclose(fd);


  if (numseq != last_seq-first_seq)
     {
         error("tree not compatible with alignment\n(%d sequences in alignment and %d in tree", (pint)last_seq-first_seq,(pint)numseq);
         return((sint)0);
     }

/*
  If the tree is unrooted, reroot the tree - ie. minimise the difference
  between the mean root->leaf distances for the left and right branches of
  the tree.
*/

  if (distance_tree == FALSE)
     {
  	if (rooted_tree == FALSE)
          {
       	     error("input tree is unrooted and has no distances.\nCannot align sequences");
             return((sint)0);
          }
     }

  if (rooted_tree == FALSE)
     {
        root = reroot(seq_tree, last_seq-first_seq+1);
     }
  else
     {
        root = seq_tree;
     }

/*
  calculate the 'order' of each node.
*/
  order_nodes();

  if (numseq >= 2)
     {
/*
  If there are more than three sequences....
*/
/*
  assign the sequence nodes (in the same order as in the alignment file)
*/
      for (i=first_seq; i<last_seq; i++)
       {
         if (strlen(names[i+1]) > MAXNAMES)
             warning("name %s is too long for PHYLIP tree format (max %d chars)", names[i+1],MAXNAMES);

         for (k=0; k< strlen(names[i+1]) && k<MAXNAMES ; k++)
           {
             c = names[i+1][k];
             if ((c>0x40) && (c<0x5b)) c=c | 0x20;
             if (c == ' ') c = '_';
             name2[k] = c;
           }
         name2[k]='\0';
         found = FALSE;
         for (j=0; j<numseq; j++)
           {
            for (k=0; k< strlen(lptr[j]->name) && k<MAXNAMES ; k++)
              {
                c = lptr[j]->name[k];
                if ((c>0x40) && (c<0x5b)) c=c | 0x20;
                name1[k] = c;
              }
            name1[k]='\0';
            if (strcmp(name1, name2) == 0)
              {
                olptr[i] = lptr[j];
                found = TRUE;
              }
           }
         if (found == FALSE)
           {
             error("tree not compatible with alignment:\n%s not found", name2);
             return((sint)0);
           }
       }

     }
   return((sint)1);
}

static void create_tree(treeptr ptree, treeptr parent)
{
   treeptr p;

   sint i, type;
   float dist;
   char name[MAXNAMES+1];

/*
  is this a node or a leaf ?
*/
  skip_space(fd);
  ch = (char)getc(fd);
  if (ch == '(')
    {
/*
   this must be a node....
*/
      type = NODE;
      name[0] = '\0';
      ptrs[ntotal] = nptr[nnodes] = ptree;
      nnodes++;
      ntotal++;

      create_node(ptree, parent);

      p = ptree->left;
      create_tree(p, ptree);

      if ( ch == ',')
       {
          p = ptree->right;
          create_tree(p, ptree);
          if ( ch == ',')
            {
               ptree = insert_node(ptree);
               ptrs[ntotal] = nptr[nnodes] = ptree;
               nnodes++;
               ntotal++;
               p = ptree->right;
               create_tree(p, ptree);
               rooted_tree = FALSE;
            }
       }

      skip_space(fd);
      ch = (char)getc(fd);
    }
/*
   ...otherwise, this is a leaf
*/
  else
    {
      type = LEAF;
      ptrs[ntotal++] = lptr[numseq++] = ptree;
/*
   get the sequence name
*/
      name[0] = ch;
      ch = (char)getc(fd);
      i = 1;
      while ((ch != ':') && (ch != ',') && (ch != ')'))
        {
          if (i < MAXNAMES) name[i++] = ch;
          ch = (char)getc(fd);
        }
      name[i] = '\0';
      if (ch != ':')
         {
           distance_tree = FALSE;
           dist = 0.0;
         }
    }

/*
   get the distance information
*/
  dist = 0.0;
  if (ch == ':')
     {
       skip_space(fd);
       fscanf(fd,"%f",&dist);
       skip_space(fd);
       ch = (char)getc(fd);
     }
   set_info(ptree, parent, type, name, dist);


}

static void create_node(treeptr pptr, treeptr parent)
{
  treeptr t;

  pptr->parent = parent;
  t = avail();
  pptr->left = t;
  t = avail();
  pptr->right = t;

}

static treeptr insert_node(treeptr pptr)
{

   treeptr newnode;

   newnode = avail();
   create_node(newnode, pptr->parent);

   newnode->left = pptr;
   pptr->parent = newnode;

   set_info(newnode, pptr->parent, NODE, "", 0.0);

   return(newnode);
}

static void skip_space(FILE *fd)
{
  int   c;

  do
     c = getc(fd);
  while(isspace(c));

  ungetc(c, fd);
}

static treeptr avail(void)
{
   treeptr p;
   p = ckalloc(sizeof(stree));
   p->left = NULL;
   p->right = NULL;
   p->parent = NULL;
   p->dist = 0.0;
   p->leaf = 0;
   p->order = 0;
   p->name[0] = '\0';
   return(p);
}


/* JP */
static streeptr savail(void)
{
	streeptr p;
	p = ckalloc(sizeof(sstree));
	p->left = NULL;
	p->right = NULL;
	p->parent = NULL;
	p->dist = 0.0;
	p->leaf = 0;
	p->order = 0;
	p->name = NULL;
	/* p->name[0]= '\0'; */
	p->seq = NULL;
	p->seqnum = 0;
	p->abstractseq = NULL;
	p->abseqnum = 0;
	p->abseqlength = 0;
	return(p);
}



void clear_tree(treeptr p)
{
   clear_tree_nodes(p);

   nptr=ckfree((void *)nptr);
   ptrs=ckfree((void *)ptrs);
   lptr=ckfree((void *)lptr);
   olptr=ckfree((void *)olptr);
}

static void clear_tree_nodes(treeptr p)
{
   if (p==NULL) p = root;
   if (p->left != NULL)
     {
       clear_tree_nodes(p->left);
     }
   if (p->right != NULL)
     {
       clear_tree_nodes(p->right);
     }
   p->left = NULL;
   p->right = NULL;
   p=ckfree((void *)p);
}

static void set_info(treeptr p, treeptr parent, sint pleaf, char *pname, float pdist)
{
   p->parent = parent;
   p->leaf = pleaf;
   p->dist = pdist;
   p->order = 0;
   strcpy(p->name, pname);
   if (p->leaf == TRUE)
     {
        p->left = NULL;
        p->right = NULL;
     }
}

static treeptr reroot(treeptr ptree, sint nseqs)
{

   treeptr p, rootnode, rootptr;
   float   diff, mindiff = 0.0, mindepth = 1.0, maxdist;
   sint   i;
   Boolean first = TRUE;

/*
  find the difference between the means of leaf->node
  distances on the left and on the right of each node
*/
   rootptr = ptree;
   for (i=0; i<ntotal; i++)
     {
        p = ptrs[i];
        if (p->parent == NULL)
           diff = calc_root_mean(p, &maxdist);
        else
           diff = calc_mean(p, &maxdist, nseqs);

        if ((diff == 0) || ((diff > 0) && (diff < 2 * p->dist)))
          {
              if ((maxdist < mindepth) || (first == TRUE))
                 {
                    first = FALSE;
                    rootptr = p;
                    mindepth = maxdist;
                    mindiff = diff;
                 }
           }

     }

/*
  insert a new node as the ancestor of the node which produces the shallowest
  tree.
*/
   if (rootptr == ptree)
     {
        mindiff = rootptr->left->dist + rootptr->right->dist;
        rootptr = rootptr->right;
     }
   rootnode = insert_root(rootptr, mindiff);

   diff = calc_root_mean(rootnode, &maxdist);

   return(rootnode);
}

static treeptr insert_root(treeptr p, float diff)
{
   treeptr newp, prev, q, t;
   float dist, prevdist,td;

   newp = avail();

   t = p->parent;
   prevdist = t->dist;

   p->parent = newp;

   dist = p->dist;

   p->dist = diff / 2;
   if (p->dist < 0.0) p->dist = 0.0;
   if (p->dist > dist) p->dist = dist;

   t->dist = dist - p->dist;

   newp->left = t;
   newp->right = p;
   newp->parent = NULL;
   newp->dist = 0.0;
   newp->leaf = NODE;

   if (t->left == p) t->left = t->parent;
   else t->right = t->parent;

   prev = t;
   q = t->parent;

   t->parent = newp;

   while (q != NULL)
     {
        if (q->left == prev)
           {
              q->left = q->parent;
              q->parent = prev;
              td = q->dist;
              q->dist = prevdist;
              prevdist = td;
              prev = q;
              q = q->left;
           }
        else
           {
              q->right = q->parent;
              q->parent = prev;
              td = q->dist;
              q->dist = prevdist;
              prevdist = td;
              prev = q;
              q = q->right;
           }
    }

/*
   remove the old root node
*/
   q = prev;
   if (q->left == NULL)
      {
         dist = q->dist;
         q = q->right;
         q->dist += dist;
         q->parent = prev->parent;
         if (prev->parent->left == prev)
            prev->parent->left = q;
         else
            prev->parent->right = q;
         prev->right = NULL;
      }
   else
      {
         dist = q->dist;
         q = q->left;
         q->dist += dist;
         q->parent = prev->parent;
         if (prev->parent->left == prev)
            prev->parent->left = q;
         else
            prev->parent->right = q;
         prev->left = NULL;
      }

   return(newp);
}

static float calc_root_mean(treeptr root, float *maxdist)
{
   float dist , lsum = 0.0, rsum = 0.0, lmean,rmean,diff;
   treeptr p;
   sint i;
   sint nl, nr;
   sint direction;
/*
   for each leaf, determine whether the leaf is left or right of the root.
*/
   dist = (*maxdist) = 0;
   nl = nr = 0;
   for (i=0; i< numseq; i++)
     {
         p = lptr[i];
         dist = 0.0;
         while (p->parent != root)
           {
               dist += p->dist;
               p = p->parent;
           }
         if (p == root->left) direction = LEFT;
         else direction = RIGHT;
         dist += p->dist;

         if (direction == LEFT)
           {
             lsum += dist;
             nl++;
           }
         else
           {
             rsum += dist;
             nr++;
           }
        if (dist > (*maxdist)) *maxdist = dist;
     }

   lmean = lsum / nl;
   rmean = rsum / nr;

   diff = lmean - rmean;
   return(diff);
}


static float calc_mean(treeptr nptr, float *maxdist, sint nseqs)
{
   float dist , lsum = 0.0, rsum = 0.0, lmean,rmean,diff;
   treeptr p, *path2root;
   float *dist2node;
   sint depth = 0, i,j , n = 0;
   sint nl , nr;
   sint direction, found;

	path2root = (treeptr *)ckalloc(nseqs * sizeof(treeptr));
	dist2node = (float *)ckalloc(nseqs * sizeof(float));
/*
   determine all nodes between the selected node and the root;
*/
   depth = (*maxdist) = dist = 0;
   nl = nr = 0;
   p = nptr;
   while (p != NULL)
     {
         path2root[depth] = p;
         dist += p->dist;
         dist2node[depth] = dist;
         p = p->parent;
         depth++;
     }

/*
   *nl = *nr = 0;
   for each leaf, determine whether the leaf is left or right of the node.
   (RIGHT = descendant, LEFT = not descendant)
*/
   for (i=0; i< numseq; i++)
     {
       p = lptr[i];
       if (p == nptr)
         {
            direction = RIGHT;
            dist = 0.0;
         }
       else
         {
         direction = LEFT;
         dist = 0.0;
/*
   find the common ancestor.
*/
         found = FALSE;
         n = 0;
         while ((found == FALSE) && (p->parent != NULL))
           {
               for (j=0; j< depth; j++)
                 if (p->parent == path2root[j])
                    {
                      found = TRUE;
                      n = j;
                    }
               dist += p->dist;
               p = p->parent;
           }
         if (p == nptr) direction = RIGHT;
         }

         if (direction == LEFT)
           {
             lsum += dist;
             lsum += dist2node[n-1];
             nl++;
           }
         else
           {
             rsum += dist;
             nr++;
           }

        if (dist > (*maxdist)) *maxdist = dist;
     }

	dist2node=ckfree((void *)dist2node);
	path2root=ckfree((void *)path2root);

   lmean = lsum / nl;
   rmean = rsum / nr;

   diff = lmean - rmean;
   return(diff);
}

static void order_nodes(void)
{
   sint i;
   treeptr p;

   for (i=0; i<numseq; i++)
     {
        p = lptr[i];
        while (p != NULL)
          {
             p->order++;
             p = p->parent;
          }
     }
}


static sint calc_weight(sint leaf)
{

  treeptr p;
  float weight = 0.0;

  p = olptr[leaf];
  while (p->parent != NULL)
    {
       weight += p->dist / p->order;
       p = p->parent;
    }

  weight *= 100.0;

  return((sint)weight);

}

/* JP change a little bit: treeptr -> streeptr
static void group_seqs(treeptr p, sint *next_groups, sint nseqs)
*/

static void group_seqs(streeptr p, sint *next_groups, sint nseqs)
{
    sint i;
    sint *tmp_groups;


    tmp_groups = (sint *)ckalloc((nseqs+1) * sizeof(sint));
    for (i=0;i<nseqs;i++)
         tmp_groups[i] = 0;

    if (p->left != NULL)
      {
         if (p->left->leaf == NODE)
            {
               group_seqs(p->left, next_groups, nseqs);
               for (i=0;i<nseqs;i++)
                 if (next_groups[i] != 0) tmp_groups[i] = 1;
            }
         else
            {
               mark_group1(p->left, tmp_groups, nseqs);
            }

      }

    if (p->right != NULL)
      {
         if (p->right->leaf == NODE)
            {
               group_seqs(p->right, next_groups, nseqs);
               for (i=0;i<nseqs;i++)
                    if (next_groups[i] != 0) tmp_groups[i] = 2;
            }
         else
            {
               mark_group2(p->right, tmp_groups, nseqs);
            }
         save_set(nseqs, tmp_groups, p);
      }
    for (i=0;i<nseqs;i++)
      next_groups[i] = tmp_groups[i];

    tmp_groups=ckfree((void *)tmp_groups);

}

/* JP: change a little bit: treeptr -> streeptr
static void mark_group1(treeptr p, sint *groups, sint n)
*/
static void mark_group1(streeptr p, sint *groups, sint n)
{
    sint i;

	/* JP
	fprintf(stdout, "olptr\n");
    for(i=0;i<n;i++) {
		fprintf(stdout, "%s\n", olptr[i]->name);
	}
	 JP */
    for (i=0;i<n;i++)
       {
         if (solptr[i] == p)
              groups[i] = 1;
         else
              groups[i] = 0;
       }
}

/* JP: change a little bit: treeptr -> streeptr
static void mark_group2(treeptr p, sint *groups, sint n)
*/
static void mark_group2(streeptr p, sint *groups, sint n)
{
    sint i;

    for (i=0;i<n;i++)
       {
         if (solptr[i] == p)
              groups[i] = 2;
         else if (groups[i] != 0)
              groups[i] = 1;
       }
}

/* JP : adding a parameter of the tree node
static void save_set(sint n, sint *groups)
*/
static void save_set(sint n, sint *groups, streeptr p)
{
    sint i;

    for (i=0;i<n;i++)
      sets[nsets+1][i+1] = groups[i];

    /* JP */
    grp_ancestor[nsets+1] = p;
    /*if(p->left->seq==NULL) fprintf(stdout, "null pointers \n");*/
    nsets++;
}



sint calc_similarities(sint nseqs)
{
   sint depth = 0, i,j, k, n;
   sint found;
   sint nerrs, *seq1,*seq2;
   treeptr p, *path2root;
   float dist;
   float *dist2node, *bad_dist;
   double **dmat;
   char err_mess[1024],err1[MAXLINE],reply[MAXLINE];

   path2root = (treeptr *)ckalloc((nseqs) * sizeof(treeptr));
   dist2node = (float *)ckalloc((nseqs) * sizeof(float));
   dmat = (double **)ckalloc((nseqs) * sizeof(double *));
   for (i=0;i<nseqs;i++)
     dmat[i] = (double *)ckalloc((nseqs) * sizeof(double));
   seq1 = (sint *)ckalloc((nseqs) * sizeof(sint));
   seq2 = (sint *)ckalloc((nseqs) * sizeof(sint));
   bad_dist = (float *)ckalloc((nseqs) * sizeof(float));

   if (nseqs >= 2)
    {
/*
   for each leaf, determine all nodes between the leaf and the root;
*/
      for (i = 0;i<nseqs; i++)
       {
          depth = dist = 0;
          p = olptr[i];
          while (p != NULL)
            {
                path2root[depth] = p;
                dist += p->dist;
                dist2node[depth] = dist;
                p = p->parent;
                depth++;
            }

/*
   for each pair....
*/
          for (j=0; j < i; j++)
            {
              p = olptr[j];
              dist = 0.0;
/*
   find the common ancestor.
*/
              found = FALSE;
              n = 0;
              while ((found == FALSE) && (p->parent != NULL))
                {
                    for (k=0; k< depth; k++)
                      if (p->parent == path2root[k])
                         {
                           found = TRUE;
                           n = k;
                         }
                    dist += p->dist;
                    p = p->parent;
                }

              dmat[i][j] = dist + dist2node[n-1];
            }
        }

		nerrs = 0;
        for (i=0;i<nseqs;i++)
          {
             dmat[i][i] = 0.0;
             for (j=0;j<i;j++)
               {
                  if (dmat[i][j] < 0.01) dmat[i][j] = 0.01;
                  if (dmat[i][j] > 1.0) {
                  	if (dmat[i][j] > 1.1) {
                  		seq1[nerrs] = i;
                  		seq2[nerrs] = j;
                  		bad_dist[nerrs] = dmat[i][j];
                  		nerrs++;
                  	}
                    dmat[i][j] = 1.0;
                  }
               }
          }
        if (nerrs>0)
          {
             strcpy(err_mess,"The following sequences are too divergent to be aligned:\n");
             for (i=0;i<nerrs && i<5;i++)
              {
             	sprintf(err1,"           %s and %s (distance %1.3f)\n",
             	                        names[seq1[i]+1],
					names[seq2[i]+1],bad_dist[i]);
             	strcat(err_mess,err1);
              }
	     strcat(err_mess,"(All distances should be between 0.0 and 1.0)\n");
	     strcat(err_mess,"This may not be fatal but you have been warned!\n");
             strcat(err_mess,"SUGGESTION: Remove one or more problem sequences and try again");
             if(interactive)
             	    (*reply)=prompt_for_yes_no(err_mess,"Continue ");
             else (*reply) = 'y';
             if ((*reply != 'y') && (*reply != 'Y'))
                    return((sint)0);
          }
     }
   else
     {
        for (i=0;i<nseqs;i++)
          {
             for (j=0;j<i;j++)
               {
                  dmat[i][j] = tmat[i+1][j+1];
               }
          }
     }

   path2root=ckfree((void *)path2root);
   dist2node=ckfree((void *)dist2node);
   for (i=0;i<nseqs;i++)
     {
        tmat[i+1][i+1] = 0.0;
        for (j=0;j<i;j++)
          {
             tmat[i+1][j+1] = 100.0 - (dmat[i][j]) * 100.0;
             tmat[j+1][i+1] = tmat[i+1][j+1];
          }
     }

   for (i=0;i<nseqs;i++) dmat[i]=ckfree((void *)dmat[i]);
   dmat=ckfree((void *)dmat);

   seq1=ckfree((void *)seq1);
   seq2=ckfree((void *)seq2);
   bad_dist=ckfree((void *)bad_dist);
   return((sint)1);
}

/* JP */
static void copy_tree(treeptr t1, streeptr t2)
{
	streeptr p, p1;
	treeptr  q;

	/*fprintf(stdout, "start copying tree\n"); */
	if(t1==root){
		t2 = savail(); sroot = t2;
		/* fprintf(stdout, "savail \n"); */
		copy_content(t1, t2);

		/*solptr = (streeptr *)ckalloc((nseqs+1) * sizeof(streeptr)); */
	}


	if(t1->left==NULL) {
	   /* fprintf(stdout, "%s\n", t2->name[1]); */
	   return;
	}
	t2->left = savail();
	t2->right = savail();
	copy_content(t1->left, t2->left);
	copy_content(t1->right, t2->right);
	/*if(t2==sroot) fprintf(stdout, "=======%s+++++\n", t2->left->name[1]);*/
	copy_tree(t1->left, t2->left);
	copy_tree(t1->right, t2->right);

}

/* JP: for_align_list */
extern int seqFormat;
extern int *seqnumlist;
extern int filecount;
extern int *seqlen_array_all;
extern char **seq_array_all; /* for all the sequences */
extern char **names_all;
extern char *am;

static void copy_content(treeptr t1, streeptr t2)
{
	sint i,j,k;

	t2->dist = t1->dist;
	t2->leaf = t1->leaf;
	t2->order = t1->order;

	/*fprintf(stdout, "%2.1f %d %d\n", t2->dist, t2->leaf, t2->order); */
	if(t1->leaf) {

	    /* JP: for_align_list */
	    if(seqFormat!=CLUSTALIST) {
		t2->name = ckalloc(2*sizeof(char *));
		t2->name[1] = ckalloc(100* sizeof(char ));
		strcpy(t2->name[1], t1->name);
		/*fprintf(stdout, "%s ", t2->name[1]);*/
		t2->seqnum = 1;

		/*fprintf(stdout, "nseqshere %d \n", nseqshere);*/
		for(i=0;i<nseqshere;i++) {
		   if(olptr[i]==t1) {/*fprintf(stdout,"i: %d----------\n",i);*/
			   solptr[i] = t2;
		   	   t2->seq = ckalloc(2*sizeof(int *));
		   	   t2->seqlength = seqlen_array[i+1];
			   /*fprintf(stdout, "length: %d\n", seqlen_array[i+1]);*/
		   	   t2->seq[1] = ckalloc((seqlen_array[i+1]+1)*sizeof(int));
		   	   for(j=1;j<=seqlen_array[i+1];j++)
		      	   {
				t2->seq[1][j] = am2num(amino_acid_codes[seq_array[i+1][j]]);
		   		/*fprintf(stdout, "%c", am[t2->seq[1][j]]);*/
				if(debug>1) fprintf(stdout, "%d ", am2num(amino_acid_codes[seq_array[i+1][j]] ));
			   }
			   if(debug>1)fprintf(stdout, "\n");
	   		}
		}
		/*fprintf(stdout, "+++\n");*/
	     }
	     /* JP: for_align_list */
	     else {
		/* find the name that matches t1->name */
		for(i=1;i<=nseqshere;i++) {
		    if(strcmp(names[i], t1->name)==0)  break;
		}

		t2->name = ckalloc((seqnumlist[i]+1)*sizeof(char *));
		for(j=1;j<=seqnumlist[i];j++) t2->name[j] = ckalloc(100* sizeof(char ));
		t2->seqnum = seqnumlist[i];
		t2->seqlength = seqlen_array[i];
		t2->seq = ckalloc((seqnumlist[i]+1)*sizeof(int *));
		for(j=1;j<=seqnumlist[i];j++) t2->seq[j] = ckalloc((t2->seqlength+1)*sizeof(int));
		/* find the starting sequence number in the seq_array_all list */
		int tmpcount = 0;
		for(j=1;j<=i-1;j++) tmpcount+= seqnumlist[j]; 
		for(j=1;j<=seqnumlist[i];j++) {
		    strcpy(t2->name[j], names_all[tmpcount+j]);
		    for(k=1;k<=t2->seqlength;k++) {
			/* pay specially attention to gaps in input sequences */
			if(seq_array_all[tmpcount+j][k]==gap_pos2) t2->seq[j][k] = 0;
			else t2->seq[j][k] = am2num(amino_acid_codes[seq_array_all[tmpcount+j][k]]);
		    }
		}
		/* test the content: output the sequences */
		if(debug>11) {
		   fprintf(stdout, "group: %d\n", i);
		   for(j=1;j<=seqnumlist[i];j++) {
			fprintf(stdout, ">%s\n", t2->name[j]);
			for(k=1;k<=t2->seqlength;k++) {
			    fprintf(stdout, "%c", amino_acid_codes[seq_array_all[tmpcount+j][k]]);
			   /*fprintf(stdout, "%d%c%d ", seq_array_all[tmpcount+j][k], amino_acid_codes[seq_array_all[tmpcount+j][k]], t2->seq[j][k]);*/
			}
			fprintf(stdout, "\n");
		   }
		   fprintf(stdout, "\n");
		}

		for(i=0;i<nseqshere;i++) {
		   if(olptr[i]==t1) {
			solptr[i] = t2; break;
		   }
		}

	     }
	     
	}
}

/* JP */
double average_group_identity(sint *group)
{
	sint i,j;
	sint count=0;
	double sum = 0;

	/*fprintf(stdout, "--------\n");*/
	for(i=0;i<nseqshere;i++) {
	  if(group[i+1]==1) {
		for(j=0;j<nseqshere;j++) {
			if(group[j+1]==2) {
				sum+=tmat[i+1][j+1]; count++;
			}
		}
      }
 	}

 	sum = sum/count; if(debug>1)fprintf(stdout, "sum: %5.3f\n", sum);
 	return sum;
}

/* JP */
void assign_node(streeptr p, sint *aligned)
{
	sint i,j;
	sint count = 0;
	sint length = 0;

	if(p->seq) {
		fprintf(stdout, "assign nodes: sequences already exist\n");
		exit(0);
	}

	for(i=1;i<=nseqshere;i++) {
		if(aligned[i]) {count++;
		   if(seqlen_array[i]> length) length = seqlen_array[i];
	   }
	}
	p->seqlength = length;
	p->seqnum = count;

	p->seq = ckalloc( (count+1) * sizeof( sint *));
	for(i=1;i<=count;i++) {
		p->seq[i]= ckalloc((length+1) * sizeof(sint ));
	}
	p->name = ckalloc( (count+1) *sizeof( char *));
	for(i=1;i<=count;i++) {
		p->name[i]=ckalloc(54*sizeof(char));
	}
	count = 0;
	for(i=1;i<=nseqshere;i++) {
		if(aligned[i]) {
			count++;
			strcpy(p->name[count], names[i]);
			for(j=1;j<=seqlen_array[i];j++) {
				if( (seq_array[i][j] == gap_pos1) || (seq_array[i][j]==gap_pos2) ) {
				       p->seq[count][j] = 0;
				}
				else {
				       p->seq[count][j] = am2num(amino_acid_codes[seq_array[i][j]]);
        			}
			}
		}
	}
}
