Skip to content

Commit 80f4e84

Browse files
committed
allow for an MLP implicitly parameterized codebook instead of just a single Linear, for SimVQ. seems to converge just fine with rotation trick
1 parent cd0fa8e commit 80f4e84

File tree

2 files changed

+13
-2
lines changed

2 files changed

+13
-2
lines changed

pyproject.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "vector-quantize-pytorch"
3-
version = "1.20.1"
3+
version = "1.20.2"
44
description = "Vector Quantization - Pytorch"
55
authors = [
66
{ name = "Phil Wang", email = "lucidrains@gmail.com" }

vector_quantize_pytorch/sim_vq.py

+12-1
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from __future__ import annotations
12
from typing import Callable
23

34
import torch
@@ -38,6 +39,7 @@ def __init__(
3839
self,
3940
dim,
4041
codebook_size,
42+
codebook_transform: Module | None = None,
4143
init_fn: Callable = identity,
4244
accept_image_fmap = False,
4345
rotation_trick = True, # works even better with rotation trick turned on, with no straight through and the commit loss from input to quantize
@@ -51,7 +53,11 @@ def __init__(
5153

5254
# the codebook is actually implicit from a linear layer from frozen gaussian or uniform
5355

54-
self.codebook_to_codes = nn.Linear(dim, dim, bias = False)
56+
if not exists(codebook_transform):
57+
codebook_transform = nn.Linear(dim, dim, bias = False)
58+
59+
self.codebook_to_codes = codebook_transform
60+
5561
self.register_buffer('codebook', codebook)
5662

5763

@@ -114,6 +120,11 @@ def forward(
114120

115121
sim_vq = SimVQ(
116122
dim = 512,
123+
codebook_transform = nn.Sequential(
124+
nn.Linear(512, 1024),
125+
nn.ReLU(),
126+
nn.Linear(1024, 512)
127+
),
117128
codebook_size = 1024,
118129
accept_image_fmap = True
119130
)

0 commit comments

Comments
 (0)