package rts.ai.mcts.naivemcts;

import com.fossgalaxy.games.tbs.GameState;
import com.fossgalaxy.games.tbs.order.Order;
import com.fossgalaxy.object.ObjectFinder;
import java.math.BigInteger;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.LinkedHashMap;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.UUID;
import org.codetome.hexameter.core.api.CubeCoordinate;
import rts.PlayerAction;
import rts.ai.mcts.MCTSNode;
import utils.MoveGenerator;
import utils.Sampler;

/* loaded from: input_file:rts/ai/mcts/naivemcts/NaiveMCTSNode.class */
public class NaiveMCTSNode extends MCTSNode {
    public static final int E_GREEDY = 0;
    public static final int UCB1 = 1;
    public static int DEBUG = 0;
    public static float C = 0.05f;
    boolean forceExplorationOfNonSampledActions;
    public MoveGenerator moveGenerator;
    public List<UnitActionTableEntry> unitActionTable;
    double evaluation_bound;
    public BigInteger[] multipliers;
    boolean hasMoreActions = true;
    HashMap<BigInteger, NaiveMCTSNode> childrenMap = new LinkedHashMap();

    public NaiveMCTSNode(int i, int i2, GameState gameState, NaiveMCTSNode naiveMCTSNode, double d, int i3, boolean z) throws Exception {
        this.forceExplorationOfNonSampledActions = true;
        this.moveGenerator = null;
        this.unitActionTable = null;
        this.parent = naiveMCTSNode;
        this.gs = gameState;
        if (this.parent == null) {
            this.depth = 0;
        } else {
            this.depth = this.parent.depth + 1;
        }
        this.evaluation_bound = d;
        this.creation_ID = i3;
        this.forceExplorationOfNonSampledActions = z;
        if (this.gs.isGameOver()) {
            this.type = -1;
            return;
        }
        if (i == i2) {
            this.type = 0;
            this.moveGenerator = new MoveGenerator(this.gs, i);
            this.actions = new ArrayList();
            this.children = new ArrayList();
            this.unitActionTable = new LinkedList();
            this.multipliers = new BigInteger[this.moveGenerator.getChoices().size()];
            BigInteger bigInteger = BigInteger.ONE;
            int i4 = 0;
            for (Map.Entry<UUID, List<Order>> entry : this.moveGenerator.getChoices().entrySet()) {
                UnitActionTableEntry unitActionTableEntry = new UnitActionTableEntry();
                unitActionTableEntry.u = entry.getKey();
                unitActionTableEntry.nactions = entry.getValue().size();
                unitActionTableEntry.actions = entry.getValue();
                unitActionTableEntry.accum_evaluation = new double[unitActionTableEntry.nactions];
                unitActionTableEntry.visit_count = new int[unitActionTableEntry.nactions];
                for (int i5 = 0; i5 < unitActionTableEntry.nactions; i5++) {
                    unitActionTableEntry.accum_evaluation[i5] = 0.0d;
                    unitActionTableEntry.visit_count[i5] = 0;
                }
                this.unitActionTable.add(unitActionTableEntry);
                this.multipliers[i4] = bigInteger;
                bigInteger = bigInteger.multiply(BigInteger.valueOf(unitActionTableEntry.nactions));
                i4++;
            }
            return;
        }
        this.type = 1;
        this.moveGenerator = new MoveGenerator(this.gs, i);
        this.actions = new ArrayList();
        this.children = new ArrayList();
        this.unitActionTable = new LinkedList();
        this.multipliers = new BigInteger[this.moveGenerator.getChoices().size()];
        BigInteger bigInteger2 = BigInteger.ONE;
        int i6 = 0;
        for (Map.Entry<UUID, List<Order>> entry2 : this.moveGenerator.getChoices().entrySet()) {
            UnitActionTableEntry unitActionTableEntry2 = new UnitActionTableEntry();
            unitActionTableEntry2.u = entry2.getKey();
            unitActionTableEntry2.nactions = entry2.getValue().size();
            unitActionTableEntry2.actions = entry2.getValue();
            unitActionTableEntry2.accum_evaluation = new double[unitActionTableEntry2.nactions];
            unitActionTableEntry2.visit_count = new int[unitActionTableEntry2.nactions];
            for (int i7 = 0; i7 < unitActionTableEntry2.nactions; i7++) {
                unitActionTableEntry2.accum_evaluation[i7] = 0.0d;
                unitActionTableEntry2.visit_count[i7] = 0;
            }
            this.unitActionTable.add(unitActionTableEntry2);
            this.multipliers[i6] = bigInteger2;
            bigInteger2 = bigInteger2.multiply(BigInteger.valueOf(unitActionTableEntry2.nactions));
            i6++;
        }
    }

