#include "expectedAlignmentDistance.h"

double expectedAlignmentDistance_formal(
 string S,
 string T,
 nullModelClass_t &nullMdl,
 SubmatClass_t    &submat,
 vector<string> &ref_algn,
 vector<vector<vector<dpcell_t> > > &dpmat, 
 int markov_time
) {
    vector<int> diagsOfRefAlgn = computeDiagsFromAlgn(ref_algn);
    double normalized_ead = DPExpAlgnDist(S,T,nullMdl,submat,markov_time,
                                           dpmat,diagsOfRefAlgn);
    return normalized_ead;
}

double expectedAlignmentDistance_sampled(
 string S,
 string T,
 vector<string> &ref_algn,
 vector<vector<vector<dpcell_t> > > &dpmat,
 int nSamples
) {
    vector<int> diagsOfRefAlgn = computeDiagsFromAlgn(ref_algn);
    int cntr = 0;
    srand(time(0));
    double ead_sampled = 0.0; 
    while (cntr++ < nSamples) {
        vector<string> sampled_algn = sampleAlignment(S,T,dpmat);
        vector<int> diagsOfSampledAlgn = computeDiagsFromAlgn(sampled_algn);
        ead_sampled += computeTwoAlgnDist(diagsOfSampledAlgn,diagsOfRefAlgn);
    }
    ead_sampled/= (double)nSamples;
    //normalize
    ead_sampled/= (S.length()+T.length());
    return ead_sampled;
}

double expectedAlignmentDistance_optimal(
 string S,
 string T,
 vector<string> &ref_algn,
 vector<string> &opt_algn
) {
    vector<int> diagsOfRefAlgn = computeDiagsFromAlgn(ref_algn);
    vector<int> diagsOfSampledAlgn = computeDiagsFromAlgn(opt_algn);

    double ead_sampled = computeTwoAlgnDist(diagsOfSampledAlgn,diagsOfRefAlgn);
    
    //normalize
    ead_sampled/= (S.length()+T.length());
    return ead_sampled;
}


double computeTwoAlgnDist(vector<int> &A, vector<int> &B) {
    assert(A.size() == B.size());
    double d = 0;
    for (int i = 0; i < A.size(); i++) d += fabs(A[i]-B[i]);
    return d;
}

int rouletteForState(double prob[3]) {
    double u=(double)rand()/((double)RAND_MAX);
    double prranges[4];
    prranges[0]=0;
    prranges[1]=prranges[0]+prob[0];
    prranges[2]=prranges[1]+prob[1];
    prranges[3]=prranges[2]+prob[2];

    assert (u > prranges[0] && u <= prranges[3]);
    int val;
    if (u>prranges[0] && u <= prranges[1]) val=0;
    else if (u>prranges[1] && u <= prranges[2]) val=1;
    else val=2;
    return val;
}


vector<string>  sampleAlignment(
 string S,
 string T,
 vector<vector<vector<dpcell_t> > > &dpmat
) {
    int nRows = dpmat[0].size();
    int nCols = dpmat[0][0].size();
    //init to sink
    int r = nRows-1, c = nCols-1;
    double prob[3];
    vector<pair<double,double> > vals;
    vals.push_back(make_pair(NLTOTM(r,c),1));
    vals.push_back(make_pair(NLTOTD(r,c),1));
    vals.push_back(make_pair(NLTOTI(r,c),1));
    double nlsum = NLSE_w_factors(vals);
    prob[0] = exp(nlsum-NLTOTM(r,c));
    prob[1] = exp(nlsum-NLTOTD(r,c));
    prob[2] = exp(nlsum-NLTOTI(r,c));

    int state = rouletteForState(prob);
    vector<string> algn(2,"");
    while (1) {
        if (state==0) {
            algn[0].push_back(S[r-1]);
            algn[1].push_back(T[c-1]);
            r--, c--;
        }
        else if(state==1) {
            algn[0].push_back(S[r-1]);
            algn[1].push_back('-');
            r--;
        }
        else if (state==2) {
             algn[0].push_back('-');
             algn[1].push_back(T[c-1]);
             c--;
        }
        if (r == 0 && c== 0) break;
        vector<pair<double,double> > newvals;
        newvals.push_back(make_pair(NLTOTM(r,c),1));
        newvals.push_back(make_pair(NLTOTD(r,c),1));
        newvals.push_back(make_pair(NLTOTI(r,c),1));
        double newnlsum = NLSE_w_factors(newvals);
        prob[0] = exp(newnlsum-NLTOTM(r,c));
        prob[1] = exp(newnlsum-NLTOTD(r,c));
        prob[2] = exp(newnlsum-NLTOTI(r,c));
        state = rouletteForState(prob);
    }
    reverse(algn[0].begin(), algn[0].end());
    reverse(algn[1].begin(), algn[1].end());
    return algn;
}

