Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove query-time usage of ByteSequence::slice in PQVectors to reduce object allocations #403

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,8 @@ protected CachingDecoder(PQVectors cv, VectorFloat<?> query, VectorSimilarityFun
}
}

protected float decodedSimilarity(ByteSequence<?> encoded) {
return VectorUtil.assembleAndSum(partialSums, cv.pq.getClusterCount(), encoded);
protected float decodedSimilarity(ByteSequence<?> encoded, int offset, int length) {
return VectorUtil.assembleAndSum(partialSums, cv.pq.getClusterCount(), encoded, offset, length);
}
}

Expand All @@ -65,7 +65,7 @@ public DotProductDecoder(PQVectors cv, VectorFloat<?> query) {

@Override
public float similarityTo(int node2) {
return (1 + decodedSimilarity(cv.get(node2))) / 2;
return (1 + decodedSimilarity(cv.getChunk(node2), cv.getOffsetInChunk(node2), cv.pq.getSubspaceCount())) / 2;
}
}

Expand All @@ -76,7 +76,7 @@ public EuclideanDecoder(PQVectors cv, VectorFloat<?> query) {

@Override
public float similarityTo(int node2) {
return 1 / (1 + decodedSimilarity(cv.get(node2)));
return 1 / (1 + decodedSimilarity(cv.getChunk(node2), cv.getOffsetInChunk(node2), cv.pq.getSubspaceCount()));
}
}

