#include "misc.h"

/* Return the catalan number C_K*/
int catalan_C(int K) {
    assert(K >= 0);
    if (K < 2) return 1;
    double val = 1;
    for (int i = 2; i <= K; i++) {
        val *= (double)(K+i)/i;
    }
    round(val);
    return (int)val;
}

double wtCodeLen_bits(int n) {
    n++; // add one to map from 0 onwards
    assert(n>0); // only non-negative integers
    if (n==1) return 1;

    int cumul = 0;
    int k = 0;
    while (cumul < n) {
        cumul += catalan_C(k++);
    }
    return 2*(k-1)+1;
}

double wtCodeLen_nits(int n) {
    return bits2nits(wtCodeLen_bits(n));
}


vector<double> getWTCodelensNormalized_0_to_N_nits(int N) {
    vector<double> codelens;
    double sum_prob = 0.0;
    for (int i = 0; i <= N; i++) {
        double l = wtCodeLen_nits(i);
        sum_prob += (double)(1.0)/(double)exp(l);
        //cout << setw(5) << i << setw(10) << setprecision(3) << nits2bits(l) << setprecision(4) << sum_prob << endl;
        codelens.push_back(l);
    }
    //cout << sum_prob << endl;
    double sum_prob2 = 0.0;
    for (int i = 0; i < codelens.size(); i++) {
        double pr = (double)exp(-codelens[i])/(double)sum_prob;
        codelens[i] = -log(pr);
        sum_prob2 += pr;
        //cout << setw(5) << i << setw(10) << setprecision(3) << nits2bits(codelens[i]) << setprecision(4) << sum_prob2 << endl;
    }
    //cout << "here\n";
    return codelens;
}

double logStar_bits( size_t n ) { 
	assert(n>=0) ;
  if (n == 0) return 0;

	double norm_const = log(2.865)/log(2) ; //in bits
	if (n == 1) return norm_const ;

	double val = log(n)/log(2) ; //in bits
	double prev = val ;

	while( val >= 0 ) {
		double t = log(prev)/log(2); //in bits
		if( t <= 0 ) break ;
		val += t ;
		prev = t ;
	}	
	return val+norm_const ;
}

double logStar_nits( size_t n ) { 
	assert(n>=0) ;
  if (n == 0) return 0;
  double val = logStar_bits(n);
  return bits2nits(val); // in nits
}

double nits2bits( double x ) {
  static double log2 = log(2) ; // log 2 to the base e
  return x/log2 ;
}

double bits2nits( double x ) {
  static double log2 = log(2) ; // log 2 to the base e
  return x*log2 ;
}

void printSeq2Term(const string s, string msg) {
    if (msg.compare("") != 0) {
        cout << msg << endl;
    }
    for (int i = 0; i < s.length(); i++) {
        cout << s[i];
        if ((i+1)%60 == 0) cout << "\n";
    }
    if (s.length()%60!=0) cout << endl;
}

void printMatrixToTerm(
 const vector<vector<double> > mat, 
 int p,
 int w,
 string msg,
 int nRows, 
 int nCols 
) {
    if (msg.compare("") != 0) {
        cout << msg << endl;
    }
    for (int i = 0; i < nRows; i++) {
        for (int j = 0; j < nCols; j++) {
            cout << fixed << setprecision(p) << setw(w) << mat[i][j];
        }
        cout << endl;
    }
}

vector<vector<double> > matrixAdd(
 vector<vector<double> > A, 
 vector<vector<double> > B
) {
    size_t nRows_A = A.size();
    size_t nCols_A = A.size() > 0 ? A[0].size(): 0;

    size_t nRows_B = B.size();
    size_t nCols_B = B.size() > 0 ? B[0].size(): 0;
    assert(nRows_A == nRows_B);
    assert(nCols_A == nCols_B);

    vector<vector<double> > res;
    for (size_t i = 0 ; i < nRows_A; i++) {
        res.push_back(vector<double>(nCols_A,0));
    }
    size_t nRows_res = res.size();
    size_t nCols_res = res.size() > 0 ? res[0].size(): 0;
    assert(nRows_res == nRows_A);
    assert(nCols_res == nCols_A);

    for (size_t i = 0; i < nRows_A; i++) {
        for (size_t j = 0; j < nCols_A; j++) {
                res[i][j] += A[i][j] + B[i][j];
        }
    }
    return res;
}

vector<vector<double> > matrixMultiply(
 vector<vector<double> > A, 
 vector<vector<double> > B
) {
    size_t nRows_A = A.size();
    size_t nCols_A = A.size() > 0 ? A[0].size(): 0;

    size_t nRows_B = B.size();
    size_t nCols_B = B.size() > 0 ? B[0].size(): 0;
    assert(nCols_A == nRows_B);

    vector<vector<double> > prod;
    for (size_t i = 0 ; i < nRows_A; i++) {
        prod.push_back(vector<double>(nCols_B,0));
    }
    size_t nRows_prod = prod.size();
    size_t nCols_prod = prod.size() > 0 ? prod[0].size(): 0;
    assert(nRows_prod == nRows_A);
    assert(nCols_prod == nCols_B);

    for (size_t i = 0; i < nRows_A; i++) {
        for (size_t j = 0; j < nCols_B; j++) {
            for (size_t k = 0; k < nCols_A; k++) {
                prod[i][j] += (A[i][k]*B[k][j]);
            }
        }
    }
    return prod;
}

