#include<iostream>
#include<fstream>
#include<cstdio>
#include<cstdlib>
#include<cstring>
#include<cassert>
#include<string>
#include<stdint.h>
#include<limits>
#include<map>
#include<vector>
#include<sstream>
#include<time.h>
#include <algorithm>
#include <iterator>

#include "misc.h"
#include "SubmatClass.h"
#include "nullModelClass.h"
#include "dpcellClass.h"
#include "threeStateDPAClass.h"
#include "cmdLineParser.h"
#include "expectedAlignmentDistance.h"
#define VERSIONSTRING "2.5-1"

using namespace std;


void coutHeader() {
    cout << "\033[0;33;1;44m" << "   Seq" << "\033[0;0m";
    cout << "\033[0;37;1;44m" << "MMLigner (v" 
        << "\033[0;31;1;44m"
        << VERSIONSTRING 
        << "\033[0;37;1;44m"
        << ") " << "\033[0;0m";
    cout << "\033[0;37;1;44m";
    cout << setw(52) 
            << "Sequence (Minimum Message Length) aligner of proteins "
        << "\033[0;0m";
    cout << "\033[0;33;1;44m" << "\033[0;0m\n";
    cout << "\033[0;1;37;1;45m" ;
    cout << "   Credits/References:" ; 
    cout << " Refer "
         << "\033[0;28;1;45m" << "README" << "\033[0;0m";
    cout << "\033[0;1;37;1;45m" 
        << " file in the main directory of the source  " 
        << "\033[0;0m";
    cout << left << setw(82);
    cout << endl;
}

vector<string> readFastaFile(char fname[]) {
    vector<string> seqs;

    ifstream infile(fname,ios::in);
    assert(infile);
    string astr;
    char buff[10000];
    size_t cntr = 0 ;
    while (!infile.eof()) {
        infile.getline(buff,10000);
        if (infile.eof()) {
            seqs.push_back(astr);
            break;
        }
        if (buff[0] == '>')  {
        if (cntr > 0) {
            seqs.push_back(astr);
        }
        cntr++;
        astr = "";
        }
        else {
        astr.append(buff);
        }
    }
    infile.close();
    return seqs;
}

vector<string> getAASeqs(vector<string> fn) { //expects files in fasta format
    vector<string> s1 = readFastaFile((char *)fn[0].c_str());
    vector<string> s2 = readFastaFile((char *)fn[1].c_str());
    assert(s1.size() == 1);
    assert(s2.size() == 1);
    int n = s1.size()+s2.size();
    if (n != 2) {
        cerr << "Expecting 2 amino acid sequences. Found " << n << "instead. Fix!\n";
        exit(1);
    }
    vector<string> s;
    s.push_back(s1[0]);
    s.push_back(s2[0]);
    return s;
}

/* Returns length of encoding the characters in the passed string
 * using the passed null model. Does   NOT   include the code len
 * of encoding the string length*/
double getNullEncodingLength(
 string s, 
 nullModelClass_t &nullMdl
) {
    double cdlen = 0;
    size_t order = nullMdl.getOrder();
    size_t alphabetSize = nullMdl.getAlphabetSize();

    //Encode symbols without full context uniformly
    //over the alphabet
    double p;
    for (int i = 0; i < order; i++) {
        p = (double)1/alphabetSize;
        cdlen += -log(p); 
    }
    for (int i = order; i < s.length(); i++) {
        p = nullMdl.pr(s.substr(i-order,order+1));
        cdlen += -log(p);
    }
    return cdlen;
}

