package jp.ac.naist.dynamix.mpca;

import jp.ac.naist.dynamix.bitools.MatrixUtils;

/* loaded from: input_file:jp/ac/naist/dynamix/mpca/DAMPEM.class */
public class DAMPEM {
    BPCAUnit[] u;
    private double[][] data;
    int numUnits;
    int max_fDim;
    int iDim;
    double[][] cmf;
    double[] lnL;
    double[] dy;
    double[] ex;
    double[] tmpx;
    double[] tmpg;
    double[][] newMu;
    double[][] ET;
    double EtrS;
    double[] gamma;
    double[] g;
    double[] lng;
    private double[][] comp_mat;
    public double FE;
    private boolean DEBUG = false;
    int maxNumUnits = 60;
    int numData = 0;
    public double gamma0 = 0.001d;
    public double M_STEP_CANCEL_THRESHOLD = 0.5d;
    public double DELETION_THRESHOLD = 0.5d;
    public double FEATURE_DELETION_THRESHOLD = 1.0E-8d;
    public double SMALL_VALUE = 1.0E-10d;
    private final double ln2pi = Math.log(6.283185307179586d);

    public DAMPEM(BPCAUnit[] bPCAUnitArr) {
        this.numUnits = 0;
        this.max_fDim = 0;
        this.iDim = 0;
        this.u = bPCAUnitArr;
        this.iDim = this.u[0].iDim;
        this.max_fDim = this.iDim - 1;
        int i = this.max_fDim;
        this.numUnits = this.u.length;
        this.gamma = new double[this.maxNumUnits];
        this.g = new double[this.maxNumUnits];
        this.lng = new double[this.maxNumUnits];
        this.comp_mat = new double[this.iDim][this.iDim];
        this.newMu = new double[this.maxNumUnits][this.iDim];
        this.ET = new double[this.iDim][i];
        this.dy = new double[this.iDim];
        this.ex = new double[this.iDim];
        this.tmpx = new double[this.iDim];
        this.tmpg = new double[this.maxNumUnits];
        MatrixUtils.init(this.iDim);
    }

    public void setData(double[][] dArr) {
        this.numData = dArr.length;
        this.data = dArr;
        this.cmf = new double[this.numUnits][this.numData];
        this.lnL = new double[this.numData];
    }

    public void setGamma(double[] dArr) {
        this.gamma = dArr;
    }

    public void setCMF(double[][] dArr) {
        this.cmf = dArr;
        MatrixUtils.fillZero(this.gamma);
        for (int i = 0; i < this.numUnits; i++) {
            int i2 = this.u[i].id;
            for (int i3 = 0; i3 < this.numData; i3++) {
                double[] dArr2 = this.gamma;
                dArr2[i2] = dArr2[i2] + dArr[i][i3];
            }
            this.g[i2] = (this.gamma[i2] + this.gamma0) / this.numData;
            this.lng[i2] = Math.log(this.g[i2]);
        }
    }

    public String profile() {
        String stringBuffer = new StringBuffer().append(new StringBuffer().append(new StringBuffer().append("logLikelihoodPerSingleDatum=").append(getMeanLogLikelihood()).toString()).append("\n FE=").append(getFreeEnergy()).toString()).append("\n gamma=[").toString();
        for (int i = 0; i < this.numUnits; i++) {
            stringBuffer = new StringBuffer().append(stringBuffer).append(((int) (this.gamma[this.u[i].id] * 10.0d)) / 10.0d).append(" ").toString();
        }
        String stringBuffer2 = new StringBuffer().append(stringBuffer).append("];\n  tau=[").toString();
        for (int i2 = 0; i2 < this.numUnits; i2++) {
            stringBuffer2 = new StringBuffer().append(stringBuffer2).append(((int) (this.u[i2].tau * 10.0d)) / 10.0d).append(" ").toString();
        }
        String stringBuffer3 = new StringBuffer().append(stringBuffer2).append("];\n  activeFeature=[").toString();
        for (int i3 = 0; i3 < this.numUnits; i3++) {
            stringBuffer3 = new StringBuffer().append(stringBuffer3).append(this.u[i3].numOfActiveFeatures()).append(" ").toString();
        }
        return new StringBuffer().append(stringBuffer3).append("];\n").toString();
    }

    public int addUnit() {
        BPCAUnit[] bPCAUnitArr = new BPCAUnit[this.numUnits + 1];
        for (int i = 0; i < this.numUnits; i++) {
            bPCAUnitArr[i] = this.u[i];
        }
        bPCAUnitArr[this.numUnits] = new BPCAUnit(this.numUnits, this.iDim, this.max_fDim);
        this.u = bPCAUnitArr;
        this.numUnits++;
        return this.numUnits - 1;
    }

