/*
 * Decompiled with CFR 0.152.
 */
package bartMachine;

import OpenSourceExtensions.StatUtil;
import bartMachine.bartMachineClassification;
import bartMachine.bartMachineRegressionMultThread;
import bartMachine.bartMachineTreeNode;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;

public class bartMachineClassificationMultThread
extends bartMachineRegressionMultThread
implements Serializable {
    private static double DEFAULT_CLASSIFICATION_RULE = 0.5;
    private double classification_rule;

    @Override
    protected void SetupBARTModels() {
        this.bart_gibbs_chain_threads = new bartMachineClassification[this.num_cores];
        for (int i = 0; i < this.num_cores; ++i) {
            this.SetupBartModel(new bartMachineClassification(), i);
        }
        this.classification_rule = DEFAULT_CLASSIFICATION_RULE;
    }

    @Override
    public double Evaluate(double[] dArray, int n) {
        return this.EvaluateViaSampAvg(dArray, n) > this.classification_rule ? 1.0 : 0.0;
    }

    @Override
    protected double[][] getGibbsSamplesForPrediction(double[][] dArray, int n) {
        int n2;
        int n3 = this.numSamplesAfterBurning();
        int n4 = dArray.length;
        double[][] dArray2 = new double[n4][n3];
        if (n4 == 0 || n3 <= 0) {
            return dArray2;
        }
        int[] nArray = new int[n4];
        for (n2 = 0; n2 < n4; ++n2) {
            nArray[n2] = n2;
        }
        if (n == 1) {
            double[] dArray3 = new double[n4];
            for (int i = 0; i < n3; ++i) {
                int n5;
                Arrays.fill(dArray3, 0.0);
                bartMachineTreeNode[] bartMachineTreeNodeArray = this.gibbs_samples_of_bart_trees_after_burn_in[i];
                for (n5 = 0; n5 < this.num_trees; ++n5) {
                    bartMachineTreeNodeArray[n5].evaluateBatch(dArray, nArray, dArray3);
                }
                for (n5 = 0; n5 < n4; ++n5) {
                    dArray2[n5][i] = StatUtil.normal_cdf(dArray3[n5]);
                }
            }
        } else {
            n2 = Math.max(1, n3 / n);
            ArrayList<Future<Object>> arrayList = new ArrayList<Future<Object>>();
            try (ExecutorService executorService = Executors.newVirtualThreadPerTaskExecutor();){
                for (int i = 0; i < n3; i += n2) {
                    int n6 = i;
                    int n7 = Math.min(n3, i + n2);
                    arrayList.add(executorService.submit(() -> {
                        double[] dArray3 = new double[n4];
                        for (int i = n6; i < n7; ++i) {
                            int n4;
                            Arrays.fill(dArray3, 0.0);
                            int[] nArray2 = (int[])nArray.clone();
                            bartMachineTreeNode[] bartMachineTreeNodeArray = this.gibbs_samples_of_bart_trees_after_burn_in[i];
                            for (n4 = 0; n4 < this.num_trees; ++n4) {
                                bartMachineTreeNodeArray[n4].evaluateBatch(dArray, nArray2, dArray3);
                            }
                            for (n4 = 0; n4 < n4; ++n4) {
                                dArray2[n4][i] = StatUtil.normal_cdf(dArray3[n4]);
                            }
                        }
                        return null;
                    }));
                }
                for (Future future : arrayList) {
                    future.get();
                }
            }
            catch (Exception exception) {
                throw new RuntimeException(exception);
            }
        }
        return dArray2;
    }

    @Override
    public double[] getPosteriorMeanForPrediction(double[][] dArray, int n) {
        int n2;
        int n3 = this.numSamplesAfterBurning();
        int n4 = dArray.length;
        double[] dArray2 = new double[n4];
        if (n4 == 0 || n3 <= 0) {
            return dArray2;
        }
        int[] nArray = new int[n4];
        for (n2 = 0; n2 < n4; ++n2) {
            nArray[n2] = n2;
        }
        if (n == 1) {
            double[] dArray3 = new double[n4];
            for (int i = 0; i < n3; ++i) {
                int n5;
                Arrays.fill(dArray3, 0.0);
                var9_12 = this.gibbs_samples_of_bart_trees_after_burn_in[i];
                for (n5 = 0; n5 < this.num_trees; ++n5) {
                    var9_12[n5].evaluateBatch(dArray, nArray, dArray3);
                }
                for (n5 = 0; n5 < n4; ++n5) {
                    int n6 = n5;
                    dArray2[n6] = dArray2[n6] + StatUtil.normal_cdf(dArray3[n5]);
                }
            }
        } else {
            n2 = Math.max(1, n3 / n);
            ArrayList<Future<double[]>> arrayList = new ArrayList<Future<double[]>>();
            try {
                var9_12 = Executors.newVirtualThreadPerTaskExecutor();
                try {
                    for (int i = 0; i < n3; i += n2) {
                        int n7 = i;
                        int n8 = Math.min(n3, i + n2);
                        arrayList.add(var9_12.submit(() -> {
                            double[] dArray2 = new double[n4];
                            double[] dArray3 = new double[n4];
                            for (int i = n7; i < n8; ++i) {
                                int n4;
                                Arrays.fill(dArray3, 0.0);
                                int[] nArray2 = (int[])nArray.clone();
                                bartMachineTreeNode[] bartMachineTreeNodeArray = this.gibbs_samples_of_bart_trees_after_burn_in[i];
                                for (n4 = 0; n4 < this.num_trees; ++n4) {
                                    bartMachineTreeNodeArray[n4].evaluateBatch(dArray, nArray2, dArray3);
                                }
                                for (n4 = 0; n4 < n4; ++n4) {
                                    int n5 = n4;
                                    dArray2[n5] = dArray2[n5] + StatUtil.normal_cdf(dArray3[n4]);
                                }
                            }
                            return dArray2;
                        }));
                    }
                    for (Future future : arrayList) {
                        double[] dArray4 = (double[])future.get();
                        for (int i = 0; i < n4; ++i) {
                            int n7 = i;
                            dArray2[n7] = dArray2[n7] + dArray4[i];
                        }
                    }
                }
                finally {
                    if (var9_12 != null) {
                        var9_12.close();
                    }
                }
            }
            catch (Exception exception) {
                throw new RuntimeException(exception);
            }
        }
        int n10 = 0;
        while (n10 < n4) {
            int n8 = n10++;
            dArray2[n8] = dArray2[n8] / (double)n3;
        }
        return dArray2;
    }

    public void setClassificationRule(double d) {
        this.classification_rule = d;
    }
}

