/*
 * Decompiled with CFR 0.152.
 */
package org.apache.druid.msq.statistics;

import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Preconditions;
import it.unimi.dsi.fastutil.ints.IntRBTreeSet;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.SortedMap;
import java.util.TreeMap;
import java.util.stream.Collectors;
import org.apache.druid.frame.FrameType;
import org.apache.druid.frame.key.ClusterBy;
import org.apache.druid.frame.key.ClusterByPartition;
import org.apache.druid.frame.key.ClusterByPartitions;
import org.apache.druid.frame.key.RowKey;
import org.apache.druid.frame.key.RowKeyReader;
import org.apache.druid.indexing.common.task.batch.TooManyBucketsException;
import org.apache.druid.java.util.common.IAE;
import org.apache.druid.java.util.common.ISE;
import org.apache.druid.java.util.common.Pair;
import org.apache.druid.java.util.common.logger.Logger;
import org.apache.druid.msq.statistics.ClusterByStatisticsCollector;
import org.apache.druid.msq.statistics.ClusterByStatisticsSnapshot;
import org.apache.druid.msq.statistics.KeyCollector;
import org.apache.druid.msq.statistics.KeyCollectorFactory;
import org.apache.druid.msq.statistics.KeyCollectorSnapshot;
import org.apache.druid.msq.statistics.KeyCollectors;
import org.apache.druid.segment.ColumnInspector;
import org.apache.druid.segment.column.RowSignature;

