forked from txie-93/cgcnn
-
Notifications
You must be signed in to change notification settings - Fork 6
/
Copy pathmodel.py
128 lines (110 loc) · 4.85 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
from __future__ import print_function, division
import torch
import torch.nn as nn
class ConvLayer(nn.Module):
"""
Convolutional operation on graphs
"""
def __init__(self, atom_fea_len, nbr_fea_len):
"""
Initialize ConvLayer.
Parameters
----------
atom_fea_len: int
Number of atom hidden features.
nbr_fea_len: int
Number of bond features.
"""
super(ConvLayer, self).__init__()
self.atom_fea_len = atom_fea_len
self.nbr_fea_len = nbr_fea_len
self.fc_full = nn.Linear(2*self.atom_fea_len+self.nbr_fea_len,
2*self.atom_fea_len)
self.sigmoid = nn.Sigmoid()
self.softplus1 = nn.Softplus()
self.bn1 = nn.BatchNorm1d(2*self.atom_fea_len)
self.bn2 = nn.BatchNorm1d(self.atom_fea_len)
self.softplus2 = nn.Softplus()
def forward(self, atom_in_fea, nbr_fea, nbr_fea_idx):
"""
Forward pass
N: Total number of atoms in the batch
M: Max number of neighbors
Parameters
----------
atom_in_fea: Variable(torch.Tensor) shape (N, atom_fea_len)
Atom hidden features before convolution
nbr_fea: Variable(torch.Tensor) shape (N, M, nbr_fea_len)
Bond features of each atom's M neighbors
nbr_fea_idx: torch.LongTensor shape (N, M)
Indices of M neighbors of each atom
Returns
-------
atom_out_fea: nn.Variable shape (N, atom_fea_len)
Atom hidden features after convolution
"""
# TODO will there be problems with the index zero padding?
N, M = nbr_fea_idx.shape
atom_nbr_fea = atom_in_fea[nbr_fea_idx, :]
total_nbr_fea = torch.cat(
[atom_in_fea.unsqueeze(1).expand(N, M, self.atom_fea_len),
atom_nbr_fea, nbr_fea], dim=2)
total_gated_fea = self.fc_full(total_nbr_fea)
total_gated_fea = self.bn1(total_gated_fea.view(
-1, self.atom_fea_len*2)).view(N, M, self.atom_fea_len*2)
nbr_filter, nbr_core = total_gated_fea.chunk(2, dim=2)
nbr_filter = self.sigmoid(nbr_filter)
nbr_core = self.softplus1(nbr_core)
nbr_sumed = torch.sum(nbr_filter * nbr_core, dim=1)
nbr_sumed = self.bn2(nbr_sumed)
out = self.softplus2(atom_in_fea + nbr_sumed)
return out
class OrbitalCrystalGraphConvNet(nn.Module):
def __init__(self, orig_atom_fea_len, nbr_fea_len, orig_hot_fea_len,
atom_fea_len, hot_fea_len, h_fea_len, n_conv=3, n_h=1,
classification=False):
super(OrbitalCrystalGraphConvNet, self).__init__()
self.classification = classification
self.embedding1 = nn.Linear(orig_atom_fea_len, hot_fea_len)
self.relu = nn.Softplus()
self.embedding2 = nn.Linear(hot_fea_len, atom_fea_len)
self.convs = nn.ModuleList([ConvLayer(atom_fea_len=atom_fea_len,
nbr_fea_len=nbr_fea_len)
for _ in range(n_conv)])
self.conv_to_fc = nn.Linear(atom_fea_len, h_fea_len)
self.conv_to_fc_softplus = nn.Softplus()
if n_h > 1:
self.fcs = nn.ModuleList([nn.Linear(h_fea_len, h_fea_len)
for _ in range(n_h-1)])
self.softpluses = nn.ModuleList([nn.Softplus()
for _ in range(n_h-1)])
if self.classification:
self.fc_out = nn.Linear(h_fea_len, 2)
else:
self.fc_out = nn.Linear(h_fea_len, 1)
if self.classification:
self.logsoftmax = nn.LogSoftmax(dim=1)
self.dropout = nn.Dropout()
def forward(self, atom_fea, nbr_fea, nbr_fea_idx, crystal_atom_idx):
atom_fea = self.embedding1(atom_fea)
atom_fea = self.relu(atom_fea)
atom_fea = self.embedding2(atom_fea)
for conv_func in self.convs:
atom_fea = conv_func(atom_fea, nbr_fea, nbr_fea_idx)
crys_fea = self.pooling(atom_fea, crystal_atom_idx)
crys_fea = self.conv_to_fc(self.conv_to_fc_softplus(crys_fea))
crys_fea = self.conv_to_fc_softplus(crys_fea)
if self.classification:
crys_fea = self.dropout(crys_fea)
if hasattr(self, 'fcs') and hasattr(self, 'softpluses'):
for fc, softplus in zip(self.fcs, self.softpluses):
crys_fea = softplus(fc(crys_fea))
out = self.fc_out(crys_fea)
if self.classification:
out = self.logsoftmax(out)
return out
def pooling(self, atom_fea, crystal_atom_idx):
assert sum([len(idx_map) for idx_map in crystal_atom_idx]) == atom_fea.data.shape[0]
summed_fea = [torch.mean(atom_fea[idx_map], dim=0, keepdim=True)
for idx_map in crystal_atom_idx]
return torch.cat(summed_fea, dim=0)