package com.alibaba.graphscope.common.ir.rel.metadata.schema;

import com.alibaba.graphscope.common.ir.meta.IrMetaStats;
import com.alibaba.graphscope.common.ir.meta.glogue.Utils;
import com.alibaba.graphscope.common.ir.rel.metadata.glogue.pattern.PatternDirection;
import com.alibaba.graphscope.common.ir.rel.metadata.glogue.pattern.PatternEdge;
import com.alibaba.graphscope.common.ir.rel.metadata.glogue.pattern.PatternVertex;
import com.alibaba.graphscope.groot.common.schema.api.EdgeRelation;
import com.alibaba.graphscope.groot.common.schema.api.GraphEdge;
import com.alibaba.graphscope.groot.common.schema.api.GraphSchema;
import com.alibaba.graphscope.groot.common.schema.api.GraphStatistics;
import com.alibaba.graphscope.groot.common.schema.api.GraphVertex;
import com.google.common.collect.Lists;
import com.google.common.collect.Sets;
import com.google.common.util.concurrent.AtomicDouble;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Optional;
import org.jgrapht.Graph;
import org.jgrapht.graph.DirectedPseudograph;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:com/alibaba/graphscope/common/ir/rel/metadata/schema/GlogueSchema.class */
public class GlogueSchema {
    private Graph<Integer, EdgeTypeId> schemaGraph;
    private HashMap<Integer, Double> vertexTypeCardinality;
    private HashMap<EdgeTypeId, Double> edgeTypeCardinality;
    private static Logger logger = LoggerFactory.getLogger((Class<?>) GlogueSchema.class);

    public GlogueSchema(GraphSchema graphSchema, HashMap<Integer, Double> hashMap, HashMap<EdgeTypeId, Double> hashMap2) {
        this.schemaGraph = new DirectedPseudograph(EdgeTypeId.class);
        Iterator<GraphVertex> it = graphSchema.getVertexList().iterator();
        while (it.hasNext()) {
            this.schemaGraph.addVertex(Integer.valueOf(it.next().getLabelId()));
        }
        for (GraphEdge graphEdge : graphSchema.getEdgeList()) {
            for (EdgeRelation edgeRelation : graphEdge.getRelationList()) {
                int labelId = edgeRelation.getSource().getLabelId();
                int labelId2 = edgeRelation.getTarget().getLabelId();
                this.schemaGraph.addEdge(Integer.valueOf(labelId), Integer.valueOf(labelId2), new EdgeTypeId(labelId, labelId2, graphEdge.getLabelId()));
            }
        }
        this.vertexTypeCardinality = hashMap;
        this.edgeTypeCardinality = hashMap2;
    }

    public GlogueSchema(GraphSchema graphSchema) {
        this.schemaGraph = new DirectedPseudograph(EdgeTypeId.class);
        this.vertexTypeCardinality = new HashMap<>();
        this.edgeTypeCardinality = new HashMap<>();
        for (GraphVertex graphVertex : graphSchema.getVertexList()) {
            this.schemaGraph.addVertex(Integer.valueOf(graphVertex.getLabelId()));
            this.vertexTypeCardinality.put(Integer.valueOf(graphVertex.getLabelId()), Double.valueOf(1.0d));
        }
        for (GraphEdge graphEdge : graphSchema.getEdgeList()) {
            for (EdgeRelation edgeRelation : graphEdge.getRelationList()) {
                int labelId = edgeRelation.getSource().getLabelId();
                int labelId2 = edgeRelation.getTarget().getLabelId();
                EdgeTypeId edgeTypeId = new EdgeTypeId(labelId, labelId2, graphEdge.getLabelId());
                this.schemaGraph.addEdge(Integer.valueOf(labelId), Integer.valueOf(labelId2), edgeTypeId);
                this.edgeTypeCardinality.put(edgeTypeId, Double.valueOf(1.0d));
            }
        }
        logger.debug("GlogueSchema created with default cardinality 1.0: {}", this);
    }

    public GlogueSchema(GraphSchema graphSchema, GraphStatistics graphStatistics) {
        this.schemaGraph = new DirectedPseudograph(EdgeTypeId.class);
        this.vertexTypeCardinality = new HashMap<>();
        this.edgeTypeCardinality = new HashMap<>();
        for (GraphVertex graphVertex : graphSchema.getVertexList()) {
            this.schemaGraph.addVertex(Integer.valueOf(graphVertex.getLabelId()));
            Long vertexTypeCount = graphStatistics.getVertexTypeCount(Integer.valueOf(graphVertex.getLabelId()));
            if (vertexTypeCount == null) {
                throw new IllegalArgumentException("Vertex type count not found for vertex type: " + graphVertex.getLabelId());
            }
            if (vertexTypeCount.longValue() == 0) {
                this.vertexTypeCardinality.put(Integer.valueOf(graphVertex.getLabelId()), Double.valueOf(1.0d));
            } else {
                this.vertexTypeCardinality.put(Integer.valueOf(graphVertex.getLabelId()), Double.valueOf(vertexTypeCount.doubleValue()));
            }
        }
        for (GraphEdge graphEdge : graphSchema.getEdgeList()) {
            for (EdgeRelation edgeRelation : graphEdge.getRelationList()) {
                int labelId = edgeRelation.getSource().getLabelId();
                int labelId2 = edgeRelation.getTarget().getLabelId();
                EdgeTypeId edgeTypeId = new EdgeTypeId(labelId, labelId2, graphEdge.getLabelId());
                this.schemaGraph.addEdge(Integer.valueOf(labelId), Integer.valueOf(labelId2), edgeTypeId);
                Long edgeTypeCount = graphStatistics.getEdgeTypeCount(Optional.of(Integer.valueOf(labelId)), Optional.of(Integer.valueOf(graphEdge.getLabelId())), Optional.of(Integer.valueOf(labelId2)));
                if (edgeTypeCount == null) {
                    throw new IllegalArgumentException("Edge type count not found for edge type: " + graphEdge.getLabelId());
                }
                if (edgeTypeCount.longValue() == 0) {
                    this.edgeTypeCardinality.put(edgeTypeId, Double.valueOf(1.0d));
                } else {
                    this.edgeTypeCardinality.put(edgeTypeId, Double.valueOf(edgeTypeCount.doubleValue()));
                }
            }
        }
        logger.debug("GlogueSchema created with statistics: {}", this);
    }

