/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *    http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.apache.spark.shuffle.writer;

import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.ConcurrentSkipListSet;
import java.util.stream.Collectors;

import org.apache.spark.shuffle.handle.ShuffleHandleInfo;
import org.apache.spark.shuffle.handle.split.PartitionSplitInfo;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import org.apache.uniffle.common.PartitionSplitMode;
import org.apache.uniffle.common.ShuffleServerInfo;
import org.apache.uniffle.common.exception.RssException;
import org.apache.uniffle.common.util.JavaUtils;

/** This class is to get the partition assignment for ShuffleWriter. */
public class TaskAttemptAssignment {
  private static final Logger LOGGER = LoggerFactory.getLogger(TaskAttemptAssignment.class);

  private Map<Integer, List<ShuffleServerInfo>> assignment;
  private ShuffleHandleInfo handle;
  private final long taskAttemptId;

  // key: partitionId, values: exclusive servers.
  // this is for the partition split mechanism with load balance mode
  private final Map<Integer, Set<ShuffleServerInfo>> exclusiveServersForPartition;

  public TaskAttemptAssignment(long taskAttemptId, ShuffleHandleInfo shuffleHandleInfo) {
    this.exclusiveServersForPartition = JavaUtils.newConcurrentMap();
    this.update(shuffleHandleInfo);
    this.handle = shuffleHandleInfo;
    this.taskAttemptId = taskAttemptId;
  }

  /**
   * Retrieving the partition's current available shuffleServers.
   *
   * @param partitionId
   * @return
   */
  public List<ShuffleServerInfo> retrieve(int partitionId) {
    return assignment.get(partitionId);
  }

  public void update(ShuffleHandleInfo handle) {
    if (handle == null) {
      throw new RssException("Errors on updating shuffle handle by the empty handleInfo.");
    }
    this.assignment =
        handle.getAvailablePartitionServersForWriter(
            this.exclusiveServersForPartition.entrySet().stream()
                .collect(Collectors.toMap(Map.Entry::getKey, x -> new ArrayList<>(x.getValue()))));
    this.handle = handle;
  }

  private boolean hasBeenLoadBalanced(int partitionId) {
    PartitionSplitInfo splitInfo = this.handle.getPartitionSplitInfo(partitionId);
    return splitInfo.isSplit() && splitInfo.getMode() == PartitionSplitMode.LOAD_BALANCE;
  }

  /**
   * If partition has been load balanced and marked as split, it could update assignment by the next
   * servers. Otherwise, it will directly return false that will trigger reassignment.
   *
   * @param partitionId
   * @param exclusiveServers
   * @return
   */
  public boolean updatePartitionSplitAssignment(
      int partitionId, List<ShuffleServerInfo> exclusiveServers) {
    if (hasBeenLoadBalanced(partitionId)) {
      Set<ShuffleServerInfo> servers =
          this.exclusiveServersForPartition.computeIfAbsent(
              partitionId, k -> new ConcurrentSkipListSet<>());
      servers.addAll(exclusiveServers);
      update(this.handle);
      return true;
    }
    return false;
  }

  /**
   * @param partitionId
   * @return all assigned shuffle servers for one partition id
   */
  public List<ShuffleServerInfo> list(int partitionId) {
    Map<Integer, List<ShuffleServerInfo>> servers = this.handle.getAllPartitionServersForReader();
    if (servers == null) {
      return Collections.emptyList();
    }
    return servers.get(partitionId);
  }
}
