/*
 * Decompiled with CFR 0.152.
 */
package org.apache.flink.table.planner.plan.rules.logical;

import java.util.ArrayList;
import java.util.BitSet;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.TreeMap;
import org.apache.calcite.linq4j.Ord;
import org.apache.calcite.plan.RelOptRule;
import org.apache.calcite.plan.RelOptRuleCall;
import org.apache.calcite.plan.RelOptRuleOperand;
import org.apache.calcite.plan.RelOptUtil;
import org.apache.calcite.plan.hep.HepRelVertex;
import org.apache.calcite.plan.volcano.RelSubset;
import org.apache.calcite.rel.RelNode;
import org.apache.calcite.rel.SingleRel;
import org.apache.calcite.rel.core.Aggregate;
import org.apache.calcite.rel.core.AggregateCall;
import org.apache.calcite.rel.core.Join;
import org.apache.calcite.rel.core.JoinRelType;
import org.apache.calcite.rel.core.RelFactories;
import org.apache.calcite.rel.logical.LogicalAggregate;
import org.apache.calcite.rel.logical.LogicalJoin;
import org.apache.calcite.rel.logical.LogicalSnapshot;
import org.apache.calcite.rel.metadata.RelMetadataQuery;
import org.apache.calcite.rel.type.RelDataType;
import org.apache.calcite.rel.type.RelDataTypeField;
import org.apache.calcite.rex.RexBuilder;
import org.apache.calcite.rex.RexCall;
import org.apache.calcite.rex.RexInputRef;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.rex.RexUtil;
import org.apache.calcite.sql.SqlAggFunction;
import org.apache.calcite.sql.SqlSplittableAggFunction;
import org.apache.calcite.tools.RelBuilder;
import org.apache.calcite.tools.RelBuilderFactory;
import org.apache.calcite.util.ImmutableBitSet;
import org.apache.calcite.util.Pair;
import org.apache.calcite.util.Util;
import org.apache.calcite.util.mapping.Mapping;
import org.apache.calcite.util.mapping.Mappings;
import org.apache.flink.calcite.shaded.com.google.common.collect.ImmutableList;
import org.apache.flink.calcite.shaded.com.google.common.collect.Lists;
import org.apache.flink.table.planner.plan.utils.AggregateUtil;
import org.apache.flink.util.Preconditions;
import scala.Tuple2;
import scala.collection.JavaConverters;
import scala.collection.Seq;