    public NaiveMCTSNode selectLeaf(int i, int i2, float f, float f2, float f3, int i3, int i4, int i5) throws Exception {
        if (this.unitActionTable != null && this.depth < i4) {
            if (this.children.size() <= 0 || r.nextFloat() < f3) {
                return selectLeafUsingLocalMABs(i, i2, f, f2, f3, i3, i4, i5);
            }
            NaiveMCTSNode naiveMCTSNode = null;
            if (i3 == 0) {
                naiveMCTSNode = selectFromAlreadySampledEpsilonGreedy(f2);
            } else if (i3 == 1) {
                naiveMCTSNode = selectFromAlreadySampledUCB1(C);
            }
            return naiveMCTSNode.selectLeaf(i, i2, f, f2, f3, i3, i4, i5);
        }
        return this;
    }

    public NaiveMCTSNode selectFromAlreadySampledEpsilonGreedy(float f) throws Exception {
        if (r.nextFloat() < f) {
            return (NaiveMCTSNode) this.children.get(r.nextInt(this.children.size()));
        }
        NaiveMCTSNode naiveMCTSNode = null;
        for (MCTSNode mCTSNode : this.children) {
            if (this.type == 0) {
                if (naiveMCTSNode == null || mCTSNode.accum_evaluation / mCTSNode.visit_count > naiveMCTSNode.accum_evaluation / naiveMCTSNode.visit_count) {
                    naiveMCTSNode = (NaiveMCTSNode) mCTSNode;
                }
            } else if (naiveMCTSNode == null || mCTSNode.accum_evaluation / mCTSNode.visit_count < naiveMCTSNode.accum_evaluation / naiveMCTSNode.visit_count) {
                naiveMCTSNode = (NaiveMCTSNode) mCTSNode;
            }
        }
        return naiveMCTSNode;
    }

    public NaiveMCTSNode selectFromAlreadySampledUCB1(float f) throws Exception {
        NaiveMCTSNode naiveMCTSNode = null;
        double d = 0.0d;
        for (MCTSNode mCTSNode : this.children) {
            double d2 = mCTSNode.accum_evaluation / mCTSNode.visit_count;
            double sqrt = (f * (this.type == 0 ? (this.evaluation_bound + d2) / (2.0d * this.evaluation_bound) : (this.evaluation_bound - d2) / (2.0d * this.evaluation_bound))) + Math.sqrt(Math.log(this.visit_count) / mCTSNode.visit_count);
            if (naiveMCTSNode == null || sqrt > d) {
                naiveMCTSNode = (NaiveMCTSNode) mCTSNode;
                d = sqrt;
            }
        }
        return naiveMCTSNode;
    }

