/*
 * Decompiled with CFR 0.152.
 */
package org.apache.flink.table.planner.plan.rules.logical;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import java.util.stream.Stream;
import org.apache.calcite.plan.RelOptRuleCall;
import org.apache.calcite.plan.RelOptUtil;
import org.apache.calcite.plan.RelRule;
import org.apache.calcite.rel.core.JoinRelType;
import org.apache.calcite.rel.rules.MultiJoin;
import org.apache.calcite.rex.RexBuilder;
import org.apache.calcite.rex.RexCall;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.sql.SqlKind;
import org.apache.calcite.sql.SqlOperator;
import org.apache.calcite.sql.fun.SqlStdOperatorTable;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.table.api.config.OptimizerConfigOptions;
import org.apache.flink.table.planner.plan.rules.logical.ImmutableRewriteMultiJoinConditionRule;
import org.apache.flink.table.planner.utils.ShortcutUtils;
import org.apache.flink.util.Preconditions;
import org.immutables.value.Value;

@Value.Enclosing
public class RewriteMultiJoinConditionRule
extends RelRule<RewriteMultiJoinConditionRuleConfig> {
    public static final RewriteMultiJoinConditionRule INSTANCE = RewriteMultiJoinConditionRuleConfig.DEFAULT.toRule();

    private RewriteMultiJoinConditionRule(RewriteMultiJoinConditionRuleConfig config) {
        super(config);
    }

    @Override
    public boolean matches(RelOptRuleCall call) {
        int bushyTreeThreshold;
        MultiJoin multiJoin = (MultiJoin)call.rel(0);
        boolean isAllInnerJoin = multiJoin.getJoinTypes().stream().allMatch(joinType -> joinType == JoinRelType.INNER);
        List equiJoinFilters = (List)this.partitionJoinFilters((MultiJoin)multiJoin).f0;
        int numJoinInputs = multiJoin.getInputs().size();
        return numJoinInputs > (bushyTreeThreshold = ((Integer)ShortcutUtils.unwrapContext(multiJoin).getTableConfig().get(OptimizerConfigOptions.TABLE_OPTIMIZER_BUSHY_JOIN_REORDER_THRESHOLD)).intValue()) && !multiJoin.isFullOuterJoin() && isAllInnerJoin && equiJoinFilters.size() > 1;
    }

    @Override
    public void onMatch(RelOptRuleCall call) {
        MultiJoin multiJoin = (MultiJoin)call.rel(0);
        Tuple2<List<RexNode>, List<RexNode>> partitions = this.partitionJoinFilters(multiJoin);
        List equiJoinFilters = (List)partitions.f0;
        List nonEquiJoinFilters = (List)partitions.f1;
        HashMap equiJoinFilterMap = new HashMap();
        equiJoinFilters.stream().filter(node -> node instanceof RexCall).forEach(rexNode -> {
            Preconditions.checkState((boolean)rexNode.isA(SqlKind.EQUALS));
            RexNode left = ((RexCall)rexNode).getOperands().get(0);
            RexNode right = ((RexCall)rexNode).getOperands().get(1);
            equiJoinFilterMap.computeIfAbsent(left, k -> new ArrayList()).add(right);
            equiJoinFilterMap.computeIfAbsent(right, k -> new ArrayList()).add(left);
        });
        List<List> candidateJoinFilters = equiJoinFilterMap.values().stream().filter(list -> list.size() > 1).collect(Collectors.toList());
        if (candidateJoinFilters.isEmpty()) {
            return;
        }
        ArrayList newEquiJoinFilters = new ArrayList(equiJoinFilters);
        RexBuilder rexBuilder = multiJoin.getCluster().getRexBuilder();
        candidateJoinFilters.forEach(candidate -> IntStream.range(0, candidate.size()).forEach(startIndex -> {
            RexNode op1 = (RexNode)candidate.get(startIndex);
            List restOps = candidate.subList(startIndex + 1, candidate.size());
            restOps.forEach(op2 -> {
                RexNode newFilter = rexBuilder.makeCall((SqlOperator)SqlStdOperatorTable.EQUALS, op1, (RexNode)op2);
                if (!this.containEquiJoinFilter(newFilter, newEquiJoinFilters)) {
                    newEquiJoinFilters.add(newFilter);
                }
            });
        }));
        if (newEquiJoinFilters.size() == equiJoinFilters.size()) {
            return;
        }
        RexNode newJoinFilter = call.builder().and(Stream.concat(newEquiJoinFilters.stream(), nonEquiJoinFilters.stream()).collect(Collectors.toList()));
        MultiJoin newMultiJoin = new MultiJoin(multiJoin.getCluster(), multiJoin.getInputs(), newJoinFilter, multiJoin.getRowType(), multiJoin.isFullOuterJoin(), multiJoin.getOuterJoinConditions(), multiJoin.getJoinTypes(), multiJoin.getProjFields(), multiJoin.getJoinFieldRefCountsMap(), multiJoin.getPostJoinFilter());
        call.transformTo(newMultiJoin);
    }

    private boolean containEquiJoinFilter(RexNode joinFilter, List<RexNode> equiJoinFiltersList) {
        return equiJoinFiltersList.stream().anyMatch(f -> f.equals(joinFilter));
    }

    private Tuple2<List<RexNode>, List<RexNode>> partitionJoinFilters(MultiJoin multiJoin) {
        List<RexNode> joinFilters = RelOptUtil.conjunctions(multiJoin.getJoinFilter());
        Map<Boolean, List<RexNode>> partitioned = joinFilters.stream().collect(Collectors.partitioningBy(filter -> filter.isA(SqlKind.EQUALS)));
        List<RexNode> equiJoinFilters = partitioned.get(true);
        List<RexNode> nonEquiJoinFilters = partitioned.get(false);
        return new Tuple2(equiJoinFilters, nonEquiJoinFilters);
    }

    @Value.Immutable(singleton=false)
    public static interface RewriteMultiJoinConditionRuleConfig
    extends RelRule.Config {
        public static final RewriteMultiJoinConditionRuleConfig DEFAULT = ImmutableRewriteMultiJoinConditionRule.RewriteMultiJoinConditionRuleConfig.builder().build().withOperandSupplier(b0 -> b0.operand(MultiJoin.class).anyInputs()).withDescription("RewriteMultiJoinConditionRule");

        @Override
        default public RewriteMultiJoinConditionRule toRule() {
            return new RewriteMultiJoinConditionRule(this);
        }
    }
}