vector<vector<int> > computeCellPathFromAlgn(vector<string> &algn) {
    assert(algn.size()==2);
    assert(algn[0].length()==algn[1].length());
    int algnlen = algn[0].length();

    vector<vector<int> > cellpath;
    vector<int> srccell(2,0); // (0,0)
    cellpath.push_back(srccell);
    int sidx = 0, tidx = 0;
    for (int i = 0; i < algnlen; i++) {
        if (algn[0][i] != '-' && algn[1][i] != '-') { sidx++; tidx++;}
        else if (algn[0][i] == '-') tidx++;
        else if (algn[1][i] == '-') sidx++;
        else assert(1); 
        vector<int> cell;
        cell.push_back(sidx);
        cell.push_back(tidx);
        cellpath.push_back(cell);
    }

    return cellpath;
}

vector<int> computeDiagsFromAlgn(vector<string> &algn) {
    vector<vector<int> > cellpath  = computeCellPathFromAlgn(algn);
    int nCrossDiags = cellpath[cellpath.size()-1][0] 
                      + cellpath[cellpath.size()-1][1] + 1; 
    vector<int> locus(nCrossDiags,numeric_limits<int>::max());
    int prevcrossdiag;
    int crossdiagloc, diagloc;
    //init 
    crossdiagloc = diagloc = 0;
    locus[0] = 0;
    prevcrossdiag = 0;
    for (int i = 0;  i< cellpath.size(); i++) {
        int crossdiag = cellpath[i][0]+cellpath[i][1]; //n.b. crossdiagid = rowid+colid
        int diag = cellpath[i][0]-cellpath[i][1];      //n.b. diaid = rowid-colid
        locus[crossdiag] = diag;
    }
    //fill up holes
    for (int i = 1; i < locus.size(); i++) {
        if (locus[i] == numeric_limits<int>::max()) {
            locus[i] = locus[i-1];
        }
    }
    return locus;
}

map<string,double> getFSAParams(
 SubmatClass_t &submat,
 int markov_time
) {
   map<string,double> fsaPr;
   double prMM = submat.getprMM(markov_time);
   double prMI = 0.5*(1.0-prMM);
   double prMD = 0.5*(1.0-prMM);
   double prII = submat.getprII(markov_time);
   double prDD = prII;
   double prIM = submat.getprIM(markov_time);
   double prDM = prIM;
   double prID = 1.0-prII-prIM;
   double prDI = prID;
   fsaPr.insert(pair<string,double>("MM",prMM));
   fsaPr.insert(pair<string,double>("MI",prMI));
   fsaPr.insert(pair<string,double>("MD",prMD));
   fsaPr.insert(pair<string,double>("IM",prIM));
   fsaPr.insert(pair<string,double>("II",prII));
   fsaPr.insert(pair<string,double>("ID",prID));
   fsaPr.insert(pair<string,double>("DM",prDM));
   fsaPr.insert(pair<string,double>("DI",prDI));
   fsaPr.insert(pair<string,double>("DD",prDD));
   return fsaPr;
}

/* NNLSE_w_factors = -log( sum exp(-Vi)*fi) )
 * computed as 
 * V1 -log(f1) -log1p(exp(V1-V2)*f2/f1 + exp(V1-V3)*f3/f1 + ...) */
