16
16
17
17
from einops import rearrange , pack , unpack
18
18
19
+ import random
20
+
19
21
# helper functions
20
22
21
23
def exists (v ):
@@ -62,9 +64,12 @@ def __init__(
62
64
channel_first : bool = False ,
63
65
projection_has_bias : bool = True ,
64
66
return_indices = True ,
65
- force_quantization_f32 = True
67
+ force_quantization_f32 = True ,
68
+ preserve_symmetry : bool = False ,
69
+ noise_approx_prob = 0.0 ,
66
70
):
67
71
super ().__init__ ()
72
+
68
73
_levels = torch .tensor (levels , dtype = int32 )
69
74
self .register_buffer ("_levels" , _levels , persistent = False )
70
75
@@ -73,6 +78,9 @@ def __init__(
73
78
74
79
self .scale = scale
75
80
81
+ self .preserve_symmetry = preserve_symmetry
82
+ self .noise_approx_prob = noise_approx_prob
83
+
76
84
codebook_dim = len (levels )
77
85
self .codebook_dim = codebook_dim
78
86
@@ -110,12 +118,36 @@ def bound(self, z, eps: float = 1e-3):
110
118
shift = (offset / half_l ).atanh ()
111
119
return (z + shift ).tanh () * half_l - offset
112
120
113
- def quantize (self , z ):
121
+ # symmetry-preserving and noise-approximated quantization, section 3.2 in https://arxiv.org/abs/2411.19842
122
+
123
+ def symmetry_preserving_bound (self , z ):
124
+ """
125
+ QL(x) = 2 / (L - 1) * [(L - 1) * (tanh(x) + 1) / 2 + 0.5] - 1
126
+ """
127
+ levels_minus_1 = (self ._levels - 1 )
128
+ scale = 2.0 / levels_minus_1
129
+ bracket = (levels_minus_1 * (torch .tanh (z ) + 1 ) / 2.0 ) + 0.5
130
+ return scale * bracket - 1.0
131
+
132
+ def noise_approx_bound (self , z ):
133
+ """
134
+ simulates quantization using noise -> Q_L(x) ~= tanh(x) + U{-1,1} / (L-1)
135
+ """
136
+ noise = torch .empty_like (z ).uniform_ (- 1 , 1 )
137
+ return torch .tanh (z ) + noise / (self ._levels - 1 )
138
+
139
+ def quantize (self , z , preserve_symmetry = False ):
114
140
""" Quantizes z, returns quantized zhat, same shape as z. """
115
- quantized = round_ste (self .bound (z ))
141
+ if self .training and random .random () < self .noise_approx_prob :
142
+ bounded = self .noise_approx_bound (z )
143
+ elif preserve_symmetry :
144
+ bounded = self .symmetry_preserving_bound (z )
145
+ else :
146
+ bounded = self .bound (z )
147
+ quantized = round_ste (bounded )
116
148
half_width = self ._levels // 2 # Renormalize to [-1, 1].
117
149
return quantized / half_width
118
-
150
+
119
151
def _scale_and_shift (self , zhat_normalized ):
120
152
half_width = self ._levels // 2
121
153
return (zhat_normalized * half_width ) + half_width
@@ -194,7 +226,7 @@ def forward(self, z):
194
226
if force_f32 and orig_dtype not in self .allowed_dtypes :
195
227
z = z .float ()
196
228
197
- codes = self .quantize (z )
229
+ codes = self .quantize (z , preserve_symmetry = self . preserve_symmetry )
198
230
199
231
# returning indices could be optional
200
232
@@ -205,7 +237,7 @@ def forward(self, z):
205
237
206
238
codes = rearrange (codes , 'b n c d -> b n (c d)' )
207
239
208
- codes = codes .type (orig_dtype )
240
+ codes = codes .to (orig_dtype )
209
241
210
242
# project out
211
243
0 commit comments