Expand Down Expand Up @@ -132,9 +132,10 @@ public float similarityTo(int node2) {

protected float decodedCosine(int node2) {

ByteSequence<?> encoded = cv.get(node2);
ByteSequence<?> encoded = cv.getChunk(node2);
int offset = cv.getOffsetInChunk(node2);

return VectorUtil.pqDecodedCosineSimilarity(encoded, cv.pq.getClusterCount(), partialSums, aMagnitude, bMagnitude);
return VectorUtil.pqDecodedCosineSimilarity(encoded, offset, cv.pq.getSubspaceCount(), cv.pq.getClusterCount(), partialSums, aMagnitude, bMagnitude);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -229,11 +229,12 @@ public ScoreFunction.ApproximateScoreFunction scoreFunctionFor(VectorFloat<?> q,
switch (similarityFunction) {
case DOT_PRODUCT:
return (node2) -> {
var encoded = get(node2);
var encodedChunk = getChunk(node2);
var encodedOffset = getOffsetInChunk(node2);
// compute the dot product of the query and the codebook centroids corresponding to the encoded points
float dp = 0;
for (int m = 0; m < pq.getSubspaceCount(); m++) {
int centroidIndex = Byte.toUnsignedInt(encoded.get(m));
int centroidIndex = Byte.toUnsignedInt(encodedChunk.get(m + encodedOffset));
int centroidLength = pq.subvectorSizesAndOffsets[m][0];
int centroidOffset = pq.subvectorSizesAndOffsets[m][1];
dp += VectorUtil.dotProduct(pq.codebooks[m], centroidIndex * centroidLength, centeredQuery, centroidOffset, centroidLength);
Expand All @@ -244,12 +245,13 @@ public ScoreFunction.ApproximateScoreFunction scoreFunctionFor(VectorFloat<?> q,
case COSINE:
float norm1 = VectorUtil.dotProduct(centeredQuery, centeredQuery);
return (node2) -> {
var encoded = get(node2);
var encodedChunk = getChunk(node2);
var encodedOffset = getOffsetInChunk(node2);
// compute the dot product of the query and the codebook centroids corresponding to the encoded points
float sum = 0;
float norm2 = 0;
for (int m = 0; m < pq.getSubspaceCount(); m++) {
int centroidIndex = Byte.toUnsignedInt(encoded.get(m));
int centroidIndex = Byte.toUnsignedInt(encodedChunk.get(m + encodedOffset));
int centroidLength = pq.subvectorSizesAndOffsets[m][0];
int centroidOffset = pq.subvectorSizesAndOffsets[m][1];
var codebookOffset = centroidIndex * centroidLength;
Expand All @@ -262,11 +264,12 @@ public ScoreFunction.ApproximateScoreFunction scoreFunctionFor(VectorFloat<?> q,
};
case EUCLIDEAN:
return (node2) -> {
var encoded = get(node2);
var encodedChunk = getChunk(node2);
var encodedOffset = getOffsetInChunk(node2);
// compute the euclidean distance between the query and the codebook centroids corresponding to the encoded points
float sum = 0;
for (int m = 0; m < pq.getSubspaceCount(); m++) {
int centroidIndex = Byte.toUnsignedInt(encoded.get(m));
int centroidIndex = Byte.toUnsignedInt(encodedChunk.get(m + encodedOffset));
int centroidLength = pq.subvectorSizesAndOffsets[m][0];
int centroidOffset = pq.subvectorSizesAndOffsets[m][1];
sum += VectorUtil.squareL2Distance(pq.codebooks[m], centroidIndex * centroidLength, centeredQuery, centroidOffset, centroidLength);
Expand All @@ -279,17 +282,49 @@ public ScoreFunction.ApproximateScoreFunction scoreFunctionFor(VectorFloat<?> q,
}
}

/**
* Returns a {@link ByteSequence} for the given ordinal.
* @param ordinal the vector's ordinal
* @return the {@link ByteSequence}
*/
public ByteSequence<?> get(int ordinal) {
if (ordinal < 0 || ordinal >= count())
throw new IndexOutOfBoundsException("Ordinal " + ordinal + " out of bounds for vector count " + count());
return get(compressedDataChunks, ordinal, vectorsPerChunk, pq.getSubspaceCount());
}

static ByteSequence<?> get(ByteSequence<?>[] chunks, int ordinal, int vectorsPerChunk, int subspaceCount) {
int chunkIndex = ordinal / vectorsPerChunk;
int vectorIndexInChunk = ordinal % vectorsPerChunk;
int start = vectorIndexInChunk * subspaceCount;
return chunks[chunkIndex].slice(start, subspaceCount);
return getChunk(chunks, ordinal, vectorsPerChunk).slice(start, subspaceCount);
}

/**
* Returns a reference to the {@link ByteSequence} containing for the given ordinal. Only intended for use where
* the caller wants to avoid an allocation for the slice object. After getting the chunk, callers should use the
* {@link #getOffsetInChunk(int)} method to get the offset of the vector within the chunk and then use the pq's
* {@link ProductQuantization#getSubspaceCount()} to get the length of the vector.
* @param ordinal the vector's ordinal
* @return the {@link ByteSequence} chunk containing the vector
*/
ByteSequence<?> getChunk(int ordinal) {
if (ordinal < 0 || ordinal >= count())
throw new IndexOutOfBoundsException("Ordinal " + ordinal + " out of bounds for vector count " + count());

return getChunk(compressedDataChunks, ordinal, vectorsPerChunk);
}

int getOffsetInChunk(int ordinal) {
if (ordinal < 0 || ordinal >= count())
throw new IndexOutOfBoundsException("Ordinal " + ordinal + " out of bounds for vector count " + count());

int vectorIndexInChunk = ordinal % vectorsPerChunk;
return vectorIndexInChunk * pq.getSubspaceCount();
}

static ByteSequence<?> getChunk(ByteSequence<?>[] chunks, int ordinal, int vectorsPerChunk) {
int chunkIndex = ordinal / vectorsPerChunk;
return chunks[chunkIndex];
}


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -296,9 +296,14 @@ public void minInPlace(VectorFloat<?> v1, VectorFloat<?> v2) {

@Override
public float assembleAndSum(VectorFloat<?> data, int dataBase, ByteSequence<?> baseOffsets) {
return assembleAndSum(data, dataBase, baseOffsets, 0, baseOffsets.length());
}

@Override
public float assembleAndSum(VectorFloat<?> data, int dataBase, ByteSequence<?> baseOffsets, int baseOffsetsOffset, int baseOffsetsLength) {
float sum = 0f;
for (int i = 0; i < baseOffsets.length(); i++) {
sum += data.get(dataBase * i + Byte.toUnsignedInt(baseOffsets.get(i)));
for (int i = 0; i < baseOffsetsLength; i++) {
sum += data.get(dataBase * i + Byte.toUnsignedInt(baseOffsets.get(i + baseOffsetsOffset)));
}
return sum;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,10 @@ public static float assembleAndSum(VectorFloat<?> data, int dataBase, ByteSequen
return impl.assembleAndSum(data, dataBase, dataOffsets);
}

public static float assembleAndSum(VectorFloat<?> data, int dataBase, ByteSequence<?> dataOffsets, int dataOffsetsOffset, int dataOffsetsLength) {
return impl.assembleAndSum(data, dataBase, dataOffsets, dataOffsetsOffset, dataOffsetsLength);
}

public static void bulkShuffleQuantizedSimilarity(ByteSequence<?> shuffles, int codebookCount, ByteSequence<?> quantizedPartials, float delta, float minDistance, VectorFloat<?> results, VectorSimilarityFunction vsf) {
impl.bulkShuffleQuantizedSimilarity(shuffles, codebookCount, quantizedPartials, delta, minDistance, vsf, results);
}
Expand Down Expand Up @@ -215,6 +219,10 @@ public static float pqDecodedCosineSimilarity(ByteSequence<?> encoded, int clust
return impl.pqDecodedCosineSimilarity(encoded, clusterCount, partialSums, aMagnitude, bMagnitude);
}

public static float pqDecodedCosineSimilarity(ByteSequence<?> encoded, int encodedOffset, int encodedLength, int clusterCount, VectorFloat<?> partialSums, VectorFloat<?> aMagnitude, float bMagnitude) {
return impl.pqDecodedCosineSimilarity(encoded, encodedOffset, encodedLength, clusterCount, partialSums, aMagnitude, bMagnitude);
}

public static float nvqDotProduct8bit(VectorFloat<?> vector, ByteSequence<?> bytes, float growthRate, float midpoint, float minValue, float maxValue) {
return impl.nvqDotProduct8bit(vector, bytes, growthRate, midpoint, minValue, maxValue);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,19 @@ public interface VectorUtilSupport {
*/
float assembleAndSum(VectorFloat<?> data, int baseIndex, ByteSequence<?> baseOffsets);

/**
* Calculates the sum of sparse points in a vector.
*
* @param data the vector of all datapoints
* @param baseIndex the start of the data in the offset table
* (scaled by the index of the lookup table)
* @param baseOffsets bytes that represent offsets from the baseIndex
* @param baseOffsetsOffset the offset into the baseOffsets ByteSequence
* @param baseOffsetsLength the length of the baseOffsets ByteSequence to use
* @return the sum of the points
*/
float assembleAndSum(VectorFloat<?> data, int baseIndex, ByteSequence<?> baseOffsets, int baseOffsetsOffset, int baseOffsetsLength);

int hammingDistance(long[] v1, long[] v2);


Expand Down Expand Up @@ -212,12 +225,17 @@ default void bulkShuffleQuantizedSimilarityCosine(ByteSequence<?> shuffles, int
float min(VectorFloat<?> v);

default float pqDecodedCosineSimilarity(ByteSequence<?> encoded, int clusterCount, VectorFloat<?> partialSums, VectorFloat<?> aMagnitude, float bMagnitude)
{
return pqDecodedCosineSimilarity(encoded, 0, encoded.length(), clusterCount, partialSums, aMagnitude, bMagnitude);
}

default float pqDecodedCosineSimilarity(ByteSequence<?> encoded, int encodedOffset, int encodedLength, int clusterCount, VectorFloat<?> partialSums, VectorFloat<?> aMagnitude, float bMagnitude)
{
float sum = 0.0f;
float aMag = 0.0f;

for (int m = 0; m < encoded.length(); ++m) {
int centroidIndex = Byte.toUnsignedInt(encoded.get(m));
for (int m = 0; m < encodedLength; ++m) {
int centroidIndex = Byte.toUnsignedInt(encoded.get(m + encodedOffset));
var index = m * clusterCount + centroidIndex;
sum += partialSums.get(index);
aMag += aMagnitude.get(index);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,14 @@ public float assembleAndSum(VectorFloat<?> data, int dataBase, ByteSequence<?> b
return NativeSimdOps.assemble_and_sum_f32_512(((MemorySegmentVectorFloat)data).get(), dataBase, ((MemorySegmentByteSequence)baseOffsets).get(), baseOffsets.length());
}

@Override
public float assembleAndSum(VectorFloat<?> data, int dataBase, ByteSequence<?> baseOffsets, int baseOffsetsOffset, int baseOffsetsLength)
{
assert baseOffsetsOffset == 0;
assert baseOffsetsLength == baseOffsets.length();
return assembleAndSum(data, dataBase, baseOffsets);
}

@Override
public int hammingDistance(long[] v1, long[] v2) {
return VectorSimdOps.hammingDistance(v1, v2);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,13 @@ public void minInPlace(VectorFloat<?> v1, VectorFloat<?> v2) {

@Override
public float assembleAndSum(VectorFloat<?> data, int dataBase, ByteSequence<?> baseOffsets) {
return SimdOps.assembleAndSum(((ArrayVectorFloat) data).get(), dataBase, ((ByteSequence<byte[]>) baseOffsets));
return SimdOps.assembleAndSum(((ArrayVectorFloat) data).get(), dataBase, ((ByteSequence<byte[]>) baseOffsets),
0, baseOffsets.length());
}

@Override
public float assembleAndSum(VectorFloat<?> data, int dataBase, ByteSequence<?> baseOffsets, int baseOffsetsOffset, int baseOffsetsLength) {
return SimdOps.assembleAndSum(((ArrayVectorFloat) data).get(), dataBase, ((ByteSequence<byte[]>) baseOffsets), baseOffsetsOffset, baseOffsetsLength);
}

@Override
Expand Down Expand Up @@ -177,9 +183,14 @@ public void quantizePartials(float delta, VectorFloat<?> partials, VectorFloat<?
}

@Override
public float pqDecodedCosineSimilarity(ByteSequence<?> encoded, int clusterCount, VectorFloat<?> partialSums, VectorFloat<?> aMagnitude, float bMagnitude)
public float pqDecodedCosineSimilarity(ByteSequence<?> encoded, int clusterCount, VectorFloat<?> partialSums, VectorFloat<?> aMagnitude, float bMagnitude) {
return pqDecodedCosineSimilarity(encoded, 0, encoded.length(), clusterCount, partialSums, aMagnitude, bMagnitude);
}

@Override
public float pqDecodedCosineSimilarity(ByteSequence<?> encoded, int encodedOffset, int encodedLength, int clusterCount, VectorFloat<?> partialSums, VectorFloat<?> aMagnitude, float bMagnitude)
{
return SimdOps.pqDecodedCosineSimilarity((ByteSequence<byte[]>) encoded, clusterCount, (ArrayVectorFloat) partialSums, (ArrayVectorFloat) aMagnitude, bMagnitude);
return SimdOps.pqDecodedCosineSimilarity((ByteSequence<byte[]>) encoded, encodedOffset, encodedLength, clusterCount, (ArrayVectorFloat) partialSums, (ArrayVectorFloat) aMagnitude, bMagnitude);
}

@Override
Expand Down
Loading