/*
 * Decompiled with CFR 0.152.
 */
package org.apache.lucene.internal.vectorization;

import java.io.IOException;
import java.lang.foreign.MemorySegment;
import java.util.Optional;
import org.apache.lucene.index.ByteVectorValues;
import org.apache.lucene.index.KnnVectorValues;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.internal.vectorization.PanamaVectorUtilSupport;
import org.apache.lucene.store.FilterIndexInput;
import org.apache.lucene.store.IndexInput;
import org.apache.lucene.store.MemorySegmentAccessInput;
import org.apache.lucene.util.hnsw.RandomVectorScorerSupplier;
import org.apache.lucene.util.hnsw.UpdateableRandomVectorScorer;

public abstract sealed class Lucene99MemorySegmentByteVectorScorerSupplier
implements RandomVectorScorerSupplier {
    final int vectorByteSize;
    final int maxOrd;
    final MemorySegmentAccessInput input;
    final KnnVectorValues values;
    byte[] scratch1;
    byte[] scratch2;

    static Optional<RandomVectorScorerSupplier> create(VectorSimilarityFunction type, IndexInput input, KnnVectorValues values) {
        assert (values instanceof ByteVectorValues);
        if (!((input = FilterIndexInput.unwrapOnlyTest(input)) instanceof MemorySegmentAccessInput)) {
            return Optional.empty();
        }
        MemorySegmentAccessInput msInput = (MemorySegmentAccessInput)((Object)input);
        Lucene99MemorySegmentByteVectorScorerSupplier.checkInvariants(values.size(), values.getVectorByteLength(), input);
        return switch (type) {
            default -> throw new MatchException(null, null);
            case VectorSimilarityFunction.COSINE -> Optional.of(new CosineSupplier(msInput, values));
            case VectorSimilarityFunction.DOT_PRODUCT -> Optional.of(new DotProductSupplier(msInput, values));
            case VectorSimilarityFunction.EUCLIDEAN -> Optional.of(new EuclideanSupplier(msInput, values));
            case VectorSimilarityFunction.MAXIMUM_INNER_PRODUCT -> Optional.of(new MaxInnerProductSupplier(msInput, values));
        };
    }

    Lucene99MemorySegmentByteVectorScorerSupplier(MemorySegmentAccessInput input, KnnVectorValues values) {
        this.input = input;
        this.values = values;
        this.vectorByteSize = values.getVectorByteLength();
        this.maxOrd = values.size();
    }

    static void checkInvariants(int maxOrd, int vectorByteLength, IndexInput input) {
        if (input.length() < (long)vectorByteLength * (long)maxOrd) {
            throw new IllegalArgumentException("input length is less than expected vector data");
        }
    }

    final void checkOrdinal(int ord) {
        if (ord < 0 || ord >= this.maxOrd) {
            throw new IllegalArgumentException("illegal ordinal: " + ord);
        }
    }

    final MemorySegment getFirstSegment(int ord) throws IOException {
        long byteOffset = (long)ord * (long)this.vectorByteSize;
        MemorySegment seg = this.input.segmentSliceOrNull(byteOffset, this.vectorByteSize);
        if (seg == null) {
            if (this.scratch1 == null) {
                this.scratch1 = new byte[this.vectorByteSize];
            }
            this.input.readBytes(byteOffset, this.scratch1, 0, this.vectorByteSize);
            seg = MemorySegment.ofArray(this.scratch1);
        }
        return seg;
    }

    final MemorySegment getSecondSegment(int ord) throws IOException {
        long byteOffset = (long)ord * (long)this.vectorByteSize;
        MemorySegment seg = this.input.segmentSliceOrNull(byteOffset, this.vectorByteSize);
        if (seg == null) {
            if (this.scratch2 == null) {
                this.scratch2 = new byte[this.vectorByteSize];
            }
            this.input.readBytes(byteOffset, this.scratch2, 0, this.vectorByteSize);
            seg = MemorySegment.ofArray(this.scratch2);
        }
        return seg;
    }

    static final class CosineSupplier
    extends Lucene99MemorySegmentByteVectorScorerSupplier {
        CosineSupplier(MemorySegmentAccessInput input, KnnVectorValues values) {
            super(input, values);
        }

        @Override
        public UpdateableRandomVectorScorer scorer() {
            return new UpdateableRandomVectorScorer.AbstractUpdateableRandomVectorScorer(this.values){
                private int queryOrd;
                {
                    super(arg0);
                    this.queryOrd = 0;
                }

                @Override
                public float score(int node) throws IOException {
                    this.checkOrdinal(node);
                    float raw = PanamaVectorUtilSupport.cosine(this.getFirstSegment(this.queryOrd), this.getSecondSegment(node));
                    return (1.0f + raw) / 2.0f;
                }

                @Override
                public void setScoringOrdinal(int node) {
                    this.checkOrdinal(node);
                    this.queryOrd = node;
                }
            };
        }

        @Override
        public CosineSupplier copy() throws IOException {
            return new CosineSupplier(this.input.clone(), this.values);
        }
    }

    static final class DotProductSupplier
    extends Lucene99MemorySegmentByteVectorScorerSupplier {
        DotProductSupplier(MemorySegmentAccessInput input, KnnVectorValues values) {
            super(input, values);
        }

        @Override
        public UpdateableRandomVectorScorer scorer() {
            return new UpdateableRandomVectorScorer.AbstractUpdateableRandomVectorScorer(this.values){
                private int queryOrd;
                {
                    super(arg0);
                    this.queryOrd = 0;
                }

                @Override
                public float score(int node) throws IOException {
                    this.checkOrdinal(node);
                    float raw = PanamaVectorUtilSupport.dotProduct(this.getFirstSegment(this.queryOrd), this.getSecondSegment(node));
                    return 0.5f + raw / (float)(values.dimension() * 32768);
                }

                @Override
                public void setScoringOrdinal(int node) {
                    this.checkOrdinal(node);
                    this.queryOrd = node;
                }
            };
        }

        @Override
        public DotProductSupplier copy() throws IOException {
            return new DotProductSupplier(this.input.clone(), this.values);
        }
    }

    static final class EuclideanSupplier
    extends Lucene99MemorySegmentByteVectorScorerSupplier {
        EuclideanSupplier(MemorySegmentAccessInput input, KnnVectorValues values) {
            super(input, values);
        }

        @Override
        public UpdateableRandomVectorScorer scorer() {
            return new UpdateableRandomVectorScorer.AbstractUpdateableRandomVectorScorer(this.values){
                private int queryOrd;
                {
                    super(arg0);
                    this.queryOrd = 0;
                }

                @Override
                public float score(int node) throws IOException {
                    this.checkOrdinal(node);
                    float raw = PanamaVectorUtilSupport.squareDistance(this.getFirstSegment(this.queryOrd), this.getSecondSegment(node));
                    return 1.0f / (1.0f + raw);
                }

                @Override
                public void setScoringOrdinal(int node) {
                    this.checkOrdinal(node);
                    this.queryOrd = node;
                }
            };
        }

        @Override
        public EuclideanSupplier copy() throws IOException {
            return new EuclideanSupplier(this.input.clone(), this.values);
        }
    }

    static final class MaxInnerProductSupplier
    extends Lucene99MemorySegmentByteVectorScorerSupplier {
        MaxInnerProductSupplier(MemorySegmentAccessInput input, KnnVectorValues values) {
            super(input, values);
        }

        @Override
        public UpdateableRandomVectorScorer scorer() {
            return new UpdateableRandomVectorScorer.AbstractUpdateableRandomVectorScorer(this.values){
                private int queryOrd;
                {
                    super(arg0);
                    this.queryOrd = 0;
                }

                @Override
                public float score(int node) throws IOException {
                    this.checkOrdinal(node);
                    float raw = PanamaVectorUtilSupport.dotProduct(this.getFirstSegment(this.queryOrd), this.getSecondSegment(node));
                    if (raw < 0.0f) {
                        return 1.0f / (1.0f + -1.0f * raw);
                    }
                    return raw + 1.0f;
                }

                @Override
                public void setScoringOrdinal(int node) {
                    this.checkOrdinal(node);
                    this.queryOrd = node;
                }
            };
        }

        @Override
        public MaxInnerProductSupplier copy() throws IOException {
            return new MaxInnerProductSupplier(this.input.clone(), this.values);
        }
    }
}

