package com.alibaba.graphscope.common.ir.meta.glogue.calcite.handler;

import com.alibaba.graphscope.common.ir.meta.glogue.PrimitiveCountEstimator;
import com.alibaba.graphscope.common.ir.meta.glogue.Utils;
import com.alibaba.graphscope.common.ir.rel.GraphExtendIntersect;
import com.alibaba.graphscope.common.ir.rel.GraphJoinDecomposition;
import com.alibaba.graphscope.common.ir.rel.GraphPattern;
import com.alibaba.graphscope.common.ir.rel.graph.AbstractBindableTableScan;
import com.alibaba.graphscope.common.ir.rel.graph.GraphLogicalExpand;
import com.alibaba.graphscope.common.ir.rel.graph.GraphLogicalGetV;
import com.alibaba.graphscope.common.ir.rel.graph.GraphLogicalPathExpand;
import com.alibaba.graphscope.common.ir.rel.graph.GraphLogicalSource;
import com.alibaba.graphscope.common.ir.rel.graph.GraphPhysicalExpand;
import com.alibaba.graphscope.common.ir.rel.graph.GraphPhysicalGetV;
import com.alibaba.graphscope.common.ir.rel.metadata.glogue.GlogueQuery;
import com.alibaba.graphscope.common.ir.rel.metadata.glogue.pattern.ElementDetails;
import com.alibaba.graphscope.common.ir.rel.metadata.glogue.pattern.FuzzyPatternEdge;
import com.alibaba.graphscope.common.ir.rel.metadata.glogue.pattern.FuzzyPatternVertex;
import com.alibaba.graphscope.common.ir.rel.metadata.glogue.pattern.Pattern;
import com.alibaba.graphscope.common.ir.rel.metadata.glogue.pattern.PatternEdge;
import com.alibaba.graphscope.common.ir.rel.metadata.glogue.pattern.PatternVertex;
import com.alibaba.graphscope.common.ir.rel.metadata.glogue.pattern.SinglePatternEdge;
import com.alibaba.graphscope.common.ir.rel.metadata.glogue.pattern.SinglePatternVertex;
import com.alibaba.graphscope.common.ir.rel.metadata.schema.EdgeTypeId;
import com.alibaba.graphscope.common.ir.tools.config.GraphOpt;
import com.google.common.collect.Lists;
import java.util.Iterator;
import java.util.List;
import java.util.Set;
import java.util.stream.Collectors;
import org.apache.calcite.plan.GraphOptCluster;
import org.apache.calcite.plan.RelOptCost;
import org.apache.calcite.plan.RelOptPlanner;
import org.apache.calcite.plan.volcano.RelSubset;
import org.apache.calcite.plan.volcano.VolcanoPlanner;
import org.apache.calcite.rel.RelNode;
import org.apache.calcite.rel.core.Aggregate;
import org.apache.calcite.rel.core.Filter;
import org.apache.calcite.rel.core.Join;
import org.apache.calcite.rel.core.Project;
import org.apache.calcite.rel.core.Sort;
import org.apache.calcite.rel.core.Union;
import org.apache.calcite.rel.metadata.BuiltInMetadata;
import org.apache.calcite.rel.metadata.RelMdRowCount;
import org.apache.calcite.rel.metadata.RelMetadataQuery;
import org.apache.calcite.rel.rules.MultiJoin;

/* loaded from: input_file:com/alibaba/graphscope/common/ir/meta/glogue/calcite/handler/GraphRowCountHandler.class */
public class GraphRowCountHandler implements BuiltInMetadata.RowCount.Handler {
    private final PrimitiveCountEstimator countEstimator;
    private final RelOptPlanner optPlanner;
    private final RelMdRowCount mdRowCount = new RelMdRowCount();

    public GraphRowCountHandler(RelOptPlanner relOptPlanner, GlogueQuery glogueQuery) {
        this.optPlanner = relOptPlanner;
        this.countEstimator = new PrimitiveCountEstimator(glogueQuery);
    }

