/*
 * Decompiled with CFR 0.152.
 */
package ai.onnxruntime;

import ai.onnxruntime.OnnxJavaType;
import java.lang.reflect.Array;
import java.nio.Buffer;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.nio.DoubleBuffer;
import java.nio.FloatBuffer;
import java.nio.IntBuffer;
import java.nio.LongBuffer;
import java.nio.ShortBuffer;
import java.util.ArrayList;
import java.util.Arrays;

public final class OrtUtil {
    private OrtUtil() {
    }

    public static int[] transformShape(long[] shape) {
        if (shape.length == 0 || shape.length > 8) {
            throw new IllegalArgumentException("Arrays with less than 1 and greater than 8 dimensions are not supported.");
        }
        int[] newShape = new int[shape.length];
        for (int i = 0; i < shape.length; ++i) {
            long curDim = shape[i];
            if (curDim < 1L || curDim > Integer.MAX_VALUE) {
                throw new IllegalArgumentException("Invalid shape for a Java array, expected positive entries smaller than Integer.MAX_VALUE. Found " + Arrays.toString(shape));
            }
            newShape[i] = (int)curDim;
        }
        return newShape;
    }

    public static long[] transformShape(int[] shape) {
        if (shape.length == 0 || shape.length > 8) {
            throw new IllegalArgumentException("Arrays with less than 1 and greater than 8 dimensions are not supported.");
        }
        long[] newShape = new long[shape.length];
        for (int i = 0; i < shape.length; ++i) {
            long curDim = shape[i];
            if (curDim < 1L) {
                throw new IllegalArgumentException("Invalid shape for a Java array, expected positive entries smaller than Integer.MAX_VALUE. Found " + Arrays.toString(shape));
            }
            newShape[i] = curDim;
        }
        return newShape;
    }

    public static Object newBooleanArray(long[] shape) {
        int[] intShape = OrtUtil.transformShape(shape);
        return Array.newInstance(Boolean.TYPE, intShape);
    }

    public static Object newByteArray(long[] shape) {
        int[] intShape = OrtUtil.transformShape(shape);
        return Array.newInstance(Byte.TYPE, intShape);
    }

    public static Object newShortArray(long[] shape) {
        int[] intShape = OrtUtil.transformShape(shape);
        return Array.newInstance(Short.TYPE, intShape);
    }

    public static Object newIntArray(long[] shape) {
        int[] intShape = OrtUtil.transformShape(shape);
        return Array.newInstance(Integer.TYPE, intShape);
    }

    public static Object newLongArray(long[] shape) {
        int[] intShape = OrtUtil.transformShape(shape);
        return Array.newInstance(Long.TYPE, intShape);
    }

    public static Object newFloatArray(long[] shape) {
        int[] intShape = OrtUtil.transformShape(shape);
        return Array.newInstance(Float.TYPE, intShape);
    }

    public static Object newDoubleArray(long[] shape) {
        int[] intShape = OrtUtil.transformShape(shape);
        return Array.newInstance(Double.TYPE, intShape);
    }

    public static Object newStringArray(long[] shape) {
        int[] intShape = OrtUtil.transformShape(shape);
        return Array.newInstance(String.class, intShape);
    }

    public static Object reshape(boolean[] input, long[] shape) {
        Object output = OrtUtil.newBooleanArray(shape);
        OrtUtil.reshape(input, output, 0);
        return output;
    }

    public static Object reshape(byte[] input, long[] shape) {
        Object output = OrtUtil.newByteArray(shape);
        OrtUtil.reshape(input, output, 0);
        return output;
    }

    public static Object reshape(short[] input, long[] shape) {
        Object output = OrtUtil.newShortArray(shape);
        OrtUtil.reshape(input, output, 0);
        return output;
    }

    public static Object reshape(int[] input, long[] shape) {
        Object output = OrtUtil.newIntArray(shape);
        OrtUtil.reshape(input, output, 0);
        return output;
    }

    public static Object reshape(long[] input, long[] shape) {
        Object output = OrtUtil.newLongArray(shape);
        OrtUtil.reshape(input, output, 0);
        return output;
    }

    public static Object reshape(float[] input, long[] shape) {
        Object output = OrtUtil.newFloatArray(shape);
        OrtUtil.reshape(input, output, 0);
        return output;
    }

    public static Object reshape(double[] input, long[] shape) {
        Object output = OrtUtil.newDoubleArray(shape);
        OrtUtil.reshape(input, output, 0);
        return output;
    }

    public static Object reshape(String[] input, long[] shape) {
        Object output = OrtUtil.newStringArray(shape);
        OrtUtil.reshape(input, output, 0);
        return output;
    }

    private static int reshape(Object input, Object output, int position) {
        if (output.getClass().isArray()) {
            Object[] outputArray;
            for (Object outputElement : outputArray = (Object[])output) {
                Class<?> outputElementClass = outputElement.getClass();
                if (outputElementClass.isArray()) {
                    Class<?> componentType = outputElementClass.getComponentType();
                    if (componentType.isPrimitive() || componentType == String.class) {
                        int length = Array.getLength(outputElement);
                        System.arraycopy(input, position, outputElement, 0, length);
                        position += length;
                        continue;
                    }
                    position = OrtUtil.reshape(input, outputElement, position);
                    continue;
                }
                throw new IllegalStateException("Found element type when expecting an array. Class " + outputElementClass);
            }
        } else {
            throw new IllegalStateException("Found element type when expecting an array. Class " + output.getClass());
        }
        return position;
    }

