/*
 * Decompiled with CFR 0.152.
 */
package iitb.ugm;

import gnu.trove.TDoubleFunction;
import gnu.trove.TDoubleProcedure;
import gnu.trove.TIntArrayList;
import gnu.trove.TIntDoubleHashMap;
import iitb.inference.Solution;
import iitb.ugm.PairwiseModel;
import iitb.ugm.UDGM;
import iitb.ugm.UDGraph;
import iitb.ugm.UDGraphImpl;

public class PairwiseModelWrapper
extends PairwiseModel {
    public double constTheta = 0.0;
    private double originalZeroLabelScore;
    int[] zeroLabeling;

    public PairwiseModelWrapper(int n) {
        super(n);
        this.numOfNodes = n;
        this.nodeArities = new int[n];
        this.graph = new UDGraphImpl(n);
    }

    public PairwiseModelWrapper(UDGraph graph, int uniformArity) {
        super(graph, uniformArity, null);
    }

    public PairwiseModelWrapper(PairwiseModelWrapper model) {
        this(model.numOfNodes);
        this.maxNodeArity = model.getMaxArity();
        int i = 0;
        while (i < model.numOfNodes) {
            this.nodeArities[i] = model.nodeArities[i];
            if (this.nodeArities[i] > this.maxNodeArity) {
                System.err.println("ERROR: Node Arity exceeds MaxArity=" + this.maxNodeArity + ", change it in the code.");
                System.exit(-1);
            }
            ++i;
        }
    }

    public void setConstTheta() {
        this.zeroLabeling = new int[this.numOfNodes];
        this.originalZeroLabelScore = this.getScore(this.zeroLabeling);
    }

    public void setConstTheta(double constTheta) {
        this.constTheta = constTheta;
    }

    public void copyModelNodeParam(int nodeNum, UDGM oldModel, double magnifier, boolean add) {
        int i = 0;
        while (i < this.nodeArities[nodeNum]) {
            this.setNodePotential(nodeNum, i, (add ? this.getNodePotential(nodeNum, i) : 0.0) + oldModel.getNodePotential(nodeNum, i) * magnifier);
            ++i;
        }
    }

    public void copyModelEdgeParam(int head, int tail, PairwiseModelWrapper oldModel, double magnifier, boolean add) {
        int i = 0;
        while (i < this.nodeArities[head]) {
            int j = 0;
            while (j < this.nodeArities[tail]) {
                this.setEdgePotential(head, i, tail, j, (add ? this.getEdgePotential(head, i, tail, j) : 0.0) + oldModel.getEdgePotential(head, i, tail, j) * magnifier);
                ++j;
            }
            ++i;
        }
    }

    public int getMesgKey(int sender, int receiver, int recvLabel) {
        return this.numOfNodes * this.maxNodeArity * sender + receiver * this.maxNodeArity + recvLabel;
    }

    public int getNodeKey(int v, int label) {
        return v * this.maxNodeArity + label;
    }

    public int getEdgeKey(int v1, int label1, int v2, int label2) {
        return (v1 * this.maxNodeArity + label1) * (this.numOfNodes * this.maxNodeArity) + (v2 * this.maxNodeArity + label2);
    }

    public static double getNodeTheta(TIntDoubleHashMap thetas, int nodeNum, int label, int maxArity) {
        int key = nodeNum * maxArity + label;
        if (thetas.containsKey(key)) {
            return thetas.get(key);
        }
        return 0.0;
    }

    public void renormalize() {
        if (this.zeroLabeling != null) {
            this.constTheta = this.originalZeroLabelScore - this.getScore(this.zeroLabeling) + this.constTheta;
        }
    }

    public void reparameterize(TIntDoubleHashMap edgeMarginals, TIntDoubleHashMap nodeMarginals) {
        int i = 0;
        while (i < this.numOfNodes) {
            int xi = 0;
            while (xi < this.nodeArities[i]) {
                double val = PairwiseModelWrapper.getNodeTheta(nodeMarginals, i, xi, this.maxNodeArity);
                this.setNodePotential(i, xi, val);
                ++xi;
            }
            ++i;
        }
        i = 0;
        while (i < this.numOfNodes) {
            TIntArrayList nbrs = this.graph.getNeighbours(i);
            int k = 0;
            while (k < nbrs.size()) {
                int j = nbrs.get(k);
                if (j >= i) {
                    int l1 = 0;
                    while (l1 < this.nodeArities[i]) {
                        int l2 = 0;
                        while (l2 < this.nodeArities[j]) {
                            double val = edgeMarginals.get(this.getEdgeKey(i, l1, j, l2)) - this.getNodePotential(i, l1) - this.getNodePotential(j, l2);
                            this.setEdgePotential(i, l1, j, l2, val);
                            ++l2;
                        }
                        ++l1;
                    }
                }
                ++k;
            }
            ++i;
        }
    }

    public double getScore(Solution s) {
        return this.getScore(s.labeling);
    }

    public double getScore(int[] labeling) {
        double score = super.getScore(labeling);
        return score + this.constTheta;
    }

    public int getUnlabeledEdgeKey(int i, int j) {
        return i * this.numOfNodes + j;
    }

    public double maxAbsTheta() {
        private class MaxCompute
        implements TDoubleProcedure {
            double maxTheta = 0.0;

            MaxCompute() {
            }

            public boolean execute(double arg0) {
                this.maxTheta = Math.max(this.maxTheta, Math.abs(arg0));
                return false;
            }
        }
        MaxCompute maxVal = new MaxCompute();
        int i = 0;
        while (i < this.numOfNodes) {
            this.getNodePotentialTable(i).getPotentialMap().forEachValue(maxVal);
            ++i;
        }
        i = 0;
        while (i < this.numOfNodes) {
            int j = i + 1;
            while (j < this.numOfNodes) {
                if (this.graph.isAdj(i, j)) {
                    this.getEdgePotentialTable(i, j).getPotentialMap().forEachValue(maxVal);
                }
                ++j;
            }
            ++i;
        }
        return maxVal.maxTheta;
    }

    public void scaleAll(double divisor) {
        private class ScaleValue
        implements TDoubleFunction {
            double div;

            ScaleValue() {
            }

            public double execute(double arg0) {
                return arg0 * this.div;
            }
        }
        ScaleValue scale = new ScaleValue();
        scale.div = divisor;
        int i = 0;
        while (i < this.numOfNodes) {
            this.getNodePotentialTable(i).getPotentialMap().transformValues(scale);
            ++i;
        }
        i = 0;
        while (i < this.numOfNodes) {
            int j = i + 1;
            while (j < this.numOfNodes) {
                if (this.graph.isAdj(i, j)) {
                    this.getEdgePotentialTable(i, j).getPotentialMap().transformValues(scale);
                }
                ++j;
            }
            ++i;
        }
    }

    public String toString() {
        String s = "Numnodes:" + this.numOfNodes + ",\tNumedges:" + this.graph.getNumEdges() + "\nAdjacency:\n";
        int i = 0;
        while (i < this.numOfNodes) {
            s = String.valueOf(s) + "(" + i + ") ";
            int li = 0;
            while (li < this.nodeArities[i]) {
                s = String.valueOf(s) + this.getNodePotential(i, li) + " ";
                ++li;
            }
            s = String.valueOf(s) + "\n";
            ++i;
        }
        i = 0;
        while (i < this.numOfNodes) {
            int j = i + 1;
            while (j < this.numOfNodes) {
                if (this.graph.isAdj(i, j)) {
                    s = String.valueOf(s) + "(" + i + "," + j + ") ";
                    int li = 0;
                    while (li < this.nodeArities[i]) {
                        int lj = 0;
                        while (lj < this.nodeArities[j]) {
                            s = String.valueOf(s) + this.getEdgePotential(i, li, j, lj) + " ";
                            ++lj;
                        }
                        ++li;
                    }
                    s = String.valueOf(s) + "\n";
                }
                ++j;
            }
            ++i;
        }
        return s;
    }
}

