Skip to content

Commit cd0fa8e

Browse files
committed
still need comit loss from quantize to input to optimize linear projection in SimVQ, best combination is rotation trick without straight though and without commit loss from input to quantize
1 parent 72ede73 commit cd0fa8e

File tree

2 files changed

+10
-11
lines changed

2 files changed

+10
-11
lines changed

pyproject.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "vector-quantize-pytorch"
3-
version = "1.20.0"
3+
version = "1.20.1"
44
description = "Vector Quantization - Pytorch"
55
authors = [
66
{ name = "Phil Wang", email = "lucidrains@gmail.com" }

vector_quantize_pytorch/sim_vq.py

+9-10
Original file line numberDiff line numberDiff line change
@@ -40,8 +40,8 @@ def __init__(
4040
codebook_size,
4141
init_fn: Callable = identity,
4242
accept_image_fmap = False,
43-
rotation_trick = True, # works even better with rotation trick turned on, with no asymmetric commit loss or straight through
44-
commit_loss_input_to_quantize_weight = 0.25,
43+
rotation_trick = True, # works even better with rotation trick turned on, with no straight through and the commit loss from input to quantize
44+
input_to_quantize_commit_loss_weight = 0.25,
4545
):
4646
super().__init__()
4747
self.accept_image_fmap = accept_image_fmap
@@ -59,11 +59,10 @@ def __init__(
5959
# https://arxiv.org/abs/2410.06424
6060

6161
self.rotation_trick = rotation_trick
62-
self.register_buffer('zero', torch.tensor(0.), persistent = False)
6362

6463
# commit loss weighting - weighing input to quantize a bit less is crucial for it to work
6564

66-
self.commit_loss_input_to_quantize_weight = commit_loss_input_to_quantize_weight
65+
self.input_to_quantize_commit_loss_weight = input_to_quantize_commit_loss_weight
6766

6867
def forward(
6968
self,
@@ -83,18 +82,18 @@ def forward(
8382

8483
quantized = get_at('[c] d, b n -> b n d', implicit_codebook, indices)
8584

85+
# commit loss and straight through, as was done in the paper
86+
87+
commit_loss = F.mse_loss(x.detach(), quantized)
88+
8689
if self.rotation_trick:
8790
# rotation trick from @cfifty
88-
8991
quantized = rotate_from_to(quantized, x)
90-
91-
commit_loss = self.zero
9292
else:
93-
# commit loss and straight through, as was done in the paper
9493

9594
commit_loss = (
96-
F.mse_loss(x, quantized.detach()) * self.commit_loss_input_to_quantize_weight +
97-
F.mse_loss(x.detach(), quantized)
95+
commit_loss +
96+
F.mse_loss(x, quantized.detach()) * self.input_to_quantize_commit_loss_weight
9897
)
9998

10099
quantized = (quantized - x).detach() + x

0 commit comments

Comments
 (0)