#include "SubmatClass.h"
#define NMATRICES 600

SubmatClass_t::SubmatClass_t() {
}

SubmatClass_t::SubmatClass_t(string matrixType) {
  /* define aa alphabet order and create index map*/
  alphabet ="ARNDCQEGHILKMFPSTWYV";
  for (int i = 0; i < alphabet.length(); i++) {
      aa2idx.insert (pair<char,int>(alphabet[i],i));
  }

  string    S_filepath = "data/" + matrixType +"_S_double_precision.txt";
  string    D_filepath = "data/" + matrixType +"_D_double_precision.txt";
  string invS_filepath = "data/" + matrixType +"_Sinv_double_precision.txt";
  //assumes reading an alphabet.size() x alphabet.size() matrices below.
  S        = readMatrix(S_filepath);
  D        = readMatrix(D_filepath);
  invS     = readMatrix(invS_filepath);
  //dirichletpriors = readDirichletPriors("data/DAM_DirichletPriorsGivenN.txt");

  generateStationaryDistribution();
  generateConditionalMatrices();
  generateJointMatrices();
  Mpow1 = cMat[1];
}

vector<vector<double> > SubmatClass_t::readMatrix(string fn) {
    ifstream infile(fn.c_str(),ios::in);
    assert(infile);
    vector<vector<double> > mat;
    size_t nRows = alphabet.size();
    size_t nCols = nRows;
    for (size_t i = 0; i < nRows; i++) {
        vector<double> row;
        double tmp;
        for (size_t j = 0; j < nCols; j++) {
            infile >> tmp;
            row.push_back(tmp);
        }
        mat.push_back(row);
    }
    infile.close();
    return mat;
}

vector<vector<DirichletClass_t> > SubmatClass_t::readDirichletPriors(string fn) {
    ifstream infile(fn.c_str(),ios::in);
    assert(infile);
    vector<DirichletClass_t> priors_Mstate;
    vector<DirichletClass_t> priors_Istate;
    char buff[1000];
    int cntr = 1;
    while (!infile.eof()) {
        infile.getline(buff,1000);
        if (infile.eof()) break;

        if (buff[0] == '#') continue; //ignore comment line
        int pamK;
        double m1,m2, i1,i2,i3;
        vector<double> Malphas(2);
        vector<double> Ialphas(3);
        stringstream ss;
        ss << buff;
        ss >> pamK;
        assert(pamK == cntr++);

        ss >> Malphas[0];
        ss >> Malphas[1];
        ss >> Ialphas[0];
        ss >> Ialphas[1];
        ss >> Ialphas[2];
        DirichletClass_t objM(2,Malphas);
        DirichletClass_t objI(3,Ialphas);
        priors_Mstate.push_back(objM);
        priors_Istate.push_back(objI);
    }
    infile.close();
    assert(priors_Mstate.size() == NMATRICES);
    assert(priors_Istate.size() == NMATRICES);

    vector<vector<DirichletClass_t> > priors;
    priors.push_back(priors_Mstate);
    priors.push_back(priors_Istate);
    assert(priors.size()==2);  
    return priors;
}

void SubmatClass_t::generateStationaryDistribution() {
    //Find eigvec in S corresp. to eigvalue=1 in D.
    //Considering floating point rep, instead find the index of largest eigvalue
    //in D and then read the corresp. col vec in S, and normalize.
    double evaloneidx = 0;
    double maxval = D[0][0];
    for( int i = 1; i < alphabet.size(); i++) {
        if (D[i][i] > maxval) {
            maxval = D[i][i];
            evaloneidx = i;
        }
    }
    //read evaloneindx col in S matrix
    vector<double> evec;
    double sum = 0.0;
    for (int i =0; i < alphabet.size(); i++) {
        evec.push_back(S[i][evaloneidx]);
        sum += S[i][evaloneidx];
    }

    //normalize evec and write to stationarydist
    double normsum = 0.0;
    for (int i =0; i < alphabet.size(); i++) {
        double normval = (double)evec[i]/(double)sum;
        stationarydist.push_back(normval);
        normsum += normval;
    }
    assert(stationarydist.size() == alphabet.size());
    assert(normsum > 0.999999 && normsum < 1.000001);
    //check if all elems lie between [0,1]
    for (int i = 0; i < alphabet.size(); i++) {
        assert(stationarydist[i] >=0 && stationarydist[i] <=1);
    }
}

void SubmatClass_t::generateConditionalMatrices() {
    assert( cMat.size()== 0); // sanity check;
    for (int N = 0; N <=NMATRICES; N++) {
        vector<vector<double> > DpowN = D;
        for (size_t i = 0;  i < D.size(); i++) {
            DpowN[i][i] = pow(D[i][i],N);
        }
        vector<vector<double> > mat =  matrixMultiply(S,matrixMultiply(DpowN,invS));
        //normalize mat.
        normalizeConditionalMat(mat);
        cMat.push_back(mat);
    }
    assert( cMat.size()== NMATRICES+1); // matrices for N=0 to N=NMATRICES;
}

