Skip to content

Commit 81d0f3b

Browse files
committed
complete residual sim vq
1 parent 8f5b428 commit 81d0f3b

File tree

6 files changed

+253
-4
lines changed

6 files changed

+253
-4
lines changed

pyproject.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "vector-quantize-pytorch"
3-
version = "1.20.4"
3+
version = "1.20.5"
44
description = "Vector Quantization - Pytorch"
55
authors = [
66
{ name = "Phil Wang", email = "lucidrains@gmail.com" }

tests/test_readme.py

+17
Original file line numberDiff line numberDiff line change
@@ -376,3 +376,20 @@ def test_sim_vq():
376376

377377
assert x.shape == quantized.shape
378378
assert torch.allclose(quantized, sim_vq.indices_to_codes(indices), atol = 1e-6)
379+
380+
def test_residual_sim_vq():
381+
382+
from vector_quantize_pytorch import ResidualSimVQ
383+
384+
residual_sim_vq = ResidualSimVQ(
385+
dim = 512,
386+
num_quantizers = 4,
387+
codebook_size = 1024,
388+
accept_image_fmap = True
389+
)
390+
391+
x = torch.randn(1, 512, 32, 32)
392+
quantized, indices, commit_loss = residual_sim_vq(x)
393+
394+
assert x.shape == quantized.shape
395+
assert torch.allclose(quantized, residual_sim_vq.get_output_from_indices(indices), atol = 1e-5)

vector_quantize_pytorch/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66
from vector_quantize_pytorch.residual_lfq import ResidualLFQ, GroupedResidualLFQ
77
from vector_quantize_pytorch.residual_fsq import ResidualFSQ, GroupedResidualFSQ
88
from vector_quantize_pytorch.latent_quantization import LatentQuantize
9+
910
from vector_quantize_pytorch.sim_vq import SimVQ
11+
from vector_quantize_pytorch.residual_sim_vq import ResidualSimVQ
1012

1113
from vector_quantize_pytorch.utils import Sequential
+226
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,226 @@
1+
from __future__ import annotations
2+
3+
import random
4+
from math import ceil
5+
from functools import partial, cache
6+
from itertools import zip_longest
7+
8+
import torch
9+
from torch import nn, Tensor
10+
from torch.nn import Module, ModuleList
11+
import torch.nn.functional as F
12+
import torch.distributed as dist
13+
14+
from vector_quantize_pytorch.sim_vq import SimVQ, pack_one
15+
16+
from einx import get_at
17+
from einops import rearrange, repeat, reduce, pack, unpack
18+
19+
# helper functions
20+
21+
def exists(val):
22+
return val is not None
23+
24+
def first(it):
25+
return it[0]
26+
27+
def default(val, d):
28+
return val if exists(val) else d
29+
30+
def round_up_multiple(num, mult):
31+
return ceil(num / mult) * mult
32+
33+
# distributed helpers
34+
35+
def is_distributed():
36+
return dist.is_initialized() and dist.get_world_size() > 1
37+
38+
def get_maybe_sync_seed(device, max_size = 10_000):
39+
rand_int = torch.randint(0, max_size, (), device = device)
40+
41+
if is_distributed():
42+
dist.all_reduce(rand_int)
43+
44+
return rand_int.item()
45+
46+
# main class
47+
48+
class ResidualSimVQ(Module):
49+
""" Follows Algorithm 1. in https://arxiv.org/pdf/2107.03312.pdf """
50+
51+
def __init__(
52+
self,
53+
*,
54+
dim,
55+
num_quantizers,
56+
codebook_size,
57+
heads = 1,
58+
quantize_dropout = False,
59+
quantize_dropout_cutoff_index = 0,
60+
quantize_dropout_multiple_of = 1,
61+
accept_image_fmap = False,
62+
rotation_trick = True, # rotation trick from @cfifty, on top of sim vq
63+
**sim_vq_kwargs
64+
):
65+
super().__init__()
66+
assert heads == 1, 'residual vq is not compatible with multi-headed codes'
67+
68+
self.accept_image_fmap = accept_image_fmap
69+
70+
self.num_quantizers = num_quantizers
71+
72+
# define sim vq across layers
73+
74+
self.layers = ModuleList([SimVQ(dim = dim, codebook_size = codebook_size, rotation_trick = rotation_trick, accept_image_fmap = accept_image_fmap, **sim_vq_kwargs) for _ in range(num_quantizers)])
75+
76+
# quantize dropout
77+
78+
self.quantize_dropout = quantize_dropout and num_quantizers > 1
79+
80+
assert quantize_dropout_cutoff_index >= 0
81+
82+
self.quantize_dropout_cutoff_index = quantize_dropout_cutoff_index
83+
self.quantize_dropout_multiple_of = quantize_dropout_multiple_of # encodec paper proposes structured dropout, believe this was set to 4
84+
85+
@property
86+
def codebook_size(self):
87+
return first(self.layers).codebook_size
88+
89+
@property
90+
def codebook_dim(self):
91+
return first(self.layers).codebook_dim
92+
93+
@property
94+
def codebooks(self):
95+
codebooks = [layer.codebook for layer in self.layers]
96+
codebooks = torch.stack(codebooks)
97+
return codebooks
98+
99+
def get_codes_from_indices(self, indices):
100+
101+
batch, quantize_dim = indices.shape[0], indices.shape[-1]
102+
103+
# may also receive indices in the shape of 'b h w q' (accept_image_fmap)
104+
105+
indices, inverse = pack_one(indices, 'b * q')
106+
107+
# because of quantize dropout, one can pass in indices that are coarse
108+
# and the network should be able to reconstruct
109+
110+
if quantize_dim < self.num_quantizers:
111+
assert self.quantize_dropout > 0., 'quantize dropout must be greater than 0 if you wish to reconstruct from a signal with less fine quantizations'
112+
indices = F.pad(indices, (0, self.num_quantizers - quantize_dim), value = -1)
113+
114+
# take care of quantizer dropout
115+
116+
mask = indices == -1.
117+
indices = indices.masked_fill(mask, 0) # have it fetch a dummy code to be masked out later
118+
119+
all_codes = get_at('q [c] d, b n q -> q b n d', self.codebooks, indices)
120+
121+
# mask out any codes that were dropout-ed
122+
123+
all_codes = all_codes.masked_fill(rearrange(mask, 'b n q -> q b n 1'), 0.)
124+
125+
# if (accept_image_fmap = True) then return shape (quantize, batch, height, width, dimension)
126+
127+
all_codes = inverse(all_codes, 'q b * d')
128+
129+
if self.accept_image_fmap:
130+
all_codes = rearrange(all_codes, 'q b ... d -> q b d ...')
131+
132+
return all_codes
133+
134+
def get_output_from_indices(self, indices):
135+
all_codes = self.get_codes_from_indices(indices)
136+
summed_residual_codes = reduce(all_codes, 'q ... -> ...', 'sum')
137+
return summed_residual_codes
138+
139+
def forward(
140+
self,
141+
x,
142+
indices: Tensor | list[Tensor] | None = None,
143+
return_all_codes = False,
144+
rand_quantize_dropout_fixed_seed = None
145+
):
146+
num_quant, quant_dropout_multiple_of, return_loss, device = self.num_quantizers, self.quantize_dropout_multiple_of, exists(indices), x.device
147+
148+
assert not (self.accept_image_fmap and exists(indices))
149+
150+
quantized_out = 0.
151+
residual = x
152+
153+
all_losses = []
154+
all_indices = []
155+
156+
if isinstance(indices, list):
157+
indices = torch.stack(indices)
158+
159+
should_quantize_dropout = self.training and self.quantize_dropout and not return_loss
160+
161+
# sample a layer index at which to dropout further residual quantization
162+
# also prepare null indices and loss
163+
164+
if should_quantize_dropout:
165+
166+
# check if seed is manually passed in
167+
168+
if not exists(rand_quantize_dropout_fixed_seed):
169+
rand_quantize_dropout_fixed_seed = get_maybe_sync_seed(device)
170+
171+
rand = random.Random(rand_quantize_dropout_fixed_seed)
172+
173+
rand_quantize_dropout_index = rand.randrange(self.quantize_dropout_cutoff_index, num_quant)
174+
175+
if quant_dropout_multiple_of != 1:
176+
rand_quantize_dropout_index = round_up_multiple(rand_quantize_dropout_index + 1, quant_dropout_multiple_of) - 1
177+
178+
null_indices_shape = (x.shape[0], *x.shape[-2:]) if self.accept_image_fmap else tuple(x.shape[:2])
179+
null_indices = torch.full(null_indices_shape, -1., device = device, dtype = torch.long)
180+
null_loss = torch.full((1,), 0., device = device, dtype = x.dtype)
181+
182+
# save all inputs across layers, for use during expiration at end under shared codebook setting
183+
184+
all_residuals = []
185+
186+
# go through the layers
187+
188+
for quantizer_index, sim_vq in enumerate(self.layers):
189+
190+
if should_quantize_dropout and quantizer_index > rand_quantize_dropout_index:
191+
all_indices.append(null_indices)
192+
all_losses.append(null_loss)
193+
continue
194+
195+
# save for expiration
196+
197+
all_residuals.append(residual)
198+
199+
# sim vq forward
200+
201+
quantized, *rest = sim_vq(residual)
202+
203+
residual = residual - quantized.detach()
204+
quantized_out = quantized_out + quantized
205+
206+
embed_indices, loss = rest
207+
208+
all_indices.append(embed_indices)
209+
all_losses.append(loss)
210+
211+
# stack all losses and indices
212+
213+
all_losses, all_indices = map(partial(torch.stack, dim = -1), (all_losses, all_indices))
214+
215+
ret = (quantized_out, all_indices, all_losses)
216+
217+
if not return_all_codes:
218+
return ret
219+
220+
# whether to return all codes from all codebooks across layers
221+
222+
all_codes = self.get_codes_from_indices(all_indices)
223+
224+
# will return all codes in shape (quantizer, batch, sequence length, codebook dimension)
225+
226+
return (*ret, all_codes)

vector_quantize_pytorch/residual_vq.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -156,7 +156,7 @@ def __init__(
156156
manual_in_place_optimizer_update = True
157157
)
158158

159-
# take care of maybe different codebook sizes across depth, used in TIGER paper https://arxiv.org/abs/2305.05065
159+
# take care of maybe different codebook sizes across depth
160160

161161
codebook_sizes = cast_tuple(codebook_size, num_quantizers)
162162

vector_quantize_pytorch/sim_vq.py

+6-2
Original file line numberDiff line numberDiff line change
@@ -44,17 +44,21 @@ def __init__(
4444
accept_image_fmap = False,
4545
rotation_trick = True, # works even better with rotation trick turned on, with no straight through and the commit loss from input to quantize
4646
input_to_quantize_commit_loss_weight = 0.25,
47+
frozen_codebook_dim = None # frozen codebook dim could have different dimensions than projection
4748
):
4849
super().__init__()
50+
self.codebook_size = codebook_size
4951
self.accept_image_fmap = accept_image_fmap
5052

51-
codebook = torch.randn(codebook_size, dim) * (dim ** -0.5)
53+
frozen_codebook_dim = default(frozen_codebook_dim, dim)
54+
codebook = torch.randn(codebook_size, frozen_codebook_dim) * (frozen_codebook_dim ** -0.5)
5255
codebook = init_fn(codebook)
5356

5457
# the codebook is actually implicit from a linear layer from frozen gaussian or uniform
5558

59+
5660
if not exists(codebook_transform):
57-
codebook_transform = nn.Linear(dim, dim, bias = False)
61+
codebook_transform = nn.Linear(frozen_codebook_dim, dim, bias = False)
5862

5963
self.code_transform = codebook_transform
6064

0 commit comments

Comments
 (0)