    @Override // org.apache.calcite.rel.metadata.BuiltInMetadata.RowCount.Handler
    public Double getRowCount(RelNode relNode, RelMetadataQuery relMetadataQuery) {
        RelSubset subset;
        RelSubset subset2;
        if (!(relNode instanceof GraphPattern)) {
            if (relNode instanceof RelSubset) {
                return relMetadataQuery.getRowCount(((RelSubset) relNode).getOriginal());
            }
            if (!(relNode instanceof GraphExtendIntersect) && !(relNode instanceof GraphJoinDecomposition)) {
                if (relNode instanceof AbstractBindableTableScan) {
                    return Double.valueOf(getRowCount((AbstractBindableTableScan) relNode, relMetadataQuery));
                }
                if (!(relNode instanceof GraphLogicalPathExpand) && !(relNode instanceof GraphPhysicalExpand) && !(relNode instanceof GraphPhysicalGetV)) {
                    if (relNode instanceof MultiJoin) {
                        RelOptCost cachedCost = ((GraphOptCluster) relNode.getCluster()).getLocalState().getCachedCost();
                        if (cachedCost != null) {
                            return Double.valueOf(cachedCost.getRows());
                        }
                    } else {
                        if (relNode instanceof Join) {
                            RelOptCost cachedCost2 = ((GraphOptCluster) relNode.getCluster()).getLocalState().getCachedCost();
                            return Double.valueOf(cachedCost2 != null ? cachedCost2.getRows() : this.mdRowCount.getRowCount((Join) relNode, relMetadataQuery).doubleValue());
                        }
                        if (relNode instanceof Union) {
                            return this.mdRowCount.getRowCount((Union) relNode, relMetadataQuery);
                        }
                        if (relNode instanceof Filter) {
                            return this.mdRowCount.getRowCount((Filter) relNode, relMetadataQuery);
                        }
                        if (relNode instanceof Aggregate) {
                            return this.mdRowCount.getRowCount((Aggregate) relNode, relMetadataQuery);
                        }
                        if (relNode instanceof Sort) {
                            return this.mdRowCount.getRowCount((Sort) relNode, relMetadataQuery);
                        }
                        if (relNode instanceof Project) {
                            return this.mdRowCount.getRowCount((Project) relNode, relMetadataQuery);
                        }
                    }
                }
                return Double.valueOf(relNode.estimateRowCount(relMetadataQuery));
            }
            if ((this.optPlanner instanceof VolcanoPlanner) && (subset = ((VolcanoPlanner) this.optPlanner).getSubset(relNode)) != null) {
                return relMetadataQuery.getRowCount(subset);
            }
            throw new IllegalArgumentException("can not estimate row count for the node=" + relNode);
        }
        Pattern pattern = ((GraphPattern) relNode).getPattern();
        Double estimate = this.countEstimator.estimate(pattern);
        if (estimate != null) {
            return estimate;
        }
        if ((this.optPlanner instanceof VolcanoPlanner) && (subset2 = ((VolcanoPlanner) this.optPlanner).getSubset(relNode)) != null) {
            GraphExtendIntersect graphExtendIntersect = (GraphExtendIntersect) feasibleIntersects(subset2);
            if (graphExtendIntersect != null) {
                PatternVertex vertexByOrder = pattern.getVertexByOrder(graphExtendIntersect.getGlogueEdge().getExtendStep().getTargetVertexOrder().intValue());
                Set<PatternEdge> edgesOf = pattern.getEdgesOf(vertexByOrder);
                Pattern pattern2 = new Pattern();
                List<PatternVertex> newArrayList = Lists.newArrayList();
                for (PatternEdge patternEdge : edgesOf) {
                    pattern2.addVertex(patternEdge.getSrcVertex());
                    pattern2.addVertex(patternEdge.getDstVertex());
                    pattern2.addEdge(patternEdge.getSrcVertex(), patternEdge.getDstVertex(), patternEdge);
                    newArrayList.add(Utils.getExtendFromVertex(patternEdge, vertexByOrder));
                }
                return Double.valueOf(getRowCount((GraphPattern) subGraphPattern(graphExtendIntersect, 0), new GraphPattern(relNode.getCluster(), relNode.getTraitSet(), pattern2), newArrayList, relMetadataQuery));
            }
            GraphJoinDecomposition graphJoinDecomposition = (GraphJoinDecomposition) feasibleJoinDecomposition(subset2);
            if (graphJoinDecomposition != null) {
                Pattern buildPattern = graphJoinDecomposition.getBuildPattern();
                return Double.valueOf(getRowCount((GraphPattern) subGraphPattern(graphJoinDecomposition, 0), (GraphPattern) subGraphPattern(graphJoinDecomposition, 1), (List) graphJoinDecomposition.getJoinVertexPairs().stream().map(joinVertexPair -> {
                    return buildPattern.getVertexByOrder(joinVertexPair.getRightOrderId());
                }).collect(Collectors.toList()), relMetadataQuery));
            }
        }
        double d = 1.0d;
        Iterator<PatternEdge> it = pattern.getEdgeSet().iterator();
        while (it.hasNext()) {
            d *= this.countEstimator.estimate(it.next());
        }
        for (PatternVertex patternVertex : pattern.getVertexSet()) {
            if (pattern.getEdgesOf(patternVertex).size() > 0) {
                d /= Math.pow(this.countEstimator.estimate(patternVertex), r0 - 1);
            }
        }
        return Double.valueOf(d);
    }

