File tree 2 files changed +13
-2
lines changed
2 files changed +13
-2
lines changed Original file line number Diff line number Diff line change 1
1
[project ]
2
2
name = " vector-quantize-pytorch"
3
- version = " 1.20.1 "
3
+ version = " 1.20.2 "
4
4
description = " Vector Quantization - Pytorch"
5
5
authors = [
6
6
{ name = " Phil Wang" , email = " lucidrains@gmail.com" }
Original file line number Diff line number Diff line change
1
+ from __future__ import annotations
1
2
from typing import Callable
2
3
3
4
import torch
@@ -38,6 +39,7 @@ def __init__(
38
39
self ,
39
40
dim ,
40
41
codebook_size ,
42
+ codebook_transform : Module | None = None ,
41
43
init_fn : Callable = identity ,
42
44
accept_image_fmap = False ,
43
45
rotation_trick = True , # works even better with rotation trick turned on, with no straight through and the commit loss from input to quantize
@@ -51,7 +53,11 @@ def __init__(
51
53
52
54
# the codebook is actually implicit from a linear layer from frozen gaussian or uniform
53
55
54
- self .codebook_to_codes = nn .Linear (dim , dim , bias = False )
56
+ if not exists (codebook_transform ):
57
+ codebook_transform = nn .Linear (dim , dim , bias = False )
58
+
59
+ self .codebook_to_codes = codebook_transform
60
+
55
61
self .register_buffer ('codebook' , codebook )
56
62
57
63
@@ -114,6 +120,11 @@ def forward(
114
120
115
121
sim_vq = SimVQ (
116
122
dim = 512 ,
123
+ codebook_transform = nn .Sequential (
124
+ nn .Linear (512 , 1024 ),
125
+ nn .ReLU (),
126
+ nn .Linear (1024 , 512 )
127
+ ),
117
128
codebook_size = 1024 ,
118
129
accept_image_fmap = True
119
130
)
You can’t perform that action at this time.
0 commit comments