/*
 * Decompiled with CFR 0.152.
 */
package org.apache.celeborn.common.write;

import java.io.IOException;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.LongAdder;
import java.util.stream.Collectors;
import org.apache.celeborn.common.CelebornConf;
import org.apache.celeborn.common.exception.CelebornIOException;
import org.apache.celeborn.common.util.JavaUtils;
import org.apache.celeborn.common.write.PushState;
import org.apache.celeborn.common.write.PushStrategy;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class InFlightRequestTracker {
    private static final Logger logger = LoggerFactory.getLogger(InFlightRequestTracker.class);
    private final long waitInflightTimeoutMs;
    private final long delta;
    private final PushState pushState;
    private final PushStrategy pushStrategy;
    private final AtomicInteger batchId = new AtomicInteger();
    private final ConcurrentHashMap<String, Set<Integer>> inflightBatchesPerAddress = JavaUtils.newConcurrentHashMap();
    private final int maxInFlightReqsTotal;
    private final LongAdder totalInflightReqs = new LongAdder();

    public InFlightRequestTracker(CelebornConf conf, PushState pushState) {
        this.waitInflightTimeoutMs = conf.clientPushLimitInFlightTimeoutMs();
        this.delta = conf.clientPushLimitInFlightSleepDeltaMs();
        this.pushState = pushState;
        this.pushStrategy = PushStrategy.getStrategy(conf);
        this.maxInFlightReqsTotal = conf.clientPushMaxReqsInFlightTotal();
    }

    public void addBatch(int batchId, String hostAndPushPort) {
        Set batchIdSetPerPair = this.inflightBatchesPerAddress.computeIfAbsent(hostAndPushPort, id -> ConcurrentHashMap.newKeySet());
        batchIdSetPerPair.add(batchId);
        this.totalInflightReqs.increment();
    }

    public void removeBatch(int batchId, String hostAndPushPort) {
        Set<Integer> batchIdSet = this.inflightBatchesPerAddress.get(hostAndPushPort);
        if (batchIdSet != null) {
            batchIdSet.remove(batchId);
        } else {
            logger.warn("BatchIdSet of {} is null.", (Object)hostAndPushPort);
        }
        this.totalInflightReqs.decrement();
    }

    public void onSuccess(String hostAndPushPort) {
        this.pushStrategy.onSuccess(hostAndPushPort);
    }

    public void onCongestControl(String hostAndPushPort) {
        this.pushStrategy.onCongestControl(hostAndPushPort);
    }

    public Set<Integer> getBatchIdSetByAddressPair(String hostAndPort) {
        return this.inflightBatchesPerAddress.computeIfAbsent(hostAndPort, pair -> ConcurrentHashMap.newKeySet());
    }

    public boolean limitMaxInFlight(String hostAndPushPort) throws IOException {
        long times;
        if (this.pushState.exception.get() != null) {
            throw this.pushState.exception.get();
        }
        this.pushStrategy.limitPushSpeed(this.pushState, hostAndPushPort);
        int currentMaxReqsInFlight = this.pushStrategy.getCurrentMaxReqsInFlight(hostAndPushPort);
        Set<Integer> batchIdSet = this.getBatchIdSetByAddressPair(hostAndPushPort);
        try {
            for (times = this.waitInflightTimeoutMs / this.delta; times > 0L && (this.totalInflightReqs.sum() > (long)this.maxInFlightReqsTotal || batchIdSet.size() > currentMaxReqsInFlight); --times) {
                if (this.pushState.exception.get() != null) {
                    throw this.pushState.exception.get();
                }
                Thread.sleep(this.delta);
            }
        }
        catch (InterruptedException e) {
            this.pushState.exception.set(new CelebornIOException(e));
        }
        if (times <= 0L) {
            logger.warn("After waiting for {} ms, there are still {} batches in flight for hostAndPushPort {}, which exceeds the current limit {}.", new Object[]{this.waitInflightTimeoutMs, batchIdSet.size(), hostAndPushPort, currentMaxReqsInFlight});
        }
        if (this.pushState.exception.get() != null) {
            throw this.pushState.exception.get();
        }
        return times <= 0L;
    }

    public boolean limitZeroInFlight() throws IOException {
        long times;
        if (this.pushState.exception.get() != null) {
            throw this.pushState.exception.get();
        }
        try {
            for (times = this.waitInflightTimeoutMs / this.delta; times > 0L && this.totalInflightReqs.sum() != 0L; --times) {
                if (this.pushState.exception.get() != null) {
                    throw this.pushState.exception.get();
                }
                Thread.sleep(this.delta);
            }
        }
        catch (InterruptedException e) {
            this.pushState.exception.set(new CelebornIOException(e));
        }
        if (times <= 0L) {
            logger.error("After waiting for {} ms, there are still {} batches in flight for hostAndPushPort {}, which exceeds the current limit 0.", new Object[]{this.waitInflightTimeoutMs, this.totalInflightReqs.sum(), this.inflightBatchesPerAddress.keySet().stream().collect(Collectors.joining(", ", "[", "]"))});
        }
        if (this.pushState.exception.get() != null) {
            throw this.pushState.exception.get();
        }
        return times <= 0L;
    }

    protected int nextBatchId() {
        return this.batchId.incrementAndGet();
    }

    public void cleanup() {
        if (!this.inflightBatchesPerAddress.isEmpty()) {
            logger.warn("Clear {}", (Object)this.getClass().getSimpleName());
            this.inflightBatchesPerAddress.clear();
        }
        this.pushStrategy.clear();
    }
}

