@@ -58,20 +58,20 @@ def __init__(
58
58
quantize_dropout = False ,
59
59
quantize_dropout_cutoff_index = 0 ,
60
60
quantize_dropout_multiple_of = 1 ,
61
- accept_image_fmap = False ,
61
+ channel_first = False ,
62
62
rotation_trick = True , # rotation trick from @cfifty, on top of sim vq
63
63
** sim_vq_kwargs
64
64
):
65
65
super ().__init__ ()
66
66
assert heads == 1 , 'residual vq is not compatible with multi-headed codes'
67
67
68
- self .accept_image_fmap = accept_image_fmap
68
+ self .channel_first = channel_first
69
69
70
70
self .num_quantizers = num_quantizers
71
71
72
72
# define sim vq across layers
73
73
74
- self .layers = ModuleList ([SimVQ (dim = dim , codebook_size = codebook_size , rotation_trick = rotation_trick , accept_image_fmap = accept_image_fmap , ** sim_vq_kwargs ) for _ in range (num_quantizers )])
74
+ self .layers = ModuleList ([SimVQ (dim = dim , codebook_size = codebook_size , rotation_trick = rotation_trick , channel_first = channel_first , ** sim_vq_kwargs ) for _ in range (num_quantizers )])
75
75
76
76
# quantize dropout
77
77
@@ -100,7 +100,7 @@ def get_codes_from_indices(self, indices):
100
100
101
101
batch , quantize_dim = indices .shape [0 ], indices .shape [- 1 ]
102
102
103
- # may also receive indices in the shape of 'b h w q' (accept_image_fmap )
103
+ # may also receive indices in the shape of 'b h w q' (images )
104
104
105
105
indices , inverse = pack_one (indices , 'b * q' )
106
106
@@ -122,11 +122,11 @@ def get_codes_from_indices(self, indices):
122
122
123
123
all_codes = all_codes .masked_fill (rearrange (mask , 'b n q -> q b n 1' ), 0. )
124
124
125
- # if (accept_image_fmap = True) then return shape (quantize, batch, height, width, dimension)
125
+ # if (channel_first = True) then return shape (quantize, batch, height, width, dimension)
126
126
127
127
all_codes = inverse (all_codes , 'q b * d' )
128
128
129
- if self .accept_image_fmap :
129
+ if self .channel_first :
130
130
all_codes = rearrange (all_codes , 'q b ... d -> q b d ...' )
131
131
132
132
return all_codes
@@ -139,23 +139,17 @@ def get_output_from_indices(self, indices):
139
139
def forward (
140
140
self ,
141
141
x ,
142
- indices : Tensor | list [Tensor ] | None = None ,
143
142
return_all_codes = False ,
144
143
rand_quantize_dropout_fixed_seed = None
145
144
):
146
- num_quant , quant_dropout_multiple_of , return_loss , device = self .num_quantizers , self .quantize_dropout_multiple_of , exists (indices ), x .device
147
-
148
- assert not (self .accept_image_fmap and exists (indices ))
145
+ num_quant , quant_dropout_multiple_of , device = self .num_quantizers , self .quantize_dropout_multiple_of , x .device
149
146
150
147
quantized_out = 0.
151
148
residual = x
152
149
153
150
all_losses = []
154
151
all_indices = []
155
152
156
- if isinstance (indices , list ):
157
- indices = torch .stack (indices )
158
-
159
153
should_quantize_dropout = self .training and self .quantize_dropout and not return_loss
160
154
161
155
# sample a layer index at which to dropout further residual quantization
@@ -175,7 +169,7 @@ def forward(
175
169
if quant_dropout_multiple_of != 1 :
176
170
rand_quantize_dropout_index = round_up_multiple (rand_quantize_dropout_index + 1 , quant_dropout_multiple_of ) - 1
177
171
178
- null_indices_shape = (x .shape [0 ], * x .shape [- 2 :]) if self .accept_image_fmap else tuple (x .shape [:2 ])
172
+ null_indices_shape = (x .shape [0 ], * x .shape [- 2 :]) if self .channel_first else tuple (x .shape [:2 ])
179
173
null_indices = torch .full (null_indices_shape , - 1. , device = device , dtype = torch .long )
180
174
null_loss = torch .full ((1 ,), 0. , device = device , dtype = x .dtype )
181
175
0 commit comments