-
Notifications
You must be signed in to change notification settings - Fork 8
/
Copy pathmodel.py
252 lines (227 loc) · 9.12 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
import numpy as np
import torch
from torch import nn
class FFTNetQueue(object):
def __init__(self, batch_size, size, num_channels, cuda=True):
super(FFTNetQueue, self).__init__()
self.size = size
self.batch_size = batch_size
self.num_channels = num_channels
self.cuda = cuda
self.queue = []
self.reset()
def reset(self):
self.queue = torch.zeros([self.batch_size, self.num_channels, self.size])
if self.cuda:
self.queue = self.queue.cuda()
def enqueue(self, x_push):
x_pop = self.queue[:, :, -1].data
self.queue[:, :, :-1] = self.queue[:, :, 1:]
self.queue[:, :, -1] = x_push.view(x_push.shape[0], x_push.shape[1])
return x_pop
class FFTNet(nn.Module):
def __init__(self, in_channels, out_channels, hid_channels, layer_id,
cond_channels=None, std_f=0.5):
super(FFTNet, self).__init__()
self.layer_id = layer_id
self.receptive_field = 2**layer_id
self.K = self.receptive_field // 2
self.in_channels = in_channels
self.out_channels = out_channels
self.hid_channels = hid_channels
self.cond_channels = cond_channels
self.conv1_1 = nn.Conv1d(in_channels, hid_channels, 1, stride=1)
self.conv1_2 = nn.Conv1d(in_channels, hid_channels, 1, stride=1)
if cond_channels is not None:
self.convc1 = nn.Conv1d(cond_channels, hid_channels, 1)
self.convc2 = nn.Conv1d(cond_channels, hid_channels, 1)
self.conv2 = nn.Conv1d(hid_channels, out_channels, 1)
self.relu = nn.ReLU()
self.init_weights(std_f)
self.buffer = None
self.cond_buffer = None
# inference params for linear operations
self.w1_1 = None
self.w1_2 = None
self.w2 = None
if cond_channels is not None:
self.wc1_1 = None
self.wc1_2 = None
def init_weights(self, std_f):
std = np.sqrt(std_f / self.in_channels)
self.conv1_1.weight.data.normal_(mean=0, std=std)
self.conv1_1.bias.data.zero_()
self.conv1_2.weight.data.normal_(mean=0, std=std)
self.conv1_2.bias.data.zero_()
if self.cond_channels is not None:
self.convc1.weight.data.normal_(mean=0, std=std)
self.convc1.bias.data.zero_()
self.convc2.weight.data.normal_(mean=0, std=std)
self.convc2.bias.data.zero_()
def forward(self, x, cx=None):
"""
Shapes:
inputs: batch x channels x time
cx: batch x cond_channels x time
out: batch x out_chennels x time - receptive_field/2
"""
T = x.shape[2]
x1 = x[:, :, :-self.K]
x2 = x[:, :, self.K:]
z1 = self.conv1_1(x1)
z2 = self.conv1_2(x2)
z = z1 + z2
# conditional input
if cx is not None:
cx1 = cx[:, :, :-self.K]
cx2 = cx[:, :, self.K:]
cz1 = self.convc1(cx1)
cz2 = self.convc2(cx2)
z = z + cz1 + cz2
out = self.relu(z)
out = self.conv2(out)
out = self.relu(out)
return out
def forward_step(self, x, cx=None):
T = x.shape[2]
B = x.shape[0]
# linear weights
if self.w1_1 is None:
self.w1_1 = self._convert_to_fc_weights(self.conv1_1)
self.w1_2 = self._convert_to_fc_weights(self.conv1_2)
if cx is not None and self.wc1_1 is None:
self.wc1_1 = self._convert_to_fc_weights(self.convc1)
self.wc1_2 = self._convert_to_fc_weights(self.convc2)
if self.w2 is None:
self.w2 = self._convert_to_fc_weights(self.conv2)
# create buffer queues
if self.buffer is None:
self.buffer = FFTNetQueue(B, self.K, self.in_channels, x.is_cuda)
if self.cond_channels is not None and self.cond_buffer is None:
self.cond_buffer = FFTNetQueue(B, self.K, self.cond_channels, x.is_cuda)
# queue inputs
x_input = x.view([B, -1])
x_input2 = self.buffer.enqueue(x).view([B, -1])
if self.cond_channels is not None:
cx1 = cx.view([B, -1])
cx2 = self.cond_buffer.enqueue(cx).view([B, -1])
# perform first set of convs
z1 = torch.nn.functional.linear(x_input, self.w1_1, self.conv1_1.bias)
z2 = torch.nn.functional.linear(x_input2, self.w1_2, self.conv1_2.bias)
z = z1 + z2
if cx is not None:
self.wc1_1 = self._convert_to_fc_weights(self.convc1)
self.wc1_2 = self._convert_to_fc_weights(self.convc2)
cz1 = torch.nn.functional.linear(cx1, self.wc1_1, self.convc1.bias)
cz2 = torch.nn.functional.linear(cx2, self.wc1_2, self.convc2.bias)
z = z + cz1 + cz2
# second conv
z = self.relu(z)
z = torch.nn.functional.linear(z, self.w2, self.conv2.bias)
z = self.relu(z)
z = z.view(B, -1, 1)
return z
def _convert_to_fc_weights(self, conv):
w = conv.weight
out_channels, in_channels, filter_size = w.shape
nw = w.transpose(1, 2).view(out_channels, -1).contiguous()
return nw
class FFTNetModel(nn.Module):
def __init__(self, hid_channels=256, out_channels=256, n_layers=11,
cond_channels=None):
super(FFTNetModel, self).__init__()
self.cond_channels = cond_channels
self.hid_channels = hid_channels
self.out_channels = out_channels
self.n_layers = n_layers
self.receptive_field = 2 ** n_layers
self.layers = []
for idx in range(self.n_layers):
layer_id = n_layers - idx
if idx == 0:
layer = FFTNet(1, hid_channels, hid_channels, layer_id=layer_id, cond_channels=cond_channels)
else:
layer = FFTNet(hid_channels, hid_channels, hid_channels, layer_id=layer_id)
self.layers.append(layer)
self.layers = nn.ModuleList(self.layers)
self.fc = nn.Linear(hid_channels, out_channels)
def forward(self, x, cx=None):
"""
Shapes:
x: batch x 1 x time
cx: batch x dim x time
"""
# FFTNet modules
out = x
for idx, layer in enumerate(self.layers):
if idx == 0 and cx is not None:
out = layer(out, cx)
else:
out = layer(out)
out = out.transpose(1, 2)
out = self.fc(out)
return out
def forward_step(self, x, cx=None):
# FFTNet modules
out = x
for idx, layer in enumerate(self.layers):
if idx == 0 and cx is not None:
out = layer.forward_step(out, cx)
else:
out = layer.forward_step(out)
out = out.transpose(1, 2)
out = self.fc(out)
return out
def sequence_mask(sequence_length):
max_len = sequence_length.data.max()
batch_size = sequence_length.size(0)
seq_range = torch.arange(0, max_len).long()
seq_range_expand = seq_range.unsqueeze(0).expand(batch_size, max_len)
if sequence_length.is_cuda:
seq_range_expand = seq_range_expand.cuda()
seq_length_expand = sequence_length.unsqueeze(1) \
.expand_as(seq_range_expand)
return (seq_range_expand < seq_length_expand).float()
class MaskedCrossEntropyLoss(nn.Module):
def __init__(self):
super(MaskedCrossEntropyLoss, self).__init__()
self.criterion = nn.CrossEntropyLoss(reduce=False)
def forward(self, input, target, lengths=None):
if lengths is None:
raise RuntimeError(" > Provide lengths for the loss function")
mask = sequence_mask(lengths)
if target.is_cuda:
mask = mask.cuda()
input = input.view([input.shape[0] * input.shape[1], -1])
target = target.view([target.shape[0] * target.shape[1]])
mask_ = mask.view([mask.shape[0] * mask.shape[1]])
losses = self.criterion(input, target)
_, pred = torch.max(input, 1)
f = (pred != target).type(torch.FloatTensor)
t = (pred == target).type(torch.FloatTensor)
if input.is_cuda:
f = f.cuda()
t = t.cuda()
f = (f.squeeze() * mask_).sum()
t = (t.squeeze() * mask_).sum()
return ((losses * mask_).sum()) / mask_.sum(), f.item(), t.item()
# https://discuss.pytorch.org/t/how-to-apply-exponential-moving-average-decay-for-variables/10856/4
# https://www.tensorflow.org/api_docs/python/tf/train/ExponentialMovingAverage
class EMA(object):
def __init__(self, decay):
self.decay = decay
self.shadow = {}
def register(self, name, val):
self.shadow[name] = val.clone()
def update(self, name, x):
assert name in self.shadow
update_delta = self.shadow[name] - x
self.shadow[name] -= (1.0 - self.decay) * update_delta
def assign_ema_model(self, model, new_model, cuda):
new_model.load_state_dict(model.state_dict())
for name, param in new_model.named_parameters():
if name in self.shadow:
param.data = self.shadow[name].clone()
if cuda:
new_model.cuda()
return new_model