Skip to content

Commit 7de0d36

Browse files
committed
add ability to softclamp the initial input going into residual fsq
1 parent 5ae9c79 commit 7de0d36

File tree

2 files changed

+16
-1
lines changed

2 files changed

+16
-1
lines changed

pyproject.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "vector-quantize-pytorch"
3-
version = "1.21.7"
3+
version = "1.21.8"
44
description = "Vector Quantization - Pytorch"
55
authors = [
66
{ name = "Phil Wang", email = "lucidrains@gmail.com" }

vector_quantize_pytorch/residual_fsq.py

+15
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@ def __init__(
5959
quantize_dropout = False,
6060
quantize_dropout_cutoff_index = 0,
6161
quantize_dropout_multiple_of = 1,
62+
soft_clamp_input_value = None,
6263
**kwargs
6364
):
6465
super().__init__()
@@ -73,6 +74,12 @@ def __init__(
7374
self.is_channel_first = is_channel_first
7475
self.num_quantizers = num_quantizers
7576

77+
# soft clamping the input value
78+
79+
self.soft_clamp_input_value = soft_clamp_input_value
80+
81+
# layers
82+
7683
self.levels = levels
7784
self.layers = nn.ModuleList([])
7885

@@ -170,6 +177,14 @@ def forward(
170177

171178
x = self.project_in(x)
172179

180+
# maybe softclamp input before residual layers
181+
182+
if exists(self.soft_clamp_input_value):
183+
clamp_value = self.soft_clamp_input_value
184+
x = (x / clamp_value).tanh() * clamp_value
185+
186+
# ready some variables to be accumulated
187+
173188
quantized_out = 0.
174189
residual = x
175190

0 commit comments

Comments
 (0)