/*
 * Decompiled with CFR 0.152.
 */
package org.apache.celeborn.plugin.flink;

import java.io.IOException;
import java.nio.ByteBuffer;
import java.util.function.BiConsumer;
import java.util.function.Function;
import org.apache.celeborn.plugin.flink.RemoteShuffleOutputGate;
import org.apache.celeborn.plugin.flink.buffer.PartitionSortedBuffer;
import org.apache.celeborn.plugin.flink.buffer.SortBuffer;
import org.apache.celeborn.plugin.flink.utils.BufferUtils;
import org.apache.celeborn.plugin.flink.utils.Utils;
import org.apache.flink.annotation.VisibleForTesting;
import org.apache.flink.core.memory.MemorySegment;
import org.apache.flink.runtime.io.network.buffer.Buffer;
import org.apache.flink.runtime.io.network.buffer.BufferCompressor;
import org.apache.flink.runtime.io.network.buffer.BufferPool;
import org.apache.flink.runtime.io.network.buffer.BufferRecycler;
import org.apache.flink.runtime.io.network.buffer.NetworkBuffer;
import org.apache.flink.util.FlinkRuntimeException;
import org.apache.flink.util.Preconditions;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class RemoteShuffleResultPartitionDelegation {
    public static final Logger LOG = LoggerFactory.getLogger(RemoteShuffleResultPartitionDelegation.class);
    public int networkBufferSize;
    public SortBuffer broadcastSortBuffer;
    public SortBuffer unicastSortBuffer;
    public RemoteShuffleOutputGate outputGate;
    private boolean endOfDataNotified;
    private int numSubpartitions;
    private BufferPool bufferPool;
    private BufferCompressor bufferCompressor;
    private Function<Buffer, Boolean> canBeCompressed;
    private Runnable checkProducerState;
    private BiConsumer<SortBuffer.BufferWithChannel, Boolean> statisticsConsumer;

    public RemoteShuffleResultPartitionDelegation(int networkBufferSize, RemoteShuffleOutputGate outputGate, BiConsumer<SortBuffer.BufferWithChannel, Boolean> statisticsConsumer, int numSubpartitions) {
        this.networkBufferSize = networkBufferSize;
        this.outputGate = outputGate;
        this.numSubpartitions = numSubpartitions;
        this.statisticsConsumer = statisticsConsumer;
    }

    public void setup(BufferPool bufferPool, BufferCompressor bufferCompressor, Function<Buffer, Boolean> canBeCompressed, Runnable checkProduceState) throws IOException {
        LOG.info("Setup {}", (Object)this);
        this.bufferPool = bufferPool;
        this.bufferCompressor = bufferCompressor;
        this.canBeCompressed = canBeCompressed;
        this.checkProducerState = checkProduceState;
        try {
            this.outputGate.setup();
        }
        catch (Throwable throwable) {
            LOG.error("Failed to setup remote output gate.", throwable);
            Utils.rethrowAsRuntimeException(throwable);
        }
    }

    public void emit(ByteBuffer record, int targetSubpartition, Buffer.DataType dataType, boolean isBroadcast) throws IOException {
        SortBuffer sortBuffer;
        this.checkProducerState.run();
        if (isBroadcast) {
            Preconditions.checkState((targetSubpartition == 0 ? 1 : 0) != 0, (Object)"Target subpartition index can only be 0 when broadcast.");
        }
        SortBuffer sortBuffer2 = sortBuffer = isBroadcast ? this.getBroadcastSortBuffer() : this.getUnicastSortBuffer();
        if (sortBuffer.append(record, targetSubpartition, dataType)) {
            return;
        }
        try {
            if (!sortBuffer.hasRemaining()) {
                sortBuffer.finish();
                sortBuffer.release();
                this.writeLargeRecord(record, targetSubpartition, dataType, isBroadcast);
                return;
            }
            this.flushSortBuffer(sortBuffer, isBroadcast);
        }
        catch (InterruptedException e) {
            LOG.error("Failed to flush the sort buffer.", (Throwable)e);
            Utils.rethrowAsRuntimeException(e);
        }
        this.emit(record, targetSubpartition, dataType, isBroadcast);
    }

    @VisibleForTesting
    public SortBuffer getUnicastSortBuffer() throws IOException {
        this.flushBroadcastSortBuffer();
        if (this.unicastSortBuffer != null && !this.unicastSortBuffer.isFinished()) {
            return this.unicastSortBuffer;
        }
        this.unicastSortBuffer = new PartitionSortedBuffer(this.bufferPool, this.numSubpartitions, this.networkBufferSize, null);
        return this.unicastSortBuffer;
    }

    public SortBuffer getBroadcastSortBuffer() throws IOException {
        this.flushUnicastSortBuffer();
        if (this.broadcastSortBuffer != null && !this.broadcastSortBuffer.isFinished()) {
            return this.broadcastSortBuffer;
        }
        this.broadcastSortBuffer = new PartitionSortedBuffer(this.bufferPool, this.numSubpartitions, this.networkBufferSize, null);
        return this.broadcastSortBuffer;
    }

    public void flushBroadcastSortBuffer() throws IOException {
        this.flushSortBuffer(this.broadcastSortBuffer, true);
    }

    public void flushUnicastSortBuffer() throws IOException {
        this.flushSortBuffer(this.unicastSortBuffer, false);
    }

    @VisibleForTesting
    void flushSortBuffer(SortBuffer sortBuffer, boolean isBroadcast) throws IOException {
        if (sortBuffer == null || sortBuffer.isReleased()) {
            return;
        }
        sortBuffer.finish();
        if (sortBuffer.hasRemaining()) {
            try {
                this.outputGate.regionStart(isBroadcast);
                while (sortBuffer.hasRemaining()) {
                    SortBuffer.BufferWithChannel bufferWithChannel;
                    MemorySegment segment = this.outputGate.getBufferPool().requestMemorySegmentBlocking();
                    try {
                        bufferWithChannel = sortBuffer.copyIntoSegment(segment, (BufferRecycler)this.outputGate.getBufferPool(), 22);
                    }
                    catch (Throwable t) {
                        this.outputGate.getBufferPool().recycle(segment);
                        throw new FlinkRuntimeException("Shuffle write failure.", t);
                    }
                    Buffer buffer = bufferWithChannel.getBuffer();
                    int subpartitionIndex = bufferWithChannel.getChannelIndex();
                    this.statisticsConsumer.accept(bufferWithChannel, isBroadcast);
                    this.writeCompressedBufferIfPossible(buffer, subpartitionIndex);
                }
                this.outputGate.regionFinish();
            }
            catch (InterruptedException e) {
                throw new IOException("Failed to flush the sort buffer, broadcast=" + isBroadcast, e);
            }
        }
        this.releaseSortBuffer(sortBuffer);
    }

    public void writeCompressedBufferIfPossible(Buffer buffer, int targetSubpartition) throws InterruptedException {
        Buffer compressedBuffer = null;
        try {
            if (this.canBeCompressed.apply(buffer).booleanValue()) {
                Buffer dataBuffer = buffer.readOnlySlice(22, buffer.getSize() - 22);
                compressedBuffer = Utils.checkNotNull(this.bufferCompressor).compressToIntermediateBuffer(dataBuffer);
            }
            BufferUtils.setCompressedDataWithHeader(buffer, compressedBuffer);
        }
        catch (Throwable throwable) {
            buffer.recycleBuffer();
            throw new RuntimeException("Shuffle write failure.", throwable);
        }
        finally {
            if (compressedBuffer != null && compressedBuffer.isCompressed()) {
                compressedBuffer.setReaderIndex(0);
                compressedBuffer.recycleBuffer();
            }
        }
        this.outputGate.write(buffer, targetSubpartition);
    }

    public void writeLargeRecord(ByteBuffer record, int targetSubpartition, Buffer.DataType dataType, boolean isBroadcast) throws InterruptedException {
        this.outputGate.regionStart(isBroadcast);
        while (record.hasRemaining()) {
            MemorySegment writeBuffer = this.outputGate.getBufferPool().requestMemorySegmentBlocking();
            int toCopy = Math.min(record.remaining(), writeBuffer.size() - 22);
            writeBuffer.put(22, record, toCopy);
            NetworkBuffer buffer = new NetworkBuffer(writeBuffer, (BufferRecycler)this.outputGate.getBufferPool(), dataType, toCopy + 22);
            SortBuffer.BufferWithChannel bufferWithChannel = new SortBuffer.BufferWithChannel((Buffer)buffer, targetSubpartition);
            this.statisticsConsumer.accept(bufferWithChannel, isBroadcast);
            this.writeCompressedBufferIfPossible((Buffer)buffer, targetSubpartition);
        }
        this.outputGate.regionFinish();
    }

    public void broadcast(ByteBuffer record, Buffer.DataType dataType) throws IOException {
        this.emit(record, 0, dataType, true);
    }

    public void releaseSortBuffer(SortBuffer sortBuffer) {
        if (sortBuffer != null) {
            sortBuffer.release();
        }
    }

    public void finish() throws IOException {
        Utils.checkState(this.unicastSortBuffer == null || this.unicastSortBuffer.isReleased(), "The unicast sort buffer should be either null or released.");
        this.flushBroadcastSortBuffer();
        try {
            this.outputGate.finish();
        }
        catch (InterruptedException e) {
            throw new IOException("Output gate fails to finish.", e);
        }
    }

    public synchronized void close(Runnable closeHandler) {
        Throwable closeException = null;
        closeException = this.checkException(() -> this.releaseSortBuffer(this.unicastSortBuffer), closeException, "Failed to release unicast sort buffer.");
        closeException = this.checkException(() -> this.releaseSortBuffer(this.broadcastSortBuffer), closeException, "Failed to release broadcast sort buffer.");
        closeException = this.checkException(() -> closeHandler.run(), closeException, "Failed to call super#close() method.");
        try {
            this.outputGate.close();
        }
        catch (Throwable throwable) {
            closeException = closeException == null ? throwable : closeException;
            LOG.error("Failed to close remote shuffle output gate.", throwable);
        }
        if (closeException != null) {
            Utils.rethrowAsRuntimeException(closeException);
        }
    }

    public Throwable checkException(Runnable runnable, Throwable exception, String errorMessage) {
        Throwable newException = null;
        try {
            runnable.run();
        }
        catch (Throwable throwable) {
            newException = exception == null ? throwable : exception;
            LOG.error(errorMessage, throwable);
        }
        return newException;
    }

    public void flushAll() {
        try {
            this.flushUnicastSortBuffer();
            this.flushBroadcastSortBuffer();
        }
        catch (Throwable t) {
            LOG.error("Failed to flush the current sort buffer.", t);
            Utils.rethrowAsRuntimeException(t);
        }
    }

    public RemoteShuffleOutputGate getOutputGate() {
        return this.outputGate;
    }

    public boolean isEndOfDataNotified() {
        return this.endOfDataNotified;
    }

    public void setEndOfDataNotified(boolean endOfDataNotified) {
        this.endOfDataNotified = endOfDataNotified;
    }
}

