/*
 * Decompiled with CFR 0.152.
 */
package org.apache.hadoop.hive.ql.exec.vector.mapjoin.fast;

import com.google.common.base.Preconditions;
import com.google.common.util.concurrent.ThreadFactoryBuilder;
import java.io.IOException;
import java.util.Collections;
import java.util.Map;
import java.util.concurrent.BlockingQueue;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.LongAccumulator;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.hive.common.Pool;
import org.apache.hadoop.hive.conf.HiveConf;
import org.apache.hadoop.hive.llap.LlapDaemonInfo;
import org.apache.hadoop.hive.ql.exec.HashTableLoader;
import org.apache.hadoop.hive.ql.exec.MapJoinOperator;
import org.apache.hadoop.hive.ql.exec.MapredContext;
import org.apache.hadoop.hive.ql.exec.MemoryMonitorInfo;
import org.apache.hadoop.hive.ql.exec.Utilities;
import org.apache.hadoop.hive.ql.exec.mapjoin.MapJoinMemoryExhaustionError;
import org.apache.hadoop.hive.ql.exec.mr.ExecMapperContext;
import org.apache.hadoop.hive.ql.exec.persistence.MapJoinTableContainer;
import org.apache.hadoop.hive.ql.exec.persistence.MapJoinTableContainerSerDe;
import org.apache.hadoop.hive.ql.exec.tez.TezContext;
import org.apache.hadoop.hive.ql.exec.vector.mapjoin.fast.VectorMapJoinFastTableContainer;
import org.apache.hadoop.hive.ql.metadata.HiveException;
import org.apache.hadoop.hive.ql.plan.MapJoinDesc;
import org.apache.hadoop.hive.serde2.SerDeException;
import org.apache.hadoop.io.BytesWritable;
import org.apache.hive.common.util.FixedSizedObjectPool;
import org.apache.tez.common.counters.TezCounter;
import org.apache.tez.runtime.api.AbstractLogicalInput;
import org.apache.tez.runtime.api.LogicalInput;
import org.apache.tez.runtime.library.api.KeyValueReader;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class VectorMapJoinFastHashTableLoader
implements HashTableLoader {
    private static final Logger LOG = LoggerFactory.getLogger((String)VectorMapJoinFastHashTableLoader.class.getName());
    private Configuration hconf;
    protected MapJoinDesc desc;
    private TezContext tezContext;
    private String cacheKey;
    private TezCounter htLoadCounter;
    private LongAccumulator totalEntries;
    private int numLoadThreads;
    private ExecutorService loadExecService;
    private HashTableElementBatch[] elementBatches;
    private FixedSizedObjectPool<HashTableElementBatch> batchPool;
    private BlockingQueue<HashTableElementBatch>[] loadBatchQueues;
    private static final HashTableElementBatch DONE_SENTINEL = new HashTableElementBatch();

    public VectorMapJoinFastHashTableLoader() {
    }

    public VectorMapJoinFastHashTableLoader(TezContext context, Configuration hconf, MapJoinOperator joinOp) {
        this.tezContext = context;
        this.hconf = hconf;
        this.desc = (MapJoinDesc)joinOp.getConf();
        this.cacheKey = joinOp.getCacheKey();
        this.htLoadCounter = this.tezContext.getTezProcessorContext().getCounters().findCounter(HiveConf.getVar((Configuration)hconf, (HiveConf.ConfVars)HiveConf.ConfVars.HIVE_COUNTER_GROUP), hconf.get("__hive.context.name", ""));
    }

    @Override
    public void init(ExecMapperContext context, MapredContext mrContext, Configuration hconf, MapJoinOperator joinOp) {
        this.tezContext = (TezContext)mrContext;
        this.hconf = hconf;
        this.desc = (MapJoinDesc)joinOp.getConf();
        this.cacheKey = joinOp.getCacheKey();
        String counterGroup = HiveConf.getVar((Configuration)hconf, (HiveConf.ConfVars)HiveConf.ConfVars.HIVE_COUNTER_GROUP);
        String vertexName = hconf.get("__hive.context.name", "");
        String counterName = Utilities.getVertexCounterName(HashTableLoader.HashTableLoaderCounters.HASHTABLE_LOAD_TIME_MS.name(), vertexName);
        this.htLoadCounter = this.tezContext.getTezProcessorContext().getCounters().findCounter(counterGroup, counterName);
    }

    private void initHTLoadingService(long estKeyCount) {
        if (estKeyCount < 0x100000L) {
            this.numLoadThreads = 1;
        } else {
            int initialValue = HiveConf.getIntVar((Configuration)this.hconf, (HiveConf.ConfVars)HiveConf.ConfVars.HIVE_MAPJOIN_PARALEL_HASHTABLE_THREADS);
            Preconditions.checkArgument((initialValue > 0 ? 1 : 0) != 0, (Object)"The number of HT-loading-threads should be positive.");
            int adjustedValue = Integer.highestOneBit(initialValue);
            if (initialValue != adjustedValue) {
                LOG.info("Adjust the number of HT-loading-threads to {}. (Previous value: {})", (Object)adjustedValue, (Object)initialValue);
            }
            this.numLoadThreads = adjustedValue;
        }
        this.totalEntries = new LongAccumulator(Long::sum, 0L);
        this.loadExecService = Executors.newFixedThreadPool(this.numLoadThreads, new ThreadFactoryBuilder().setDaemon(true).setPriority(5).setNameFormat("HT-Load-Thread-%d").build());
        this.batchPool = new FixedSizedObjectPool(8 * this.numLoadThreads, (Pool.PoolObjectHelper)new Pool.PoolObjectHelper<HashTableElementBatch>(this){

            public HashTableElementBatch create() {
                return new HashTableElementBatch();
            }

            public void resetBeforeOffer(HashTableElementBatch elementBatch) {
                elementBatch.reset();
            }
        });
        this.elementBatches = new HashTableElementBatch[this.numLoadThreads];
        this.loadBatchQueues = new BlockingQueue[this.numLoadThreads];
        for (int i = 0; i < this.numLoadThreads; ++i) {
            this.loadBatchQueues[i] = new LinkedBlockingQueue<HashTableElementBatch>();
            this.elementBatches[i] = (HashTableElementBatch)this.batchPool.take();
        }
    }

    private void submitQueueDrainThreads(VectorMapJoinFastTableContainer vectorMapJoinFastTableContainer) throws InterruptedException, IOException, SerDeException {
        int partitionId = 0;
        while (partitionId < this.numLoadThreads) {
            int finalPartitionId = partitionId++;
            this.loadExecService.submit(() -> {
                try {
                    LOG.info("Partition id {} with Queue size {}", (Object)finalPartitionId, (Object)this.loadBatchQueues[finalPartitionId].size());
                    this.drainAndLoadForPartition(finalPartitionId, vectorMapJoinFastTableContainer);
                }
                catch (IOException | InterruptedException | HiveException | SerDeException e) {
                    throw new RuntimeException("Failed to start HT Load threads", e);
                }
            });
        }
    }

    private void drainAndLoadForPartition(int partitionId, VectorMapJoinFastTableContainer tableContainer) throws InterruptedException, IOException, HiveException, SerDeException {
        LOG.info("Starting draining thread {}", (Object)partitionId);
        long totalProcessedEntries = 0L;
        HashTableElementBatch batch = null;
        while (batch != DONE_SENTINEL) {
            batch = this.loadBatchQueues[partitionId].take();
            LOG.debug("Draining thread {} batchSize {}", (Object)partitionId, (Object)batch.getSize());
            for (int i = 0; i < batch.getSize(); ++i) {
                try {
                    HashTableElement h = batch.getBatch(i);
                    tableContainer.putRow(h.getHashCode(), h.getKey(), h.getValue());
                    continue;
                }
                catch (Exception e) {
                    throw new HiveException("Exception in draining thread put row", (Throwable)e);
                }
            }
            totalProcessedEntries += (long)batch.getSize();
            LOG.debug("Draining thread {} added {} entries", (Object)partitionId, (Object)batch.getSize());
            this.totalEntries.accumulate(batch.getSize());
            this.batchPool.offer((Object)batch);
        }
        LOG.info("Terminating draining thread {} after processing Entries {}", (Object)partitionId, (Object)totalProcessedEntries);
    }

    private void addQueueDoneSentinel() {
        for (int i = 0; i < this.numLoadThreads; ++i) {
            this.loadBatchQueues[i].add(this.elementBatches[i]);
            this.loadBatchQueues[i].add(DONE_SENTINEL);
        }
    }

    @Override
    public void load(MapJoinTableContainer[] mapJoinTables, MapJoinTableContainerSerDe[] mapJoinTableSerdes) throws HiveException {
        Map<Integer, String> parentToInput = this.desc.getParentToInput();
        Map<Integer, Long> parentKeyCounts = this.desc.getParentKeyCounts();
        MemoryMonitorInfo memoryMonitorInfo = this.desc.getMemoryMonitorInfo();
        boolean doMemCheck = false;
        long effectiveThreshold = 0L;
        if (memoryMonitorInfo != null) {
            effectiveThreshold = memoryMonitorInfo.getEffectiveThreshold(this.desc.getMaxMemoryAvailable());
            if (!LlapDaemonInfo.INSTANCE.isLlap()) {
                memoryMonitorInfo.setLlap(false);
            }
            if (memoryMonitorInfo.doMemoryMonitoring()) {
                doMemCheck = true;
                LOG.info("Memory monitoring for hash table loader enabled. {}", (Object)memoryMonitorInfo);
            }
        }
        if (!doMemCheck) {
            LOG.info("Not doing hash table memory monitoring. {}", (Object)memoryMonitorInfo);
        }
        for (int pos = 0; pos < mapJoinTables.length; ++pos) {
            if (pos == this.desc.getPosBigTable()) continue;
            String inputName = parentToInput.get(pos);
            LogicalInput input = this.tezContext.getInput(inputName);
            try {
                input.start();
                this.tezContext.getTezProcessorContext().waitForAnyInputReady(Collections.singletonList(input));
            }
            catch (Exception e) {
                throw new HiveException((Throwable)e);
            }
            try {
                KeyValueReader kvReader = (KeyValueReader)input.getReader();
                Long keyCountObj = parentKeyCounts.get(pos);
                long estKeyCount = keyCountObj == null ? -1L : keyCountObj;
                long inputRecords = -1L;
                try {
                    inputRecords = ((AbstractLogicalInput)input).getContext().getCounters().findCounter("org.apache.tez.common.counters.TaskCounter", "APPROXIMATE_INPUT_RECORDS").getValue();
                }
                catch (Exception e) {
                    LOG.debug("Failed to get value for counter APPROXIMATE_INPUT_RECORDS", (Throwable)e);
                }
                long keyCount = Math.max(estKeyCount, inputRecords);
                this.initHTLoadingService(keyCount);
                VectorMapJoinFastTableContainer tableContainer = new VectorMapJoinFastTableContainer(this.desc, this.hconf, keyCount, this.numLoadThreads);
                LOG.info("Loading hash table for input: {} cacheKey: {} tableContainer: {} smallTablePos: {} estKeyCount : {} keyCount : {}", new Object[]{inputName, this.cacheKey, tableContainer.getClass().getSimpleName(), pos, estKeyCount, keyCount});
                tableContainer.setSerde(null, null);
                this.submitQueueDrainThreads(tableContainer);
                long receivedEntries = 0L;
                long startTime = System.currentTimeMillis();
                while (kvReader.next()) {
                    HashTableElement h;
                    BytesWritable currentKey = (BytesWritable)kvReader.getCurrentKey();
                    BytesWritable currentValue = (BytesWritable)kvReader.getCurrentValue();
                    long hashCode = tableContainer.getHashCode(currentKey);
                    int partitionId = (int)((long)(this.numLoadThreads - 1) & hashCode);
                    if (this.elementBatches[partitionId].addElement(h = new HashTableElement(hashCode, currentValue.copyBytes(), currentKey.copyBytes()))) {
                        this.loadBatchQueues[partitionId].add(this.elementBatches[partitionId]);
                        this.elementBatches[partitionId] = (HashTableElementBatch)this.batchPool.take();
                    }
                    if (!doMemCheck || ++receivedEntries % memoryMonitorInfo.getMemoryCheckInterval() != 0L) continue;
                    long estMemUsage = tableContainer.getEstimatedMemorySize();
                    if (estMemUsage > effectiveThreshold) {
                        String msg = "Hash table loading exceeded memory limits for input: " + inputName + " numEntries: " + receivedEntries + " estimatedMemoryUsage: " + estMemUsage + " effectiveThreshold: " + effectiveThreshold + " memoryMonitorInfo: " + String.valueOf(memoryMonitorInfo);
                        LOG.error(msg);
                        throw new MapJoinMemoryExhaustionError(msg);
                    }
                    LOG.info("Checking hash table loader memory usage for input: {} numEntries: {} estimatedMemoryUsage: {} effectiveThreshold: {}", new Object[]{inputName, receivedEntries, estMemUsage, effectiveThreshold});
                }
                LOG.info("Finished loading the queue for input: {} waiting {} minutes for TPool shutdown", (Object)inputName, (Object)2);
                this.addQueueDoneSentinel();
                this.loadExecService.shutdown();
                if (!this.loadExecService.awaitTermination(2L, TimeUnit.MINUTES)) {
                    throw new HiveException("Failed to complete the hash table loader. Loading timed out.");
                }
                this.batchPool.clear();
                LOG.info("Total received entries: {} Threads {} HT entries: {}", new Object[]{receivedEntries, this.numLoadThreads, this.totalEntries.get()});
                long delta = System.currentTimeMillis() - startTime;
                this.htLoadCounter.increment(delta);
                tableContainer.seal();
                mapJoinTables[pos] = tableContainer;
                if (doMemCheck) {
                    LOG.info("Finished loading hash table for input: {} cacheKey: {} numEntries: {} estimatedMemoryUsage: {} Load Time : {} ", new Object[]{inputName, this.cacheKey, receivedEntries, tableContainer.getEstimatedMemorySize(), delta});
                    continue;
                }
                LOG.info("Finished loading hash table for input: {} cacheKey: {} numEntries: {} Load Time : {} ", new Object[]{inputName, this.cacheKey, receivedEntries, delta});
                continue;
            }
            catch (InterruptedException e) {
                Thread.currentThread().interrupt();
                throw new HiveException((Throwable)e);
            }
            catch (HiveException e) {
                throw e;
            }
            catch (Exception e) {
                throw new HiveException((Throwable)e);
            }
            finally {
                if (this.loadExecService != null && !this.loadExecService.isTerminated()) {
                    this.loadExecService.shutdownNow();
                }
            }
        }
    }

    private static class HashTableElementBatch {
        private static final int BATCH_SIZE = 1024;
        private final HashTableElement[] batch = new HashTableElement[1024];
        private int currentIndex = 0;

        public boolean addElement(HashTableElement h) {
            this.batch[this.currentIndex++] = h;
            return this.currentIndex == 1024;
        }

        public HashTableElement getBatch(int i) {
            return this.batch[i];
        }

        public int getSize() {
            return this.currentIndex;
        }

        public void reset() {
            this.currentIndex = 0;
        }
    }

    private static class HashTableElement {
        private final long hashCode;
        private final byte[] keyBytes;
        private final byte[] valueBytes;

        public HashTableElement(long hashCode, byte[] valueBytes, byte[] keyBytes) {
            this.hashCode = hashCode;
            this.keyBytes = keyBytes;
            this.valueBytes = valueBytes;
        }

        public BytesWritable getKey() {
            return new BytesWritable(this.keyBytes, this.keyBytes.length);
        }

        public BytesWritable getValue() {
            return new BytesWritable(this.valueBytes, this.valueBytes.length);
        }

        public long getHashCode() {
            return this.hashCode;
        }
    }
}

