/*
 * Decompiled with CFR 0.152.
 */
package org.apache.sysds.hops.estim;

import java.util.Arrays;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.data.DenseBlock;
import org.apache.sysds.runtime.data.SparseBlock;
import org.apache.sysds.runtime.data.SparseRowVector;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.runtime.util.UtilFunctions;

public abstract class EstimationUtils {
    public static long getSelfProductOutputNnz(MatrixBlock m1) {
        int m = m1.getNumRows();
        int n = m1.getNumColumns();
        long retNnz = 0L;
        if (m1.isInSparseFormat()) {
            SparseBlock a = m1.getSparseBlock();
            SparseRowVector tmpS = new SparseRowVector(1024);
            double[] tmpD = null;
            for (int i = 0; i < m; ++i) {
                boolean ldense;
                if (a.isEmpty(i)) continue;
                int alen = a.size(i);
                int apos = a.pos(i);
                int[] aix = a.indexes(i);
                double[] avals = a.values(i);
                int nnz1 = (int)Math.min(UtilFunctions.computeNnz(a, aix, apos, alen), (long)n);
                boolean bl = ldense = nnz1 > n / 128;
                if (ldense) {
                    tmpD = tmpD == null ? new double[n] : tmpD;
                    Arrays.fill(tmpD, 0.0);
                } else {
                    tmpS.setSize(0);
                }
                for (int k = apos; k < apos + alen; ++k) {
                    int j;
                    if (a.isEmpty(aix[k])) continue;
                    int blen = a.size(aix[k]);
                    int bpos = a.pos(aix[k]);
                    int[] bix = a.indexes(aix[k]);
                    double aval = avals[k];
                    double[] bvals = a.values(aix[k]);
                    if (ldense) {
                        for (j = bpos; j < bpos + blen; ++j) {
                            int n2 = bix[j];
                            tmpD[n2] = tmpD[n2] + aval * bvals[j];
                        }
                        continue;
                    }
                    for (j = bpos; j < bpos + blen; ++j) {
                        tmpS.add(bix[j], aval * bvals[j]);
                    }
                }
                retNnz += !ldense ? (long)tmpS.size() : (long)UtilFunctions.computeNnz(tmpD, 0, n);
            }
        } else {
            DenseBlock a = m1.getDenseBlock();
            double[] tmp = new double[n];
            for (int i = 0; i < m; ++i) {
                double[] avals = a.values(i);
                int aix = a.pos(i);
                Arrays.fill(tmp, 0.0);
                for (int k = 0; k < n; ++k) {
                    double aval = avals[aix + k];
                    if (aval == 0.0) continue;
                    double[] bvals = a.values(k);
                    int bix = a.pos(k);
                    for (int j = 0; j < n; ++j) {
                        int n3 = j;
                        tmp[n3] = tmp[n3] + aval * bvals[bix + j];
                    }
                }
                retNnz += (long)UtilFunctions.computeNnz(tmp, 0, n);
            }
        }
        return retNnz;
    }

    public static long getSparseProductOutputNnz(MatrixBlock m1, MatrixBlock m2) {
        if (!m1.isInSparseFormat() || !m2.isInSparseFormat()) {
            throw new DMLRuntimeException("Invalid call to sparse output nnz estimation.");
        }
        int m = m1.getNumRows();
        int n2 = m2.getNumColumns();
        long retNnz = 0L;
        SparseBlock a = m1.getSparseBlock();
        SparseBlock b = m2.getSparseBlock();
        SparseRowVector tmpS = new SparseRowVector(1024);
        double[] tmpD = null;
        for (int i = 0; i < m; ++i) {
            boolean ldense;
            if (a.isEmpty(i)) continue;
            int alen = a.size(i);
            int apos = a.pos(i);
            int[] aix = a.indexes(i);
            double[] avals = a.values(i);
            int nnz1 = (int)Math.min(UtilFunctions.computeNnz(b, aix, apos, alen), (long)n2);
            boolean bl = ldense = nnz1 > n2 / 128;
            if (ldense) {
                tmpD = tmpD == null ? new double[n2] : tmpD;
                Arrays.fill(tmpD, 0.0);
            } else {
                tmpS.setSize(0);
            }
            for (int k = apos; k < apos + alen; ++k) {
                int j;
                if (b.isEmpty(aix[k])) continue;
                int blen = b.size(aix[k]);
                int bpos = b.pos(aix[k]);
                int[] bix = b.indexes(aix[k]);
                double aval = avals[k];
                double[] bvals = b.values(aix[k]);
                if (ldense) {
                    for (j = bpos; j < bpos + blen; ++j) {
                        int n = bix[j];
                        tmpD[n] = tmpD[n] + aval * bvals[j];
                    }
                    continue;
                }
                for (j = bpos; j < bpos + blen; ++j) {
                    tmpS.add(bix[j], aval * bvals[j]);
                }
            }
            retNnz += !ldense ? (long)tmpS.size() : (long)UtilFunctions.computeNnz(tmpD, 0, n2);
        }
        return retNnz;
    }
}