double getAdaptiveEncoding(nullModelClass_t &nullMdl, vector<string> seqs) {
    string alph = nullMdl.getAlphabet();
    assert(nullMdl.getOrder() == 0);
    vector<double> curr_null_probs;
    for (int i = 0; i < alph.size(); i++) {
        string s(1,alph[i]);
        curr_null_probs.push_back(nullMdl.pr(s));
    }

    assert(curr_null_probs.size() == alph.size());

    //compute a prior count (int) vector
    //Max 2 digits
    vector<int> counts;
    int sum = 0;
    for (int i = 0; i < curr_null_probs.size(); i++) {
        int c = (int)(curr_null_probs[i]*1000);
        if(c == 0) c = 1; //avoid setting anything to 0;
        sum += c;
        counts.push_back(c);
    }
    assert(counts.size() == curr_null_probs.size());
    map<char,int> adaptivecounts;

    int runningtotal = 0;
    for (int i = 0; i < alph.length(); i++) {
      adaptivecounts.insert(pair<char,int>(alph[i],counts[i]));
      runningtotal += counts[i];
      cout << alph[i] << " " << setw(5) << counts[i] << endl;
    }
    cout << "---- Sum = " << runningtotal;
    assert(runningtotal == sum);

    double msglen = 0.0;

    //adapt using seqs
    map<char,int>::iterator it;
    for (int i = 0; i < seqs.size(); i++) {
        for (int j = 0; j < seqs[i].length(); j++) {
            it = adaptivecounts.find(seqs[i][j]);
            assert( it != adaptivecounts.end());
            msglen += -log((double)it->second/(double)runningtotal);
            cout << it->first << " " << it->second << " " << runningtotal << endl;
            it->second++;
            runningtotal++;
        }
    }

    //override the nullmodel params with  these new params
    for (it = adaptivecounts.begin(); it != adaptivecounts.end(); it++) {
        double pr = (double)it->second/(double)runningtotal;
        string s(1,it->first);
        nullMdl.set_pr(s,pr);
    }
    return msglen;
}

double computeINullST(
 string &S,
 string &T,
 nullModelClass_t &nullMdl,
 double &lensaving,
 double &adaptivesaving) {

    size_t nS = S.size();
    size_t nT = T.size();

    vector<double> normwtcdlens = getWTCodelensNormalized_0_to_N_nits(nS+nT);
    int diff = nS-nT;
    if (diff < 0) diff *= -1;
    //cout << nS << " " << nT << " " << diff << endl;
    //cout << "New computation: " << nits2bits(wtCodeLen_nits(nS+nT)+ normwtcdens[diff]) << endl;


    //regular
    double nullModelMsgLen_S = getNullEncodingLength(S, nullMdl);
    double nullModelMsgLen_T = getNullEncodingLength(T, nullMdl);

    vector<string> seqs;
    seqs.push_back(S);
    seqs.push_back(T);
    
    double mmlestMsgLen = getNullEncodingLength(S, nullMdl) + getNullEncodingLength(T, nullMdl);
    
    //encode lengths
    double combinedLengths_msglen = wtCodeLen_nits(nS+nT)+ normwtcdlens[diff];
    lensaving =  (wtCodeLen_nits(nS)+wtCodeLen_nits(nT))-combinedLengths_msglen;

    cout << "-------------------" << getTerminalColorString("NULL MODEL STATS (SEQUENCE)",45) <<"--------------------\n";
    cout << "NULL(S)          = " << setw(10) << setprecision(3) << fixed << nits2bits(nullModelMsgLen_S) << " bits ";
    cout << "(" << setw(5) << setprecision(3) << fixed << nits2bits(nullModelMsgLen_S)/nS << " bits-per-aa " ;
    cout << "over " << nS << " a.a.)" << endl;
    cout << "NULL(T)          = " << setw(10) << setprecision(3) << fixed << nits2bits(nullModelMsgLen_T) << " bits ";
    cout << "(" << setw(5) << setprecision(3) << fixed << nits2bits(nullModelMsgLen_T)/nT << " bits-per-aa " ;
    cout << "over " << nT << " a.a.)" << endl;
    cout << "Msglen(|S|,|T|)  = " << setw(10) << setprecision(3) << fixed << nits2bits(combinedLengths_msglen) << " bits ";
    cout << "(saved = " << setw(7) << setprecision(3) << fixed << nits2bits(lensaving) << " bits) \n";
    stringstream ss;
    ss << setprecision(3) << fixed << nits2bits(combinedLengths_msglen+nullModelMsgLen_S+nullModelMsgLen_T);
    cout << "NULL(<S,T>)      =" << getTerminalColorString(ss.str(),43) << "bits ";
    cout << "(" << setw(5) << setprecision(3) << fixed << (nits2bits(combinedLengths_msglen+nullModelMsgLen_S+nullModelMsgLen_T))
                    /(double)(nS+nT) << " bits-per-aa " ;
    cout << "over " << nS+nT << " a.a.)" << endl << endl;

    stringstream ss3;
    ss3 << setprecision(3) << fixed << nits2bits(combinedLengths_msglen+mmlestMsgLen);
    cout << "MMLEst(<S,T>)    =" << getTerminalColorString(ss3.str(),43) << "bits ";
    cout << "(" << setw(5) << setprecision(3) << fixed << (nits2bits(combinedLengths_msglen+mmlestMsgLen))
                    /(double)(nS+nT) << " bits-per-aa " ;
    cout << "over " << nS+nT << " a.a.)" << endl << endl;
    cout << "--------------------------------------------------------------------\n";
    cout << endl;

    return combinedLengths_msglen+nullModelMsgLen_S+nullModelMsgLen_T;
}

