Skip to content

Commit a766304

Browse files
committed
allow for arbitrary dimensions into SimVQ and ResidualSimVQ (video and beyond)
1 parent 448c4a5 commit a766304

File tree

5 files changed

+23
-28
lines changed

5 files changed

+23
-28
lines changed

examples/autoencoder_sim_vq.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
num_codes = 256
1616
seed = 1234
1717

18-
rotation_trick = True # rotation trick instead ot straight-through
18+
rotation_trick = True # rotation trick instead ot straight-through
1919
use_mlp = True # use a one layer mlp with relu instead of linear
2020

2121
device = "cuda" if torch.cuda.is_available() else "cpu"

pyproject.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "vector-quantize-pytorch"
3-
version = "1.20.8"
3+
version = "1.20.9"
44
description = "Vector Quantization - Pytorch"
55
authors = [
66
{ name = "Phil Wang", email = "lucidrains@gmail.com" }

tests/test_readme.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -385,7 +385,7 @@ def test_residual_sim_vq():
385385
dim = 512,
386386
num_quantizers = 4,
387387
codebook_size = 1024,
388-
accept_image_fmap = True
388+
channel_first = True
389389
)
390390

391391
x = torch.randn(1, 512, 32, 32)

vector_quantize_pytorch/residual_sim_vq.py

+8-14
Original file line numberDiff line numberDiff line change
@@ -58,20 +58,20 @@ def __init__(
5858
quantize_dropout = False,
5959
quantize_dropout_cutoff_index = 0,
6060
quantize_dropout_multiple_of = 1,
61-
accept_image_fmap = False,
61+
channel_first = False,
6262
rotation_trick = True, # rotation trick from @cfifty, on top of sim vq
6363
**sim_vq_kwargs
6464
):
6565
super().__init__()
6666
assert heads == 1, 'residual vq is not compatible with multi-headed codes'
6767

68-
self.accept_image_fmap = accept_image_fmap
68+
self.channel_first = channel_first
6969

7070
self.num_quantizers = num_quantizers
7171

7272
# define sim vq across layers
7373

74-
self.layers = ModuleList([SimVQ(dim = dim, codebook_size = codebook_size, rotation_trick = rotation_trick, accept_image_fmap = accept_image_fmap, **sim_vq_kwargs) for _ in range(num_quantizers)])
74+
self.layers = ModuleList([SimVQ(dim = dim, codebook_size = codebook_size, rotation_trick = rotation_trick, channel_first = channel_first, **sim_vq_kwargs) for _ in range(num_quantizers)])
7575

7676
# quantize dropout
7777

@@ -100,7 +100,7 @@ def get_codes_from_indices(self, indices):
100100

101101
batch, quantize_dim = indices.shape[0], indices.shape[-1]
102102

103-
# may also receive indices in the shape of 'b h w q' (accept_image_fmap)
103+
# may also receive indices in the shape of 'b h w q' (images)
104104

105105
indices, inverse = pack_one(indices, 'b * q')
106106

@@ -122,11 +122,11 @@ def get_codes_from_indices(self, indices):
122122

123123
all_codes = all_codes.masked_fill(rearrange(mask, 'b n q -> q b n 1'), 0.)
124124

125-
# if (accept_image_fmap = True) then return shape (quantize, batch, height, width, dimension)
125+
# if (channel_first = True) then return shape (quantize, batch, height, width, dimension)
126126

127127
all_codes = inverse(all_codes, 'q b * d')
128128

129-
if self.accept_image_fmap:
129+
if self.channel_first:
130130
all_codes = rearrange(all_codes, 'q b ... d -> q b d ...')
131131

132132
return all_codes
@@ -139,23 +139,17 @@ def get_output_from_indices(self, indices):
139139
def forward(
140140
self,
141141
x,
142-
indices: Tensor | list[Tensor] | None = None,
143142
return_all_codes = False,
144143
rand_quantize_dropout_fixed_seed = None
145144
):
146-
num_quant, quant_dropout_multiple_of, return_loss, device = self.num_quantizers, self.quantize_dropout_multiple_of, exists(indices), x.device
147-
148-
assert not (self.accept_image_fmap and exists(indices))
145+
num_quant, quant_dropout_multiple_of, device = self.num_quantizers, self.quantize_dropout_multiple_of, x.device
149146

