#include "DirichletClass.h"
#include "misc.h"



DirichletClass_t::DirichletClass_t(int order, vector<double> concparams) {
    assert (order >= 2);
    assert(concparams.size() == order);
    bool isValidConcparams = true;
    for (int i = 0; i < concparams.size(); i++) {
        if(concparams[i]<0) {
            isValidConcparams = false;
            break;
        }
    }
    assert(isValidConcparams); 

    nFreeParams = order;
    alphas  = concparams; 

    kappa = 0.0;
    for (int i = 0; i < nFreeParams; i++) kappa += alphas[i];
    mu = getMean();
}

vector<double> DirichletClass_t::getAlphas() {
    return alphas;
}

vector<double> DirichletClass_t::getMean() {
    vector<double> mean;
    for (int i = 0; i < nFreeParams; i++) {
        double res = alphas[i]/kappa;
        mean.push_back(res);
    }
    return mean;
}

vector<double> DirichletClass_t::getMode() {
    vector<double> mode;
    double component_sum = 0.0;
    for (int i = 0; i < nFreeParams; i++) {
        double res = (alphas[i]-1)/(kappa-nFreeParams);
        component_sum += res;
        mode.push_back(res);
    }

    //normalize mode vector
    double sanity_check_sum = 0.0;
    for (int i = 0; i < nFreeParams; i++) {
        double res = mode[i]/component_sum;
        mode[i] = res;
        sanity_check_sum += res;
    }
    assert(sanity_check_sum > 0.999999999 &&  sanity_check_sum < 1.00000001);
    return mode;
}

/* USE geLogPrDensity instead. This can result in -nan */
double DirichletClass_t::getPrDensity(vector<double> x) {
    assert(x.size() == nFreeParams);
    /* pr.density  = normaConst * Product_1<=i<=nFreeParams x_i^{alpha_i-1)
     * where normaConst =  Gamma(kappa)/ Pi_1<=i<=nFreeParams Gamma(alpha_i)
     * */
    double numerator  = tgamma(kappa);
    double denominator = 1.0;
    for (int i = 0; i < nFreeParams; i++) denominator *= tgamma(alphas[i]);
    double normaConst = numerator/denominator;

    double term = 1.0;
    for (int i = 0; i < nFreeParams; i++) term *= pow(x[i],alphas[i]-1);
    return normaConst*term;
}

double DirichletClass_t::getLogPrDensity(vector<double> x) {
    assert(x.size() == nFreeParams);
    /* log(pr.density)  = lognormaConst +  sum_1<=i<=nFreeParams (alpha_i-1)log(x_i)
     * where lognormaConst =  log(Gamma(kappa)) - [sum 1<=i<=nFreeParams log(Gamma(alpha_i))]
     * */
    double lognormaConst = 0.0;
    lognormaConst = lgamma(kappa);
    for (int i = 0; i < nFreeParams; i++) lognormaConst -= lgamma(alphas[i]);

    double logterm = 0.0;
    for (int i = 0; i < nFreeParams; i++) logterm += ((alphas[i]-1)*log(x[i]));

    return lognormaConst+logterm;
}

void DirichletClass_t::printParams() {
    cout << "Printing DirichletClass params....\n";
    cout << "nFreeParams = " << nFreeParams << endl;
    cout << "kappa   = " << kappa << endl;
    cout << "meanvec = " ;
    for (int i = 0; i < nFreeParams; i++) cout << setprecision(3) << setw(7) <<mu[i];
    cout << endl;
    cout << "sanity check -- evaluating pdf at mean\n";
    vector<double> mode = getMode();
    cout << "pdf value at mean:\n";
    //cout << "density: " << getPrDensity(mu) << endl;
    cout << "log-density:" << getLogPrDensity(mu) << endl;
}