    public int addUnit(BPCAUnit bPCAUnit) {
        BPCAUnit[] bPCAUnitArr = new BPCAUnit[this.numUnits + 1];
        for (int i = 0; i < this.numUnits; i++) {
            bPCAUnitArr[i] = this.u[i];
        }
        bPCAUnitArr[this.numUnits] = bPCAUnit;
        this.u = bPCAUnitArr;
        this.numUnits++;
        return this.numUnits - 1;
    }

    public boolean deleteUnit(int i) {
        BPCAUnit[] bPCAUnitArr = new BPCAUnit[this.numUnits - 1];
        boolean z = false;
        int i2 = 0;
        int i3 = 0;
        while (i2 < this.numUnits - 1) {
            if (this.u[i3].id == i) {
                i3++;
                z = true;
            }
            bPCAUnitArr[i2] = this.u[i3];
            i2++;
            i3++;
        }
        if (z) {
            this.u = bPCAUnitArr;
            this.numUnits--;
        }
        return z;
    }

    public double getFreeEnergy() {
        return this.FE;
    }

    public int getNumUnits() {
        return this.numUnits;
    }

    public double[] getTau() {
        double[] dArr = new double[this.numUnits];
        for (int i = this.numUnits - 1; i >= 0; i--) {
            dArr[i] = this.u[i].tau;
        }
        return dArr;
    }

    public double[] getGamma() {
        return this.gamma;
    }

    public double[] getG() {
        return this.g;
    }

    public double[][] getCMF() {
        return this.cmf;
    }

    public double getMeanLogLikelihood() {
        return MatrixUtils.mean(this.lnL);
    }

    private void calcInvRx(BPCAUnit bPCAUnit) {
        int hiddenDim = bPCAUnit.getHiddenDim();
        double d = this.iDim / bPCAUnit.gamma;
        for (int i = hiddenDim - 1; i >= 0; i--) {
            for (int i2 = hiddenDim - 1; i2 >= i; i2--) {
                double d2 = 0.0d;
                for (int i3 = this.iDim - 1; i3 >= 0; i3--) {
                    d2 += bPCAUnit.W[i3][i] * bPCAUnit.W[i3][i2];
                }
                if (i == i2) {
                    bPCAUnit.diagWTW[i] = d2;
                }
                double d3 = (d2 * bPCAUnit.tau) + (d * bPCAUnit.invDw[i][i2]);
                bPCAUnit.Rx[i][i2] = d3;
                bPCAUnit.Rx[i2][i] = d3;
                this.comp_mat[i][i2] = d3;
                this.comp_mat[i2][i] = d3;
            }
            double[] dArr = bPCAUnit.Rx[i];
            int i4 = i;
            dArr[i4] = dArr[i4] + 1.0d;
            double[] dArr2 = this.comp_mat[i];
            int i5 = i;
            dArr2[i5] = dArr2[i5] + 1.0d;
        }
        bPCAUnit.logdetRx = MatrixUtils.logDetWithInverse(hiddenDim, this.comp_mat, bPCAUnit.invRx);
    }

    private void calcAlpha(BPCAUnit bPCAUnit) {
        int hiddenDim = bPCAUnit.getHiddenDim();
        double d = this.iDim / bPCAUnit.gamma;
        for (int i = 0; i < hiddenDim; i++) {
            double d2 = 0.0d;
            for (int i2 = 0; i2 < this.iDim; i2++) {
                d2 += bPCAUnit.W[i2][i] * bPCAUnit.W[i2][i];
            }
            bPCAUnit.alpha[i] = this.iDim / ((bPCAUnit.tau * d2) + (d * bPCAUnit.invDw[i][i]));
        }
    }

    private void calc_tmpg() {
        for (int i = 0; i < this.numUnits; i++) {
            int i2 = this.u[i].id;
            this.tmpg[i2] = ((0.5d * ((-this.u[i].logdetRx) + (this.iDim * (this.u[i].lntau - this.ln2pi)))) + this.lng[i2]) - (this.iDim / (2.0d * (this.gamma[i2] + this.u[i].gmu0)));
        }
    }

    public double logLikelihood(BPCAUnit bPCAUnit, double[] dArr) {
        int hiddenDim = bPCAUnit.getHiddenDim();
        for (int i = this.iDim - 1; i >= 0; i--) {
            this.dy[i] = dArr[i] - bPCAUnit.mu[i];
        }
        MatrixUtils.mul(this.iDim, hiddenDim, this.dy, bPCAUnit.W, this.ex);
        MatrixUtils.mulScalar(hiddenDim, bPCAUnit.tau, this.ex);
        MatrixUtils.mul(hiddenDim, hiddenDim, this.ex, bPCAUnit.invRx, this.tmpx);
        return ((0.5d * MatrixUtils.innerProduct(hiddenDim, this.ex, this.tmpx)) - ((0.5d * bPCAUnit.tau) * MatrixUtils.innerProduct(this.iDim, this.dy, this.dy))) + this.tmpg[bPCAUnit.id];
    }