void SubmatClass_t::generateJointMatrices() {
    assert( cMat.size()== NMATRICES+1 && jMat.size() == 0); //sanity check
    
    for (int N = 0; N <=NMATRICES; N++) {
        vector<vector<double> > mat;
        //alloc memory for mat
        for (int i =0; i < alphabet.size(); i++) {
            mat.push_back(vector<double>(alphabet.size(),0.0));
        }
        //compute joint using stationary and conditional
        for (int i = 0; i < alphabet.size(); i++) {
            for (int j = i; j < alphabet.size(); j++) {
                mat[i][j] = (double)(stationarydist[j]*cMat[N][i][j] 
                             +       stationarydist[i]*cMat[N][j][i])
                             /(double)2;
                if (i != j) mat[j][i] = mat[i][j];
            }
        }
        normalizeJointMat(mat);
        jMat.push_back(mat);
    }
    assert( jMat.size()== NMATRICES+1); // matrices for N=0 to N=NMATRICES;
}


void SubmatClass_t::regenerateJointMatricesUsingNullModelEstimates(nullModelClass_t &nullMdl) {
    cout << "!!!!!@@@@@#####\n";
    assert( cMat.size()== NMATRICES+1 && jMat.size() == NMATRICES+1); //sanity check
    
    for (int N = 0; N <=NMATRICES; N++) {
        //compute joint using stationary and conditional
        for (int i = 0; i < alphabet.size(); i++) {
            for (int j = i; j < alphabet.size(); j++) {
                string char_i(1,alphabet[i]);
                string char_j(1,alphabet[j]);

                jMat[N][i][j] = (double)(nullMdl.pr(char_j)*cMat[N][i][j] 
                             +           nullMdl.pr(char_i)*cMat[N][j][i])
                             /(double)2;
                if (i != j) jMat[j][i] = jMat[i][j];
            }
        }
        normalizeJointMat(jMat[N]);
    }
}


//joint pr of x,y at  N
double SubmatClass_t::jointpr(char x, char y, int N) {
    map<char,int>::iterator it;
    it = aa2idx.find(x);
    assert(it != aa2idx.end());
    it = aa2idx.find(y);
    assert(it != aa2idx.end());
    return jMat[N][aa2idx[x]][aa2idx[y]];
}

//conditional pr x|y at N
double SubmatClass_t::condlpr(char x, char y, int N) {
    map<char,int>::iterator it;
    it = aa2idx.find(x);
    assert(it != aa2idx.end());
    it = aa2idx.find(y);
    assert(it != aa2idx.end());
    return cMat[N][aa2idx[x]][aa2idx[y]];
}


string SubmatClass_t::getAlphabet() {
    return alphabet;
}

double SubmatClass_t::getS(int x,int y) {
    static int n = alphabet.length();
    assert(x < n && x >= 0);
    assert(y < n && y >= 0);
    return S[x][y];
}

double SubmatClass_t::getSinv(int x,int y) {
    static int n = alphabet.length();
    assert(x < n && x >= 0);
    assert(y < n && y >= 0);
    return invS[x][y];
}

double SubmatClass_t::getD(int x,int y) {
    static int n = alphabet.length();
    assert(x < n && x >= 0);
    assert(y < n && y >= 0);
    return D[x][y];
}

int SubmatClass_t::getIndx(char x) {
    assert(alphabet.find(x)!=string::npos);
    return aa2idx[x];
}

void SubmatClass_t::normalizeConditionalMat(vector<vector<double> > &M) {
    /* assumes input is a conditional probability matrix, and normalizes
     * by ensuring all cols add to 1 up to double precision
     * */
    int size = alphabet.length();

    /* When N is small (eg N=1; note: M input is an instance of MpowN),
     * some elements of M approach 0 from the -ve side due to numerical 
     * issues. So, first convert values <0 to 0 */
    for (int i = 0;  i < size; i++) {
        for (int j = 0;  j < size; j++) {
            if (M[i][j] < 0) M[i][j] = 0.0;
        }
    }

    //find col sums
    vector<double> colsum;
    for (int j = 0;  j < size; j++) {
        double tmpsum = 0;
        for (int i = 0;  i < size; i++) {
            tmpsum += M[i][j];
        }
        colsum.push_back(tmpsum);
    }

    //normalize
    for (int i = 0;  i < size; i++) {
        for (int j = 0;  j < size; j++) {
            M[i][j] /= colsum[j];
        }
    }
}