string removeNonStandardAA(string s) {
    string sanitized_s;
    for (size_t i = 0; i < s.length(); i++) {
        switch (s[i]) {
            case 'A':
            case 'C':
            case 'D':
            case 'E':
            case 'F':
            case 'G':
            case 'H':
            case 'I':
            case 'K':
            case 'L':
            case 'M':
            case 'N':
            case 'P':
            case 'Q':
            case 'R':
            case 'S':
            case 'T':
            case 'V':
            case 'W':
            case 'Y': sanitized_s.push_back(s[i]);
                      break;
            default: //cout << "ignoring " << s[i] << " at pos = " << i+1 << endl;
                     break;
        }
    }
    return sanitized_s;
}

vector<string> removeNonStandardAAColumnsInAlgn(vector<string> a) {
    vector<string> sanitized_a(2);
    for (size_t i = 0; i < a[0].length(); i++) {
        string algncolchars = "";
        algncolchars += a[0][i];
        algncolchars += a[1][i];
        assert(algncolchars.length()==2);
        bool ignoreAlgnColFlag = false;
        for (size_t j = 0; j < 2; j++) {
            switch (algncolchars[j]) {
                case '-':
                case 'A':
                case 'C':
                case 'D':
                case 'E':
                case 'F':
                case 'G':
                case 'H':
                case 'I':
                case 'K':
                case 'L':
                case 'M':
                case 'N':
                case 'P':
                case 'Q':
                case 'R':
                case 'S':
                case 'T':
                case 'V':
                case 'W':
                case 'Y': break;
                default: ignoreAlgnColFlag = true;
                         cout << algncolchars << endl;
                         break;
            }
        }
        if (ignoreAlgnColFlag == false) {
            if (algncolchars[0] != '-' || algncolchars[1] != '-') { //...this ignores (-,-) columns
                sanitized_a[0].push_back(algncolchars[0]);
                sanitized_a[1].push_back(algncolchars[1]);
            }
        }
    }
    return sanitized_a;
}



