/*
 * Decompiled with CFR 0.152.
 */
package org.apache.tez.runtime.library.common.sort.buffer;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Comparator;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicLong;
import java.util.concurrent.locks.Condition;
import java.util.concurrent.locks.ReentrantLock;
import org.apache.hadoop.io.RawComparator;
import org.apache.hadoop.io.serializer.Serializer;
import org.apache.tez.common.RssTezUtils;
import org.apache.tez.common.counters.TezCounter;
import org.apache.tez.dag.records.TezDAGID;
import org.apache.tez.dag.records.TezTaskAttemptID;
import org.apache.tez.dag.records.TezVertexID;
import org.apache.tez.runtime.library.common.sort.buffer.WriteBuffer;
import org.apache.uniffle.client.api.ShuffleWriteClient;
import org.apache.uniffle.client.response.SendShuffleDataResult;
import org.apache.uniffle.com.google.common.collect.Lists;
import org.apache.uniffle.com.google.common.collect.Maps;
import org.apache.uniffle.com.google.common.collect.Sets;
import org.apache.uniffle.com.google.common.util.concurrent.Uninterruptibles;
import org.apache.uniffle.common.ShuffleBlockInfo;
import org.apache.uniffle.common.ShuffleServerInfo;
import org.apache.uniffle.common.compression.Codec;
import org.apache.uniffle.common.config.RssConf;
import org.apache.uniffle.common.exception.RssException;
import org.apache.uniffle.common.util.ChecksumUtils;
import org.apache.uniffle.common.util.ThreadUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class WriteBufferManager<K, V> {
    private static final Logger LOG = LoggerFactory.getLogger(WriteBufferManager.class);
    private long copyTime = 0L;
    private long sortTime = 0L;
    private long compressTime = 0L;
    private final Map<Integer, Integer> partitionToSeqNo = Maps.newHashMap();
    private long uncompressedDataLen = 0L;
    private final long maxMemSize;
    private final ExecutorService sendExecutorService;
    private final ShuffleWriteClient shuffleWriteClient;
    private final String appId;
    private final Set<Long> successBlockIds;
    private final Set<Long> failedBlockIds;
    private final ReentrantLock memoryLock = new ReentrantLock();
    private final AtomicLong memoryUsedSize = new AtomicLong(0L);
    private final AtomicLong inSendListBytes = new AtomicLong(0L);
    private final Condition full = this.memoryLock.newCondition();
    private final RawComparator<K> comparator;
    private final long maxSegmentSize;
    private final Serializer<K> keySerializer;
    private final Serializer<V> valSerializer;
    private final List<WriteBuffer<K, V>> waitSendBuffers = Lists.newLinkedList();
    private final Map<Integer, WriteBuffer<K, V>> buffers = Maps.newConcurrentMap();
    private final long maxBufferSize;
    private final double memoryThreshold;
    private final double sendThreshold;
    private final int batch;
    private final Codec codec;
    private final Map<Integer, List<ShuffleServerInfo>> partitionToServers;
    private final Set<Long> allBlockIds = Sets.newConcurrentHashSet();
    private final Map<Integer, List<Long>> partitionToBlocks = Maps.newConcurrentMap();
    private final int numMaps;
    private final boolean isMemoryShuffleEnabled;
    private final long sendCheckInterval;
    private final long sendCheckTimeout;
    private final int bitmapSplitNum;
    private final long taskAttemptId;
    private TezTaskAttemptID tezTaskAttemptID;
    private final RssConf rssConf;
    private final int shuffleId;
    private final boolean isNeedSorted;
    private final TezCounter mapOutputByteCounter;
    private final TezCounter mapOutputRecordCounter;

    public WriteBufferManager(TezTaskAttemptID tezTaskAttemptID, long maxMemSize, String appId, long taskAttemptId, Set<Long> successBlockIds, Set<Long> failedBlockIds, ShuffleWriteClient shuffleWriteClient, RawComparator<K> comparator, long maxSegmentSize, Serializer<K> keySerializer, Serializer<V> valSerializer, long maxBufferSize, double memoryThreshold, int sendThreadNum, double sendThreshold, int batch, RssConf rssConf, Map<Integer, List<ShuffleServerInfo>> partitionToServers, int numMaps, boolean isMemoryShuffleEnabled, long sendCheckInterval, long sendCheckTimeout, int bitmapSplitNum, int shuffleId, boolean isNeedSorted, TezCounter mapOutputByteCounter, TezCounter mapOutputRecordCounter) {
        this.tezTaskAttemptID = tezTaskAttemptID;
        this.maxMemSize = maxMemSize;
        this.appId = appId;
        this.taskAttemptId = taskAttemptId;
        this.successBlockIds = successBlockIds;
        this.failedBlockIds = failedBlockIds;
        this.shuffleWriteClient = shuffleWriteClient;
        this.comparator = comparator;
        this.maxSegmentSize = maxSegmentSize;
        this.keySerializer = keySerializer;
        this.valSerializer = valSerializer;
        this.maxBufferSize = maxBufferSize;
        this.memoryThreshold = memoryThreshold;
        this.sendThreshold = sendThreshold;
        this.batch = batch;
        this.codec = Codec.newInstance(rssConf);
        this.partitionToServers = partitionToServers;
        this.numMaps = numMaps;
        this.isMemoryShuffleEnabled = isMemoryShuffleEnabled;
        this.sendCheckInterval = sendCheckInterval;
        this.sendCheckTimeout = sendCheckTimeout;
        this.bitmapSplitNum = bitmapSplitNum;
        this.rssConf = rssConf;
        this.shuffleId = shuffleId;
        this.isNeedSorted = isNeedSorted;
        this.mapOutputByteCounter = mapOutputByteCounter;
        this.mapOutputRecordCounter = mapOutputRecordCounter;
        this.sendExecutorService = Executors.newFixedThreadPool(sendThreadNum, ThreadUtils.getThreadFactory("send-thread"));
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    public void addRecord(int partitionId, K key, V value) throws InterruptedException, IOException {
        WriteBuffer<K, V> buffer;
        int length;
        this.memoryLock.lock();
        try {
            while (this.memoryUsedSize.get() > this.maxMemSize) {
                LOG.warn("memoryUsedSize {} is more than {}, inSendListBytes {}", new Object[]{this.memoryUsedSize, this.maxMemSize, this.inSendListBytes});
                this.full.await();
            }
        }
        finally {
            this.memoryLock.unlock();
        }
        this.checkFailedBlocks();
        if (!this.buffers.containsKey(partitionId)) {
            WriteBuffer<K, V> sortWriterBuffer = new WriteBuffer<K, V>(this.isNeedSorted, partitionId, this.comparator, this.maxSegmentSize, this.keySerializer, this.valSerializer);
            this.buffers.putIfAbsent(partitionId, sortWriterBuffer);
            this.waitSendBuffers.add(sortWriterBuffer);
        }
        if ((long)(length = (buffer = this.buffers.get(partitionId)).addRecord(key, value)) > this.maxMemSize) {
            throw new RssException("record is too big");
        }
        this.memoryUsedSize.addAndGet(length);
        if ((long)buffer.getDataLength() > this.maxBufferSize) {
            if (this.waitSendBuffers.remove(buffer)) {
                this.sendBufferToServers(buffer);
            } else {
                LOG.error("waitSendBuffers don't contain buffer {}", buffer);
            }
        }
        if ((double)this.memoryUsedSize.get() > (double)this.maxMemSize * this.memoryThreshold && (double)this.inSendListBytes.get() <= (double)this.maxMemSize * this.sendThreshold) {
            this.sendBuffersToServers();
        }
        this.mapOutputRecordCounter.increment(1L);
        this.mapOutputByteCounter.increment((long)length);
    }

    private void sendBufferToServers(WriteBuffer<K, V> buffer) {
        ArrayList<ShuffleBlockInfo> shuffleBlocks = Lists.newArrayList();
        this.prepareBufferForSend(shuffleBlocks, buffer);
        this.sendShuffleBlocks(shuffleBlocks);
    }

    void sendBuffersToServers() {
        this.waitSendBuffers.sort(new Comparator<WriteBuffer<K, V>>(){

            @Override
            public int compare(WriteBuffer<K, V> o1, WriteBuffer<K, V> o2) {
                return o2.getDataLength() - o1.getDataLength();
            }
        });
        int sendSize = this.batch;
        if (this.batch > this.waitSendBuffers.size()) {
            sendSize = this.waitSendBuffers.size();
        }
        Iterator<WriteBuffer<K, V>> iterator = this.waitSendBuffers.iterator();
        ArrayList<ShuffleBlockInfo> shuffleBlocks = Lists.newArrayList();
        for (int index = 0; iterator.hasNext() && index < sendSize; ++index) {
            WriteBuffer<K, V> buffer = iterator.next();
            this.prepareBufferForSend(shuffleBlocks, buffer);
            iterator.remove();
        }
        this.sendShuffleBlocks(shuffleBlocks);
    }

    private void prepareBufferForSend(List<ShuffleBlockInfo> shuffleBlocks, WriteBuffer buffer) {
        this.buffers.remove(buffer.getPartitionId());
        ShuffleBlockInfo block = this.createShuffleBlock(buffer);
        buffer.clear();
        shuffleBlocks.add(block);
        this.allBlockIds.add(block.getBlockId());
        if (!this.partitionToBlocks.containsKey(block.getPartitionId())) {
            this.partitionToBlocks.putIfAbsent(block.getPartitionId(), Lists.newArrayList());
        }
        this.partitionToBlocks.get(block.getPartitionId()).add(block.getBlockId());
    }

    private void sendShuffleBlocks(final List<ShuffleBlockInfo> shuffleBlocks) {
        this.sendExecutorService.submit(new Runnable(){

            /*
             * WARNING - Removed try catching itself - possible behaviour change.
             */
            @Override
            public void run() {
                long size = 0L;
                try {
                    for (ShuffleBlockInfo block : shuffleBlocks) {
                        size += block.getFreeMemory();
                    }
                    SendShuffleDataResult result = WriteBufferManager.this.shuffleWriteClient.sendShuffleData(WriteBufferManager.this.appId, shuffleBlocks, () -> false);
                    WriteBufferManager.this.successBlockIds.addAll(result.getSuccessBlockIds());
                    WriteBufferManager.this.failedBlockIds.addAll(result.getFailedBlockIds());
                }
                catch (Throwable t) {
                    LOG.warn("send shuffle data exception ", t);
                }
                finally {
                    try {
                        WriteBufferManager.this.memoryLock.lock();
                        LOG.debug("memoryUsedSize {} decrease {}", (Object)WriteBufferManager.this.memoryUsedSize, (Object)size);
                        WriteBufferManager.this.memoryUsedSize.addAndGet(-size);
                        WriteBufferManager.this.inSendListBytes.addAndGet(-size);
                        WriteBufferManager.this.full.signalAll();
                    }
                    finally {
                        WriteBufferManager.this.memoryLock.unlock();
                    }
                }
            }
        });
    }

    public void waitSendFinished() {
        long start;
        block3: {
            while (!this.waitSendBuffers.isEmpty()) {
                this.sendBuffersToServers();
            }
            start = System.currentTimeMillis();
            do {
                this.checkFailedBlocks();
                this.allBlockIds.removeAll(this.successBlockIds);
                if (this.allBlockIds.isEmpty()) break block3;
                LOG.info("Wait " + this.allBlockIds.size() + " blocks sent to shuffle server");
                Uninterruptibles.sleepUninterruptibly(this.sendCheckInterval, TimeUnit.MILLISECONDS);
            } while (System.currentTimeMillis() - start <= this.sendCheckTimeout);
            String errorMsg = "Timeout: failed because " + this.allBlockIds.size() + " blocks can't be sent to shuffle server in " + this.sendCheckTimeout + " ms.";
            LOG.error(errorMsg);
            throw new RssException(errorMsg);
        }
        long commitDuration = 0L;
        if (!this.isMemoryShuffleEnabled) {
            long s = System.currentTimeMillis();
            this.sendCommit();
            commitDuration = System.currentTimeMillis() - s;
        }
        start = System.currentTimeMillis();
        TezVertexID tezVertexID = this.tezTaskAttemptID.getTaskID().getVertexID();
        TezDAGID tezDAGID = tezVertexID.getDAGId();
        LOG.info("tezVertexID is {}, tezDAGID is {}, shuffleId is {}", new Object[]{tezVertexID, tezDAGID, this.shuffleId});
        this.shuffleWriteClient.reportShuffleResult(this.partitionToServers, this.appId, this.shuffleId, this.taskAttemptId, this.partitionToBlocks, this.bitmapSplitNum);
        LOG.info("Report shuffle result for task[{}] with bitmapNum[{}] cost {} ms", new Object[]{this.taskAttemptId, this.bitmapSplitNum, System.currentTimeMillis() - start});
        LOG.info("Task uncompressed data length {} compress time cost {} ms, commit time cost {} ms, copy time cost {} ms, sort time cost {} ms", new Object[]{this.uncompressedDataLen, this.compressTime, commitDuration, this.copyTime, this.sortTime});
    }

    private void checkFailedBlocks() {
        if (this.failedBlockIds.size() > 0) {
            String errorMsg = "Send failed: failed because " + this.failedBlockIds.size() + " blocks can't be sent to shuffle server.";
            LOG.error(errorMsg);
            throw new RssException(errorMsg);
        }
    }

    ShuffleBlockInfo createShuffleBlock(WriteBuffer wb) {
        byte[] data = wb.getData();
        this.copyTime += wb.getCopyTime();
        this.sortTime += wb.getSortTime();
        int partitionId = wb.getPartitionId();
        int uncompressLength = data.length;
        long start = System.currentTimeMillis();
        byte[] compressed = this.codec.compress(data);
        long crc32 = ChecksumUtils.getCrc32(compressed);
        this.compressTime += System.currentTimeMillis() - start;
        long blockId = RssTezUtils.getBlockId(partitionId, this.taskAttemptId, this.getNextSeqNo(partitionId));
        LOG.info("blockId is {}", (Object)blockId);
        this.uncompressedDataLen += (long)data.length;
        this.inSendListBytes.addAndGet(wb.getDataLength());
        TezVertexID tezVertexID = this.tezTaskAttemptID.getTaskID().getVertexID();
        TezDAGID tezDAGID = tezVertexID.getDAGId();
        LOG.info("tezVertexID is {}, tezDAGID is {}, shuffleId is {}", new Object[]{tezVertexID, tezDAGID, this.shuffleId});
        return new ShuffleBlockInfo(this.shuffleId, partitionId, blockId, compressed.length, crc32, compressed, this.partitionToServers.get(partitionId), uncompressLength, (long)wb.getDataLength(), this.taskAttemptId);
    }

    protected void sendCommit() {
        ExecutorService executor = Executors.newSingleThreadExecutor();
        HashSet<ShuffleServerInfo> serverInfos = Sets.newHashSet();
        for (List<ShuffleServerInfo> serverInfoList : this.partitionToServers.values()) {
            for (ShuffleServerInfo serverInfo : serverInfoList) {
                serverInfos.add(serverInfo);
            }
        }
        LOG.info("sendCommit  shuffle id is {}", (Object)this.shuffleId);
        Future<Boolean> future = executor.submit(() -> this.shuffleWriteClient.sendCommit(serverInfos, this.appId, this.shuffleId, this.numMaps));
        long start = System.currentTimeMillis();
        int currentWait = 200;
        int maxWait = 5000;
        while (!future.isDone()) {
            LOG.info("Wait commit to shuffle server for task[" + this.taskAttemptId + "] cost " + (System.currentTimeMillis() - start) + " ms");
            Uninterruptibles.sleepUninterruptibly(currentWait, TimeUnit.MILLISECONDS);
            currentWait = Math.min(currentWait * 2, maxWait);
        }
        try {
            if (!future.get().booleanValue()) {
                throw new RssException("Failed to commit task to shuffle server");
            }
        }
        catch (InterruptedException ie) {
            LOG.warn("Ignore the InterruptedException which should be caused by internal killed");
        }
        catch (Exception e) {
            throw new RssException("Exception happened when get commit status", e);
        }
        finally {
            executor.shutdown();
        }
    }

    List<WriteBuffer<K, V>> getWaitSendBuffers() {
        return this.waitSendBuffers;
    }

    private int getNextSeqNo(int partitionId) {
        this.partitionToSeqNo.putIfAbsent(partitionId, 0);
        int seqNo = this.partitionToSeqNo.get(partitionId);
        this.partitionToSeqNo.put(partitionId, seqNo + 1);
        return seqNo;
    }

    public void freeAllResources() {
        this.sendExecutorService.shutdownNow();
    }
}

