File tree 2 files changed +16
-1
lines changed
2 files changed +16
-1
lines changed Original file line number Diff line number Diff line change 1
1
[project ]
2
2
name = " vector-quantize-pytorch"
3
- version = " 1.21.7 "
3
+ version = " 1.21.8 "
4
4
description = " Vector Quantization - Pytorch"
5
5
authors = [
6
6
{ name = " Phil Wang" , email = " lucidrains@gmail.com" }
Original file line number Diff line number Diff line change @@ -59,6 +59,7 @@ def __init__(
59
59
quantize_dropout = False ,
60
60
quantize_dropout_cutoff_index = 0 ,
61
61
quantize_dropout_multiple_of = 1 ,
62
+ soft_clamp_input_value = None ,
62
63
** kwargs
63
64
):
64
65
super ().__init__ ()
@@ -73,6 +74,12 @@ def __init__(
73
74
self .is_channel_first = is_channel_first
74
75
self .num_quantizers = num_quantizers
75
76
77
+ # soft clamping the input value
78
+
79
+ self .soft_clamp_input_value = soft_clamp_input_value
80
+
81
+ # layers
82
+
76
83
self .levels = levels
77
84
self .layers = nn .ModuleList ([])
78
85
@@ -170,6 +177,14 @@ def forward(
170
177
171
178
x = self .project_in (x )
172
179
180
+ # maybe softclamp input before residual layers
181
+
182
+ if exists (self .soft_clamp_input_value ):
183
+ clamp_value = self .soft_clamp_input_value
184
+ x = (x / clamp_value ).tanh () * clamp_value
185
+
186
+ # ready some variables to be accumulated
187
+
173
188
quantized_out = 0.
174
189
residual = x
175
190
You can’t perform that action at this time.
0 commit comments