    public static long elementCount(long[] shape) {
        long count = 1L;
        for (int i = 0; i < shape.length; ++i) {
            if (shape[i] > 0L) {
                count *= shape[i];
                continue;
            }
            throw new IllegalArgumentException("Received non-positive value in shape " + Arrays.toString(shape) + " .");
        }
        return count;
    }

    public static boolean validateShape(long[] shape) {
        boolean valid = true;
        for (int i = 0; i < shape.length; ++i) {
            valid &= shape[i] > 0L;
            valid &= (long)((int)shape[i]) == shape[i];
        }
        return valid && shape.length <= 8;
    }

    public static String[] flattenString(Object o) {
        if (o instanceof String[]) {
            return (String[])o;
        }
        ArrayList<String> output = new ArrayList<String>();
        OrtUtil.flattenString((Object[])o, output);
        return output.toArray(new String[0]);
    }

    private static void flattenString(Object[] input, ArrayList<String> output) {
        for (Object i : input) {
            Class<?> iClazz = i.getClass();
            if (iClazz.isArray()) {
                if (iClazz.getComponentType().isArray()) {
                    OrtUtil.flattenString((Object[])i, output);
                    continue;
                }
                if (iClazz.getComponentType().equals(String.class)) {
                    output.addAll(Arrays.asList((String[])i));
                    continue;
                }
                throw new IllegalStateException("Found a non-String, non-array element type, " + iClazz);
            }
            throw new IllegalStateException("Found an element type where there should have been an array. Class = " + iClazz);
        }
    }

    static Object convertBoxedPrimitiveToArray(OnnxJavaType javaType, Object data) {
        switch (javaType) {
            case FLOAT: {
                float[] floatArr = new float[]{((Float)data).floatValue()};
                return floatArr;
            }
            case DOUBLE: {
                double[] doubleArr = new double[]{(Double)data};
                return doubleArr;
            }
            case UINT8: 
            case INT8: {
                byte[] byteArr = new byte[]{(Byte)data};
                return byteArr;
            }
            case INT16: {
                short[] shortArr = new short[]{(Short)data};
                return shortArr;
            }
            case INT32: {
                int[] intArr = new int[]{(Integer)data};
                return intArr;
            }
            case INT64: {
                long[] longArr = new long[]{(Long)data};
                return longArr;
            }
            case BOOL: {
                boolean[] booleanArr = new boolean[]{(Boolean)data};
                return booleanArr;
            }
        }
        return null;
    }

    static int capacityFromSize(int size) {
        return (int)((double)size / 0.75 + 1.0);
    }

    static BufferTuple prepareBuffer(Buffer data, OnnxJavaType type) {
        int bufferPos;
        Buffer tmp;
        long bufferSizeLong = (long)data.remaining() * (long)type.size;
        if (bufferSizeLong > (long)(Integer.MAX_VALUE - 8 * type.size)) {
            throw new IllegalStateException("Cannot allocate a direct buffer of the requested size and type, size " + data.remaining() + ", type = " + (Object)((Object)type));
        }
        int bufferSize = data.remaining() * type.size;
        if (data.isDirect()) {
            tmp = data;
            bufferPos = data.position() * type.size;
        } else {
            int origPosition = data.position();
            ByteBuffer buffer = ByteBuffer.allocateDirect(bufferSize).order(ByteOrder.nativeOrder());
            switch (type) {
                case FLOAT: {
                    tmp = buffer.asFloatBuffer().put((FloatBuffer)data);
                    break;
                }
                case DOUBLE: {
                    tmp = buffer.asDoubleBuffer().put((DoubleBuffer)data);
                    break;
                }
                case UINT8: 
                case INT8: {
                    tmp = buffer.put((ByteBuffer)data);
                    break;
                }
                case INT16: {
                    tmp = buffer.asShortBuffer().put((ShortBuffer)data);
                    break;
                }
                case INT32: {
                    tmp = buffer.asIntBuffer().put((IntBuffer)data);
                    break;
                }
                case INT64: {
                    tmp = buffer.asLongBuffer().put((LongBuffer)data);
                    break;
                }
                default: {
                    throw new IllegalStateException("Impossible to reach here, managed to cast a buffer as an incorrect type");
                }
            }
            data.position(origPosition);
            tmp.rewind();
            bufferPos = 0;
        }
        return new BufferTuple(tmp, bufferPos, bufferSize, data.remaining(), tmp != data);
    }

    static final class BufferTuple {
        final Buffer data;
        final int pos;
        final long byteSize;
        final long size;
        final boolean isCopy;

        BufferTuple(Buffer data, int pos, long byteSize, long size, boolean isCopy) {
            this.data = data;
            this.pos = pos;
            this.byteSize = byteSize;
            this.size = size;
            this.isCopy = isCopy;
        }
    }
}

