Skip to content

Commit 8f5b428

Browse files
committed
more efficient to select out the frozen codes and then do projection in indices_to_codes for SimVQ
1 parent a0e8f2c commit 8f5b428

File tree

2 files changed

+5
-4
lines changed

2 files changed

+5
-4
lines changed

pyproject.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "vector-quantize-pytorch"
3-
version = "1.20.3"
3+
version = "1.20.4"
44
description = "Vector Quantization - Pytorch"
55
authors = [
66
{ name = "Phil Wang", email = "lucidrains@gmail.com" }

vector_quantize_pytorch/sim_vq.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ def __init__(
5656
if not exists(codebook_transform):
5757
codebook_transform = nn.Linear(dim, dim, bias = False)
5858

59-
self.codebook_to_codes = codebook_transform
59+
self.code_transform = codebook_transform
6060

6161
self.register_buffer('frozen_codebook', codebook)
6262

@@ -72,15 +72,16 @@ def __init__(
7272

7373
@property
7474
def codebook(self):
75-
return self.codebook_to_codes(self.frozen_codebook)
75+
return self.code_transform(self.frozen_codebook)
7676

7777
def indices_to_codes(
7878
self,
7979
indices
8080
):
8181
implicit_codebook = self.codebook
8282

83-
quantized = get_at('[c] d, b ... -> b ... d', implicit_codebook, indices)
83+
frozen_codes = get_at('[c] d, b ... -> b ... d', self.frozen_codebook, indices)
84+
quantized = self.code_transform(frozen_codes)
8485

8586
if self.accept_image_fmap:
8687
quantized = rearrange(quantized, 'b ... d -> b d ...')

0 commit comments

Comments
 (0)