-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathutils.py
51 lines (32 loc) · 1.16 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
"""
Various utilities used by original code in
https://github.com/timbmg/Sentence-VAE
"""
import torch
import numpy as np
from torch.autograd import Variable
from collections import defaultdict, Counter, OrderedDict
class OrderedCounter(Counter, OrderedDict):
'Counter that remembers the order elements are first encountered'
def __repr__(self):
return '%s(%r)' % (self.__class__.__name__, OrderedDict(self))
def __reduce__(self):
return self.__class__, (OrderedDict(self),)
def to_var(x, volatile=False):
if torch.cuda.is_available():
x = x.cuda()
return Variable(x, volatile=volatile)
def idx2word(idx, i2w, pad_idx):
sent_str = [str()] * len(idx)
for i, sent in enumerate(idx):
for word_id in sent:
if word_id == pad_idx:
break
sent_str[i] += i2w[str(word_id.item())] + " "
sent_str[i] = sent_str[i].strip()
return sent_str
def interpolate(start, end, steps):
interpolation = np.zeros((start.shape[0], steps + 2))
for dim, (s, e) in enumerate(zip(start, end)):
interpolation[dim] = np.linspace(s, e, steps + 2)
return interpolation.T