    public void mStep(BPCAUnit bPCAUnit, double[][] dArr, double[] dArr2) {
        int hiddenDim = bPCAUnit.getHiddenDim();
        MatrixUtils.fillZero(this.ET);
        this.EtrS = 0.0d;
        for (int i = 0; i < this.numData; i++) {
            if (dArr2[i] >= this.SMALL_VALUE) {
                for (int i2 = this.iDim - 1; i2 >= 0; i2--) {
                    this.dy[i2] = dArr[i][i2] - bPCAUnit.mu[i2];
                }
                MatrixUtils.mul(this.iDim, hiddenDim, this.dy, bPCAUnit.W, this.ex);
                MatrixUtils.mulScalar(this.iDim, bPCAUnit.tau, this.ex);
                MatrixUtils.mul(hiddenDim, hiddenDim, this.ex, bPCAUnit.invRx, this.tmpx);
                for (int i3 = this.iDim - 1; i3 >= 0; i3--) {
                    this.EtrS += dArr2[i] * this.dy[i3] * this.dy[i3];
                    for (int i4 = hiddenDim - 1; i4 >= 0; i4--) {
                        double[] dArr3 = this.ET[i3];
                        int i5 = i4;
                        dArr3[i5] = dArr3[i5] + (dArr2[i] * this.dy[i3] * this.tmpx[i4]);
                    }
                }
            }
        }
        double d = this.gamma[bPCAUnit.id];
        MatrixUtils.mulScalar(this.iDim, hiddenDim, 1.0d / d, this.ET);
        this.EtrS /= d;
        bPCAUnit.gamma = d;
        for (int i6 = hiddenDim - 1; i6 >= 0; i6--) {
            for (int i7 = hiddenDim - 1; i7 >= 0; i7--) {
                double d2 = 0.0d;
                for (int i8 = this.iDim - 1; i8 >= 0; i8--) {
                    d2 += this.ET[i8][i6] * bPCAUnit.W[i8][i7];
                }
                bPCAUnit.invDw[i6][i7] = d2 * bPCAUnit.tau;
            }
            double[] dArr4 = bPCAUnit.invDw[i6];
            int i9 = i6;
            dArr4[i9] = dArr4[i9] + 1.0d;
        }
        MatrixUtils.mul(hiddenDim, hiddenDim, hiddenDim, bPCAUnit.invDw, bPCAUnit.invRx, this.comp_mat);
        for (int i10 = hiddenDim - 1; i10 >= 0; i10--) {
            double[] dArr5 = this.comp_mat[i10];
            int i11 = i10;
            dArr5[i11] = dArr5[i11] + (bPCAUnit.alpha[i10] / d);
        }
        if (this.DEBUG) {
            MatrixUtils.disp("comp_mat00", this.comp_mat);
        }
        MatrixUtils.inverse(hiddenDim, this.comp_mat, bPCAUnit.invDw);
        MatrixUtils.mul(this.iDim, hiddenDim, hiddenDim, this.ET, bPCAUnit.invDw, bPCAUnit.W);
        if (this.DEBUG) {
            MatrixUtils.disp("ET", this.ET);
            MatrixUtils.disp("W", bPCAUnit.W);
            MatrixUtils.disp("invDw", bPCAUnit.invDw);
            MatrixUtils.disp("invRx", bPCAUnit.invRx);
            System.out.println(new StringBuffer().append("e1=").append(d).toString());
        }
        MatrixUtils.orthogonalize(bPCAUnit.W);
        bPCAUnit.tau = (this.iDim + ((2.0d * bPCAUnit.gtau0) / bPCAUnit.gamma)) / ((this.EtrS - MatrixUtils.matrixInnerProduct(this.iDim, hiddenDim, bPCAUnit.W, this.ET)) + (((MatrixUtils.innerProduct(this.iDim, bPCAUnit.mu, bPCAUnit.mu) * bPCAUnit.gmu0) + ((2.0d * bPCAUnit.gtau0) / bPCAUnit.btau0)) / bPCAUnit.gamma));
        double d3 = ((this.iDim * d) / 2.0d) + bPCAUnit.gtau0;
        bPCAUnit.lntau = SpecialFunctions.digamma(d3) - Math.log(d3);
        bPCAUnit.tau = Math.min(Math.max(bPCAUnit.tau, bPCAUnit.min_tau), bPCAUnit.max_tau);
        bPCAUnit.lntau += Math.log(bPCAUnit.tau);
        if (this.DEBUG) {
            System.out.println(new StringBuffer().append("ugamma=").append(bPCAUnit.gamma).toString());
        }
        if (this.DEBUG) {
            System.out.println(new StringBuffer().append("EtrS=").append(this.EtrS).toString());
        }
        if (this.DEBUG) {
            System.out.println(new StringBuffer().append("tau=").append(bPCAUnit.tau).toString());
        }
        calcInvRx(bPCAUnit);
        calcAlpha(bPCAUnit);
    }

