/*
 * Decompiled with CFR 0.152.
 */
package org.apache.druid.msq.querykit;

import it.unimi.dsi.fastutil.ints.Int2IntMap;
import it.unimi.dsi.fastutil.ints.Int2IntOpenHashMap;
import it.unimi.dsi.fastutil.ints.Int2ObjectMap;
import it.unimi.dsi.fastutil.ints.IntIterator;
import it.unimi.dsi.fastutil.ints.IntOpenHashSet;
import it.unimi.dsi.fastutil.ints.IntSet;
import java.io.Closeable;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Optional;
import java.util.concurrent.atomic.AtomicLong;
import java.util.function.Function;
import java.util.stream.Collectors;
import org.apache.druid.frame.Frame;
import org.apache.druid.frame.channel.ReadableFrameChannel;
import org.apache.druid.frame.channel.WritableFrameChannel;
import org.apache.druid.frame.processor.FrameProcessor;
import org.apache.druid.frame.processor.FrameProcessors;
import org.apache.druid.frame.processor.ReturnOrAwait;
import org.apache.druid.frame.read.FrameReader;
import org.apache.druid.frame.segment.FrameCursor;
import org.apache.druid.msq.indexing.error.BroadcastTablesTooLargeFault;
import org.apache.druid.msq.indexing.error.MSQException;
import org.apache.druid.msq.input.ReadableInput;
import org.apache.druid.msq.querykit.InputNumberDataSource;
import org.apache.druid.query.DataSource;
import org.apache.druid.query.InlineDataSource;
import org.apache.druid.query.JoinAlgorithm;
import org.apache.druid.query.Query;
import org.apache.druid.segment.ColumnValueSelector;
import org.apache.druid.segment.Cursor;
import org.apache.druid.segment.SegmentReference;
import org.apache.druid.segment.column.RowSignature;

