Skip to content

Commit fa2211d

Browse files
committed
do not do straight through nor rotation trick if input does not require grad, to make Genie2 cleaner
1 parent c243e83 commit fa2211d

File tree

3 files changed

+17
-8
lines changed

3 files changed

+17
-8
lines changed

pyproject.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "vector-quantize-pytorch"
3-
version = "1.21.0"
3+
version = "1.21.1"
44
description = "Vector Quantization - Pytorch"
55
authors = [
66
{ name = "Phil Wang", email = "lucidrains@gmail.com" }

tests/test_readme.py

+7-1
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,11 @@ def exists(v):
66

77
@pytest.mark.parametrize('use_cosine_sim', (True, False))
88
@pytest.mark.parametrize('rotation_trick', (True, False))
9+
@pytest.mark.parametrize('input_requires_grad', (True, False))
910
def test_vq(
1011
use_cosine_sim,
11-
rotation_trick
12+
rotation_trick,
13+
input_requires_grad
1214
):
1315
from vector_quantize_pytorch import VectorQuantize
1416

@@ -22,6 +24,10 @@ def test_vq(
2224
)
2325

2426
x = torch.randn(1, 1024, 256)
27+
28+
if input_requires_grad:
29+
x.requires_grad_()
30+
2531
quantized, indices, commit_loss = vq(x)
2632

2733
def test_vq_eval():

vector_quantize_pytorch/vector_quantize_pytorch.py

+9-6
Original file line numberDiff line numberDiff line change
@@ -1023,7 +1023,7 @@ def forward(
10231023
return_loss_breakdown = False,
10241024
codebook_transform_fn: Callable | None = None
10251025
):
1026-
orig_input = x
1026+
orig_input, input_requires_grad = x, x.requires_grad
10271027

10281028
# handle masking, either passed in as `mask` or `lens`
10291029

@@ -1117,11 +1117,14 @@ def forward(
11171117

11181118
commit_quantize = maybe_detach(quantize)
11191119

1120-
if self.rotation_trick:
1121-
quantize = rotate_to(x, quantize)
1122-
else:
1123-
# standard STE to get gradients through VQ layer.
1124-
quantize = x + (quantize - x).detach()
1120+
# spare rotation trick calculation if inputs do not need gradients
1121+
1122+
if input_requires_grad:
1123+
if self.rotation_trick:
1124+
quantize = rotate_to(x, quantize)
1125+
else:
1126+
# standard STE to get gradients through VQ layer.
1127+
quantize = x + (quantize - x).detach()
11251128

11261129
if self.sync_update_v > 0.:
11271130
# (21) in https://minyoungg.github.io/vqtorch/assets/draft_050523.pdf

0 commit comments

Comments
 (0)