Skip to content

Commit e2abbe5

Browse files
authored
Merge pull request #185 from lucasnewman/fsq_preserve_symmetry
Add symmetry-preserving and noise-approximated quantization for FSQ from arxiv:2411.19842
2 parents 8c27a0e + 09c33f3 commit e2abbe5

File tree

2 files changed

+43
-8
lines changed

2 files changed

+43
-8
lines changed

tests/test_readme.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -219,11 +219,14 @@ def test_tiger():
219219

220220
assert torch.allclose(quantized, quantized_out, atol = 1e-5)
221221

222-
def test_fsq():
222+
@pytest.mark.parametrize('preserve_symmetry', (True, False))
223+
def test_fsq(
224+
preserve_symmetry
225+
):
223226
from vector_quantize_pytorch import FSQ
224227

225228
levels = [8,5,5,5] # see 4.1 and A.4.1 in the paper
226-
quantizer = FSQ(levels)
229+
quantizer = FSQ(levels, preserve_symmetry = preserve_symmetry)
227230

228231
x = torch.randn(1, 1024, 4) # 4 since there are 4 levels
229232
xhat, indices = quantizer(x)

vector_quantize_pytorch/finite_scalar_quantization.py

+38-6
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@
1616

1717
from einops import rearrange, pack, unpack
1818

19+
import random
20+
1921
# helper functions
2022

2123
def exists(v):
@@ -62,9 +64,12 @@ def __init__(
6264
channel_first: bool = False,
6365
projection_has_bias: bool = True,
6466
return_indices = True,
65-
force_quantization_f32 = True
67+
force_quantization_f32 = True,
68+
preserve_symmetry: bool = False,
69+
noise_approx_prob = 0.0,
6670
):
6771
super().__init__()
72+
6873
_levels = torch.tensor(levels, dtype=int32)
6974
self.register_buffer("_levels", _levels, persistent = False)
7075

@@ -73,6 +78,9 @@ def __init__(
7378

7479
self.scale = scale
7580

81+
self.preserve_symmetry = preserve_symmetry
82+
self.noise_approx_prob = noise_approx_prob
83+
7684
codebook_dim = len(levels)
7785
self.codebook_dim = codebook_dim
7886

@@ -110,12 +118,36 @@ def bound(self, z, eps: float = 1e-3):
110118
shift = (offset / half_l).atanh()
111119
return (z + shift).tanh() * half_l - offset
112120

113-
def quantize(self, z):
121+
# symmetry-preserving and noise-approximated quantization, section 3.2 in https://arxiv.org/abs/2411.19842
122+
123+
def symmetry_preserving_bound(self, z):
124+
"""
125+
QL(x) = 2 / (L - 1) * [(L - 1) * (tanh(x) + 1) / 2 + 0.5] - 1
126+
"""
127+
levels_minus_1 = (self._levels - 1)
128+
scale = 2.0 / levels_minus_1
129+
bracket = (levels_minus_1 * (torch.tanh(z) + 1) / 2.0) + 0.5
130+
return scale * bracket - 1.0
131+
132+
def noise_approx_bound(self, z):
133+
"""
134+
simulates quantization using noise -> Q_L(x) ~= tanh(x) + U{-1,1} / (L-1)
135+
"""
136+
noise = torch.empty_like(z).uniform_(-1, 1)
137+
return torch.tanh(z) + noise / (self._levels - 1)
138+
139+
def quantize(self, z, preserve_symmetry = False):
114140
""" Quantizes z, returns quantized zhat, same shape as z. """
115-
quantized = round_ste(self.bound(z))
141+
if self.training and random.random() < self.noise_approx_prob:
142+
bounded = self.noise_approx_bound(z)
143+
elif preserve_symmetry:
144+
bounded = self.symmetry_preserving_bound(z)
145+
else:
146+
bounded = self.bound(z)
147+
quantized = round_ste(bounded)
116148
half_width = self._levels // 2 # Renormalize to [-1, 1].
117149
return quantized / half_width
118-
150+
119151
def _scale_and_shift(self, zhat_normalized):
120152
half_width = self._levels // 2
121153
return (zhat_normalized * half_width) + half_width
@@ -194,7 +226,7 @@ def forward(self, z):
194226
if force_f32 and orig_dtype not in self.allowed_dtypes:
195227
z = z.float()
196228

197-
codes = self.quantize(z)
229+
codes = self.quantize(z, preserve_symmetry=self.preserve_symmetry)
198230

199231
# returning indices could be optional
200232

@@ -205,7 +237,7 @@ def forward(self, z):
205237

206238
codes = rearrange(codes, 'b n c d -> b n (c d)')
207239

208-
codes = codes.type(orig_dtype)
240+
codes = codes.to(orig_dtype)
209241

210242
# project out
211243

0 commit comments

Comments
 (0)