void SubmatClass_t::normalizeJointMat(vector<vector<double> > &M) {
    /* assumes input is a joint probability matrix, and normalizes
     * by ensuring all elements over the matrix add to 1 to double precision
     * */
    int size = alphabet.length();

    //ensure no elem is <0
    for (int i = 0;  i < size; i++) {
        for (int j = 0;  j < size; j++) {
            assert(M[i][j] >= 0);
        }
    }

    //find col sums
    double allsum = 0.0;
    for (int j = 0;  j < size; j++) {
        for (int i = 0;  i < size; i++) {
            allsum += M[i][j];
        }
    }

    //normalize
    for (int i = 0;  i < size; i++) {
        for (int j = 0;  j < size; j++) {
            M[i][j] /= allsum;
        }
    }
}


DirichletClass_t SubmatClass_t::getDirichlet(string state,int N) {
    //only allowed arguments are "match" and "insert"
    if (state.compare("match") == 0) return dirichletpriors[0][N-1]; //1 <= N <= NMATRICES 
    else if (state.compare("insert") == 0) return dirichletpriors[1][N-1]; //1 <= N <= NMATRICES 

    else assert(!"getDirichlet state string invalid") ;
}

void SubmatClass_t::printStationaryDist() {
    cout << "--Printing stationary distribution of the matrix:"<< endl;
    for (int i = 0;  i < alphabet.size(); i++) {
        cout <<  alphabet[i] << " " << setw(10) << setprecision(7) << stationarydist[i] << endl;
    }
    cout << "---\n";
}

void SubmatClass_t::printMatpowN(string mattype,int N) { // mattype = {"conditional", "joint"}
    assert(N>=0 && N <=NMATRICES);
    cout << "--Printing details of " << mattype << "Mat(" << N << ")"<< endl;



    int size = alphabet.length();
    //find row sums
    vector<double> rowsum;
    double allsum = 0.0;
    for (int i = 0;  i < size; i++) {
        double tmpsum = 0;
        for (int j = 0;  j < size; j++) {
            if (mattype.compare("conditional") == 0) {
                tmpsum += cMat[N][i][j];
                allsum += cMat[N][i][j];
            }
            else if (mattype.compare("joint") == 0) {
                tmpsum += jMat[N][i][j];
                allsum += jMat[N][i][j];
            }
        }
        rowsum.push_back(tmpsum);
    }
    //find col sums
    vector<double> colsum;
    for (int j = 0;  j < size; j++) {
        double tmpsum = 0;
        for (int i = 0;  i < size; i++) {
            if (mattype.compare("conditional") == 0) {
                tmpsum += cMat[N][i][j];
            }
            else if (mattype.compare("joint") == 0) {
                tmpsum += jMat[N][i][j];
            }
        }
        colsum.push_back(tmpsum);
    }

    int w = 7, p = 3;
    cout << fixed;
    //alphabet row
    cout << setw(w) << " ";
    for (int i = 0;  i < size; i++) {
        cout << setw(w) << alphabet[i];
    }
    cout << " | ";
    cout << setw(w) << "rowsum";
    cout << endl;

    //matrix rows
    for (int i = 0;  i < size; i++) {
        cout << setw(w) << alphabet[i];
        for (int j = 0;  j < size; j++) {
            if (mattype.compare("conditional") == 0) {
                cout << setw(w) << setprecision(p) << cMat[N][i][j];
            }
            else if (mattype.compare("joint") == 0) {
                cout << setw(w) << setprecision(p) << jMat[N][i][j];
            }
        }
        cout << " | ";
        cout << setw(w) << setprecision(p) << rowsum[i];
        cout << endl;
    }
    //col sum line
    cout << "------";
    cout << endl;
    cout << setw(w) << "colsum";
    for (int j = 0;  j < size; j++) {
        cout << setw(w) << setprecision(p) << colsum[j];
    }
    cout << setw(w) << setprecision(p) << allsum;
    cout << endl;
    cout << endl;

    cout << "Printing associated Dirichlet priors for N=" << N << endl;
    cout << "Params for M-state priors: " << endl;
    dirichletpriors[0][N-1].printParams();
    cout << "Params for I-state priors: " << endl;
    dirichletpriors[1][N-1].printParams();
    cout << "---\n";
    cout << "Printing associated FSAparams for N=" << N << endl;
    cout << "PrMM = " << fsaobj.getprMM(N) << endl;
    cout << "PrII = " << fsaobj.getprII(N) << endl;
    cout << "PrIM = " << fsaobj.getprIM(N) << endl;
}

double SubmatClass_t::getprMM(int N) {
    return fsaobj.getprMM(N);
}
double SubmatClass_t::getprII(int N) {
    return fsaobj.getprII(N);
}
double SubmatClass_t::getprIM(int N) {
    return fsaobj.getprIM(N);
}
