-
Notifications
You must be signed in to change notification settings - Fork 66
/
Copy pathtest.py
66 lines (49 loc) · 1.83 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
from __future__ import print_function
import argparse
import cv2
import numpy as np
from model.unet import unet
from model.fcn import fcn_8s
from model.pspnet import pspnet50
def result_map_to_img(res_map):
img = np.zeros((256, 512, 3), dtype=np.uint8)
res_map = np.squeeze(res_map)
argmax_idx = np.argmax(res_map, axis=2)
# For np.where calculation.
person = (argmax_idx == 1)
car = (argmax_idx == 2)
road = (argmax_idx == 3)
img[:, :, 0] = np.where(person, 255, 0)
img[:, :, 1] = np.where(car, 255, 0)
img[:, :, 2] = np.where(road, 255, 0)
return img
# Parse Options
parser = argparse.ArgumentParser()
parser.add_argument("-M", "--model", required=True, choices=['fcn', 'unet', 'pspnet'],
help="Model to test. 'fcn', 'unet', 'pspnet' is available.")
parser.add_argument("-P", "--img_path", required=True, help="The image path you want to test")
args = parser.parse_args()
model_name = args.model
img_path = args.img_path
# Use only 3 classes.
labels = ['background', 'person', 'car', 'road']
# Choose model to train
if model_name == "fcn":
model = fcn_8s(input_shape=(256, 512, 3), num_classes=len(labels), lr_init=1e-3, lr_decay=5e-4)
elif model_name == "unet":
model = unet(input_shape=(256, 512, 3), num_classes=len(labels), lr_init=1e-3, lr_decay=5e-4)
elif model_name == "pspnet":
model = pspnet50(input_shape=(256, 512, 3), num_classes=len(labels), lr_init=1e-3, lr_decay=5e-4)
try:
model.load_weights(model_name + '_model_weight.h5')
except:
print("You must train model and get weight before test.")
x_img = cv2.imread(img_path)
cv2.imshow('x_img', x_img)
x_img = cv2.cvtColor(x_img, cv2.COLOR_BGR2RGB)
x_img = x_img / 127.5 - 1
x_img = np.expand_dims(x_img, 0)
pred = model.predict(x_img)
res = result_map_to_img(pred[0])
cv2.imshow('res', res)
cv2.waitKey(0)