package com.alibaba.graphscope.common.ir.rex;

import com.alibaba.graphscope.common.ir.rel.graph.AbstractBindableTableScan;
import com.alibaba.graphscope.common.ir.rel.graph.GraphLogicalSource;
import com.alibaba.graphscope.common.ir.tools.GraphBuilder;
import com.alibaba.graphscope.common.ir.tools.Utils;
import com.alibaba.graphscope.common.ir.type.GraphLabelType;
import com.alibaba.graphscope.common.ir.type.GraphNameOrId;
import com.alibaba.graphscope.common.ir.type.GraphProperty;
import com.alibaba.graphscope.common.ir.type.GraphSchemaType;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.Lists;
import com.google.common.collect.Maps;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.stream.Collectors;
import org.apache.calcite.rel.RelNode;
import org.apache.calcite.rex.RexCall;
import org.apache.calcite.rex.RexDynamicParam;
import org.apache.calcite.rex.RexLiteral;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.rex.RexUtil;
import org.apache.calcite.rex.RexVisitorImpl;
import org.apache.calcite.sql.SqlKind;
import org.apache.calcite.util.ImmutableBitSet;
import org.apache.calcite.util.Sarg;
import org.javatuples.Pair;

/* loaded from: input_file:com/alibaba/graphscope/common/ir/rex/RexFilterClassifier.class */
public class RexFilterClassifier extends RexVisitorImpl<Filter> {
    private final GraphBuilder builder;
    private final AbstractBindableTableScan tableScan;

    /* loaded from: input_file:com/alibaba/graphscope/common/ir/rex/RexFilterClassifier$Filter.class */
    public static class Filter {
        private final List<SchemaFilter> schemaFilters;
        private final RexNode otherFilter;

        /* loaded from: input_file:com/alibaba/graphscope/common/ir/rex/RexFilterClassifier$Filter$SchemaFilter.class */
        public static class SchemaFilter {
            private final Integer tagId;
            private final RexNode filter;
            private final SchemaType schemaType;

            public SchemaFilter(Integer num, RexNode rexNode, SchemaType schemaType) {
                this.tagId = num;
                this.filter = rexNode;
                this.schemaType = schemaType;
            }

            public Integer getTagId() {
                return this.tagId;
            }

            public RexNode getFilter() {
                return this.filter;
            }

            public SchemaType getSchemaType() {
                return this.schemaType;
            }
        }

        /* loaded from: input_file:com/alibaba/graphscope/common/ir/rex/RexFilterClassifier$Filter$SchemaType.class */
        public enum SchemaType {
            LABEL,
            UNIQUE_KEY
        }

        public Filter(List<SchemaFilter> list, RexNode rexNode) {
            this.schemaFilters = list;
            this.otherFilter = rexNode;
        }

        public List<SchemaFilter> getSchemaFilters() {
            return Collections.unmodifiableList(this.schemaFilters);
        }

