/*
 * Decompiled with CFR 0.152.
 */
package org.apache.celeborn.client.write;

import java.io.IOException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Iterator;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;
import org.apache.celeborn.client.ShuffleClient;
import org.apache.celeborn.client.write.DataPusher;
import org.apache.celeborn.client.write.PushTask;
import org.apache.celeborn.common.CelebornConf;
import org.apache.celeborn.common.protocol.PartitionLocation;
import org.apache.celeborn.common.util.Utils;
import org.apache.celeborn.common.write.PushState;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class DataPushQueue {
    private static final Logger logger = LoggerFactory.getLogger(DataPushQueue.class);
    private final long WAIT_TIME_NANOS = TimeUnit.MILLISECONDS.toNanos(500L);
    private final LinkedBlockingQueue<PushTask> workingQueue;
    private final PushState pushState;
    private final DataPusher dataPusher;
    private final int maxInFlightPerWorker;
    private final int shuffleId;
    private final int numMappers;
    private final int numPartitions;
    private final ShuffleClient client;
    private final long takeTaskWaitIntervalMs;
    private final int takeTaskMaxWaitAttempts;

    public DataPushQueue(CelebornConf conf, DataPusher dataPusher, ShuffleClient client, int shuffleId, int mapId, int attemptId, int numMappers, int numPartitions) {
        this.shuffleId = shuffleId;
        this.numMappers = numMappers;
        this.numPartitions = numPartitions;
        this.client = client;
        this.dataPusher = dataPusher;
        String mapKey = Utils.makeMapKey(shuffleId, mapId, attemptId);
        this.pushState = client.getPushState(mapKey);
        this.maxInFlightPerWorker = conf.clientPushMaxReqsInFlightPerWorker();
        this.takeTaskWaitIntervalMs = conf.clientPushTakeTaskWaitIntervalMs();
        this.takeTaskMaxWaitAttempts = conf.clientPushTakeTaskMaxWaitAttempts();
        int capacity = conf.clientPushQueueCapacity();
        this.workingQueue = new LinkedBlockingQueue(capacity);
    }

    public ArrayList<PushTask> takePushTasks() throws IOException, InterruptedException {
        ArrayList<PushTask> tasks = new ArrayList<PushTask>();
        HashMap<String, Integer> workerCapacity = new HashMap<String, Integer>();
        HashMap<String, AtomicInteger> workerWaitAttempts = new HashMap<String, AtomicInteger>();
        while (this.dataPusher.stillRunning()) {
            workerCapacity.clear();
            Iterator<PushTask> iterator = this.workingQueue.iterator();
            while (iterator.hasNext()) {
                PushTask task = iterator.next();
                int partitionId = task.getPartitionId();
                ConcurrentHashMap<Integer, PartitionLocation> partitionLocationMap = this.client.getPartitionLocation(this.shuffleId, this.numMappers, this.numPartitions);
                if (partitionLocationMap != null) {
                    PartitionLocation loc = (PartitionLocation)partitionLocationMap.get(partitionId);
                    if (loc != null) {
                        Integer oldCapacity = (Integer)workerCapacity.get(loc.hostAndPushPort());
                        if (oldCapacity == null) {
                            oldCapacity = this.maxInFlightPerWorker - this.pushState.inflightPushes(loc.hostAndPushPort());
                            workerCapacity.put(loc.hostAndPushPort(), oldCapacity);
                        }
                        workerWaitAttempts.putIfAbsent(loc.hostAndPushPort(), new AtomicInteger(0));
                        if (oldCapacity > 0) {
                            iterator.remove();
                            tasks.add(task);
                            workerCapacity.put(loc.hostAndPushPort(), oldCapacity - 1);
                            continue;
                        }
                        if (((AtomicInteger)workerWaitAttempts.get(loc.hostAndPushPort())).get() < this.takeTaskMaxWaitAttempts) continue;
                        iterator.remove();
                        tasks.add(task);
                        ((AtomicInteger)workerWaitAttempts.get(loc.hostAndPushPort())).set(0);
                        continue;
                    }
                    iterator.remove();
                    tasks.add(task);
                    continue;
                }
                iterator.remove();
                tasks.add(task);
            }
            if (!tasks.isEmpty()) {
                return tasks;
            }
            try {
                Thread.sleep(this.takeTaskWaitIntervalMs);
                workerWaitAttempts.values().forEach(AtomicInteger::incrementAndGet);
            }
            catch (InterruptedException ie) {
                logger.info("Thread interrupted while waiting push task.");
                throw ie;
            }
        }
        return tasks;
    }

    public boolean addPushTask(PushTask pushTask) throws InterruptedException {
        return this.workingQueue.offer(pushTask, this.WAIT_TIME_NANOS, TimeUnit.NANOSECONDS);
    }

    public void clear() {
        this.workingQueue.clear();
    }
}

