/*
 * Decompiled with CFR 0.152.
 */
package org.opensearch.ml.engine.algorithms.remote;

import java.io.BufferedReader;
import java.io.InputStream;
import java.io.InputStreamReader;
import java.net.URI;
import java.nio.charset.StandardCharsets;
import java.security.AccessController;
import java.util.List;
import java.util.Map;
import lombok.Generated;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.opensearch.OpenSearchStatusException;
import org.opensearch.core.rest.RestStatus;
import org.opensearch.ml.common.connector.AwsConnector;
import org.opensearch.ml.common.connector.Connector;
import org.opensearch.ml.common.exception.MLException;
import org.opensearch.ml.common.input.MLInput;
import org.opensearch.ml.common.output.model.ModelTensors;
import org.opensearch.ml.engine.algorithms.remote.ConnectorUtils;
import org.opensearch.ml.engine.algorithms.remote.RemoteConnectorExecutor;
import org.opensearch.ml.engine.annotation.ConnectorExecutor;
import org.opensearch.script.ScriptService;
import software.amazon.awssdk.core.internal.http.loader.DefaultSdkHttpClientBuilder;
import software.amazon.awssdk.core.sync.RequestBody;
import software.amazon.awssdk.http.AbortableInputStream;
import software.amazon.awssdk.http.ContentStreamProvider;
import software.amazon.awssdk.http.HttpExecuteRequest;
import software.amazon.awssdk.http.HttpExecuteResponse;
import software.amazon.awssdk.http.SdkHttpClient;
import software.amazon.awssdk.http.SdkHttpFullRequest;
import software.amazon.awssdk.http.SdkHttpMethod;
import software.amazon.awssdk.http.SdkHttpRequest;

@ConnectorExecutor(value="aws_sigv4")
public class AwsConnectorExecutor
implements RemoteConnectorExecutor {
    @Generated
    private static final Logger log = LogManager.getLogger(AwsConnectorExecutor.class);
    private AwsConnector connector;
    private final SdkHttpClient httpClient;
    private ScriptService scriptService;

    public AwsConnectorExecutor(Connector connector, SdkHttpClient httpClient) {
        this.connector = (AwsConnector)connector;
        this.httpClient = httpClient;
    }

    public AwsConnectorExecutor(Connector connector) {
        this(connector, new DefaultSdkHttpClientBuilder().build());
    }

    @Override
    public void invokeRemoteModel(MLInput mlInput, Map<String, String> parameters, String payload, List<ModelTensors> tensorOutputs) {
        try {
            String endpoint = this.connector.getPredictEndpoint(parameters);
            RequestBody requestBody = RequestBody.fromString((String)payload);
            SdkHttpFullRequest.Builder builder = SdkHttpFullRequest.builder().method(SdkHttpMethod.POST).uri(URI.create(endpoint)).contentStreamProvider(requestBody.contentStreamProvider());
            Map headers = this.connector.getDecryptedHeaders();
            if (headers != null) {
                for (String key : headers.keySet()) {
                    builder.putHeader(key, (String)headers.get(key));
                }
            }
            SdkHttpFullRequest request = builder.build();
            HttpExecuteRequest executeRequest = HttpExecuteRequest.builder().request((SdkHttpRequest)this.signRequest(request)).contentStreamProvider((ContentStreamProvider)request.contentStreamProvider().orElse(null)).build();
            HttpExecuteResponse response = AccessController.doPrivileged(() -> this.httpClient.prepareRequest(executeRequest).call());
            AbortableInputStream body = null;
            if (response.responseBody().isPresent()) {
                body = (AbortableInputStream)response.responseBody().get();
            }
            StringBuilder responseBuilder = new StringBuilder();
            if (body != null) {
                try (BufferedReader reader = new BufferedReader(new InputStreamReader((InputStream)body, StandardCharsets.UTF_8));){
                    String line;
                    while ((line = reader.readLine()) != null) {
                        responseBuilder.append(line);
                    }
                }
            } else {
                throw new OpenSearchStatusException("No response from model", RestStatus.BAD_REQUEST, new Object[0]);
            }
            String modelResponse = responseBuilder.toString();
            ModelTensors tensors = ConnectorUtils.processOutput(modelResponse, (Connector)this.connector, this.scriptService, parameters);
            tensorOutputs.add(tensors);
        }
        catch (RuntimeException exception) {
            log.error("Failed to execute predict in aws connector: " + exception.getMessage(), (Throwable)exception);
            throw exception;
        }
        catch (Throwable e) {
            log.error("Failed to execute predict in aws connector", e);
            throw new MLException("Fail to execute predict in aws connector", e);
        }
    }

    private SdkHttpFullRequest signRequest(SdkHttpFullRequest request) {
        String accessKey = this.connector.getAccessKey();
        String secretKey = this.connector.getSecretKey();
        String sessionToken = this.connector.getSessionToken();
        String signingName = this.connector.getServiceName();
        String region = this.connector.getRegion();
        return ConnectorUtils.signRequest(request, accessKey, secretKey, sessionToken, signingName, region);
    }

    @Generated
    public AwsConnector getConnector() {
        return this.connector;
    }

    @Override
    @Generated
    public void setScriptService(ScriptService scriptService) {
        this.scriptService = scriptService;
    }

    @Override
    @Generated
    public ScriptService getScriptService() {
        return this.scriptService;
    }
}