    public void batchDoStep() {
        MatrixUtils.fillZero(this.newMu);
        MatrixUtils.fillZero(this.gamma);
        calc_tmpg();
        for (int i = 0; i < this.numData; i++) {
            double d = 0.0d;
            double d2 = 1.0E-300d;
            for (int i2 = 0; i2 < this.numUnits; i2++) {
                double logLikelihood = logLikelihood(this.u[i2], this.data[i]);
                if (i2 == 0) {
                    d2 = logLikelihood;
                } else if (logLikelihood > d2) {
                    d *= Math.exp(d2 - logLikelihood);
                    d2 = logLikelihood;
                }
                this.cmf[i2][i] = logLikelihood;
                d += Math.exp(logLikelihood - d2);
            }
            this.lnL[i] = Math.log(d) + d2;
            for (int i3 = 0; i3 < this.numUnits; i3++) {
                double exp = Math.exp(this.cmf[i3][i] - d2) / d;
                this.cmf[i3][i] = exp;
                double[] dArr = this.gamma;
                int i4 = this.u[i3].id;
                dArr[i4] = dArr[i4] + exp;
                for (int i5 = 0; i5 < this.iDim; i5++) {
                    double[] dArr2 = this.newMu[i3];
                    int i6 = i5;
                    dArr2[i6] = dArr2[i6] + (exp * this.data[i][i5]);
                }
            }
        }
        for (int i7 = 0; i7 < this.numUnits; i7++) {
            for (int i8 = 0; i8 < this.iDim; i8++) {
                this.u[i7].mu[i8] = this.newMu[i7][i8] / (this.gamma[this.u[i7].id] + this.u[i7].gmu0);
            }
        }
        calcFreeEnergy();
        double d3 = this.numData;
        for (int i9 = 0; i9 < this.numUnits; i9++) {
            if (this.gamma[this.u[i9].id] > this.M_STEP_CANCEL_THRESHOLD) {
                mStep(this.u[i9], this.data, this.cmf[i9]);
            }
        }
        int i10 = 0;
        while (i10 < this.numUnits) {
            if (this.gamma[this.u[i10].id] < this.DELETION_THRESHOLD) {
                deleteUnit(this.u[i10].id);
                i10--;
            }
            i10++;
        }
        for (int i11 = 0; i11 < this.numUnits; i11++) {
            int i12 = this.u[i11].id;
            this.lng[i12] = SpecialFunctions.digamma((this.gamma[i12] + (this.gamma0 / this.numUnits)) + 1.0d) - SpecialFunctions.digamma((d3 + this.gamma0) + this.numUnits);
        }
        deleteDeadFeatures();
    }

    public double calcModelComplexity() {
        double d = 0.0d;
        double d2 = 0.0d;
        double d3 = 0.0d;
        double d4 = 0.0d;
        for (int i = this.numUnits - 1; i >= 0; i--) {
            d += this.u[i].modelComplexity();
            double gamma = this.u[i].getGamma() + (this.gamma0 / this.numUnits);
            d2 += gamma;
            d3 += this.gamma0 / this.numUnits;
            d4 += (SpecialFunctions.gammaln(gamma + 1.0d) - SpecialFunctions.gammaln((this.gamma0 / this.numUnits) + 1.0d)) + (((this.gamma0 / this.numUnits) - gamma) * this.lng[i]);
        }
        return d + d4 + (SpecialFunctions.gammaln(d3 + this.numUnits) - SpecialFunctions.gammaln(d2 + this.numUnits));
    }

    public boolean deleteDeadFeatures() {
        boolean z = false;
        for (int i = 0; i < this.numUnits; i++) {
            double d = 10.0d;
            int i2 = 0;
            for (int i3 = this.u[i].fDim - 1; i3 >= 0; i3--) {
                if (this.u[i].diagWTW[i3] < d) {
                    d = this.u[i].diagWTW[i3];
                    i2 = i3;
                }
            }
            if (d < this.FEATURE_DELETION_THRESHOLD) {
                this.u[i].deleteFeature(i2);
                z = true;
            }
        }
        return z;
    }

    public double calcFreeEnergy() {
        this.FE = 0.0d;
        this.FE = (getMeanLogLikelihood() * this.numData) + calcModelComplexity();
        return this.FE;
    }
}