public class FlinkAggregateJoinTransposeRule
extends RelOptRule {
    public static final FlinkAggregateJoinTransposeRule INSTANCE = new FlinkAggregateJoinTransposeRule(LogicalAggregate.class, LogicalJoin.class, RelFactories.LOGICAL_BUILDER, false);
    public static final FlinkAggregateJoinTransposeRule EXTENDED = new FlinkAggregateJoinTransposeRule(LogicalAggregate.class, LogicalJoin.class, RelFactories.LOGICAL_BUILDER, true);
    private final boolean allowFunctions;

    public FlinkAggregateJoinTransposeRule(Class<? extends Aggregate> aggregateClass, Class<? extends Join> joinClass, RelBuilderFactory relBuilderFactory, boolean allowFunctions) {
        super(FlinkAggregateJoinTransposeRule.operandJ(aggregateClass, null, aggregate -> aggregate.getGroupType() == Aggregate.Group.SIMPLE, FlinkAggregateJoinTransposeRule.operand(joinClass, FlinkAggregateJoinTransposeRule.any()), new RelOptRuleOperand[0]), relBuilderFactory, null);
        this.allowFunctions = allowFunctions;
    }

    @Deprecated
    public FlinkAggregateJoinTransposeRule(Class<? extends Aggregate> aggregateClass, RelFactories.AggregateFactory aggregateFactory, Class<? extends Join> joinClass, RelFactories.JoinFactory joinFactory) {
        this(aggregateClass, joinClass, RelBuilder.proto(aggregateFactory, joinFactory), false);
    }

    @Deprecated
    public FlinkAggregateJoinTransposeRule(Class<? extends Aggregate> aggregateClass, RelFactories.AggregateFactory aggregateFactory, Class<? extends Join> joinClass, RelFactories.JoinFactory joinFactory, boolean allowFunctions) {
        this(aggregateClass, joinClass, RelBuilder.proto(aggregateFactory, joinFactory), allowFunctions);
    }

    @Deprecated
    public FlinkAggregateJoinTransposeRule(Class<? extends Aggregate> aggregateClass, RelFactories.AggregateFactory aggregateFactory, Class<? extends Join> joinClass, RelFactories.JoinFactory joinFactory, RelFactories.ProjectFactory projectFactory) {
        this(aggregateClass, joinClass, RelBuilder.proto(aggregateFactory, joinFactory, projectFactory), false);
    }

    @Deprecated
    public FlinkAggregateJoinTransposeRule(Class<? extends Aggregate> aggregateClass, RelFactories.AggregateFactory aggregateFactory, Class<? extends Join> joinClass, RelFactories.JoinFactory joinFactory, RelFactories.ProjectFactory projectFactory, boolean allowFunctions) {
        this(aggregateClass, joinClass, RelBuilder.proto(aggregateFactory, joinFactory, projectFactory), allowFunctions);
    }

    private boolean containsSnapshot(RelNode relNode) {
        RelNode original = null;
        original = relNode instanceof RelSubset ? ((RelSubset)relNode).getOriginal() : (relNode instanceof HepRelVertex ? ((HepRelVertex)relNode).getCurrentRel() : relNode);
        if (original instanceof LogicalSnapshot) {
            return true;
        }
        if (original instanceof SingleRel) {
            return this.containsSnapshot(((SingleRel)original).getInput());
        }
        return false;
    }

    @Override
    public boolean matches(RelOptRuleCall call) {
        Join join = (Join)call.rel(1);
        RelNode right = join.getRight();
        return !this.containsSnapshot(right);
    }

    /*
     * Could not resolve type clashes
     */
    @Override
    public void onMatch(RelOptRuleCall call) {
        Aggregate origAgg = (Aggregate)call.rel(0);
        Join join = (Join)call.rel(1);
        RexBuilder rexBuilder = origAgg.getCluster().getRexBuilder();
        RelBuilder relBuilder = call.builder();
        Pair<Aggregate, List<RexNode>> newAggAndProject = this.toRegularAggregate(origAgg);
        Aggregate aggregate = (Aggregate)newAggAndProject.left;
        List projectAfterAgg = (List)newAggAndProject.right;
        for (AggregateCall aggregateCall : aggregate.getAggCallList()) {
            if (aggregateCall.getAggregation().unwrap(SqlSplittableAggFunction.class) == null) {
                return;
            }
            if (aggregateCall.filterArg < 0 && !aggregateCall.isDistinct()) continue;
            return;
        }
        if (join.getJoinType() != JoinRelType.INNER) {
            return;
        }
        if (!this.allowFunctions && !aggregate.getAggCallList().isEmpty()) {
            return;
        }
        ImmutableBitSet aggregateColumns = aggregate.getGroupSet();
        RelMetadataQuery mq = call.getMetadataQuery();
        ImmutableBitSet keyColumns = FlinkAggregateJoinTransposeRule.keyColumns(aggregateColumns, mq.getPulledUpPredicates((RelNode)join).pulledUpPredicates);
        ImmutableBitSet joinColumns = RelOptUtil.InputFinder.bits(join.getCondition());
        boolean allColumnsInAggregate = keyColumns.contains(joinColumns);
        ImmutableBitSet belowAggregateColumns = aggregateColumns.union(joinColumns);
        ArrayList<Integer> leftKeys = Lists.newArrayList();
        ArrayList<Integer> rightKeys = Lists.newArrayList();
        ArrayList<Boolean> filterNulls = Lists.newArrayList();
        RexNode nonEquiConj = RelOptUtil.splitJoinCondition(join.getLeft(), join.getRight(), join.getCondition(), leftKeys, rightKeys, filterNulls);
        if (!nonEquiConj.isAlwaysTrue()) {
            return;
        }
        HashMap<Integer, Integer> map = new HashMap<Integer, Integer>();
        ArrayList<Side> sides = new ArrayList<Side>();
        int uniqueCount = 0;
        int offset = 0;
        int belowOffset = 0;
        for (int s2 = 0; s2 < 2; ++s2) {
            boolean unique;
            Side side = new Side();
            RelNode joinInput = join.getInput(s2);
            int fieldCount = joinInput.getRowType().getFieldCount();
            ImmutableBitSet fieldSet = ImmutableBitSet.range(offset, offset + fieldCount);
            ImmutableBitSet belowAggregateKeyNotShifted = belowAggregateColumns.intersect(fieldSet);
            for (Ord<Integer> c : Ord.zip(belowAggregateKeyNotShifted)) {
                map.put((Integer)c.e, belowOffset + c.i);
            }
            Mappings.TargetMapping mapping = s2 == 0 ? Mappings.createIdentity(fieldCount) : Mappings.createShiftMapping(fieldCount + offset, 0, offset, fieldCount);
            ImmutableBitSet belowAggregateKey = belowAggregateKeyNotShifted.shift(-offset);
            if (!this.allowFunctions) {
                assert (aggregate.getAggCallList().isEmpty());
                Util.discard(false);
                unique = true;
            } else {
                Boolean unique0 = mq.areColumnsUnique(joinInput, belowAggregateKey);
                boolean bl = unique = unique0 != null && unique0 != false;
            }
            if (unique) {
                ++uniqueCount;
                side.aggregate = false;
                relBuilder.push(joinInput);
                HashMap<Integer, Integer> belowAggregateKeyToNewProjectMap = new HashMap<Integer, Integer>();
                ArrayList<RexNode> projects = new ArrayList<RexNode>();
                for (Integer i : belowAggregateKey) {
                    belowAggregateKeyToNewProjectMap.put(i, projects.size());
                    projects.add(relBuilder.field(i));
                }
                for (Ord aggCall : Ord.zip(aggregate.getAggCallList())) {
                    SqlAggFunction aggregation = ((AggregateCall)aggCall.e).getAggregation();
                    SqlSplittableAggFunction splitter = (SqlSplittableAggFunction)Preconditions.checkNotNull((Object)aggregation.unwrap(SqlSplittableAggFunction.class));
                    if (((AggregateCall)aggCall.e).getArgList().isEmpty() || !fieldSet.contains(ImmutableBitSet.of(((AggregateCall)aggCall.e).getArgList()))) continue;
                    RexNode singleton = splitter.singleton(rexBuilder, joinInput.getRowType(), ((AggregateCall)aggCall.e).transform(mapping));
                    RexNode targetSingleton = rexBuilder.ensureType(((AggregateCall)aggCall.e).type, singleton, false);
                    if (targetSingleton instanceof RexInputRef) {
                        int index = ((RexInputRef)targetSingleton).getIndex();
                        if (!belowAggregateKey.get(index)) {
                            projects.add(targetSingleton);
                            side.split.put(aggCall.i, projects.size() - 1);
                            continue;
                        }
                        side.split.put(aggCall.i, (Integer)belowAggregateKeyToNewProjectMap.get(index));
                        continue;
                    }
                    projects.add(targetSingleton);
                    side.split.put(aggCall.i, projects.size() - 1);
                }
                relBuilder.project(projects);
                side.newInput = relBuilder.build();
            } else {
                side.aggregate = true;
                ArrayList<AggregateCall> belowAggCalls = new ArrayList<AggregateCall>();
                SqlSplittableAggFunction.Registry<AggregateCall> belowAggCallRegistry = FlinkAggregateJoinTransposeRule.registry(belowAggCalls);
                int oldGroupKeyCount = aggregate.getGroupCount();
                int newGroupKeyCount = belowAggregateKey.cardinality();
                for (Ord<AggregateCall> aggCall : Ord.zip(aggregate.getAggCallList())) {
                    AggregateCall call1;
                    SqlAggFunction aggregation = ((AggregateCall)aggCall.e).getAggregation();
                    SqlSplittableAggFunction splitter = (SqlSplittableAggFunction)Preconditions.checkNotNull((Object)aggregation.unwrap(SqlSplittableAggFunction.class));
                    if (fieldSet.contains(ImmutableBitSet.of(((AggregateCall)aggCall.e).getArgList()))) {
                        AggregateCall splitCall = splitter.split((AggregateCall)aggCall.e, mapping);
                        call1 = splitCall.adaptTo(joinInput, splitCall.getArgList(), splitCall.filterArg, oldGroupKeyCount, newGroupKeyCount);
                    } else {
                        call1 = splitter.other(rexBuilder.getTypeFactory(), (AggregateCall)aggCall.e);
                    }
                    if (call1 == null) continue;
                    side.split.put(aggCall.i, belowAggregateKey.cardinality() + belowAggCallRegistry.register(call1));
                }
                side.newInput = relBuilder.push(joinInput).aggregate(relBuilder.groupKey(belowAggregateKey, (Iterable<? extends ImmutableBitSet>)org.apache.flink.shaded.guava32.com.google.common.collect.ImmutableList.of((Object)belowAggregateKey)), (List<AggregateCall>)belowAggCalls).build();
            }
            offset += fieldCount;
            belowOffset += side.newInput.getRowType().getFieldCount();
            sides.add(side);
        }
        if (uniqueCount == 2) {
            return;
        }
        Mapping mapping = (Mapping)Mappings.target(map::get, join.getRowType().getFieldCount(), belowOffset);
        RexNode newCondition = RexUtil.apply((Mappings.TargetMapping)mapping, join.getCondition());
        relBuilder.push(((Side)sides.get((int)0)).newInput).push(((Side)sides.get((int)1)).newInput).join(join.getJoinType(), newCondition);
        ArrayList<AggregateCall> newAggCalls = new ArrayList<AggregateCall>();
        int groupIndicatorCount = aggregate.getGroupCount() + aggregate.getIndicatorCount();
        int newLeftWidth = ((Side)sides.get((int)0)).newInput.getRowType().getFieldCount();
        ArrayList<RexNode> projects = new ArrayList<RexNode>(rexBuilder.identityProjects(relBuilder.peek().getRowType()));
        for (Ord aggCall : Ord.zip(aggregate.getAggCallList())) {
            SqlAggFunction aggregation = ((AggregateCall)aggCall.e).getAggregation();
            SqlSplittableAggFunction splitter = (SqlSplittableAggFunction)Preconditions.checkNotNull((Object)aggregation.unwrap(SqlSplittableAggFunction.class));
            Integer leftSubTotal = ((Side)sides.get((int)0)).split.get(aggCall.i);
            Integer rightSubTotal = ((Side)sides.get((int)1)).split.get(aggCall.i);
            newAggCalls.add(splitter.topSplit(rexBuilder, FlinkAggregateJoinTransposeRule.registry(projects), groupIndicatorCount, relBuilder.peek().getRowType(), (AggregateCall)aggCall.e, leftSubTotal == null ? -1 : leftSubTotal, rightSubTotal == null ? -1 : rightSubTotal + newLeftWidth));
        }
        relBuilder.project(projects);
        boolean aggConvertedToProjects = false;
        if (allColumnsInAggregate) {
            ArrayList<RexNode> projects2 = new ArrayList<RexNode>();
            for (int key : Mappings.apply(mapping, aggregate.getGroupSet())) {
                projects2.add(relBuilder.field(key));
            }
            int aggCallIdx = projects2.size();
            for (AggregateCall newAggCall : newAggCalls) {
                SqlSplittableAggFunction splitter = newAggCall.getAggregation().unwrap(SqlSplittableAggFunction.class);
                if (splitter != null) {
                    RelDataType rowType = relBuilder.peek().getRowType();
                    RexNode singleton = splitter.singleton(rexBuilder, rowType, newAggCall);
                    RelDataType originalAggCallType = aggregate.getRowType().getFieldList().get(aggCallIdx).getType();
                    RexNode targetSingleton = rexBuilder.ensureType(originalAggCallType, singleton, false);
                    projects2.add(targetSingleton);
                }
                ++aggCallIdx;
            }
            if (projects2.size() == aggregate.getGroupSet().cardinality() + newAggCalls.size()) {
                relBuilder.project(projects2);
                aggConvertedToProjects = true;
            }
        }
        if (!aggConvertedToProjects) {
            relBuilder.aggregate(relBuilder.groupKey(Mappings.apply(mapping, aggregate.getGroupSet()), (Iterable<? extends ImmutableBitSet>)Mappings.apply2(mapping, aggregate.getGroupSets())), (List<AggregateCall>)newAggCalls);
        }
        if (projectAfterAgg != null) {
            relBuilder.project(projectAfterAgg, origAgg.getRowType().getFieldNames());
        }
        call.transformTo(relBuilder.build());
    }

    private Pair<Aggregate, List<RexNode>> toRegularAggregate(Aggregate aggregate) {
        Tuple2<int[], Seq<AggregateCall>> auxGroupAndRegularAggCalls = AggregateUtil.checkAndSplitAggCalls(aggregate);
        int[] auxGroup = (int[])auxGroupAndRegularAggCalls._1;
        Seq regularAggCalls = (Seq)auxGroupAndRegularAggCalls._2;
        if (auxGroup.length != 0) {
            int[] fullGroupSet = AggregateUtil.checkAndGetFullGroupSet(aggregate);
            ImmutableBitSet newGroupSet = ImmutableBitSet.of(fullGroupSet);
            List aggCalls = (List)JavaConverters.seqAsJavaListConverter((Seq)regularAggCalls).asJava();
            Aggregate newAgg = aggregate.copy(aggregate.getTraitSet(), aggregate.getInput(), newGroupSet, ImmutableList.of(newGroupSet), aggCalls);
            List<RelDataTypeField> aggFields = aggregate.getRowType().getFieldList();
            ArrayList<RexInputRef> projectAfterAgg = new ArrayList<RexInputRef>();
            for (int i = 0; i < fullGroupSet.length; ++i) {
                int group = fullGroupSet[i];
                int index = newGroupSet.indexOf(group);
                projectAfterAgg.add(new RexInputRef(index, aggFields.get(i).getType()));
            }
            int fieldCntOfAgg = aggFields.size();
            for (int i = fullGroupSet.length; i < fieldCntOfAgg; ++i) {
                projectAfterAgg.add(new RexInputRef(i, aggFields.get(i).getType()));
            }
            Preconditions.checkArgument((projectAfterAgg.size() == fieldCntOfAgg ? 1 : 0) != 0);
            return new Pair<Aggregate, List<RexNode>>(newAgg, projectAfterAgg);
        }
        return new Pair<Aggregate, Object>(aggregate, null);
    }

    private static ImmutableBitSet keyColumns(ImmutableBitSet aggregateColumns, ImmutableList<RexNode> predicates) {
        TreeMap<Integer, BitSet> equivalence = new TreeMap<Integer, BitSet>();
        for (RexNode predicate : predicates) {
            FlinkAggregateJoinTransposeRule.populateEquivalences(equivalence, predicate);
        }
        ImmutableBitSet keyColumns = aggregateColumns;
        for (Integer aggregateColumn : aggregateColumns) {
            BitSet bitSet = (BitSet)equivalence.get(aggregateColumn);
            if (bitSet == null) continue;
            keyColumns = keyColumns.union(bitSet);
        }
        return keyColumns;
    }

    private static void populateEquivalences(Map<Integer, BitSet> equivalence, RexNode predicate) {
        switch (predicate.getKind()) {
            case EQUALS: {
                RexCall call = (RexCall)predicate;
                List<RexNode> operands = call.getOperands();
                if (!(operands.get(0) instanceof RexInputRef)) break;
                RexInputRef ref0 = (RexInputRef)operands.get(0);
                if (!(operands.get(1) instanceof RexInputRef)) break;
                RexInputRef ref1 = (RexInputRef)operands.get(1);
                FlinkAggregateJoinTransposeRule.populateEquivalence(equivalence, ref0.getIndex(), ref1.getIndex());
                FlinkAggregateJoinTransposeRule.populateEquivalence(equivalence, ref1.getIndex(), ref0.getIndex());
            }
        }
    }

    private static void populateEquivalence(Map<Integer, BitSet> equivalence, int i0, int i1) {
        BitSet bitSet = equivalence.get(i0);
        if (bitSet == null) {
            bitSet = new BitSet();
            equivalence.put(i0, bitSet);
        }
        bitSet.set(i1);
    }

    private static <E> SqlSplittableAggFunction.Registry<E> registry(final List<E> list) {
        return new SqlSplittableAggFunction.Registry<E>(){

            @Override
            public int register(E e) {
                int i = list.indexOf(e);
                if (i < 0) {
                    i = list.size();
                    list.add(e);
                }
                return i;
            }
        };
    }

    private static class Side {
        final Map<Integer, Integer> split = new HashMap<Integer, Integer>();
        RelNode newInput;
        boolean aggregate;

        private Side() {
        }
    }
}

