Skip to content

Commit 6887c80

Browse files
committed
device-agnostic model loading + formatting
1 parent 448022e commit 6887c80

File tree

2 files changed

+18
-20
lines changed

2 files changed

+18
-20
lines changed

scripts/train_tox.py

+1-4
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,6 @@ def main(
8686
embedding_path: str,
8787
finetune_path: str,
8888
):
89-
9089
logging.basicConfig(level=logging.INFO, format="%(message)s")
9190
logger = logging.getLogger(f"{training_name}")
9291
logger.setLevel(logging.INFO)
@@ -282,7 +281,7 @@ def smiles_tensor_batch_to_fp(smiles):
282281
if finetune_path:
283282
if os.path.isfile(finetune_path):
284283
try:
285-
model.load(finetune_path)
284+
model.load(finetune_path, map_location=device)
286285
logger.info(f"Restored pretrained model {finetune_path}")
287286
except Exception:
288287
raise KeyError(f"Could not restore model from {finetune_path}")
@@ -313,7 +312,6 @@ def smiles_tensor_batch_to_fp(smiles):
313312
)
314313

315314
for epoch in range(params["epochs"]):
316-
317315
performer.epoch += 1
318316
model.train()
319317
logger.info(params_filepath.split("/")[-1])
@@ -349,7 +347,6 @@ def smiles_tensor_batch_to_fp(smiles):
349347
predictions = []
350348
labels = []
351349
for ind, (smiles, y) in enumerate(test_loader):
352-
353350
smiles = torch.squeeze(smiles.to(device))
354351
# Transform smiles to FP if needed
355352
if params.get("model_fn", "mca") == "dense":

toxsmi/utils/performance.py

+17-16
Original file line numberDiff line numberDiff line change
@@ -56,9 +56,8 @@ def __init__(
5656
train_batches: int,
5757
test_batches: int,
5858
task_names: List[str],
59-
beta: float=1
59+
beta: float = 1,
6060
):
61-
6261
if task == "binary_classification":
6362
self.report = self.performance_report_binary_classification
6463
self.inference_report = self.inference_report_binary_classification
@@ -70,14 +69,14 @@ def __init__(
7069
elif task == "regression":
7170
self.report = self.performance_report_regression
7271
self.inference_report = self.inference_report_regression
73-
self.metric_initializer("rmse", 10 ** 9)
74-
self.metric_initializer("mae", 10 ** 9)
72+
self.metric_initializer("rmse", 10**9)
73+
self.metric_initializer("mae", 10**9)
7574
self.metric_initializer("pearson", -1)
7675
self.metric_initializer("spearman", -1)
7776
self.task_final_report = self.final_report_regression
7877
else:
7978
raise ValueError(f"Unknown task {task}")
80-
self.metric_initializer("loss", 10 ** 9)
79+
self.metric_initializer("loss", 10**9)
8180

8281
self.task = task
8382
self.task_names = task_names
@@ -130,7 +129,9 @@ def performance_report_binary_classification(
130129
negative_recall = report["0.0"]["recall"]
131130
positive_precision = report["1.0"]["precision"]
132131
positive_recall = report["1.0"]["recall"]
133-
f1 = fbeta_score(labels, bin_preds, beta=self.beta,pos_label=1, average='binary')
132+
f1 = fbeta_score(
133+
labels, bin_preds, beta=self.beta, pos_label=1, average="binary"
134+
)
134135

135136
logger.info(
136137
f"\t **** TEST **** Epoch [{self.epoch + 1}/{self.epochs}], "
@@ -148,9 +149,9 @@ def performance_report_binary_classification(
148149
"test_auc": roc_auc,
149150
"best_auc": self.roc_auc,
150151
"test_precision_recall": precision_recall,
151-
'best_precision_recall': self.precision_recall,
152+
"best_precision_recall": self.precision_recall,
152153
"test_f1": f1,
153-
'best_f1': self.f1
154+
"best_f1": self.f1,
154155
}
155156
self.metrics.append(info)
156157
if roc_auc > self.roc_auc:
@@ -163,8 +164,8 @@ def performance_report_binary_classification(
163164
best = "Precision-Recall"
164165
if f1 > self.f1:
165166
self.f1 = f1
166-
self.save_model(model, 'F1', 'best', value=f1)
167-
best = 'F1'
167+
self.save_model(model, "F1", "best", value=f1)
168+
best = "F1"
168169
if loss_a < self.loss:
169170
self.loss = loss_a
170171
self.save_model(model, "loss", "best", value=loss_a)
@@ -236,13 +237,13 @@ def inference_report_binary_classification(
236237
precision, recall, _ = precision_recall_curve(labels, preds)
237238
precision_recall = average_precision_score(labels, preds)
238239
report = classification_report(labels, bin_preds, output_dict=True)
239-
negative_precision = report["0"]["precision"]
240-
negative_recall = report["0"]["recall"]
241-
positive_precision = report["1"]["precision"]
242-
positive_recall = report["1"]["recall"]
240+
negative_precision = report["0.0"]["precision"]
241+
negative_recall = report["0.0"]["recall"]
242+
positive_precision = report["1.0"]["precision"]
243+
positive_recall = report["1.0"]["recall"]
243244
accuracy = accuracy_score(labels, bin_preds)
244245
bal_accuracy = balanced_accuracy_score(labels, bin_preds)
245-
f1 = fbeta_score(labels, bin_preds, beta=0.5,pos_label=1, average='binary')
246+
f1 = fbeta_score(labels, bin_preds, beta=0.5, pos_label=1, average="binary")
246247

247248
info = {
248249
"roc_auc": roc_auc,
@@ -281,7 +282,7 @@ def inference_report_binarized_regression(
281282
positive_recall = report.get("1", {"recall": 0.0})["recall"]
282283
accuracy = accuracy_score(bin_labels, bin_preds)
283284
bal_accuracy = balanced_accuracy_score(bin_labels, bin_preds)
284-
f1 = fbeta_score(bin_labels, bin_preds, beta=0.5 ,pos_label=1, average='binary')
285+
f1 = fbeta_score(bin_labels, bin_preds, beta=0.5, pos_label=1, average="binary")
285286

286287
info = {
287288
"f1": f1,

0 commit comments

Comments
 (0)