    private double getRowCount(AbstractBindableTableScan abstractBindableTableScan, RelMetadataQuery relMetadataQuery) {
        if (abstractBindableTableScan.getCachedCost() != null) {
            return abstractBindableTableScan.estimateRowCount(relMetadataQuery);
        }
        if ((abstractBindableTableScan instanceof GraphLogicalSource) || (abstractBindableTableScan instanceof GraphLogicalGetV)) {
            List<Integer> vertexTypeIds = Utils.getVertexTypeIds(abstractBindableTableScan);
            return relMetadataQuery.getRowCount(new GraphPattern(abstractBindableTableScan.getCluster(), abstractBindableTableScan.getTraitSet(), new Pattern(vertexTypeIds.size() == 1 ? new SinglePatternVertex(vertexTypeIds.get(0)) : new FuzzyPatternVertex(vertexTypeIds)))).doubleValue();
        }
        if (!(abstractBindableTableScan instanceof GraphLogicalExpand)) {
            throw new IllegalArgumentException("can not estimate row count for the rel=" + abstractBindableTableScan);
        }
        List<EdgeTypeId> edgeTypeIds = Utils.getEdgeTypeIds(abstractBindableTableScan);
        List list = (List) edgeTypeIds.stream().map(edgeTypeId -> {
            return edgeTypeId.getSrcLabelId();
        }).collect(Collectors.toList());
        List list2 = (List) edgeTypeIds.stream().map(edgeTypeId2 -> {
            return edgeTypeId2.getDstLabelId();
        }).collect(Collectors.toList());
        PatternVertex singlePatternVertex = list.size() == 1 ? new SinglePatternVertex((Integer) list.get(0), 0) : new FuzzyPatternVertex(list, 0);
        PatternVertex singlePatternVertex2 = list2.size() == 1 ? new SinglePatternVertex((Integer) list2.get(0), 1) : new FuzzyPatternVertex(list2, 1);
        boolean z = ((GraphLogicalExpand) abstractBindableTableScan).getOpt() == GraphOpt.Expand.BOTH;
        PatternEdge singlePatternEdge = edgeTypeIds.size() == 1 ? new SinglePatternEdge(singlePatternVertex, singlePatternVertex2, edgeTypeIds.get(0), 0, z, new ElementDetails()) : new FuzzyPatternEdge(singlePatternVertex, singlePatternVertex2, edgeTypeIds, 0, z, new ElementDetails());
        Pattern pattern = new Pattern();
        pattern.addVertex(singlePatternVertex);
        pattern.addVertex(singlePatternVertex2);
        pattern.addEdge(singlePatternVertex, singlePatternVertex2, singlePatternEdge);
        return relMetadataQuery.getRowCount(new GraphPattern(abstractBindableTableScan.getCluster(), abstractBindableTableScan.getTraitSet(), pattern)).doubleValue();
    }

    private double getRowCount(GraphPattern graphPattern, GraphPattern graphPattern2, List<PatternVertex> list, RelMetadataQuery relMetadataQuery) {
        double doubleValue = getRowCount(graphPattern, relMetadataQuery).doubleValue() * getRowCount(graphPattern2, relMetadataQuery).doubleValue();
        Iterator<PatternVertex> it = list.iterator();
        while (it.hasNext()) {
            doubleValue /= this.countEstimator.estimate(it.next());
        }
        return doubleValue;
    }

    private RelNode feasibleIntersects(RelSubset relSubset) {
        for (RelNode relNode : relSubset.getRelList()) {
            if (relNode instanceof GraphExtendIntersect) {
                GraphExtendIntersect graphExtendIntersect = (GraphExtendIntersect) relNode;
                if ((graphExtendIntersect.getInput(0) instanceof RelSubset) && ((RelSubset) graphExtendIntersect.getInput(0)).getBest() != null) {
                    return relNode;
                }
            }
        }
        return null;
    }

    private RelNode subGraphPattern(RelNode relNode, int i) {
        RelNode input = relNode.getInput(i);
        return input instanceof RelSubset ? ((RelSubset) input).getOriginal() : input;
    }

    private RelNode feasibleJoinDecomposition(RelSubset relSubset) {
        for (RelNode relNode : relSubset.getRelList()) {
            if (relNode instanceof GraphJoinDecomposition) {
                GraphJoinDecomposition graphJoinDecomposition = (GraphJoinDecomposition) relNode;
                if ((graphJoinDecomposition.getLeft() instanceof RelSubset) && (graphJoinDecomposition.getRight() instanceof RelSubset)) {
                    RelSubset relSubset2 = (RelSubset) graphJoinDecomposition.getLeft();
                    RelSubset relSubset3 = (RelSubset) graphJoinDecomposition.getRight();
                    if (relSubset2.getBest() != null && relSubset3.getBest() != null) {
                        return relNode;
                    }
                }
            }
        }
        return null;
    }
}
