Skip to content

Commit 448022e

Browse files
committed
feat: support fbeta in logger
1 parent 5b9e577 commit 448022e

File tree

1 file changed

+33
-7
lines changed

1 file changed

+33
-7
lines changed

toxsmi/utils/performance.py

+33-7
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
average_precision_score,
1313
balanced_accuracy_score,
1414
classification_report,
15-
f1_score,
15+
fbeta_score,
1616
mean_absolute_error,
1717
mean_squared_error,
1818
precision_recall_curve,
@@ -56,6 +56,7 @@ def __init__(
5656
train_batches: int,
5757
test_batches: int,
5858
task_names: List[str],
59+
beta: float=1
5960
):
6061

6162
if task == "binary_classification":
@@ -64,6 +65,7 @@ def __init__(
6465
self.metric_initializer("roc_auc", 0)
6566
self.metric_initializer("accuracy", 0)
6667
self.metric_initializer("precision_recall", 0)
68+
self.metric_initializer("f1", 0)
6769
self.task_final_report = self.final_report_binary_classification
6870
elif task == "regression":
6971
self.report = self.performance_report_regression
@@ -86,6 +88,8 @@ def __init__(
8688
self.train_batches = train_batches
8789
self.test_batches = test_batches
8890
self.metrics = []
91+
# for Fbeta score, only used in classification mode
92+
self.beta = beta
8993

9094
def metric_initializer(self, metric: str, value: float):
9195
setattr(self, metric, value)
@@ -120,18 +124,33 @@ def performance_report_binary_classification(
120124
# score for precision vs accuracy
121125
precision_recall = average_precision_score(labels, preds)
122126

127+
bin_preds, youden = binarize_predictions(preds, labels, return_youden=True)
128+
report = classification_report(labels, bin_preds, output_dict=True)
129+
negative_precision = report["0.0"]["precision"]
130+
negative_recall = report["0.0"]["recall"]
131+
positive_precision = report["1.0"]["precision"]
132+
positive_recall = report["1.0"]["recall"]
133+
f1 = fbeta_score(labels, bin_preds, beta=self.beta,pos_label=1, average='binary')
134+
123135
logger.info(
124136
f"\t **** TEST **** Epoch [{self.epoch + 1}/{self.epochs}], "
125-
f"loss: {loss_a:.5f}, , roc_auc: {roc_auc:.5f}, "
126-
f"avg precision-recall score: {precision_recall:.5f}"
137+
f"loss: {loss_a:.5f}, roc_auc: {roc_auc:.5f}, "
138+
f"avg precision-recall score: {precision_recall:.5f}, "
139+
f"PosPrecision: {positive_precision:.5f}, "
140+
f"PosRecall: {positive_recall:.5f}, "
141+
f"NegPrecision: {negative_precision:.5f}, "
142+
f"NegRecall: {negative_recall:.5f}, "
143+
f"F1 ({self.beta}): {f1:.5f}"
127144
)
128145
info = {
129146
"test_loss": loss_a,
130147
"best_loss": self.loss,
131148
"test_auc": roc_auc,
132149
"best_auc": self.roc_auc,
133150
"test_precision_recall": precision_recall,
134-
"best_precision_recall": self.precision_recall,
151+
'best_precision_recall': self.precision_recall,
152+
"test_f1": f1,
153+
'best_f1': self.f1
135154
}
136155
self.metrics.append(info)
137156
if roc_auc > self.roc_auc:
@@ -140,8 +159,12 @@ def performance_report_binary_classification(
140159
best = "ROC-AUC"
141160
if precision_recall > self.precision_recall:
142161
self.precision_recall = precision_recall
143-
# self.save_model(model, "Precision-Recall", "best", value=precision_recall)
162+
self.save_model(model, "Precision-Recall", "best", value=precision_recall)
144163
best = "Precision-Recall"
164+
if f1 > self.f1:
165+
self.f1 = f1
166+
self.save_model(model, 'F1', 'best', value=f1)
167+
best = 'F1'
145168
if loss_a < self.loss:
146169
self.loss = loss_a
147170
self.save_model(model, "loss", "best", value=loss_a)
@@ -219,7 +242,7 @@ def inference_report_binary_classification(
219242
positive_recall = report["1"]["recall"]
220243
accuracy = accuracy_score(labels, bin_preds)
221244
bal_accuracy = balanced_accuracy_score(labels, bin_preds)
222-
f1 = f1_score(labels, bin_preds)
245+
f1 = fbeta_score(labels, bin_preds, beta=0.5,pos_label=1, average='binary')
223246

224247
info = {
225248
"roc_auc": roc_auc,
@@ -258,7 +281,7 @@ def inference_report_binarized_regression(
258281
positive_recall = report.get("1", {"recall": 0.0})["recall"]
259282
accuracy = accuracy_score(bin_labels, bin_preds)
260283
bal_accuracy = balanced_accuracy_score(bin_labels, bin_preds)
261-
f1 = f1_score(bin_labels, bin_preds)
284+
f1 = fbeta_score(bin_labels, bin_preds, beta=0.5 ,pos_label=1, average='binary')
262285

263286
info = {
264287
"f1": f1,
@@ -329,6 +352,9 @@ def final_report_binary_classification(self):
329352
logger.info(
330353
f"Precision-Recall = {self.precision_recall:.4f} in epoch {self.metric_df['test_precision_recall'].idxmax()} "
331354
)
355+
logger.info(
356+
f"F1 ({self.beta})= {self.f1:.4f} in epoch {self.metric_df['test_f1'].idxmax()} "
357+
)
332358

333359
def final_report_regression(self):
334360
logger.info(

0 commit comments

Comments
 (0)