|
| 1 | +from __future__ import annotations |
| 2 | + |
| 3 | +import random |
| 4 | +from math import ceil |
| 5 | +from functools import partial, cache |
| 6 | +from itertools import zip_longest |
| 7 | + |
| 8 | +import torch |
| 9 | +from torch import nn, Tensor |
| 10 | +from torch.nn import Module, ModuleList |
| 11 | +import torch.nn.functional as F |
| 12 | +import torch.distributed as dist |
| 13 | + |
| 14 | +from vector_quantize_pytorch.sim_vq import SimVQ, pack_one |
| 15 | + |
| 16 | +from einx import get_at |
| 17 | +from einops import rearrange, repeat, reduce, pack, unpack |
| 18 | + |
| 19 | +# helper functions |
| 20 | + |
| 21 | +def exists(val): |
| 22 | + return val is not None |
| 23 | + |
| 24 | +def first(it): |
| 25 | + return it[0] |
| 26 | + |
| 27 | +def default(val, d): |
| 28 | + return val if exists(val) else d |
| 29 | + |
| 30 | +def round_up_multiple(num, mult): |
| 31 | + return ceil(num / mult) * mult |
| 32 | + |
| 33 | +# distributed helpers |
| 34 | + |
| 35 | +def is_distributed(): |
| 36 | + return dist.is_initialized() and dist.get_world_size() > 1 |
| 37 | + |
| 38 | +def get_maybe_sync_seed(device, max_size = 10_000): |
| 39 | + rand_int = torch.randint(0, max_size, (), device = device) |
| 40 | + |
| 41 | + if is_distributed(): |
| 42 | + dist.all_reduce(rand_int) |
| 43 | + |
| 44 | + return rand_int.item() |
| 45 | + |
| 46 | +# main class |
| 47 | + |
| 48 | +class ResidualSimVQ(Module): |
| 49 | + """ Follows Algorithm 1. in https://arxiv.org/pdf/2107.03312.pdf """ |
| 50 | + |
| 51 | + def __init__( |
| 52 | + self, |
| 53 | + *, |
| 54 | + dim, |
| 55 | + num_quantizers, |
| 56 | + codebook_size, |
| 57 | + heads = 1, |
| 58 | + quantize_dropout = False, |
| 59 | + quantize_dropout_cutoff_index = 0, |
| 60 | + quantize_dropout_multiple_of = 1, |
| 61 | + accept_image_fmap = False, |
| 62 | + rotation_trick = True, # rotation trick from @cfifty, on top of sim vq |
| 63 | + **sim_vq_kwargs |
| 64 | + ): |
| 65 | + super().__init__() |
| 66 | + assert heads == 1, 'residual vq is not compatible with multi-headed codes' |
| 67 | + |
| 68 | + self.accept_image_fmap = accept_image_fmap |
| 69 | + |
| 70 | + self.num_quantizers = num_quantizers |
| 71 | + |
| 72 | + # define sim vq across layers |
| 73 | + |
| 74 | + self.layers = ModuleList([SimVQ(dim = dim, codebook_size = codebook_size, rotation_trick = rotation_trick, accept_image_fmap = accept_image_fmap, **sim_vq_kwargs) for _ in range(num_quantizers)]) |
| 75 | + |
| 76 | + # quantize dropout |
| 77 | + |
| 78 | + self.quantize_dropout = quantize_dropout and num_quantizers > 1 |
| 79 | + |
| 80 | + assert quantize_dropout_cutoff_index >= 0 |
| 81 | + |
| 82 | + self.quantize_dropout_cutoff_index = quantize_dropout_cutoff_index |
| 83 | + self.quantize_dropout_multiple_of = quantize_dropout_multiple_of # encodec paper proposes structured dropout, believe this was set to 4 |
| 84 | + |
| 85 | + @property |
| 86 | + def codebook_size(self): |
| 87 | + return first(self.layers).codebook_size |
| 88 | + |
| 89 | + @property |
| 90 | + def codebook_dim(self): |
| 91 | + return first(self.layers).codebook_dim |
| 92 | + |
| 93 | + @property |
| 94 | + def codebooks(self): |
| 95 | + codebooks = [layer.codebook for layer in self.layers] |
| 96 | + codebooks = torch.stack(codebooks) |
| 97 | + return codebooks |
| 98 | + |
| 99 | + def get_codes_from_indices(self, indices): |
| 100 | + |
| 101 | + batch, quantize_dim = indices.shape[0], indices.shape[-1] |
| 102 | + |
| 103 | + # may also receive indices in the shape of 'b h w q' (accept_image_fmap) |
| 104 | + |
| 105 | + indices, inverse = pack_one(indices, 'b * q') |
| 106 | + |
| 107 | + # because of quantize dropout, one can pass in indices that are coarse |
| 108 | + # and the network should be able to reconstruct |
| 109 | + |
| 110 | + if quantize_dim < self.num_quantizers: |
| 111 | + assert self.quantize_dropout > 0., 'quantize dropout must be greater than 0 if you wish to reconstruct from a signal with less fine quantizations' |
| 112 | + indices = F.pad(indices, (0, self.num_quantizers - quantize_dim), value = -1) |
| 113 | + |
| 114 | + # take care of quantizer dropout |
| 115 | + |
| 116 | + mask = indices == -1. |
| 117 | + indices = indices.masked_fill(mask, 0) # have it fetch a dummy code to be masked out later |
| 118 | + |
| 119 | + all_codes = get_at('q [c] d, b n q -> q b n d', self.codebooks, indices) |
| 120 | + |
| 121 | + # mask out any codes that were dropout-ed |
| 122 | + |
| 123 | + all_codes = all_codes.masked_fill(rearrange(mask, 'b n q -> q b n 1'), 0.) |
| 124 | + |
| 125 | + # if (accept_image_fmap = True) then return shape (quantize, batch, height, width, dimension) |
| 126 | + |
| 127 | + all_codes = inverse(all_codes, 'q b * d') |
| 128 | + |
| 129 | + if self.accept_image_fmap: |
| 130 | + all_codes = rearrange(all_codes, 'q b ... d -> q b d ...') |
| 131 | + |
| 132 | + return all_codes |
| 133 | + |
| 134 | + def get_output_from_indices(self, indices): |
| 135 | + all_codes = self.get_codes_from_indices(indices) |
| 136 | + summed_residual_codes = reduce(all_codes, 'q ... -> ...', 'sum') |
| 137 | + return summed_residual_codes |
| 138 | + |
| 139 | + def forward( |
| 140 | + self, |
| 141 | + x, |
| 142 | + indices: Tensor | list[Tensor] | None = None, |
| 143 | + return_all_codes = False, |
| 144 | + rand_quantize_dropout_fixed_seed = None |
| 145 | + ): |
| 146 | + num_quant, quant_dropout_multiple_of, return_loss, device = self.num_quantizers, self.quantize_dropout_multiple_of, exists(indices), x.device |
| 147 | + |
| 148 | + assert not (self.accept_image_fmap and exists(indices)) |
| 149 | + |
| 150 | + quantized_out = 0. |
| 151 | + residual = x |
| 152 | + |
| 153 | + all_losses = [] |
| 154 | + all_indices = [] |
| 155 | + |
| 156 | + if isinstance(indices, list): |
| 157 | + indices = torch.stack(indices) |
| 158 | + |
| 159 | + should_quantize_dropout = self.training and self.quantize_dropout and not return_loss |
| 160 | + |
| 161 | + # sample a layer index at which to dropout further residual quantization |
| 162 | + # also prepare null indices and loss |
| 163 | + |
| 164 | + if should_quantize_dropout: |
| 165 | + |
| 166 | + # check if seed is manually passed in |
| 167 | + |
| 168 | + if not exists(rand_quantize_dropout_fixed_seed): |
| 169 | + rand_quantize_dropout_fixed_seed = get_maybe_sync_seed(device) |
| 170 | + |
| 171 | + rand = random.Random(rand_quantize_dropout_fixed_seed) |
| 172 | + |
| 173 | + rand_quantize_dropout_index = rand.randrange(self.quantize_dropout_cutoff_index, num_quant) |
| 174 | + |
| 175 | + if quant_dropout_multiple_of != 1: |
| 176 | + rand_quantize_dropout_index = round_up_multiple(rand_quantize_dropout_index + 1, quant_dropout_multiple_of) - 1 |
| 177 | + |
| 178 | + null_indices_shape = (x.shape[0], *x.shape[-2:]) if self.accept_image_fmap else tuple(x.shape[:2]) |
| 179 | + null_indices = torch.full(null_indices_shape, -1., device = device, dtype = torch.long) |
| 180 | + null_loss = torch.full((1,), 0., device = device, dtype = x.dtype) |
| 181 | + |
| 182 | + # save all inputs across layers, for use during expiration at end under shared codebook setting |
| 183 | + |
| 184 | + all_residuals = [] |
| 185 | + |
| 186 | + # go through the layers |
| 187 | + |
| 188 | + for quantizer_index, sim_vq in enumerate(self.layers): |
| 189 | + |
| 190 | + if should_quantize_dropout and quantizer_index > rand_quantize_dropout_index: |
| 191 | + all_indices.append(null_indices) |
| 192 | + all_losses.append(null_loss) |
| 193 | + continue |
| 194 | + |
| 195 | + # save for expiration |
| 196 | + |
| 197 | + all_residuals.append(residual) |
| 198 | + |
| 199 | + # sim vq forward |
| 200 | + |
| 201 | + quantized, *rest = sim_vq(residual) |
| 202 | + |
| 203 | + residual = residual - quantized.detach() |
| 204 | + quantized_out = quantized_out + quantized |
| 205 | + |
| 206 | + embed_indices, loss = rest |
| 207 | + |
| 208 | + all_indices.append(embed_indices) |
| 209 | + all_losses.append(loss) |
| 210 | + |
| 211 | + # stack all losses and indices |
| 212 | + |
| 213 | + all_losses, all_indices = map(partial(torch.stack, dim = -1), (all_losses, all_indices)) |
| 214 | + |
| 215 | + ret = (quantized_out, all_indices, all_losses) |
| 216 | + |
| 217 | + if not return_all_codes: |
| 218 | + return ret |
| 219 | + |
| 220 | + # whether to return all codes from all codebooks across layers |
| 221 | + |
| 222 | + all_codes = self.get_codes_from_indices(all_indices) |
| 223 | + |
| 224 | + # will return all codes in shape (quantizer, batch, sequence length, codebook dimension) |
| 225 | + |
| 226 | + return (*ret, all_codes) |
0 commit comments