public class BroadcastJoinSegmentMapFnProcessor
implements FrameProcessor<Function<SegmentReference, SegmentReference>> {
    private final Query<?> query;
    private final Int2IntMap inputNumberToProcessorChannelMap;
    private final List<ReadableFrameChannel> channels;
    private final List<FrameReader> channelReaders;
    private final List<List<Object[]>> channelData;
    private final IntSet sideChannelNumbers;
    private final long memoryReservedForBroadcastJoin;
    private long memoryUsed = 0L;

    public BroadcastJoinSegmentMapFnProcessor(Query<?> query, Int2IntMap inputNumberToProcessorChannelMap, List<ReadableFrameChannel> channels, List<FrameReader> channelReaders, long memoryReservedForBroadcastJoin) {
        this.query = query;
        this.inputNumberToProcessorChannelMap = inputNumberToProcessorChannelMap;
        this.channels = channels;
        this.channelReaders = channelReaders;
        this.channelData = new ArrayList<List<Object[]>>();
        this.sideChannelNumbers = new IntOpenHashSet();
        this.sideChannelNumbers.addAll(inputNumberToProcessorChannelMap.values());
        this.memoryReservedForBroadcastJoin = memoryReservedForBroadcastJoin;
        for (int i = 0; i < channels.size(); ++i) {
            if (this.sideChannelNumbers.contains(i)) {
                this.channelData.add(new ArrayList());
                this.sideChannelNumbers.add(i);
                continue;
            }
            this.channelData.add(null);
        }
    }

    public static BroadcastJoinSegmentMapFnProcessor create(Query<?> query, Int2ObjectMap<ReadableInput> sideChannels, long memoryReservedForBroadcastJoin) {
        Int2IntOpenHashMap inputNumberToProcessorChannelMap = new Int2IntOpenHashMap();
        ArrayList<ReadableFrameChannel> inputChannels = new ArrayList<ReadableFrameChannel>();
        ArrayList<FrameReader> channelReaders = new ArrayList<FrameReader>();
        for (Int2ObjectMap.Entry sideChannelEntry : sideChannels.int2ObjectEntrySet()) {
            int inputNumber = sideChannelEntry.getIntKey();
            inputNumberToProcessorChannelMap.put(inputNumber, inputChannels.size());
            inputChannels.add(((ReadableInput)sideChannelEntry.getValue()).getChannel());
            channelReaders.add(((ReadableInput)sideChannelEntry.getValue()).getChannelFrameReader());
        }
        return new BroadcastJoinSegmentMapFnProcessor(query, (Int2IntMap)inputNumberToProcessorChannelMap, inputChannels, channelReaders, memoryReservedForBroadcastJoin);
    }

    public List<ReadableFrameChannel> inputChannels() {
        return this.channels;
    }

    public List<WritableFrameChannel> outputChannels() {
        return Collections.emptyList();
    }

    public ReturnOrAwait<Function<SegmentReference, SegmentReference>> runIncrementally(IntSet readableInputs) {
        if (this.buildBroadcastTablesIncrementally(readableInputs)) {
            return ReturnOrAwait.returnObject(this.createSegmentMapFunction());
        }
        return ReturnOrAwait.awaitAny((IntSet)this.sideChannelNumbers);
    }

    public void cleanup() throws IOException {
        FrameProcessors.closeAll(this.inputChannels(), this.outputChannels(), (Closeable[])new Closeable[0]);
    }

    private void addFrame(int channelNumber, Frame frame) {
        List<Object[]> data = this.channelData.get(channelNumber);
        FrameReader frameReader = this.channelReaders.get(channelNumber);
        FrameCursor cursor = FrameProcessors.makeCursor((Frame)frame, (FrameReader)frameReader);
        List selectors = frameReader.signature().getColumnNames().stream().map(arg_0 -> BroadcastJoinSegmentMapFnProcessor.lambda$addFrame$0((Cursor)cursor, arg_0)).collect(Collectors.toList());
        while (!cursor.isDone()) {
            Object[] row = new Object[selectors.size()];
            for (int i = 0; i < row.length; ++i) {
                row[i] = ((ColumnValueSelector)selectors.get(i)).getObject();
            }
            data.add(row);
            cursor.advance();
        }
    }

    private Function<SegmentReference, SegmentReference> createSegmentMapFunction() {
        return this.inlineChannelData(this.query.getDataSource()).createSegmentMapFunction(this.query, new AtomicLong());
    }

    DataSource inlineChannelData(DataSource originalDataSource) {
        if (originalDataSource instanceof InputNumberDataSource) {
            int inputNumber = ((InputNumberDataSource)originalDataSource).getInputNumber();
            if (this.inputNumberToProcessorChannelMap.containsKey(inputNumber)) {
                int channelNumber = this.inputNumberToProcessorChannelMap.get(inputNumber);
                if (this.sideChannelNumbers.contains(channelNumber)) {
                    return InlineDataSource.fromIterable((Iterable)this.channelData.get(channelNumber), (RowSignature)this.channelReaders.get(channelNumber).signature());
                }
                return originalDataSource;
            }
            return originalDataSource;
        }
        ArrayList<DataSource> newChildren = new ArrayList<DataSource>(originalDataSource.getChildren().size());
        for (DataSource child : originalDataSource.getChildren()) {
            newChildren.add(this.inlineChannelData(child));
        }
        return originalDataSource.withChildren(newChildren);
    }

    boolean buildBroadcastTablesIncrementally(IntSet readableInputs) {
        IntIterator inputChannelIterator = readableInputs.iterator();
        while (inputChannelIterator.hasNext()) {
            int channelNumber = inputChannelIterator.nextInt();
            if (!this.sideChannelNumbers.contains(channelNumber) || !this.channels.get(channelNumber).canRead()) continue;
            Frame frame = this.channels.get(channelNumber).read();
            this.memoryUsed += frame.numBytes();
            if (this.memoryUsed > this.memoryReservedForBroadcastJoin) {
                throw new MSQException(new BroadcastTablesTooLargeFault(this.memoryReservedForBroadcastJoin, Optional.ofNullable(this.query).map(q -> q.context().getString("sqlJoinAlgorithm")).map(JoinAlgorithm::fromString).orElse(null)));
            }
            this.addFrame(channelNumber, frame);
        }
        IntIterator intIterator = this.sideChannelNumbers.iterator();
        while (intIterator.hasNext()) {
            int channelNumber = (Integer)intIterator.next();
            if (this.channels.get(channelNumber).isFinished()) continue;
            return false;
        }
        return true;
    }

    IntSet getSideChannelNumbers() {
        return this.sideChannelNumbers;
    }

    private static /* synthetic */ ColumnValueSelector lambda$addFrame$0(Cursor cursor, String columnName) {
        return cursor.getColumnSelectorFactory().makeColumnValueSelector(columnName);
    }
}