int main( int argc , char *argv[] ) {
    coutHeader();
    vector<string> param_seqfn;
    string param_runType;
    string param_input_afasta;
    string param_output_prefix;
    string param_priorType = "dirichlet-mode";
    string param_nullMarkovOrder = "0";
    string param_matchStatementType = "symmetric";
    int    param_startingSubmat_N = 250;
    string param_alignmentType = "optimal"; 
    string param_matrixType = "MMLSUM"; 
    bool param_interactive = false;
    string param_input_refalgn; // for ExpAlgnDist computation

    parseCmdLine(argc,argv,param_seqfn,param_runType,param_input_afasta,
            param_output_prefix, param_alignmentType,param_matrixType,param_input_refalgn);
    

    /* get SubmatClass_t obj */
    SubmatClass_t submat(param_matrixType); // instantiate MMLSUM series condl and joint matrices 

    /* get nullModelClass_t obj */
    nullModelClass_t nullMdl(param_nullMarkovOrder);
    assert(nullMdl.isCompatible(submat.getAlphabet()));


    // check if you are running the program in:
    //   ---> alignment mode (i.e. param_runType == default)
    //   or
    //   ---> ivalue computation mode (i.e. param_runType == ivalue)
    //   or
    //   --> expected alignment distance mode (i.e. param_runType == expalgndist)
    if (param_runType.compare("default")==0) { // that is, run in alignment mode
        /* read sequences*/
        vector<string> seqs = getAASeqs(param_seqfn);
        string S = removeNonStandardAA(seqs[0]);
        string T = removeNonStandardAA(seqs[1]);

        cout << setw(1);
        cout << ".................................................Running in"
            << getTerminalColorString(param_alignmentType,42)  << "mode\n";
        cout << "Using matrixType = " <<  param_matrixType << endl;

        printSeq2Term(S,"Seq S");
        printSeq2Term(T,"Seq T");

        /* Compute Null model message length */
        double lensaving,adaptivesaving;
        double nullModelMesgLen = computeINullST(S,T,nullMdl,lensaving,adaptivesaving);

        /* Compute Alignment model message length */
        threeStateDPAClass_t algnobj(S,T,nullMdl,submat,param_startingSubmat_N,
                param_alignmentType,param_output_prefix);
        double modelMsgLen;
        if (param_alignmentType.compare("optimal")==0) {
            modelMsgLen = algnobj.getOptAlignMsgLen();
        }
        else modelMsgLen = algnobj.getMarginalProbMsgLen();

        /* Compression */
        double compression = nullModelMesgLen-modelMsgLen;

        double IA = algnobj.IA; double ISTgA = algnobj.ISTgA; 

        if (param_alignmentType.compare("optimal")==0) {
            
            string algnfsastr = algnobj.getOptAlignFSAstr();
            /* write alignment fasta file */
            vector<string> algn = fsastr2Algn(algnfsastr,S,T);
            assert(algn.size()==2);
            assert(algn[0].length() == algnfsastr.length());
            assert(algn[1].length() == algnfsastr.length());
            
            string outfn = param_output_prefix+".afasta";
            ofstream algnout(outfn.c_str(), ios::out);
            assert(algnout);

            algnout << ">aligned seq 1" << endl;
            for (size_t i = 0; i < algnfsastr.length(); i++) {
                algnout << algn[0][i];
                if ((i+1)%60 == 0 || (i+1) == algnfsastr.length()) algnout << endl;
            }
            algnout << endl;

            algnout << ">aligned seq 2" << endl;
            for (size_t i = 0; i < algnfsastr.length(); i++) {
                algnout << algn[1][i];
                if ((i+1)%60 == 0 || (i+1) == algnfsastr.length()) algnout << endl;
            }
            cout << "Alignment fasta (afasta) file written to:  " 
                << getTerminalColorString(param_output_prefix + ".afasta", 43) << endl;

            cout << "\n----------------------------+=o0o=+-------------------------------\n";
            cout <<  getTerminalColorString("COMPRESSION",46) << "wrt null =   ";
            stringstream ss;
            ss << nits2bits(compression);
            if (compression >0) cout << getTerminalColorString(ss.str(), 42);
            else cout << getTerminalColorString(ss.str(), 41);
            cout << " bits" << endl;
            if (compression > 0) {
                cout << " According to MML, this alignment hypothesis is 2^{"
                    << nits2bits(compression) 
                    << "} times" <<  getTerminalColorString("MORE LIKELY",42) << " than null.\n";
            }
            else {
                cout << "According to MML, this alignment hypothesis is worse than null, so should be"
                    << getTerminalColorString("REJECTED",41) << "\n";
                cout << "Suggestion: check the marginal probability of <S,T> by running in \'"
                    << getTerminalColorString("--criterion marginal",45) << "\' mode.\n";
            }
            cout << "----------------------------+=o0o=+-------------------------------\n";

        } else if (param_alignmentType.compare("marginal")==0) {
            cout << "\n----------------------------+=o0o=+-------------------------------\n";
            cout <<  getTerminalColorString("COMPRESSION",46) << "wrt null =   ";
            stringstream ss;
            ss << nits2bits(compression);

            if (compression >0)
                cout << getTerminalColorString(ss.str(), 42);
            else 
                cout << getTerminalColorString(ss.str(), 41);
            cout << " bits" << endl;

            if (compression > 0) {
                if (nits2bits(compression) > 6){
                    cout << "According to MML, the hypothesis that <S,T> are related is 2^{"
                    << nits2bits(compression) 
                    << "} times MORE LIKELY than null.\n";
                } else {
                    cout << "According to MML, the hypothesis that <S,T> are related is very weak but positive.\n";
                }
            } else {
                cout << "According to MML, the hypothesis that <S,T> are related is worse than null, so NOT related.\n";
            }
            cout << "----------------------------+=o0o=+-------------------------------\n";

        } else {
            cout << "----------------------------+=o0o=+-------------------------------\n";
        }

    } else if (param_runType.compare("ivalue") == 0) { // run in  --ivalue mode
        cout << "\n-------------IVALUE STATISTICS OF A GIVEN ALIGNMENT---------------\n";

        vector<string> algn = readFastaFile((char*)param_input_afasta.c_str());
        assert (algn.size() == 2);
        assert (algn[0].size() == algn[1].size());
        vector<string> sanitized_algn = removeNonStandardAAColumnsInAlgn(algn);
        string algnfsastr = algn2Fsastr(sanitized_algn);
        string S = removeNonStandardAA(sanitized_algn[0]);
        string T = removeNonStandardAA(sanitized_algn[1]);

        printSeq2Term(sanitized_algn[0],"Alignment String of S:");
        printSeq2Term(sanitized_algn[1],"Alignment String of T:");

        /* Compute Null model message length */
        double lensaving,adaptivesaving;
        double nullModelMesgLen = computeINullST(S,T,nullMdl,lensaving,adaptivesaving);

        /* prepare ivalue stats*/
        threeStateDPAClass_t ivalueobj(S,T,algnfsastr,nullMdl,submat);

        double algnModelMsgLen = ivalueobj.getOptAlignMsgLen();
        /* Compression */
        double compression = nullModelMesgLen-algnModelMsgLen;
        cout << "\n----------------------------+=o0o=+-------------------------------\n";
        cout <<  "COMPRESSION wrt null =   " << nits2bits(compression)  << " bits" << endl;
        cout << "----------------------------+=o0o=+-------------------------------\n";

    } else if (param_runType.compare("expalgndist")==0) {
        //first read sequences as provided
        vector<string> seqs = getAASeqs(param_seqfn);
        string S = removeNonStandardAA(seqs[0]);
        string T = removeNonStandardAA(seqs[1]);

        //next read ref alignment...
        vector<string> ref_algn = readFastaFile((char*)param_input_refalgn.c_str());
        assert (ref_algn.size() == 2);
        assert (ref_algn[0].size() == ref_algn[1].size());

        vector<string> sanitized_ref_algn = removeNonStandardAAColumnsInAlgn(ref_algn);
        string refalgnfsastr = algn2Fsastr(sanitized_ref_algn);
        //...and extract sequences from ref alignment
        string extracted_S = removeNonStandardAA(sanitized_ref_algn[0]);
        string extracted_T = removeNonStandardAA(sanitized_ref_algn[1]);

        //Now ensure S==extracted_S and T==extracted_T
        if (S.compare(extracted_S)!=0) {
            cerr << "Error: The  FIRST  a.a. sequence extracted from the provided reference "
                << "alignment differs from the a.a. in the sequence fasta files."
                << "Ensure they are the same (and preserve the order)." << endl;
            exit(1);
        } else if (T.compare(extracted_T)!=0) {
            cerr << "Error: The  SECOND  a.a. sequence extracted from the provided reference "
                << "alignment differs from the a.a. in the sequence fasta files."
                << "Ensure they are the same (and preserve the order)." << endl;
            exit(1);
        }

        /*Expected alignment distannce computation*/
        param_alignmentType = "marginal";
        threeStateDPAClass_t algnobj(S,T,nullMdl,submat,param_startingSubmat_N,
                param_alignmentType,param_output_prefix);
        vector<vector<vector<dpcell_t> > > dpmat = algnobj.getMarginalDPMatrices();
        int markov_time = algnobj.getInferredSubmatN();

        double ead_formal  = expectedAlignmentDistance_formal(S,T,nullMdl,submat,sanitized_ref_algn,dpmat,markov_time);
        double ead_sampled = expectedAlignmentDistance_sampled(S,T,sanitized_ref_algn,dpmat,10000);
        
        cout << "----------------------- EXPECTED ALIGNMENT DISTANCE STATISTICS -----------------------" << endl;
        cout << "Alignment Distance from randomly Sampled Alignments = " << fixed << setprecision(4) << ead_sampled << endl << endl;
        cout << "Expected Alignment Distance                         = " << fixed << setprecision(4) << ead_formal << endl;
        cout << "--------------------------------------+=o0o=+-----------------------------------------\n";
    }
    return 0;
}