vector<vector<double> > matrixMultiply(
 vector<vector<double> > A, 
 int lambda
) {
    size_t nRows_A = A.size();
    size_t nCols_A = A.size() > 0 ? A[0].size(): 0;

    vector<vector<double> > prod;
    for (size_t i = 0 ; i < nRows_A; i++) {
        prod.push_back(vector<double>(nCols_A,0));
    }
    size_t nRows_prod = prod.size();
    size_t nCols_prod = prod.size() > 0 ? prod[0].size(): 0;
    assert(nRows_prod == nRows_A);
    assert(nCols_prod == nCols_A);

    for (size_t i = 0; i < nRows_A; i++) {
        for (size_t j = 0; j < nCols_A; j++) {
                prod[i][j] = lambda*A[i][j];
        }
    }
    return prod;
}


void printFSAstr2Algn(const string &fsastr, const string &S, const string &T) {
    size_t i = 0,j = 0;
    vector<string> algn(2);
    for (size_t l = 0; l < fsastr.length(); l++){
        switch (fsastr[l]) {
            case 'm': algn[0].push_back(S[i++]);
                      algn[1].push_back(T[j++]);
                      break;
            case 'i': algn[0].push_back('-');
                      algn[1].push_back(T[j++]);
                      break;
            case 'd': algn[0].push_back(S[i++]);
                      algn[1].push_back('-');
                      break;
        }
    }
    assert(S.length()==i);
    assert(T.length()==j);
    assert(algn[0].length() == fsastr.length());
    assert(algn[1].length() == fsastr.length());

    size_t nCharsPerLine = 60;
    size_t nLines        = (size_t)fsastr.length()/nCharsPerLine+1;
    size_t nc = nCharsPerLine;
    int nEquiv = 0;
    for (size_t n = 0; n < nLines; n++) {
        if (n == nLines-1) nc = algn[0].length()%nCharsPerLine;
        cout << endl;
        for (size_t c = 0; c < nc; c++) {
            if((c+1)%10 == 0) cout << ".";
            else cout << " ";
        }
        cout << endl;
        for (size_t c = 0; c < nc; c++) {
            cout << algn[0][n*nCharsPerLine+c];
        }
        cout << endl;
        for (size_t c = 0; c < nc; c++) {
            cout << algn[1][n*nCharsPerLine+c];
        }
        cout << endl;
        for (size_t c = 0; c < nc; c++) {
            if (algn[0][n*nCharsPerLine+c] == algn[1][n*nCharsPerLine+c]) {
                cout << "*";
                nEquiv++;
            }
            else {
                cout << " ";
            }
        }
    }
    cout << endl;
    cout << "nEquivalences = " << nEquiv << endl;
    cout << "\%-Identity (as a function of |A|)  = " << double(nEquiv*100)/fsastr.length() << endl;
    cout << "\%-Identity (as a function of |S|)  = " << double(nEquiv*100)/S.length() << endl;
    cout << "\%-Identity (as a function of |T|)  = " << double(nEquiv*100)/T.length() << endl;
    cout << endl;
}

vector<string> fsastr2Algn(const string &fsastr, const string &S, const string &T) {
    vector<string> algn(2);
    size_t i = 0,j = 0;
    for (size_t l = 0; l < fsastr.length(); l++){
        switch (fsastr[l]) {
            case 'm': algn[0].push_back(S[i++]);
                      algn[1].push_back(T[j++]);
                      break;
            case 'i': algn[0].push_back('-');
                      algn[1].push_back(T[j++]);
                      break;
            case 'd': algn[0].push_back(S[i++]);
                      algn[1].push_back('-');
                      break;
        }
    }
    assert(S.length()==i);
    assert(T.length()==j);
    assert(algn[0].length() == fsastr.length());
    assert(algn[1].length() == fsastr.length());

    return algn;
}

string algn2Fsastr(const vector<string> &algn) {
    string fsastr = "";
    assert(algn[0].length() == algn[1].length());

    for (size_t l = 0; l < algn[0].length(); l++) {
        if (algn[0][l] != '-' && algn[1][l] != '-') fsastr += 'm';
        else if (algn[0][l] == '-') fsastr += 'i';
        else if (algn[1][l] == '-') fsastr += 'd';
    }
    assert(algn[0].length() == fsastr.length());
    assert(algn[1].length() == fsastr.length());

    return fsastr;
}

string reverse(string r) {
    for (int i = 0, j = r.length()-1; i < j; i++, j--) {
        char tmp = r[i];
        r[i] = r[j];
        r[j] = tmp;
    }
    return r;
}

void printToTerminalMsgStatus(string s) {
    cout << "[ \033[32;4m" << s << "\033[0m ]" << endl;
}

void printToTerminalMsgStatus_wo_newline(string s) {
    cout << "[ \033[32;4m" << s << "\033[0m ]";
}

string getTerminalColorString(string s,size_t color) {
    ostringstream oss;
    oss << " \033["<< color<<";4m" << s << "\033[0m " ;
    return oss.str();
}

string getFilenameFromPath(string fn) {
      size_t fnpos = fn.find_last_of("/\\");
      size_t extnpos = fn.find_last_of(".");
      if(fnpos!= string::npos && extnpos < fnpos) extnpos = string::npos;

           if (fnpos == string::npos && extnpos == string::npos) return fn;
      else if (fnpos == string::npos)   return fn.substr(0,extnpos);
      else if (extnpos == string::npos) return fn.substr(fnpos+1);
      else    return fn.substr(fnpos+1,extnpos-fnpos-1);
}
