-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtta_remap.py
156 lines (140 loc) · 5.42 KB
/
tta_remap.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
#!/usr/bin/env python3
# This file is covered by the LICENSE file in the root of this project.
import argparse
import os
import yaml
import numpy as np
import pdb
from pcseg.data.dataset.ceph import PetrelBackend
import multiprocessing
# possible splits
splits = ["train", "valid", "trainval", "test"]
if __name__ == '__main__':
parser = argparse.ArgumentParser("./remap_semantic_labels.py")
parser.add_argument(
'--dataset', '-d',
type=str,
required=False,
default=None,
help='Dataset dir. WARNING: This file remaps the labels in place, so the original labels will be lost. Cannot be used together with -predictions- flag.'
)
parser.add_argument(
'--predictions', '-p',
type=str,
required=False,
default=None,
help='Prediction dir. WARNING: This file remaps the predictions in place, so the original predictions will be lost. Cannot be used together with -dataset- flag.'
)
parser.add_argument(
'--split', '-s',
type=str,
required=False,
default="valid",
help='Split to evaluate on. One of ' +
str(splits) + '. Defaults to %(default)s',
)
parser.add_argument(
'--datacfg', '-dc',
type=str,
required=False,
default="semantic-kitti-all.yaml",
help='Dataset config file. Defaults to %(default)s',
)
parser.add_argument(
'--inverse',
dest='inverse',
default=False,
action='store_true',
help='Map from xentropy to original, instead of original to xentropy. '
'Defaults to %(default)s',
)
FLAGS, unparsed = parser.parse_known_args()
# print summary of what we will do
print("*" * 80)
print("INTERFACE:")
print("Data: ", FLAGS.dataset)
print("Predictions: ", FLAGS.predictions)
print("Split: ", FLAGS.split)
print("Config: ", FLAGS.datacfg)
print("Inverse: ", FLAGS.inverse)
print("*" * 80)
# only predictions or dataset can be handled at once and one MUST be given (xor)
assert((FLAGS.dataset is not None) != (FLAGS.predictions is not None))
# check name
root_directory = ""
label_directory = ""
if(FLAGS.dataset is not None):
root_directory = FLAGS.dataset
label_directory = "labels"
elif(FLAGS.predictions is not None):
root_directory = FLAGS.predictions
label_directory = "predictions"
else:
print("I don't even know how I got here")
quit()
# assert split
assert(FLAGS.split in splits)
print("Opening data config file %s" % FLAGS.datacfg)
DATA = yaml.safe_load(open(FLAGS.datacfg, 'r'))
# get number of interest classes, and the label mappings
if FLAGS.inverse:
print("Mapping xentropy to original labels")
remapdict = DATA["learning_map_inv"]
else:
remapdict = DATA["learning_map"]
nr_classes = len(remapdict)
# make lookup table for mapping
maxkey = max(remapdict.keys())
# +100 hack making lut bigger just in case there are unknown labels
remap_lut = np.zeros((maxkey + 100), dtype=np.int32)
remap_lut[list(remapdict.keys())] = list(remapdict.values())
# print(remap_lut)
# get wanted set
sequences = []
sequences.extend(DATA["split"][FLAGS.split])
if 's3://' in root_directory:
petrel_client = PetrelBackend()
for sequence in sequences:
sequence = '{0:02d}'.format(int(sequence))
label_dir = os.path.join(root_directory, "sequences",
sequence, label_directory)
seq_label_names = petrel_client.list_dir_one_depth(label_dir)
seq_label_names = [seq_label_name for seq_label_name in seq_label_names if seq_label_name.endswith('.label')]
seq_label_names.sort()
remap_dir = os.path.join(root_directory+'_remap', "sequences",
sequence, label_directory)
def remap_single_frame(label_name):
label_file = os.path.join(label_dir, label_name)
print(label_file)
label = np.copy(petrel_client.load_bin(label_file, dtype='uint32').reshape((-1)))
label = label.reshape((-1))
upper_half = label >> 16 # get upper half for instances
lower_half = label & 0xFFFF # get lower half for semantics
lower_half = remap_lut[lower_half] # do the remapping of semantics
label = (upper_half << 16) + lower_half # reconstruct full label
label = label.astype(np.uint32)
remap_file = os.path.join(remap_dir, label_name)
petrel_client.save_bin(remap_file, label)
return remap_file
with multiprocessing.Pool(32) as p:
remap_path_list = list(p.map(remap_single_frame, seq_label_names))
else:
label_names = []
for sequence in sequences:
sequence = '{0:02d}'.format(int(sequence))
label_paths = os.path.join(root_directory, "sequences",
sequence, label_directory)
seq_label_names = [os.path.join(dp, f) for dp, dn, fn in os.walk(
os.path.expanduser(label_paths)) for f in fn if ".label" in f]
seq_label_names.sort()
label_names.extend(seq_label_names)
for label_file in label_names:
print(label_file)
label = np.fromfile(label_file, dtype=np.uint32)
label = label.reshape((-1))
upper_half = label >> 16 # get upper half for instances
lower_half = label & 0xFFFF # get lower half for semantics
lower_half = remap_lut[lower_half] # do the remapping of semantics
label = (upper_half << 16) + lower_half # reconstruct full label
label = label.astype(np.uint32)
label.tofile(label_file)