        public RexNode getOtherFilter() {
            return this.otherFilter;
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:com/alibaba/graphscope/common/ir/rex/RexFilterClassifier$LabelValueCollector.class */
    public static class LabelValueCollector extends RexVisitorImpl<List<Comparable>> {
        public LabelValueCollector() {
            super(true);
        }

        @Override // org.apache.calcite.rex.RexVisitorImpl, org.apache.calcite.rex.RexVisitor
        public List<Comparable> visitCall(RexCall rexCall) {
            switch (rexCall.getOperator().getKind()) {
                case AND:
                    ArrayList newArrayList = Lists.newArrayList();
                    rexCall.getOperands().forEach(rexNode -> {
                        List list = (List) rexNode.accept(this);
                        if (newArrayList.isEmpty()) {
                            newArrayList.addAll(list);
                        } else {
                            newArrayList.retainAll(list);
                        }
                        if (newArrayList.isEmpty()) {
                            throw new IllegalArgumentException("cannot find common labels between values=" + newArrayList + " and values=" + list);
                        }
                    });
                case OR:
                    ArrayList newArrayList2 = Lists.newArrayList();
                    rexCall.getOperands().forEach(rexNode2 -> {
                        newArrayList2.addAll((Collection) rexNode2.accept(this));
                    });
                    return (List) newArrayList2.stream().distinct().collect(Collectors.toList());
                case EQUALS:
                case SEARCH:
                    RexLiteral isLabelEqualFilter0 = RexFilterClassifier.isLabelEqualFilter0(rexCall);
                    if (isLabelEqualFilter0 != null) {
                        return Utils.getValuesAsList((Comparable) isLabelEqualFilter0.getValueAs(Comparable.class));
                    }
                    break;
            }
            return ImmutableList.of();
        }
    }

    public RexFilterClassifier(GraphBuilder graphBuilder, AbstractBindableTableScan abstractBindableTableScan) {
        super(true);
        this.builder = graphBuilder;
        this.tableScan = abstractBindableTableScan;
    }

    public ClassifiedFilter classify(RexNode rexNode) {
        Filter filter = (Filter) rexNode.accept(this);
        ArrayList newArrayList = Lists.newArrayList();
        ArrayList newArrayList2 = Lists.newArrayList();
        filter.getSchemaFilters().forEach(schemaFilter -> {
            switch (schemaFilter.getSchemaType()) {
                case LABEL:
                    newArrayList.add(schemaFilter.getFilter());
                    return;
                case UNIQUE_KEY:
                default:
                    newArrayList2.add(schemaFilter.getFilter());
                    return;
            }
        });
        ArrayList newArrayList3 = Lists.newArrayList();
        if (filter.getOtherFilter() != null) {
            newArrayList3.add(filter.getOtherFilter());
        }
        ArrayList newArrayList4 = Lists.newArrayList();
        newArrayList.forEach(rexNode2 -> {
            newArrayList4.addAll(getLabelValues(rexNode2));
        });
        return new ClassifiedFilter(newArrayList, newArrayList4, newArrayList2, newArrayList3);
    }

    private List<Comparable> getLabelValues(RexNode rexNode) {
        return (List) rexNode.accept(new LabelValueCollector());
    }

    @Override // org.apache.calcite.rex.RexVisitorImpl, org.apache.calcite.rex.RexVisitor
    public Filter visitCall(RexCall rexCall) {
        rexCall.getOperator();
        List<RexNode> operands = rexCall.getOperands();
        switch (r0.getKind()) {
            case AND:
                return conjunctions(rexCall, visitList(operands));
            case OR:
                return disjunctions(rexCall);
            case EQUALS:
            case SEARCH:
                RexVariableAliasCollector rexVariableAliasCollector = new RexVariableAliasCollector(true, rexGraphVariable -> {
                    return Integer.valueOf(rexGraphVariable.getAliasId());
                });
                if (isLabelEqualFilter(rexCall)) {
                    return new Filter(ImmutableList.of(new Filter.SchemaFilter((Integer) ((List) rexCall.accept(rexVariableAliasCollector)).get(0), rexCall, Filter.SchemaType.LABEL)), null);
                }
                if (this.tableScan != null && isUniqueKeyEqualFilter(rexCall)) {
                    return new Filter(ImmutableList.of(new Filter.SchemaFilter((Integer) ((List) rexCall.accept(rexVariableAliasCollector)).get(0), rexCall, Filter.SchemaType.UNIQUE_KEY)), null);
                }
                break;
        }
        return new Filter(ImmutableList.of(), rexCall);
    }

    private Filter conjunctions(RexNode rexNode, List<Filter> list) {
        LinkedHashMap newLinkedHashMap = Maps.newLinkedHashMap();
        ArrayList newArrayList = Lists.newArrayList();
        list.forEach(filter -> {
            filter.getSchemaFilters().forEach(schemaFilter -> {
                Pair with = Pair.with(schemaFilter.getTagId(), schemaFilter.getSchemaType());
                RexNode filter = schemaFilter.getFilter();
                RexNode rexNode2 = (RexNode) newLinkedHashMap.get(with);
                if (rexNode2 != null) {
                    filter = this.builder.and(rexNode2, schemaFilter.getFilter());
                }
                if (filter.equals(rexNode2)) {
                    return;
                }
                newLinkedHashMap.put(with, filter);
            });
            if (filter.getOtherFilter() != null) {
                newArrayList.add(filter.getOtherFilter());
            }
        });
        ArrayList newArrayList2 = Lists.newArrayList();
        newLinkedHashMap.forEach((pair, rexNode2) -> {
            newArrayList2.add(new Filter.SchemaFilter((Integer) pair.getValue0(), rexNode2, (Filter.SchemaType) pair.getValue1()));
        });
        return new Filter(newArrayList2, newArrayList.isEmpty() ? null : RexUtil.composeConjunction(this.builder.getRexBuilder(), newArrayList, false));
    }

    private Filter disjunctions(RexCall rexCall) {
        switch (rexCall.getOperator().getKind()) {
            case OR:
                Filter.SchemaFilter schemaFilter = null;
                for (RexNode rexNode : rexCall.getOperands()) {
                    if (rexNode.getKind() != SqlKind.EQUALS && rexNode.getKind() != SqlKind.SEARCH) {
                        return new Filter(ImmutableList.of(), rexCall);
                    }
                    List<Filter.SchemaFilter> schemaFilters = ((Filter) rexNode.accept(this)).getSchemaFilters();
                    if (schemaFilters.size() != 1) {
                        return new Filter(ImmutableList.of(), rexCall);
                    }
                    Filter.SchemaFilter schemaFilter2 = schemaFilters.get(0);
                    if (schemaFilter == null) {
                        schemaFilter = schemaFilter2;
                    } else {
                        if (schemaFilter.getTagId() != schemaFilter2.getTagId() || schemaFilter.getSchemaType() != schemaFilter2.getSchemaType()) {
                            return new Filter(ImmutableList.of(), rexCall);
                        }
                        schemaFilter = new Filter.SchemaFilter(schemaFilter.getTagId(), this.builder.or(schemaFilter.getFilter(), schemaFilter2.getFilter()), schemaFilter.getSchemaType());
                    }
                }
                break;
        }
        return new Filter(ImmutableList.of(), rexCall);
    }

    private boolean isLabelEqualFilter(RexCall rexCall) {
        return isLabelEqualFilter0(rexCall) != null;
    }

    private static RexLiteral isLabelEqualFilter0(RexNode rexNode) {
        if (!(rexNode instanceof RexCall)) {
            return null;
        }
        RexCall rexCall = (RexCall) rexNode;
        switch (rexCall.getOperator().getKind()) {
            case EQUALS:
            case SEARCH:
                RexNode rexNode2 = rexCall.getOperands().get(0);
                RexNode rexNode3 = rexCall.getOperands().get(1);
                if ((rexNode2.getType() instanceof GraphLabelType) && (rexNode3 instanceof RexLiteral)) {
                    Comparable value = ((RexLiteral) rexNode3).getValue();
                    if (!(value instanceof Sarg) || ((Sarg) value).isPoints()) {
                        return (RexLiteral) rexNode3;
                    }
                    return null;
                }
                if (!(rexNode3.getType() instanceof GraphLabelType) || !(rexNode2 instanceof RexLiteral)) {
                    return null;
                }
                Comparable value2 = ((RexLiteral) rexNode2).getValue();
                if (!(value2 instanceof Sarg) || ((Sarg) value2).isPoints()) {
                    return (RexLiteral) rexNode2;
                }
                return null;
            default:
                return null;
        }
    }

    private boolean isUniqueKeyEqualFilter(RexNode rexNode) {
        if (!(this.tableScan instanceof GraphLogicalSource) || !(rexNode instanceof RexCall)) {
            return false;
        }
        RexCall rexCall = (RexCall) rexNode;
        switch (rexCall.getOperator().getKind()) {
            case EQUALS:
            case SEARCH:
                RexNode rexNode2 = rexCall.getOperands().get(0);
                RexNode rexNode3 = rexCall.getOperands().get(1);
                if (isUniqueKey(rexNode2, this.tableScan) && isLiteralOrDynamicParams(rexNode3)) {
                    if (!(rexNode3 instanceof RexLiteral)) {
                        return true;
                    }
                    Comparable value = ((RexLiteral) rexNode3).getValue();
                    return !(value instanceof Sarg) || ((Sarg) value).isPoints();
                }
                if (!isUniqueKey(rexNode3, this.tableScan) || !isLiteralOrDynamicParams(rexNode2)) {
                    return false;
                }
                if (!(rexNode2 instanceof RexLiteral)) {
                    return true;
                }
                Comparable value2 = ((RexLiteral) rexNode2).getValue();
                return !(value2 instanceof Sarg) || ((Sarg) value2).isPoints();
            default:
                return false;
        }
    }

    private boolean isUniqueKey(RexNode rexNode, RelNode relNode) {
        if (rexNode instanceof RexGraphVariable) {
            return isUniqueKey((RexGraphVariable) rexNode, relNode);
        }
        return false;
    }

    private boolean isUniqueKey(RexGraphVariable rexGraphVariable, RelNode relNode) {
        if (rexGraphVariable.getProperty() == null) {
            return false;
        }
        switch (rexGraphVariable.getProperty().getOpt()) {
            case ID:
                return true;
            case KEY:
                ImmutableBitSet propertyIds = getPropertyIds(rexGraphVariable.getProperty(), (GraphSchemaType) relNode.getRowType().getFieldList().get(0).getType());
                return !propertyIds.isEmpty() && ((AbstractBindableTableScan) relNode).getTableConfig().getTables().stream().allMatch(relOptTable -> {
                    return relOptTable.isKey(propertyIds);
                });
            case LABEL:
            case ALL:
            case LEN:
            default:
                return false;
        }
    }

    private ImmutableBitSet getPropertyIds(GraphProperty graphProperty, GraphSchemaType graphSchemaType) {
        if (graphProperty.getOpt() != GraphProperty.Opt.KEY) {
            return ImmutableBitSet.of();
        }
        GraphNameOrId key = graphProperty.getKey();
        if (key.getOpt() == GraphNameOrId.Opt.ID) {
            return ImmutableBitSet.of(key.getId());
        }
        for (int i = 0; i < graphSchemaType.getFieldList().size(); i++) {
            if (graphSchemaType.getFieldList().get(i).getName().equals(key.getName())) {
                return ImmutableBitSet.of(i);
            }
        }
        return ImmutableBitSet.of();
    }

    private boolean isLiteralOrDynamicParams(RexNode rexNode) {
        return (rexNode instanceof RexLiteral) || (rexNode instanceof RexDynamicParam);
    }
}
