Skip to content

Commit a0e8f2c

Browse files
committed
add .indices_to_codes for SimVQ
1 parent 3bb00f5 commit a0e8f2c

File tree

3 files changed

+34
-3
lines changed

3 files changed

+34
-3
lines changed

pyproject.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "vector-quantize-pytorch"
3-
version = "1.20.2"
3+
version = "1.20.3"
44
description = "Vector Quantization - Pytorch"
55
authors = [
66
{ name = "Phil Wang", email = "lucidrains@gmail.com" }

tests/test_readme.py

+14
Original file line numberDiff line numberDiff line change
@@ -362,3 +362,17 @@ def test_latent_q():
362362

363363
assert image_feats.shape == quantized.shape
364364
assert (quantized == quantizer.indices_to_codes(indices)).all()
365+
366+
def test_sim_vq():
367+
from vector_quantize_pytorch import SimVQ
368+
369+
sim_vq = SimVQ(
370+
dim = 512,
371+
codebook_size = 1024,
372+
)
373+
374+
x = torch.randn(1, 1024, 512)
375+
quantized, indices, commit_loss = sim_vq(x)
376+
377+
assert x.shape == quantized.shape
378+
assert torch.allclose(quantized, sim_vq.indices_to_codes(indices), atol = 1e-6)

vector_quantize_pytorch/sim_vq.py

+19-2
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ def __init__(
5858

5959
self.codebook_to_codes = codebook_transform
6060

61-
self.register_buffer('codebook', codebook)
61+
self.register_buffer('frozen_codebook', codebook)
6262

6363

6464
# whether to use rotation trick from Fifty et al.
@@ -70,6 +70,23 @@ def __init__(
7070

7171
self.input_to_quantize_commit_loss_weight = input_to_quantize_commit_loss_weight
7272

73+
@property
74+
def codebook(self):
75+
return self.codebook_to_codes(self.frozen_codebook)
76+
77+
def indices_to_codes(
78+
self,
79+
indices
80+
):
81+
implicit_codebook = self.codebook
82+
83+
quantized = get_at('[c] d, b ... -> b ... d', implicit_codebook, indices)
84+
85+
if self.accept_image_fmap:
86+
quantized = rearrange(quantized, 'b ... d -> b d ...')
87+
88+
return quantized
89+
7390
def forward(
7491
self,
7592
x
@@ -78,7 +95,7 @@ def forward(
7895
x = rearrange(x, 'b d h w -> b h w d')
7996
x, inverse_pack = pack_one(x, 'b * d')
8097

81-
implicit_codebook = self.codebook_to_codes(self.codebook)
98+
implicit_codebook = self.codebook
8299

83100
with torch.no_grad():
84101
dist = torch.cdist(x, implicit_codebook)

0 commit comments

Comments
 (0)