Skip to content

Commit 6898026

Browse files
committed
addressing missing floor #200
1 parent 4380fe8 commit 6898026

File tree

2 files changed

+35
-28
lines changed

2 files changed

+35
-28
lines changed

pyproject.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "vector-quantize-pytorch"
3-
version = "1.21.9"
3+
version = "1.22.0"
44
description = "Vector Quantization - Pytorch"
55
authors = [
66
{ name = "Phil Wang", email = "lucidrains@gmail.com" }

vector_quantize_pytorch/finite_scalar_quantization.py

+34-27
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from torch import Tensor, int32
1515
from torch.amp import autocast
1616

17+
import einx
1718
from einops import rearrange, pack, unpack
1819

1920
import random
@@ -45,11 +46,15 @@ def unpack_one(t, ps, pattern):
4546

4647
# tensor helpers
4748

48-
def round_ste(z: Tensor) -> Tensor:
49+
def round_ste(z):
4950
"""Round with straight through gradients."""
5051
zhat = z.round()
5152
return z + (zhat - z).detach()
5253

54+
def floor_ste(z):
55+
zhat = z.floor()
56+
return z + (zhat - z).detach()
57+
5358
# main class
5459

5560
class FSQ(Module):
@@ -127,41 +132,43 @@ def symmetry_preserving_bound(self, z):
127132
levels_minus_1 = (self._levels - 1)
128133
scale = 2.0 / levels_minus_1
129134
bracket = (levels_minus_1 * (torch.tanh(z) + 1) / 2.0) + 0.5
135+
bracket = floor_ste(bracket)
130136
return scale * bracket - 1.0
131137

132-
def quantize(self, z, preserve_symmetry = False):
138+
def quantize(self, z):
133139
""" Quantizes z, returns quantized zhat, same shape as z. """
134140

141+
preserve_symmetry = self.preserve_symmetry
135142
half_width = self._levels // 2
136143

137-
if self.training:
138-
unquantized = z
139-
140-
# determine where to quantize elementwise
141-
142-
quantize_mask = torch.bernoulli(
143-
torch.full([z.shape[0], 1, 1, 1], self.noise_dropout, device = z.device)
144-
).bool().expand_as(z)
145-
146-
if preserve_symmetry:
147-
quantized = round_ste(self.symmetry_preserving_bound(z)) / half_width
148-
else:
149-
quantized = round_ste(self.bound(z)) / half_width
150-
quantized = torch.where(quantize_mask, unquantized, quantized)
151-
152-
# determine where to add a random offset elementwise
153-
154-
offset_mask = torch.bernoulli(
155-
torch.full([z.shape[0], 1, 1, 1], self.noise_dropout, device = z.device)
156-
).bool().expand_as(z)
157-
158-
offset = (torch.rand_like(z) - 0.5) / half_width
159-
quantized = torch.where(offset_mask, unquantized + offset, quantized)
160-
elif preserve_symmetry:
144+
if preserve_symmetry:
161145
quantized = round_ste(self.symmetry_preserving_bound(z)) / half_width
162146
else:
163147
quantized = round_ste(self.bound(z)) / half_width
164148

149+
if not self.training:
150+
return quantized
151+
152+
batch, device, noise_dropout = z.shape[0], z.device, self.noise_dropout
153+
unquantized = z
154+
155+
# determine where to quantize elementwise
156+
157+
quantize_mask = torch.bernoulli(
158+
torch.full((batch,), noise_dropout, device = device)
159+
).bool()
160+
161+
quantized = torch.where(quantize_mask, unquantized, quantized)
162+
163+
# determine where to add a random offset elementwise
164+
165+
offset_mask = torch.bernoulli(
166+
torch.full((batch,), noise_dropout, device = device)
167+
).bool()
168+
169+
offset = (torch.rand_like(z) - 0.5) / half_width
170+
quantized = einx.where('b, b ..., b ...', offset_mask, unquantized + offset, quantized)
171+
165172
return quantized
166173

167174
def _scale_and_shift(self, zhat_normalized):
@@ -242,7 +249,7 @@ def forward(self, z):
242249
if force_f32 and orig_dtype not in self.allowed_dtypes:
243250
z = z.float()
244251

245-
codes = self.quantize(z, preserve_symmetry=self.preserve_symmetry)
252+
codes = self.quantize(z)
246253

247254
# returning indices could be optional
248255

0 commit comments

Comments
 (0)