|
14 | 14 | from torch import Tensor, int32
|
15 | 15 | from torch.amp import autocast
|
16 | 16 |
|
| 17 | +import einx |
17 | 18 | from einops import rearrange, pack, unpack
|
18 | 19 |
|
19 | 20 | import random
|
@@ -45,11 +46,15 @@ def unpack_one(t, ps, pattern):
|
45 | 46 |
|
46 | 47 | # tensor helpers
|
47 | 48 |
|
48 |
| -def round_ste(z: Tensor) -> Tensor: |
| 49 | +def round_ste(z): |
49 | 50 | """Round with straight through gradients."""
|
50 | 51 | zhat = z.round()
|
51 | 52 | return z + (zhat - z).detach()
|
52 | 53 |
|
| 54 | +def floor_ste(z): |
| 55 | + zhat = z.floor() |
| 56 | + return z + (zhat - z).detach() |
| 57 | + |
53 | 58 | # main class
|
54 | 59 |
|
55 | 60 | class FSQ(Module):
|
@@ -127,41 +132,43 @@ def symmetry_preserving_bound(self, z):
|
127 | 132 | levels_minus_1 = (self._levels - 1)
|
128 | 133 | scale = 2.0 / levels_minus_1
|
129 | 134 | bracket = (levels_minus_1 * (torch.tanh(z) + 1) / 2.0) + 0.5
|
| 135 | + bracket = floor_ste(bracket) |
130 | 136 | return scale * bracket - 1.0
|
131 | 137 |
|
132 |
| - def quantize(self, z, preserve_symmetry = False): |
| 138 | + def quantize(self, z): |
133 | 139 | """ Quantizes z, returns quantized zhat, same shape as z. """
|
134 | 140 |
|
| 141 | + preserve_symmetry = self.preserve_symmetry |
135 | 142 | half_width = self._levels // 2
|
136 | 143 |
|
137 |
| - if self.training: |
138 |
| - unquantized = z |
139 |
| - |
140 |
| - # determine where to quantize elementwise |
141 |
| - |
142 |
| - quantize_mask = torch.bernoulli( |
143 |
| - torch.full([z.shape[0], 1, 1, 1], self.noise_dropout, device = z.device) |
144 |
| - ).bool().expand_as(z) |
145 |
| - |
146 |
| - if preserve_symmetry: |
147 |
| - quantized = round_ste(self.symmetry_preserving_bound(z)) / half_width |
148 |
| - else: |
149 |
| - quantized = round_ste(self.bound(z)) / half_width |
150 |
| - quantized = torch.where(quantize_mask, unquantized, quantized) |
151 |
| - |
152 |
| - # determine where to add a random offset elementwise |
153 |
| - |
154 |
| - offset_mask = torch.bernoulli( |
155 |
| - torch.full([z.shape[0], 1, 1, 1], self.noise_dropout, device = z.device) |
156 |
| - ).bool().expand_as(z) |
157 |
| - |
158 |
| - offset = (torch.rand_like(z) - 0.5) / half_width |
159 |
| - quantized = torch.where(offset_mask, unquantized + offset, quantized) |
160 |
| - elif preserve_symmetry: |
| 144 | + if preserve_symmetry: |
161 | 145 | quantized = round_ste(self.symmetry_preserving_bound(z)) / half_width
|
162 | 146 | else:
|
163 | 147 | quantized = round_ste(self.bound(z)) / half_width
|
164 | 148 |
|
| 149 | + if not self.training: |
| 150 | + return quantized |
| 151 | + |
| 152 | + batch, device, noise_dropout = z.shape[0], z.device, self.noise_dropout |
| 153 | + unquantized = z |
| 154 | + |
| 155 | + # determine where to quantize elementwise |
| 156 | + |
| 157 | + quantize_mask = torch.bernoulli( |
| 158 | + torch.full((batch,), noise_dropout, device = device) |
| 159 | + ).bool() |
| 160 | + |
| 161 | + quantized = torch.where(quantize_mask, unquantized, quantized) |
| 162 | + |
| 163 | + # determine where to add a random offset elementwise |
| 164 | + |
| 165 | + offset_mask = torch.bernoulli( |
| 166 | + torch.full((batch,), noise_dropout, device = device) |
| 167 | + ).bool() |
| 168 | + |
| 169 | + offset = (torch.rand_like(z) - 0.5) / half_width |
| 170 | + quantized = einx.where('b, b ..., b ...', offset_mask, unquantized + offset, quantized) |
| 171 | + |
165 | 172 | return quantized
|
166 | 173 |
|
167 | 174 | def _scale_and_shift(self, zhat_normalized):
|
@@ -242,7 +249,7 @@ def forward(self, z):
|
242 | 249 | if force_f32 and orig_dtype not in self.allowed_dtypes:
|
243 | 250 | z = z.float()
|
244 | 251 |
|
245 |
| - codes = self.quantize(z, preserve_symmetry=self.preserve_symmetry) |
| 252 | + codes = self.quantize(z) |
246 | 253 |
|
247 | 254 | # returning indices could be optional
|
248 | 255 |
|
|
0 commit comments