/*
 * Decompiled with CFR 0.152.
 */
package org.opensearch.knn.index.mapper;

import java.util.Arrays;
import java.util.Locale;
import lombok.Generated;
import org.apache.lucene.document.FieldType;
import org.apache.lucene.document.StoredField;
import org.apache.lucene.index.DocValuesType;
import org.apache.lucene.util.BytesRef;
import org.opensearch.index.mapper.ParametrizedFieldMapper;
import org.opensearch.knn.common.KNNConstants;
import org.opensearch.knn.common.KNNValidationUtil;
import org.opensearch.knn.index.KNNMethodContext;
import org.opensearch.knn.index.VectorDataType;
import org.opensearch.knn.index.codec.util.KNNVectorSerializerFactory;
import org.opensearch.knn.index.mapper.KNNVectorFieldMapper;
import org.opensearch.knn.index.util.KNNEngine;
import org.opensearch.knn.indices.ModelDao;
import org.opensearch.knn.indices.ModelMetadata;
import org.opensearch.knn.indices.ModelUtil;

public class KNNVectorFieldMapperUtil {
    private static ModelDao modelDao;

    public static void initialize(ModelDao modelDao) {
        KNNVectorFieldMapperUtil.modelDao = modelDao;
    }

    public static void validateFP16VectorValue(float value) {
        KNNValidationUtil.validateFloatVectorValue(value);
        if (value < KNNConstants.FP16_MIN_VALUE.floatValue() || value > KNNConstants.FP16_MAX_VALUE.floatValue()) {
            throw new IllegalArgumentException(String.format(Locale.ROOT, "encoder name is set as [%s] and type is set as [%s] in index mapping. But, KNN vector values are not within in the FP16 range [%f, %f]", "sq", "fp16", KNNConstants.FP16_MIN_VALUE, KNNConstants.FP16_MAX_VALUE));
        }
    }

    public static float clipVectorValueToFP16Range(float value) {
        KNNValidationUtil.validateFloatVectorValue(value);
        if (value < KNNConstants.FP16_MIN_VALUE.floatValue()) {
            return KNNConstants.FP16_MIN_VALUE.floatValue();
        }
        if (value > KNNConstants.FP16_MAX_VALUE.floatValue()) {
            return KNNConstants.FP16_MAX_VALUE.floatValue();
        }
        return value;
    }

    public static void validateVectorDataType(KNNMethodContext methodContext, VectorDataType vectorDataType) {
        if (VectorDataType.FLOAT == vectorDataType) {
            return;
        }
        if (VectorDataType.BYTE == vectorDataType) {
            if (KNNEngine.LUCENE == methodContext.getKnnEngine()) {
                return;
            }
            throw new IllegalArgumentException(String.format(Locale.ROOT, "[%s] field with value [%s] is only supported for [%s] engine", "data_type", vectorDataType.getValue(), "lucene"));
        }
        if (VectorDataType.BINARY == vectorDataType) {
            if (KNNEngine.FAISS == methodContext.getKnnEngine()) {
                if ("hnsw".equals(methodContext.getMethodComponentContext().getName())) {
                    return;
                }
                throw new IllegalArgumentException(String.format(Locale.ROOT, "[%s] field with value [%s] is only supported for [%s] method", "data_type", vectorDataType.getValue(), "hnsw"));
            }
            throw new IllegalArgumentException(String.format(Locale.ROOT, "[%s] field with value [%s] is only supported for [%s] engine", "data_type", vectorDataType.getValue(), "faiss"));
        }
        throw new IllegalArgumentException("This line should not be reached");
    }

    public static void validateVectorDataTypeWithKnnIndexSetting(boolean knnIndexSetting, ParametrizedFieldMapper.Parameter<VectorDataType> vectorDataType) {
        if (VectorDataType.FLOAT == vectorDataType.getValue()) {
            return;
        }
        if (knnIndexSetting) {
            throw new IllegalArgumentException(String.format(Locale.ROOT, "[%s] field with value [%s] is not supported for [%s] engine", "data_type", ((VectorDataType)((Object)vectorDataType.getValue())).getValue(), "nmslib"));
        }
    }

    public static FieldType buildDocValuesFieldType(KNNEngine knnEngine) {
        FieldType field = new FieldType();
        field.putAttribute("engine", knnEngine.getName());
        field.setDocValuesType(DocValuesType.BINARY);
        field.freeze();
        return field;
    }

    public static StoredField createStoredFieldForByteVector(String name, byte[] vector) {
        return new StoredField(name, vector);
    }

    public static StoredField createStoredFieldForFloatVector(String name, float[] vector) {
        return new StoredField(name, KNNVectorSerializerFactory.getDefaultSerializer().floatToByteArray(vector));
    }

    public static Object deserializeStoredVector(BytesRef storedVector, VectorDataType vectorDataType) {
        if (VectorDataType.BYTE == vectorDataType) {
            byte[] bytes = storedVector.bytes;
            int[] byteAsIntArray = new int[bytes.length];
            Arrays.setAll(byteAsIntArray, i -> bytes[i]);
            return byteAsIntArray;
        }
        return vectorDataType.getVectorFromBytesRef(storedVector);
    }

    public static int getExpectedVectorLength(KNNVectorFieldMapper.KNNVectorFieldType knnVectorFieldType) {
        int expectedDimensions = knnVectorFieldType.getDimension();
        if (KNNVectorFieldMapperUtil.isModelBasedIndex(expectedDimensions)) {
            ModelMetadata modelMetadata = KNNVectorFieldMapperUtil.getModelMetadataForField(knnVectorFieldType);
            expectedDimensions = modelMetadata.getDimension();
        }
        return VectorDataType.BINARY == knnVectorFieldType.getVectorDataType() ? expectedDimensions / 8 : expectedDimensions;
    }

    private static boolean isModelBasedIndex(int expectedDimensions) {
        return expectedDimensions == -1;
    }

    private static ModelMetadata getModelMetadataForField(KNNVectorFieldMapper.KNNVectorFieldType knnVectorField) {
        String modelId = knnVectorField.getModelId();
        if (modelId == null) {
            throw new IllegalArgumentException(String.format("Field '%s' does not have model.", knnVectorField.getKnnMethodContext().getMethodComponentContext().getName()));
        }
        ModelMetadata modelMetadata = modelDao.getMetadata(modelId);
        if (!ModelUtil.isModelCreated(modelMetadata)) {
            throw new IllegalArgumentException(String.format("Model ID '%s' is not created.", modelId));
        }
        return modelMetadata;
    }

    @Generated
    private KNNVectorFieldMapperUtil() {
    }
}

