@@ -40,8 +40,8 @@ def __init__(
40
40
codebook_size ,
41
41
init_fn : Callable = identity ,
42
42
accept_image_fmap = False ,
43
- rotation_trick = True , # works even better with rotation trick turned on, with no asymmetric commit loss or straight through
44
- commit_loss_input_to_quantize_weight = 0.25 ,
43
+ rotation_trick = True , # works even better with rotation trick turned on, with no straight through and the commit loss from input to quantize
44
+ input_to_quantize_commit_loss_weight = 0.25 ,
45
45
):
46
46
super ().__init__ ()
47
47
self .accept_image_fmap = accept_image_fmap
@@ -59,11 +59,10 @@ def __init__(
59
59
# https://arxiv.org/abs/2410.06424
60
60
61
61
self .rotation_trick = rotation_trick
62
- self .register_buffer ('zero' , torch .tensor (0. ), persistent = False )
63
62
64
63
# commit loss weighting - weighing input to quantize a bit less is crucial for it to work
65
64
66
- self .commit_loss_input_to_quantize_weight = commit_loss_input_to_quantize_weight
65
+ self .input_to_quantize_commit_loss_weight = input_to_quantize_commit_loss_weight
67
66
68
67
def forward (
69
68
self ,
@@ -83,18 +82,18 @@ def forward(
83
82
84
83
quantized = get_at ('[c] d, b n -> b n d' , implicit_codebook , indices )
85
84
85
+ # commit loss and straight through, as was done in the paper
86
+
87
+ commit_loss = F .mse_loss (x .detach (), quantized )
88
+
86
89
if self .rotation_trick :
87
90
# rotation trick from @cfifty
88
-
89
91
quantized = rotate_from_to (quantized , x )
90
-
91
- commit_loss = self .zero
92
92
else :
93
- # commit loss and straight through, as was done in the paper
94
93
95
94
commit_loss = (
96
- F . mse_loss ( x , quantized . detach ()) * self . commit_loss_input_to_quantize_weight +
97
- F .mse_loss (x .detach (), quantized )
95
+ commit_loss +
96
+ F .mse_loss (x , quantized .detach ()) * self . input_to_quantize_commit_loss_weight
98
97
)
99
98
100
99
quantized = (quantized - x ).detach () + x
0 commit comments