File tree 3 files changed +17
-8
lines changed
3 files changed +17
-8
lines changed Original file line number Diff line number Diff line change 1
1
[project ]
2
2
name = " vector-quantize-pytorch"
3
- version = " 1.21.0 "
3
+ version = " 1.21.1 "
4
4
description = " Vector Quantization - Pytorch"
5
5
authors = [
6
6
{ name = " Phil Wang" , email = " lucidrains@gmail.com" }
Original file line number Diff line number Diff line change @@ -6,9 +6,11 @@ def exists(v):
6
6
7
7
@pytest .mark .parametrize ('use_cosine_sim' , (True , False ))
8
8
@pytest .mark .parametrize ('rotation_trick' , (True , False ))
9
+ @pytest .mark .parametrize ('input_requires_grad' , (True , False ))
9
10
def test_vq (
10
11
use_cosine_sim ,
11
- rotation_trick
12
+ rotation_trick ,
13
+ input_requires_grad
12
14
):
13
15
from vector_quantize_pytorch import VectorQuantize
14
16
@@ -22,6 +24,10 @@ def test_vq(
22
24
)
23
25
24
26
x = torch .randn (1 , 1024 , 256 )
27
+
28
+ if input_requires_grad :
29
+ x .requires_grad_ ()
30
+
25
31
quantized , indices , commit_loss = vq (x )
26
32
27
33
def test_vq_eval ():
Original file line number Diff line number Diff line change @@ -1023,7 +1023,7 @@ def forward(
1023
1023
return_loss_breakdown = False ,
1024
1024
codebook_transform_fn : Callable | None = None
1025
1025
):
1026
- orig_input = x
1026
+ orig_input , input_requires_grad = x , x . requires_grad
1027
1027
1028
1028
# handle masking, either passed in as `mask` or `lens`
1029
1029
@@ -1117,11 +1117,14 @@ def forward(
1117
1117
1118
1118
commit_quantize = maybe_detach (quantize )
1119
1119
1120
- if self .rotation_trick :
1121
- quantize = rotate_to (x , quantize )
1122
- else :
1123
- # standard STE to get gradients through VQ layer.
1124
- quantize = x + (quantize - x ).detach ()
1120
+ # spare rotation trick calculation if inputs do not need gradients
1121
+
1122
+ if input_requires_grad :
1123
+ if self .rotation_trick :
1124
+ quantize = rotate_to (x , quantize )
1125
+ else :
1126
+ # standard STE to get gradients through VQ layer.
1127
+ quantize = x + (quantize - x ).detach ()
1125
1128
1126
1129
if self .sync_update_v > 0. :
1127
1130
# (21) in https://minyoungg.github.io/vqtorch/assets/draft_050523.pdf
You can’t perform that action at this time.
0 commit comments