-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathgridsearch_hyperpars.py
107 lines (92 loc) · 3.12 KB
/
gridsearch_hyperpars.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
import json
import itertools
from utils.templates_hyperpars import *
def grid_btree(
nl = [2, 3],
ns = [20, 40],
ni = [20, 40]
):
return [template_btree(*p) for p in list(itertools.product(nl, ns, ni))]
def grid_vtree(
nl = [2, 3],
ns = [20, 40],
ni = [20, 40]
):
return [template_vtree(*p) for p in list(itertools.product(nl, ns, ni))]
def grid_rtree(
nl = [2, 3],
nr = [1, 10],
ns = [20, 40],
ni = [20, 40]
):
return [template_rtree(*p) for p in list(itertools.product(nl, nr, ns, ni))]
def grid_ptree(
nl = [2, 3],
ns = [20, 40],
ni = [20, 40]
):
return [template_ptree(*p) for p in list(itertools.product(nl, ns, ni))]
def grid_ctree(
nh = [2, 3]
):
return [template_ctree(p) for p in nh]
def grid_sort(dataset, model):
order = ['canonical', 'bft', 'dft', 'rcm', 'unordered']
nc = [1]
backend_name = ['btree', 'vtree', 'rtree', 'ptree', 'ctree']
backend_grid = [grid_btree, grid_vtree, grid_rtree, grid_ptree, grid_ctree]
match dataset:
case 'qm9':
backend_xpar = [
{"nl":[3], "ns":[32], "ni":[32]},
{"nl":[3], "ns":[32], "ni":[32]},
{"nl":[3], "nr":[16], "ns":[32], "ni":[32]},
{"nl":[3], "ns":[32], "ni":[32]},
{"nh":[256]}
]
backend_apar = [
{"nl":[5], "ns":[32], "ni":[32]},
{"nl":[5], "ns":[32], "ni":[32]},
{"nl":[5], "nr":[16], "ns":[32], "ni":[32]},
{"nl":[5], "ns":[32], "ni":[32]},
{"nh":[256]}
]
case 'zinc250k':
backend_xpar = [
{"nl":[4], "ns":[32], "ni":[32]},
{"nl":[4], "ns":[32], "ni":[32]},
{"nl":[4], "nr":[16], "ns":[32], "ni":[32]},
{"nl":[4], "ns":[32], "ni":[32]},
{"nh":[256]}
]
backend_apar = [
{"nl":[6], "ns":[32], "ni":[32]},
{"nl":[6], "ns":[32], "ni":[32]},
{"nl":[6], "nr":[16], "ns":[32], "ni":[32]},
{"nl":[6], "ns":[32], "ni":[32]},
{"nh":[256]}
]
case _:
raise 'Unknown dataset'
backend_nr = [
[None],
[None],
[None],
[16],
[None]
]
batch_size = [256]
lr = [0.05]
seed = [0]
hyperpars = []
for b_name, b_grid, b_xpar, b_apar, b_nr in zip(backend_name, backend_grid, backend_xpar, backend_apar, backend_nr):
grid = itertools.product(order, nc, b_nr, [b_name], b_grid(**b_xpar), b_grid(**b_apar), batch_size, lr, seed)
hyperpars.extend([template_sort(dataset, model, *p) for p in list(grid)])
return hyperpars
GRIDS = {
'marg_sort': grid_sort,
}
if __name__ == "__main__":
print(len(grid_sort('qm9', 'marg_sort')))
# for p in grid_sort('qm9', 'marg_sort')[-9:]:
# print(json.dumps(p, indent=4))