    public static GlogueSchema fromMeta(IrMetaStats irMetaStats) {
        return irMetaStats.getStatistics() == null ? new GlogueSchema(irMetaStats.getSchema()) : new GlogueSchema(irMetaStats.getSchema(), irMetaStats.getStatistics());
    }

    public Double getLabelConstraintsDeltaCost(PatternEdge patternEdge, PatternVertex patternVertex) {
        PatternDirection extendDirection = Utils.getExtendDirection(patternEdge, patternVertex);
        double d = 0.0d;
        if (extendDirection != PatternDirection.IN) {
            d = 0.0d + getLabelConstraintsDeltaCost(patternEdge, PatternDirection.OUT).doubleValue();
        }
        if (extendDirection != PatternDirection.OUT) {
            d += getLabelConstraintsDeltaCost(patternEdge, PatternDirection.IN).doubleValue();
        }
        return Double.valueOf(d);
    }

    private Double getLabelConstraintsDeltaCost(PatternEdge patternEdge, PatternDirection patternDirection) {
        AtomicDouble atomicDouble = new AtomicDouble(0.0d);
        HashSet newHashSet = Sets.newHashSet();
        patternEdge.getEdgeTypeIds().forEach(edgeTypeId -> {
            EdgeTypeId edgeTypeId = patternDirection == PatternDirection.OUT ? new EdgeTypeId(edgeTypeId.getSrcLabelId().intValue(), edgeTypeId.getEdgeLabelId().intValue(), -1) : new EdgeTypeId(-1, edgeTypeId.getEdgeLabelId().intValue(), edgeTypeId.getDstLabelId().intValue());
            if (newHashSet.contains(edgeTypeId)) {
                return;
            }
            newHashSet.add(edgeTypeId);
            ArrayList newArrayList = Lists.newArrayList();
            this.edgeTypeCardinality.forEach((edgeTypeId2, d) -> {
                switch (patternDirection) {
                    case OUT:
                        if (edgeTypeId.getSrcLabelId() == edgeTypeId2.getSrcLabelId() && edgeTypeId.getEdgeLabelId() == edgeTypeId2.getEdgeLabelId()) {
                            newArrayList.add(edgeTypeId2);
                            return;
                        }
                        return;
                    case IN:
                        if (edgeTypeId.getDstLabelId() == edgeTypeId2.getDstLabelId() && edgeTypeId.getEdgeLabelId() == edgeTypeId2.getEdgeLabelId()) {
                            newArrayList.add(edgeTypeId2);
                            return;
                        }
                        return;
                    default:
                        return;
                }
            });
            if (patternEdge.getEdgeTypeIds().containsAll(newArrayList)) {
                return;
            }
            double d2 = 0.0d;
            Iterator it = newArrayList.iterator();
            while (it.hasNext()) {
                d2 += getEdgeTypeCardinality((EdgeTypeId) it.next()).doubleValue();
            }
            atomicDouble.addAndGet(d2);
        });
        return Double.valueOf(atomicDouble.get());
    }

    public List<Integer> getVertexTypes() {
        return List.copyOf(this.schemaGraph.vertexSet());
    }

    public List<EdgeTypeId> getEdgeTypes() {
        return List.copyOf(this.schemaGraph.edgeSet());
    }

    public List<EdgeTypeId> getAdjEdgeTypes(Integer num) {
        return List.copyOf(this.schemaGraph.edgesOf(num));
    }

    public List<EdgeTypeId> getEdgeTypes(Integer num, Integer num2) {
        return List.copyOf(this.schemaGraph.getAllEdges(num, num2));
    }

    public Double getVertexTypeCardinality(Integer num) {
        Double d = this.vertexTypeCardinality.get(num);
        if (d != null) {
            return d;
        }
        logger.debug("Vertex type {} not found in schema, assuming cardinality 1.0", num);
        return Double.valueOf(1.0d);
    }

    public Double getEdgeTypeCardinality(EdgeTypeId edgeTypeId) {
        Double d = this.edgeTypeCardinality.get(edgeTypeId);
        if (d != null) {
            return d;
        }
        logger.debug("Edge type {} not found in schema, assuming cardinality 1.0", edgeTypeId);
        return Double.valueOf(1.0d);
    }
}
