Skip to content

Commit 7653dae

Browse files
committed
:q
1 parent 6887c80 commit 7653dae

File tree

3 files changed

+11
-11
lines changed

3 files changed

+11
-11
lines changed

README.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ where `-checkpoint` specifies which `.pt` file to pick for the evaluation (based
8282

8383
## Attention visualization
8484
The model uses a self-attention mechanism that can highlight chemical motifs used for the predictions.
85-
In [notebooks/toxicity_attention.ipynb](notebooks/toxicity_attention.ipynb) we share a tutorial on how to create such plots:
85+
In [notebooks/toxicity_attention_plot.ipynb](notebooks/toxicity_attention_plot.ipynb) we share a tutorial on how to create such plots:
8686
![Attention](assets/attention.gif "toxicophore attention")
8787

8888

toxsmi/models/mca.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
smiles_projection,
1212
)
1313
from paccmann_predictor.utils.utils import get_device
14-
1514
from toxsmi.utils.hyperparams import ACTIVATION_FN_FACTORY, LOSS_FN_FACTORY
1615
from toxsmi.utils.layers import EnsembleLayer
1716

@@ -231,6 +230,7 @@ def __init__(self, params: dict, *args, **kwargs):
231230
self.loss_name = params.get(
232231
"loss_fn", "binary_cross_entropy_ignore_nan_and_sum"
233232
)
233+
234234
final_activation = (
235235
ACTIVATION_FN_FACTORY["sigmoid"]
236236
if "cross" in self.loss_name

toxsmi/utils/performance.py

+9-9
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,7 @@ def process_data(self, labels: np.array, preds: np.array):
106106
preds = preds[~np.isnan(labels)]
107107
labels = labels[~np.isnan(labels)]
108108

109-
return labels, preds
109+
return labels.astype(float), preds.astype(float)
110110

111111
def performance_report_binary_classification(
112112
self, labels: np.array, preds: np.array, loss: float, model: Callable
@@ -125,10 +125,10 @@ def performance_report_binary_classification(
125125

126126
bin_preds, youden = binarize_predictions(preds, labels, return_youden=True)
127127
report = classification_report(labels, bin_preds, output_dict=True)
128-
negative_precision = report["0.0"]["precision"]
129-
negative_recall = report["0.0"]["recall"]
130-
positive_precision = report["1.0"]["precision"]
131-
positive_recall = report["1.0"]["recall"]
128+
negative_precision = report.get("0.0", {}).get("precision", -1)
129+
negative_recall = report.get("0.0", {}).get("recall", -1)
130+
positive_precision = report.get("1.0", {}).get("precision", -1)
131+
positive_recall = report.get("1.0", {}).get("recall", -1)
132132
f1 = fbeta_score(
133133
labels, bin_preds, beta=self.beta, pos_label=1, average="binary"
134134
)
@@ -237,10 +237,10 @@ def inference_report_binary_classification(
237237
precision, recall, _ = precision_recall_curve(labels, preds)
238238
precision_recall = average_precision_score(labels, preds)
239239
report = classification_report(labels, bin_preds, output_dict=True)
240-
negative_precision = report["0.0"]["precision"]
241-
negative_recall = report["0.0"]["recall"]
242-
positive_precision = report["1.0"]["precision"]
243-
positive_recall = report["1.0"]["recall"]
240+
negative_precision = report.get("0.0", {}).get("precision", -1)
241+
negative_recall = report.get("0.0", {}).get("recall", -1)
242+
positive_precision = report.get("1.0", {}).get("precision", -1)
243+
positive_recall = report.get("1.0", {}).get("recall", -1)
244244
accuracy = accuracy_score(labels, bin_preds)
245245
bal_accuracy = balanced_accuracy_score(labels, bin_preds)
246246
f1 = fbeta_score(labels, bin_preds, beta=0.5, pos_label=1, average="binary")

0 commit comments

Comments
 (0)