Skip to content

Commit 5b9e577

Browse files
committed
fix: cast model
1 parent 3a6133d commit 5b9e577

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

toxsmi/models/mca.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,6 @@ def __init__(self, params: dict, *args, **kwargs):
9494

9595
# Build the model. First the embeddings
9696
if params.get("embedding", "learned") == "learned":
97-
9897
self.smiles_embedding = nn.Embedding(
9998
self.params["smiles_vocabulary_size"],
10099
self.params["smiles_embedding_size"],
@@ -252,6 +251,8 @@ def __init__(self, params: dict, *args, **kwargs):
252251
):
253252
self.loss_fn.class_weights = params.get("class_weights", [1, 1])
254253

254+
self.to(self.device)
255+
255256
def forward(self, smiles: torch.Tensor) -> Tuple[torch.Tensor, dict]:
256257
"""Forward pass through the MCA.
257258
@@ -279,7 +280,6 @@ def forward(self, smiles: torch.Tensor) -> Tuple[torch.Tensor, dict]:
279280
smiles_alphas, encodings = [], []
280281
for layer in range(len(self.multiheads)):
281282
for head in range(self.multiheads[layer]):
282-
283283
ind = self.multiheads[0] * layer + head
284284
smiles_alphas.append(
285285
self.alpha_projections[ind](

0 commit comments

Comments
 (0)