Skip to content

Commit 5ae9c79

Browse files
committed
allow for 1-dimensional channel first training for residual sim vq, for @zaptrem
1 parent 55fa8f1 commit 5ae9c79

File tree

2 files changed

+3
-3
lines changed

2 files changed

+3
-3
lines changed

pyproject.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "vector-quantize-pytorch"
3-
version = "1.21.5"
3+
version = "1.21.7"
44
description = "Vector Quantization - Pytorch"
55
authors = [
66
{ name = "Phil Wang", email = "lucidrains@gmail.com" }

vector_quantize_pytorch/residual_sim_vq.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -169,9 +169,9 @@ def forward(
169169
if quant_dropout_multiple_of != 1:
170170
rand_quantize_dropout_index = round_up_multiple(rand_quantize_dropout_index + 1, quant_dropout_multiple_of) - 1
171171

172-
null_indices_shape = (x.shape[0], *x.shape[-2:]) if self.channel_first else tuple(x.shape[:2])
172+
null_indices_shape = (x.shape[0], *x.shape[2:]) if self.channel_first else tuple(x.shape[:2])
173173
null_indices = torch.full(null_indices_shape, -1., device = device, dtype = torch.long)
174-
null_loss = torch.full((1,), 0., device = device, dtype = x.dtype)
174+
null_loss = torch.full((), 0., device = device, dtype = x.dtype)
175175

176176
# save all inputs across layers, for use during expiration at end under shared codebook setting
177177

0 commit comments

Comments
 (0)