-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathImageDataLoader.py
72 lines (63 loc) · 2.51 KB
/
ImageDataLoader.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
from PIL import Image
import os
import os.path
import torch.utils.data
import torchvision.transforms as transforms
import numpy as np
import torch
def default_image_loader(path):
return Image.open(path).convert('RGB')
class TransformTwice:
def __init__(self, transform):
self.transform = transform
def __call__(self, inp):
out1 = self.transform(inp)
out2 = self.transform(inp)
return out1, out2
class SimpleImageLoader(torch.utils.data.Dataset):
def __init__(self, rootdir, split, ids=None, transform=None, loader=default_image_loader):
if split == 'test':
self.impath = os.path.join(rootdir, 'test_data')
meta_file = os.path.join(self.impath, 'test_meta.txt')
else:
self.impath = os.path.join(rootdir, 'train/train_data')
meta_file = os.path.join(rootdir, 'train/train_label')
imnames = []
imclasses = []
with open(meta_file, 'r') as rf:
for i, line in enumerate(rf):
if i == 0:
continue
instance_id, label, file_name = line.strip().split()
if int(label) == -1 and (split != 'unlabel' and split != 'test'):
continue
if int(label) != -1 and (split == 'unlabel' or split == 'test'):
continue
if (ids is None) or (int(instance_id) in ids):
if os.path.exists(os.path.join(self.impath, file_name)):
imnames.append(file_name)
if split == 'train' or split == 'val':
imclasses.append(int(label))
self.transform = transform
self.TransformTwice = TransformTwice(transform)
self.loader = loader
self.split = split
self.imnames = imnames
self.imclasses = imclasses
def __getitem__(self, index):
filename = self.imnames[index]
img = self.loader(os.path.join(self.impath, filename))
if self.split == 'test':
if self.transform is not None:
img = self.transform(img)
return img
elif self.split != 'unlabel':
if self.transform is not None:
img = self.transform(img)
label = self.imclasses[index]
return img, label
else:
img1, img2 = self.TransformTwice(img)
return img1, img2
def __len__(self):
return len(self.imnames)