150147
quantized_out = 0.
151148
residual = x
152149

153150
all_losses = []
154151
all_indices = []
155152

156-
if isinstance(indices, list):
157-
indices = torch.stack(indices)
158-
159153
should_quantize_dropout = self.training and self.quantize_dropout and not return_loss
160154

161155
# sample a layer index at which to dropout further residual quantization
@@ -175,7 +169,7 @@ def forward(
175169
if quant_dropout_multiple_of != 1:
176170
rand_quantize_dropout_index = round_up_multiple(rand_quantize_dropout_index + 1, quant_dropout_multiple_of) - 1
177171

178-
null_indices_shape = (x.shape[0], *x.shape[-2:]) if self.accept_image_fmap else tuple(x.shape[:2])
172+
null_indices_shape = (x.shape[0], *x.shape[-2:]) if self.channel_first else tuple(x.shape[:2])
179173
null_indices = torch.full(null_indices_shape, -1., device = device, dtype = torch.long)
180174
null_loss = torch.full((1,), 0., device = device, dtype = x.dtype)
181175

vector_quantize_pytorch/sim_vq.py

+12-11
Original file line numberDiff line numberDiff line change
@@ -41,15 +41,15 @@ def __init__(
4141
codebook_size,
4242
codebook_transform: Module | None = None,
4343
init_fn: Callable = identity,
44-
accept_image_fmap = False,
44+
channel_first = False,
4545
rotation_trick = True, # works even better with rotation trick turned on, with no straight through and the commit loss from input to quantize
4646
input_to_quantize_commit_loss_weight = 0.25,
4747
commitment_weight = 1.,
4848
frozen_codebook_dim = None # frozen codebook dim could have different dimensions than projection
4949
):
5050
super().__init__()
5151
self.codebook_size = codebook_size
52-
self.accept_image_fmap = accept_image_fmap
52+
self.channel_first = channel_first
5353

5454
frozen_codebook_dim = default(frozen_codebook_dim, dim)
5555
codebook = torch.randn(codebook_size, frozen_codebook_dim) * (frozen_codebook_dim ** -0.5)
@@ -92,7 +92,7 @@ def indices_to_codes(
9292
frozen_codes = get_at('[c] d, b ... -> b ... d', self.frozen_codebook, indices)
9393
quantized = self.code_transform(frozen_codes)
9494

95-
if self.accept_image_fmap:
95+
if self.channel_first:
9696
quantized = rearrange(quantized, 'b ... d -> b d ...')
9797

9898
return quantized
@@ -101,9 +101,10 @@ def forward(
101101
self,
102102
x
103103
):
104-
if self.accept_image_fmap:
105-
x = rearrange(x, 'b d h w -> b h w d')
106-
x, inverse_pack = pack_one(x, 'b * d')
104+
if self.channel_first:
105+
x = rearrange(x, 'b d ... -> b ... d')
106+
107+
x, inverse_pack = pack_one(x, 'b * d')
107108

108109
implicit_codebook = self.codebook
109110

@@ -131,11 +132,11 @@ def forward(
131132

132133
quantized = (quantized - x).detach() + x
133134

134-
if self.accept_image_fmap:
135-
quantized = inverse_pack(quantized)
136-
quantized = rearrange(quantized, 'b h w d-> b d h w')
135+
quantized = inverse_pack(quantized)
136+
indices = inverse_pack(indices, 'b *')
137137

138-
indices = inverse_pack(indices, 'b *')
138+
if self.channel_first:
139+
quantized = rearrange(quantized, 'b ... d-> b d ...')
139140

140141
return quantized, indices, commit_loss * self.commitment_weight
141142

@@ -153,7 +154,7 @@ def forward(
153154
nn.Linear(1024, 512)
154155
),
155156
codebook_size = 1024,
156-
accept_image_fmap = True
157+
channel_first = True
157158
)
158159

159160
quantized, indices, commit_loss = sim_vq(x)

0 commit comments

Comments
 (0)