-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathsampler.py
271 lines (217 loc) · 9.43 KB
/
sampler.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
from __future__ import division
import math
import torch
import numpy as np
from mmcv.runner import get_dist_info
from torch.utils.data import Sampler
from torch.utils.data import DistributedSampler as _DistributedSampler
import random
import itertools
class DistributedSampler(_DistributedSampler):
def __init__(self, dataset, num_replicas=None, rank=None, shuffle=True):
super().__init__(dataset, num_replicas=num_replicas, rank=rank)
self.shuffle = shuffle
def __iter__(self):
# deterministically shuffle based on epoch
if self.shuffle:
g = torch.Generator()
g.manual_seed(self.epoch)
indices = torch.randperm(len(self.dataset), generator=g).tolist()
else:
indices = torch.arange(len(self.dataset)).tolist()
# add extra samples to make it evenly divisible
indices += indices[:(self.total_size - len(indices))]
assert len(indices) == self.total_size
# subsample
indices = indices[self.rank:self.total_size:self.num_replicas]
assert len(indices) == self.num_samples
return iter(indices)
class GroupSampler(Sampler):
def __init__(self, dataset, samples_per_gpu=1):
assert hasattr(dataset, 'flag')
self.dataset = dataset
self.samples_per_gpu = samples_per_gpu
self.flag = dataset.flag.astype(np.int64)
self.epoch = 0
#self.flag = (torch.randn(len(dataset)) > 0).type(torch.int64)
self.group_sizes = np.bincount(self.flag)
if min(self.group_sizes) < self.samples_per_gpu:
for i in range(len(self.flag)):
self.flag[i] = i % 2
self.group_sizes = np.bincount(self.flag)
print('\033[1;35m >>> group sampler randomly aranged!\033[0;0m')
self.num_samples = 0
for i, size in enumerate(self.group_sizes):
self.num_samples += int(np.ceil(
size / self.samples_per_gpu)) * self.samples_per_gpu
def __iter__(self):
indices = []
for i, size in enumerate(self.group_sizes):
if size == 0:
continue
indice = np.where(self.flag == i)[0]
assert len(indice) == size
np.random.shuffle(indice)
num_extra = int(np.ceil(size / self.samples_per_gpu)
) * self.samples_per_gpu - len(indice)
indice = np.concatenate([indice, indice[:num_extra]])
indices.append(indice)
indices = np.concatenate(indices)
indices = [
indices[i * self.samples_per_gpu:(i + 1) * self.samples_per_gpu]
for i in np.random.permutation(
range(len(indices) // self.samples_per_gpu))
]
indices = np.concatenate(indices)
indices = indices.astype(np.int64).tolist()
assert len(indices) == self.num_samples
return iter(indices)
def __len__(self):
return self.num_samples
def set_epoch(self, epoch):
self.epoch = epoch
class DistributedGroupSampler(Sampler):
"""Sampler that restricts data loading to a subset of the dataset.
It is especially useful in conjunction with
:class:`torch.nn.parallel.DistributedDataParallel`. In such case, each
process can pass a DistributedSampler instance as a DataLoader sampler,
and load a subset of the original dataset that is exclusive to it.
.. note::
Dataset is assumed to be of constant size.
Arguments:
dataset: Dataset used for sampling.
num_replicas (optional): Number of processes participating in
distributed training.
rank (optional): Rank of the current process within num_replicas.
"""
def __init__(self,
dataset,
samples_per_gpu=1,
num_replicas=None,
rank=None):
_rank, _num_replicas = get_dist_info()
if num_replicas is None:
num_replicas = _num_replicas
if rank is None:
rank = _rank
self.dataset = dataset
self.samples_per_gpu = samples_per_gpu
self.num_replicas = num_replicas
self.rank = rank
self.epoch = 0
assert hasattr(self.dataset, 'flag')
self.flag = self.dataset.flag
self.group_sizes = np.bincount(self.flag)
self.num_samples = 0
for i, j in enumerate(self.group_sizes):
self.num_samples += int(
math.ceil(self.group_sizes[i] * 1.0 / self.samples_per_gpu /
self.num_replicas)) * self.samples_per_gpu
self.total_size = self.num_samples * self.num_replicas
def __iter__(self):
# deterministically shuffle based on epoch
g = torch.Generator()
g.manual_seed(self.epoch)
indices = []
for i, size in enumerate(self.group_sizes):
if size > 0:
indice = np.where(self.flag == i)[0]
assert len(indice) == size
indice = indice[list(torch.randperm(int(size),
generator=g))].tolist()
extra = int(
math.ceil(
size * 1.0 / self.samples_per_gpu / self.num_replicas)
) * self.samples_per_gpu * self.num_replicas - len(indice)
indice += indice[:extra]
indices += indice
assert len(indices) == self.total_size
indices = [
indices[j] for i in list(
torch.randperm(
len(indices) // self.samples_per_gpu, generator=g))
for j in range(i * self.samples_per_gpu, (i + 1) *
self.samples_per_gpu)
]
# subsample
offset = self.num_samples * self.rank
indices = indices[offset:offset + self.num_samples]
assert len(indices) == self.num_samples
return iter(indices)
def __len__(self):
return self.num_samples
def set_epoch(self, epoch):
self.epoch = epoch
class RandomCycleIter:
def __init__(self, data, test_mode=False):
self.data_list = list(data)
self.length = len(self.data_list)
self.i = self.length - 1
self.test_mode = test_mode
def __iter__(self):
return self
def __next__(self):
self.i += 1
if self.i == self.length:
self.i = 0
if not self.test_mode:
random.shuffle(self.data_list)
return self.data_list[self.i]
def class_aware_sample_generator(cls_iter, data_iter_list, n, num_samples_cls=1):
i = 0
j = 0
while i < n:
# yield next(data_iter_list[next(cls_iter)])
if j >= num_samples_cls:
j = 0
if j == 0:
temp_tuple = next(zip(*[data_iter_list[next(cls_iter)]] * num_samples_cls))
yield temp_tuple[j]
else:
yield temp_tuple[j]
i += 1
j += 1
class ClassAwareSampler(Sampler):
def __init__(self, data_source, num_samples_cls=1,):
num_classes = len(np.unique(data_source.targets))
self.class_iter = RandomCycleIter(range(num_classes))
cls_data_list = [list() for _ in range(num_classes)]
for i, label in enumerate(data_source.targets):
cls_data_list[label].append(i)
self.data_iter_list = [RandomCycleIter(x) for x in cls_data_list]
self.num_samples = max([len(x) for x in cls_data_list]) * len(cls_data_list)
self.num_samples_cls = num_samples_cls
print('>>> Class Aware Sampler Built! Class number: {}'.format(num_classes))
def __iter__(self):
return class_aware_sample_generator(self.class_iter, self.data_iter_list,
self.num_samples, self.num_samples_cls)
def __len__(self):
return self.num_samples
def set_epoch(self, epoch):
self.epoch = epoch
class MixSampler(Sampler):
def __init__(self, data_source, num_samples_cls=1,):
num_classes = len(np.unique(data_source.targets))
self.class_iter = RandomCycleIter(range(num_classes))
cls_data_list = [list() for _ in range(num_classes)]
for i, label in enumerate(data_source.targets):
cls_data_list[label].append(i)
self.data_iter_list = [RandomCycleIter(x) for x in cls_data_list]
self.num_samples = len(data_source)
self.num_samples_cls = num_samples_cls
self.epoch = 0
print('>>> Class Aware Sampler Built! Class number: {}'.format(num_classes))
def __iter__(self):
# deterministically shuffle based on epoch
g = torch.Generator()
g.manual_seed(self.epoch)
random_indices = torch.randperm(self.num_samples, generator=g).tolist()
class_aware_samplar = class_aware_sample_generator(self.class_iter, self.data_iter_list,
self.num_samples, self.num_samples_cls)
class_aware_indices = [i for i in class_aware_samplar]
indices = list(itertools.chain.from_iterable(zip(random_indices, class_aware_indices)))
return iter(indices)
def __len__(self):
return self.num_samples
def set_epoch(self, epoch):
self.epoch = epoch