public class ClusterByStatisticsCollectorImpl
implements ClusterByStatisticsCollector {
    private static final int MAX_COUNT_MAX_ITERATIONS = 500;
    private static final double MAX_COUNT_ITERATION_GROWTH_FACTOR = 1.05;
    private final Logger log = new Logger(ClusterByStatisticsCollectorImpl.class);
    private final ClusterBy clusterBy;
    private final RowKeyReader keyReader;
    private final KeyCollectorFactory<? extends KeyCollector<?>, ? extends KeyCollectorSnapshot> keyCollectorFactory;
    private final SortedMap<RowKey, BucketHolder> buckets;
    private final boolean checkHasMultipleValues;
    private final boolean[] hasMultipleValues;
    private final long maxRetainedBytes;
    private final int maxBuckets;
    private long totalRetainedBytes;

    private ClusterByStatisticsCollectorImpl(ClusterBy clusterBy, RowSignature rowSignature, RowKeyReader keyReader, KeyCollectorFactory<?, ?> keyCollectorFactory, long maxRetainedBytes, int maxBuckets, boolean checkHasMultipleValues) {
        this.clusterBy = clusterBy;
        this.keyReader = keyReader;
        this.keyCollectorFactory = keyCollectorFactory;
        this.maxRetainedBytes = maxRetainedBytes;
        this.buckets = new TreeMap<RowKey, BucketHolder>(clusterBy.bucketComparator(rowSignature));
        this.maxBuckets = maxBuckets;
        this.checkHasMultipleValues = checkHasMultipleValues;
        boolean[] blArray = this.hasMultipleValues = checkHasMultipleValues ? new boolean[clusterBy.getColumns().size()] : null;
        if ((long)maxBuckets > maxRetainedBytes) {
            throw new IAE("maxBuckets[%s] cannot be larger than maxRetainedBytes[%s]", new Object[]{maxBuckets, maxRetainedBytes});
        }
    }

    public static ClusterByStatisticsCollector create(ClusterBy clusterBy, RowSignature signature, FrameType frameType, long maxRetainedBytes, int maxBuckets, boolean aggregate, boolean checkHasMultipleValues) {
        RowKeyReader keyReader = clusterBy.keyReader((ColumnInspector)signature, frameType);
        KeyCollectorFactory<?, ?> keyCollectorFactory = KeyCollectors.makeStandardFactory(clusterBy, aggregate, signature);
        return new ClusterByStatisticsCollectorImpl(clusterBy, signature, keyReader, keyCollectorFactory, maxRetainedBytes, maxBuckets, checkHasMultipleValues);
    }

    @Override
    public ClusterBy getClusterBy() {
        return this.clusterBy;
    }

    @Override
    public ClusterByStatisticsCollector add(RowKey key, int weight) {
        if (this.checkHasMultipleValues) {
            for (int i = 0; i < this.clusterBy.getColumns().size(); ++i) {
                this.hasMultipleValues[i] = this.hasMultipleValues[i] || this.keyReader.hasMultipleValues(key, i);
            }
        }
        BucketHolder bucketHolder = this.getOrCreateBucketHolder(this.keyReader.trim(key, this.clusterBy.getBucketByCount()));
        bucketHolder.keyCollector.add(key, weight);
        this.totalRetainedBytes = (long)((double)this.totalRetainedBytes + bucketHolder.updateRetainedBytes());
        if (this.totalRetainedBytes > this.maxRetainedBytes) {
            this.log.debug("Downsampling ClusterByStatisticsCollector as totalRetainedBytes[%s] is greater than maxRetainedBytes[%s]", new Object[]{this.totalRetainedBytes, this.maxRetainedBytes});
            this.downSample();
        }
        return this;
    }

    @Override
    public ClusterByStatisticsCollector addAll(ClusterByStatisticsCollector other) {
        if (other instanceof ClusterByStatisticsCollectorImpl) {
            ClusterByStatisticsCollectorImpl that = (ClusterByStatisticsCollectorImpl)other;
            for (Map.Entry<RowKey, BucketHolder> otherBucketEntry : that.buckets.entrySet()) {
                BucketHolder bucketHolder = this.getOrCreateBucketHolder(otherBucketEntry.getKey());
                bucketHolder.keyCollector.addAll(otherBucketEntry.getValue().keyCollector);
                this.totalRetainedBytes = (long)((double)this.totalRetainedBytes + bucketHolder.updateRetainedBytes());
                if (this.totalRetainedBytes <= this.maxRetainedBytes) continue;
                this.log.debug("Downsampling ClusterByStatisticsCollector as totalRetainedBytes[%s] is greater than maxRetainedBytes[%s]", new Object[]{this.totalRetainedBytes, this.maxRetainedBytes});
                this.downSample();
            }
            if (this.checkHasMultipleValues) {
                for (int i = 0; i < this.clusterBy.getColumns().size(); ++i) {
                    this.hasMultipleValues[i] = this.hasMultipleValues[i] || that.hasMultipleValues[i];
                }
            }
        } else {
            this.addAll(other.snapshot());
        }
        return this;
    }

    @Override
    public ClusterByStatisticsCollector addAll(ClusterByStatisticsSnapshot snapshot) {
        for (ClusterByStatisticsSnapshot.Bucket otherBucket : snapshot.getBuckets().values()) {
            KeyCollector<?> otherKeyCollector = this.keyCollectorFactory.fromSnapshot(otherBucket.getKeyCollectorSnapshot());
            BucketHolder bucketHolder = this.getOrCreateBucketHolder(otherBucket.getBucketKey());
            bucketHolder.keyCollector.addAll(otherKeyCollector);
            this.totalRetainedBytes = (long)((double)this.totalRetainedBytes + bucketHolder.updateRetainedBytes());
            if (this.totalRetainedBytes <= this.maxRetainedBytes) continue;
            this.log.debug("Downsampling ClusterByStatisticsCollector as totalRetainedBytes[%s] is greater than maxRetainedBytes[%s]", new Object[]{this.totalRetainedBytes, this.maxRetainedBytes});
            this.downSample();
        }
        if (this.checkHasMultipleValues) {
            Iterator<Object> iterator = snapshot.getHasMultipleValues().iterator();
            while (iterator.hasNext()) {
                int keyPosition = (Integer)iterator.next();
                this.hasMultipleValues[keyPosition] = true;
            }
        }
        return this;
    }

    @Override
    public long estimatedTotalWeight() {
        long count = 0L;
        for (BucketHolder bucketHolder : this.buckets.values()) {
            count += bucketHolder.keyCollector.estimatedTotalWeight();
        }
        return count;
    }

    @VisibleForTesting
    long getTotalRetainedBytes() {
        return this.totalRetainedBytes;
    }

    @Override
    public boolean hasMultipleValues(int keyPosition) {
        if (this.checkHasMultipleValues) {
            if (keyPosition < 0 || keyPosition >= this.clusterBy.getColumns().size()) {
                throw new IAE("Invalid keyPosition [%d]", new Object[]{keyPosition});
            }
            return this.hasMultipleValues[keyPosition];
        }
        throw new ISE("hasMultipleValues not available for this collector", new Object[0]);
    }

    @Override
    public ClusterByStatisticsCollector clear() {
        this.buckets.clear();
        this.totalRetainedBytes = 0L;
        return this;
    }

    @Override
    public ClusterByPartitions generatePartitionsWithTargetWeight(long targetWeight) {
        if (targetWeight < 1L) {
            throw new IAE("Target weight must be positive", new Object[0]);
        }
        this.assertRetainedByteCountsAreTrackedCorrectly();
        if (this.buckets.isEmpty()) {
            return ClusterByPartitions.oneUniversalPartition();
        }
        ArrayList<ClusterByPartition> partitions = new ArrayList<ClusterByPartition>();
        for (BucketHolder bucket : this.buckets.values()) {
            List bucketPartitions = bucket.keyCollector.generatePartitionsWithTargetWeight(targetWeight).ranges();
            if (!partitions.isEmpty() && !bucketPartitions.isEmpty()) {
                partitions.set(partitions.size() - 1, new ClusterByPartition(((ClusterByPartition)partitions.get(partitions.size() - 1)).getStart(), ((ClusterByPartition)bucketPartitions.get(0)).getStart()));
            }
            partitions.addAll(bucketPartitions);
        }
        ClusterByPartitions retVal = new ClusterByPartitions(partitions);
        if (!retVal.allAbutting()) {
            throw new ISE("Partitions are not all abutting", new Object[0]);
        }
        return retVal;
    }

    @Override
    public ClusterByPartitions generatePartitionsWithMaxCount(int maxNumPartitions) {
        ClusterByPartitions ranges;
        if (maxNumPartitions < 1) {
            throw new IAE("Must have at least one partition", new Object[0]);
        }
        if (this.buckets.isEmpty()) {
            return ClusterByPartitions.oneUniversalPartition();
        }
        if (maxNumPartitions == 1 && this.clusterBy.getBucketByCount() == 0) {
            return new ClusterByPartitions(Collections.singletonList(new ClusterByPartition(((BucketHolder)this.buckets.get((Object)this.buckets.firstKey())).keyCollector.minKey(), null)));
        }
        long totalWeight = 0L;
        for (BucketHolder bucketHolder : this.buckets.values()) {
            totalWeight += bucketHolder.keyCollector.estimatedTotalWeight();
        }
        long targetPartitionWeight = (long)Math.ceil((double)totalWeight / (double)maxNumPartitions);
        int iterations = 0;
        do {
            if (iterations++ > 500) {
                throw new ISE("Unable to compute partition ranges", new Object[0]);
            }
            ranges = this.generatePartitionsWithTargetWeight(targetPartitionWeight);
            targetPartitionWeight = (long)Math.ceil((double)targetPartitionWeight * 1.05);
        } while (ranges.size() > maxNumPartitions);
        return ranges;
    }

    @Override
    public void logSketches() {
        if (this.log.isDebugEnabled()) {
            List keyCollectors = this.buckets.values().stream().map(bucketHolder -> bucketHolder.keyCollector).sorted(Comparator.comparingInt(KeyCollector::sketchAccuracyFactor)).collect(Collectors.toList());
            this.log.debug("KeyCollectors at partition generation: [%s]", new Object[]{keyCollectors});
        } else {
            List limitedKeyCollectors = this.buckets.values().stream().map(bucketHolder -> bucketHolder.keyCollector).sorted(Comparator.comparingInt(KeyCollector::sketchAccuracyFactor)).limit(5L).collect(Collectors.toList());
            this.log.info("Most downsampled keyCollectors: [%s]", new Object[]{limitedKeyCollectors});
        }
    }

    @Override
    public ClusterByStatisticsSnapshot snapshot() {
        IntRBTreeSet hasMultipleValuesSet;
        this.assertRetainedByteCountsAreTrackedCorrectly();
        HashMap<Long, ClusterByStatisticsSnapshot.Bucket> bucketSnapshots = new HashMap<Long, ClusterByStatisticsSnapshot.Bucket>();
        RowKeyReader trimmedRowReader = this.keyReader.trimmedKeyReader(this.clusterBy.getBucketByCount());
        for (Map.Entry<RowKey, BucketHolder> bucketEntry : this.buckets.entrySet()) {
            KeyCollectorSnapshot keyCollectorSnapshot = this.keyCollectorFactory.toSnapshot(bucketEntry.getValue().keyCollector);
            Long bucketKey = Long.MIN_VALUE;
            if (this.clusterBy.getBucketByCount() == 1) {
                bucketKey = (Long)trimmedRowReader.read(bucketEntry.getKey(), 0);
            }
            bucketSnapshots.put(bucketKey, new ClusterByStatisticsSnapshot.Bucket(bucketEntry.getKey(), keyCollectorSnapshot, this.totalRetainedBytes));
        }
        if (this.checkHasMultipleValues) {
            hasMultipleValuesSet = new IntRBTreeSet();
            for (int i = 0; i < this.hasMultipleValues.length; ++i) {
                if (!this.hasMultipleValues[i]) continue;
                hasMultipleValuesSet.add(i);
            }
        } else {
            hasMultipleValuesSet = null;
        }
        return new ClusterByStatisticsSnapshot(bucketSnapshots, (Set<Integer>)hasMultipleValuesSet);
    }

    @VisibleForTesting
    List<KeyCollector<?>> getKeyCollectors() {
        return this.buckets.values().stream().map(holder -> holder.keyCollector).collect(Collectors.toList());
    }

    private BucketHolder getOrCreateBucketHolder(RowKey bucketKey) {
        BucketHolder existingHolder = (BucketHolder)this.buckets.get(Preconditions.checkNotNull((Object)bucketKey, (Object)"bucketKey"));
        if (existingHolder != null) {
            return existingHolder;
        }
        if (this.buckets.size() < this.maxBuckets) {
            BucketHolder newHolder = new BucketHolder(this.keyCollectorFactory.newKeyCollector());
            this.buckets.put(bucketKey, newHolder);
            return newHolder;
        }
        throw new TooManyBucketsException(this.maxBuckets);
    }

    void downSample() {
        BucketHolder bucketHolder;
        long newTotalRetainedBytes = this.totalRetainedBytes;
        long targetTotalRetainedBytes = Math.min(this.totalRetainedBytes / 2L, this.maxRetainedBytes);
        ArrayList<Pair> sortedHolders = new ArrayList<Pair>(this.buckets.size());
        RowKeyReader trimmedRowReader = this.keyReader.trimmedKeyReader(this.clusterBy.getBucketByCount());
        for (Map.Entry<RowKey, BucketHolder> entry : this.buckets.entrySet()) {
            bucketHolder = entry.getValue();
            if (bucketHolder == null || bucketHolder.keyCollector.estimatedRetainedKeys() <= 1) continue;
            Long timeChunk = this.clusterBy.getBucketByCount() == 0 ? null : (Long)trimmedRowReader.read(entry.getKey(), 0);
            sortedHolders.add(Pair.of(timeChunk, (Object)bucketHolder));
        }
        sortedHolders.sort(Comparator.comparing(pair -> (double)((BucketHolder)pair.rhs).keyCollector.estimatedTotalWeight() / (double)((BucketHolder)pair.rhs).keyCollector.estimatedRetainedKeys()));
        int i = 0;
        while (i < sortedHolders.size() && newTotalRetainedBytes > targetTotalRetainedBytes) {
            Long timeChunk = (Long)((Pair)sortedHolders.get((int)i)).lhs;
            bucketHolder = (BucketHolder)((Pair)sortedHolders.get((int)i)).rhs;
            this.log.debug("Downsampling sketch for timeChunk [%s]: [%s]", new Object[]{timeChunk, bucketHolder.keyCollector});
            bucketHolder.keyCollector.downSample();
            newTotalRetainedBytes = (long)((double)newTotalRetainedBytes + bucketHolder.updateRetainedBytes());
            if (i != sortedHolders.size() - 1 && !(((BucketHolder)((Pair)sortedHolders.get((int)(i + 1))).rhs).retainedBytes > bucketHolder.retainedBytes) && bucketHolder.keyCollector.estimatedRetainedKeys() > 1) continue;
            ++i;
        }
        this.totalRetainedBytes = newTotalRetainedBytes;
    }

    private void assertRetainedByteCountsAreTrackedCorrectly() {
        assert (this.buckets.values().stream().allMatch(holder -> holder.retainedBytes == (double)holder.keyCollector.estimatedRetainedBytes()));
        assert ((double)this.totalRetainedBytes == this.buckets.values().stream().mapToDouble(holder -> holder.keyCollector.estimatedRetainedBytes()).sum());
    }

    private static class BucketHolder {
        private final KeyCollector<?> keyCollector;
        private double retainedBytes;

        public BucketHolder(KeyCollector<?> keyCollector) {
            this.keyCollector = keyCollector;
            this.retainedBytes = keyCollector.estimatedRetainedBytes();
        }

        public double updateRetainedBytes() {
            double newRetainedBytes = this.keyCollector.estimatedRetainedBytes();
            double difference = newRetainedBytes - this.retainedBytes;
            this.retainedBytes = newRetainedBytes;
            return difference;
        }
    }
}