double NLSE_w_factors(vector<pair<double,double> > v) {
  assert(v.size() >= 2);
  sort(v.begin(),v.end());
  double tmp=0;
  for (int i=1; i < v.size(); i++) { 
    tmp += ((exp(v[0].first-v[i].first))*(v[i].second/v[0].second));
  }
  return v[0].first -log(v[0].second) -log1p(tmp);
}

void initialize(
 vector<vector<vector<double> > > &nlead,
 int nMatrices, int nRows, int nCols
) {
    vector<vector<double> > tmpmat;
    for (int i = 0;  i < nRows; i++) {
        tmpmat.push_back(vector<double>(nCols,numeric_limits<double>::infinity()));
    }
    for (int i = 0;  i < nMatrices; i++)
        nlead.push_back(tmpmat);

    nlead[0][0][0] = nlead[1][0][0] = nlead[2][0][0] = 0;
}

double DPExpAlgnDist(
 string &S, 
 string &T,
 nullModelClass_t &nullMdl,
 SubmatClass_t &submat,
 int markov_time,
 vector<vector<vector<dpcell_t> > > &dpmat,
 vector<int> &diagsOfRefAlgn
) {
    size_t nRows = S.length()+1;
    size_t nCols = T.length()+1;

    size_t order = nullMdl.getOrder();
    size_t alphabetSize = nullMdl.getAlphabetSize();
    assert(order==0 && alphabetSize == 20);

    map<string,double> fsaPr = getFSAParams(submat, markov_time);

    //neg log exp algn distance (nlead) matrices
    vector<vector<vector<double> > > nlead;
    initialize(nlead,dpmat.size(),dpmat[0].size(),dpmat[0][0].size());

    double pr_null_Si, pr_null_Tj, pr_match_SiTj, pr_fsa_transition, dist2ref;
    double factor1, factor2, factor3;
    int diagOfCurrCell, crossdiagOfCurrCell;
    // Boundary conditions:
    // n.b. [0] (M=> Si vs Tj) [1] (D=> Si vs -) [2] (I=> - vs Tj) 
    //fill first col of ead[1] (delete state)
    for (size_t r = 1; r < nRows; r++) {
      pr_null_Si = nullMdl.pr(S.substr(r-1,1));
      if (r == 1) pr_fsa_transition = (double)1/(double)3; 
      else        pr_fsa_transition = fsaPr["DD"];
      double factor = pr_fsa_transition*pr_null_Si;

      diagOfCurrCell = r-0;
      crossdiagOfCurrCell = r+0;
      dist2ref       = fabs(diagOfCurrCell-diagsOfRefAlgn[crossdiagOfCurrCell]) ; 

      vector<pair<double,double> > vals;
      vals.push_back(make_pair(NLEADD((r-1),0),factor));
      vals.push_back(make_pair(NLTOTD(r,0),dist2ref));
      NLEADD(r,0) = NLSE_w_factors(vals);
    }

    //fill first row of ead[2] (insert state)
    for (size_t c = 1; c < nCols; c++) {
      pr_null_Tj = nullMdl.pr(T.substr(c-1,1));
      if (c == 1)     pr_fsa_transition = (double)1/(double)3; 
      else            pr_fsa_transition = fsaPr["II"];
      double factor = pr_fsa_transition*pr_null_Tj;

      diagOfCurrCell = 0-c;
      crossdiagOfCurrCell = 0+c;
      dist2ref       = fabs(diagOfCurrCell-diagsOfRefAlgn[crossdiagOfCurrCell]); 

      vector<pair<double,double> > vals;
      vals.push_back(make_pair(NLEADI(0,c-1),factor));
      vals.push_back(make_pair(NLTOTI(0,c),dist2ref));
      NLEADI(0,c) = NLSE_w_factors(vals);
    }

    // Filling the rest
    for (size_t r = 1; r < nRows; r++) {
      pr_null_Si = nullMdl.pr(S.substr(r-1,1));
      for (size_t c = 1; c < nCols; c++) {
         pr_null_Tj = nullMdl.pr(T.substr(c-1,1));
         pr_match_SiTj = submat.jointpr(S[r-1],T[c-1],markov_time);

         diagOfCurrCell = r-c;
         crossdiagOfCurrCell = r+c;
         dist2ref = fabs(diagOfCurrCell-diagsOfRefAlgn[crossdiagOfCurrCell]);  
         // RECURRENCES GOING INTO M STATE 
         //n.b. ALSO account for the cross-diagonal skipped during M move.
         double misseddist = fabs(diagOfCurrCell-diagsOfRefAlgn[crossdiagOfCurrCell-1]);
         vector<pair<double,double> > valsM;
         if (c==1 && r==1) factor1 = factor2 = factor3 
                                   = ((double)1/(double)3)*pr_match_SiTj;
         else {
            factor1 = fsaPr["MM"] * pr_match_SiTj; //M->M
            factor2 = fsaPr["DM"] * pr_match_SiTj; //D->M
            factor3 = fsaPr["IM"] * pr_match_SiTj; //I->M
         }
         valsM.push_back(make_pair(NLEADM(r-1,c-1),factor1));
         valsM.push_back(make_pair(NLEADD(r-1,c-1),factor2));
         valsM.push_back(make_pair(NLEADI(r-1,c-1),factor3));
         valsM.push_back(make_pair(NLTOTM(r  ,c  ),dist2ref+misseddist));
         NLEADM(r,c) = NLSE_w_factors(valsM);
         // RECURRENCES GOING INTO D STATE 
         vector<pair<double,double> > valsD;
         factor1 = fsaPr["MD"]*pr_null_Si; // M->D
         factor2 = fsaPr["DD"]*pr_null_Si; // D->D
         factor3 = fsaPr["ID"]*pr_null_Si; // I->D
         valsD.push_back(make_pair(NLEADM(r-1,c),factor1));
         valsD.push_back(make_pair(NLEADD(r-1,c),factor2));
         valsD.push_back(make_pair(NLEADI(r-1,c),factor3));
         valsD.push_back(make_pair(NLTOTD(r  ,c),dist2ref));
         NLEADD(r,c) = NLSE_w_factors(valsD);
         // RECURRENCES GOING INTO I STATE 
         vector<pair<double,double> > valsI;
         factor1 = fsaPr["MI"]*pr_null_Tj; // M->I
         factor2 = fsaPr["DI"]*pr_null_Tj; // D->I
         factor3 = fsaPr["II"]*pr_null_Tj; // I->I
         valsI.push_back(make_pair(NLEADM(r,c-1),factor1));
         valsI.push_back(make_pair(NLEADD(r,c-1),factor2));
         valsI.push_back(make_pair(NLEADI(r,c-1),factor3));
         valsI.push_back(make_pair(NLTOTI(r,c  ),dist2ref));
         NLEADI(r,c) = NLSE_w_factors(valsI);
       }
    }
    vector<pair<double,double> > sinknlead;
    sinknlead.push_back(make_pair(NLEADM(nRows-1,nCols-1),1));
    sinknlead.push_back(make_pair(NLEADD(nRows-1,nCols-1),1));
    sinknlead.push_back(make_pair(NLEADI(nRows-1,nCols-1),1));
    double negLogExpAlgnDist = NLSE_w_factors(sinknlead);
    //now normalize ead by totpr
    vector<pair<double,double> > sinknltot;
    sinknltot.push_back(make_pair(NLTOTM(nRows-1,nCols-1),1));
    sinknltot.push_back(make_pair(NLTOTD(nRows-1,nCols-1),1));
    sinknltot.push_back(make_pair(NLTOTI(nRows-1,nCols-1),1));
    double negLogTotpr = NLSE_w_factors(sinknltot);
    double diff = negLogExpAlgnDist-negLogTotpr;
    double normalized_ead = exp(-diff)/(S.length()+T.length());
    return normalized_ead;
}