    public NaiveMCTSNode selectLeafUsingLocalMABs(int i, int i2, float f, float f2, float f3, int i3, int i4, int i5) throws Exception {
        LinkedList linkedList = new LinkedList();
        LinkedList linkedList2 = new LinkedList();
        for (UnitActionTableEntry unitActionTableEntry : this.unitActionTable) {
            if (unitActionTableEntry.nactions != 0) {
                double[] dArr = new double[unitActionTableEntry.nactions];
                int i6 = -1;
                double d = 0.0d;
                int i7 = 0;
                for (int i8 = 0; i8 < unitActionTableEntry.nactions; i8++) {
                    if (this.type == 0) {
                        if (i6 == -1 || ((i7 != 0 && unitActionTableEntry.visit_count[i8] == 0) || (i7 != 0 && unitActionTableEntry.accum_evaluation[i8] / unitActionTableEntry.visit_count[i8] > d))) {
                            i6 = i8;
                            d = unitActionTableEntry.visit_count[i8] > 0 ? unitActionTableEntry.accum_evaluation[i8] / unitActionTableEntry.visit_count[i8] : 0.0d;
                            i7 = unitActionTableEntry.visit_count[i8];
                        }
                    } else if (i6 == -1 || ((i7 != 0 && unitActionTableEntry.visit_count[i8] == 0) || (i7 != 0 && unitActionTableEntry.accum_evaluation[i8] / unitActionTableEntry.visit_count[i8] < d))) {
                        i6 = i8;
                        d = unitActionTableEntry.visit_count[i8] > 0 ? unitActionTableEntry.accum_evaluation[i8] / unitActionTableEntry.visit_count[i8] : 0.0d;
                        i7 = unitActionTableEntry.visit_count[i8];
                    }
                    dArr[i8] = f / unitActionTableEntry.nactions;
                }
                if (unitActionTableEntry.visit_count[i6] != 0) {
                    dArr[i6] = (1.0f - f) + (f / unitActionTableEntry.nactions);
                } else if (this.forceExplorationOfNonSampledActions) {
                    for (int i9 = 0; i9 < dArr.length; i9++) {
                        if (unitActionTableEntry.visit_count[i9] > 0) {
                            dArr[i9] = 0.0d;
                        }
                    }
                }
                if (DEBUG >= 3) {
                    System.out.print("[ ");
                    for (int i10 = 0; i10 < unitActionTableEntry.nactions; i10++) {
                        System.out.print("(" + unitActionTableEntry.visit_count[i10] + CubeCoordinate.SEP + (unitActionTableEntry.accum_evaluation[i10] / unitActionTableEntry.visit_count[i10]) + ")");
                    }
                    System.out.println(ObjectFinder.PARAM_END);
                    System.out.print("[ ");
                    for (double d2 : dArr) {
                        System.out.print(d2 + " ");
                    }
                    System.out.println(ObjectFinder.PARAM_END);
                }
                linkedList2.add(Integer.valueOf(linkedList.size()));
                linkedList.add(dArr);
            }
        }
        PlayerAction playerAction = new PlayerAction();
        BigInteger bigInteger = BigInteger.ZERO;
        while (!linkedList2.isEmpty()) {
            int intValue = ((Integer) linkedList2.remove(r.nextInt(linkedList2.size()))).intValue();
            try {
                UnitActionTableEntry unitActionTableEntry2 = this.unitActionTable.get(intValue);
                int weighted = Sampler.weighted((double[]) linkedList.get(intValue));
                playerAction.addUnitAction(unitActionTableEntry2.u, unitActionTableEntry2.actions.get(weighted));
                bigInteger = bigInteger.add(BigInteger.valueOf(weighted).multiply(this.multipliers[intValue]));
            } catch (Exception e) {
                e.printStackTrace();
            }
        }
        NaiveMCTSNode naiveMCTSNode = this.childrenMap.get(bigInteger);
        if (naiveMCTSNode != null) {
            return naiveMCTSNode.selectLeaf(i, i2, f, f2, f3, i3, i4, i5);
        }
        this.actions.add(playerAction);
        GameState gameState = new GameState(this.gs);
        playerAction.apply(gameState);
        NaiveMCTSNode naiveMCTSNode2 = new NaiveMCTSNode(i, i2, new GameState(gameState), this, this.evaluation_bound, i5, this.forceExplorationOfNonSampledActions);
        this.childrenMap.put(bigInteger, naiveMCTSNode2);
        this.children.add(naiveMCTSNode2);
        return naiveMCTSNode2;
    }

    public UnitActionTableEntry getActionTableEntry(UUID uuid) {
        for (UnitActionTableEntry unitActionTableEntry : this.unitActionTable) {
            if (unitActionTableEntry.u.equals(uuid)) {
                return unitActionTableEntry;
            }
        }
        throw new Error("Could not find Action Table Entry!");
    }

    public void propagateEvaluation(double d, NaiveMCTSNode naiveMCTSNode) {
        this.accum_evaluation += d;
        this.visit_count++;
        if (naiveMCTSNode != null) {
            for (Map.Entry<UUID, Order> entry : this.actions.get(this.children.indexOf(naiveMCTSNode)).getOrders().entrySet()) {
                UnitActionTableEntry actionTableEntry = getActionTableEntry(entry.getKey());
                int indexOf = actionTableEntry.actions.indexOf(entry.getValue());
                if (indexOf == -1) {
                    System.out.println("Looking for action: " + entry.getValue());
                    System.out.println("Available actions are: " + actionTableEntry.actions);
                }
                double[] dArr = actionTableEntry.accum_evaluation;
                dArr[indexOf] = dArr[indexOf] + d;
                int[] iArr = actionTableEntry.visit_count;
                iArr[indexOf] = iArr[indexOf] + 1;
            }
        }
        if (this.parent != null) {
            ((NaiveMCTSNode) this.parent).propagateEvaluation(d, this);
        }
    }

    public void printUnitActionTable() {
        for (UnitActionTableEntry unitActionTableEntry : this.unitActionTable) {
            System.out.println("Actions for unit " + unitActionTableEntry.u);
            for (int i = 0; i < unitActionTableEntry.nactions; i++) {
                System.out.println("   " + unitActionTableEntry.actions.get(i) + " visited " + unitActionTableEntry.visit_count[i] + " with average evaluation " + (unitActionTableEntry.accum_evaluation[i] / unitActionTableEntry.visit_count[i]));
            }
        }
    }
}
