12
12
average_precision_score ,
13
13
balanced_accuracy_score ,
14
14
classification_report ,
15
- f1_score ,
15
+ fbeta_score ,
16
16
mean_absolute_error ,
17
17
mean_squared_error ,
18
18
precision_recall_curve ,
@@ -56,6 +56,7 @@ def __init__(
56
56
train_batches : int ,
57
57
test_batches : int ,
58
58
task_names : List [str ],
59
+ beta : float = 1
59
60
):
60
61
61
62
if task == "binary_classification" :
@@ -64,6 +65,7 @@ def __init__(
64
65
self .metric_initializer ("roc_auc" , 0 )
65
66
self .metric_initializer ("accuracy" , 0 )
66
67
self .metric_initializer ("precision_recall" , 0 )
68
+ self .metric_initializer ("f1" , 0 )
67
69
self .task_final_report = self .final_report_binary_classification
68
70
elif task == "regression" :
69
71
self .report = self .performance_report_regression
@@ -86,6 +88,8 @@ def __init__(
86
88
self .train_batches = train_batches
87
89
self .test_batches = test_batches
88
90
self .metrics = []
91
+ # for Fbeta score, only used in classification mode
92
+ self .beta = beta
89
93
90
94
def metric_initializer (self , metric : str , value : float ):
91
95
setattr (self , metric , value )
@@ -120,18 +124,33 @@ def performance_report_binary_classification(
120
124
# score for precision vs accuracy
121
125
precision_recall = average_precision_score (labels , preds )
122
126
127
+ bin_preds , youden = binarize_predictions (preds , labels , return_youden = True )
128
+ report = classification_report (labels , bin_preds , output_dict = True )
129
+ negative_precision = report ["0.0" ]["precision" ]
130
+ negative_recall = report ["0.0" ]["recall" ]
131
+ positive_precision = report ["1.0" ]["precision" ]
132
+ positive_recall = report ["1.0" ]["recall" ]
133
+ f1 = fbeta_score (labels , bin_preds , beta = self .beta ,pos_label = 1 , average = 'binary' )
134
+
123
135
logger .info (
124
136
f"\t **** TEST **** Epoch [{ self .epoch + 1 } /{ self .epochs } ], "
125
- f"loss: { loss_a :.5f} , , roc_auc: { roc_auc :.5f} , "
126
- f"avg precision-recall score: { precision_recall :.5f} "
137
+ f"loss: { loss_a :.5f} , roc_auc: { roc_auc :.5f} , "
138
+ f"avg precision-recall score: { precision_recall :.5f} , "
139
+ f"PosPrecision: { positive_precision :.5f} , "
140
+ f"PosRecall: { positive_recall :.5f} , "
141
+ f"NegPrecision: { negative_precision :.5f} , "
142
+ f"NegRecall: { negative_recall :.5f} , "
143
+ f"F1 ({ self .beta } ): { f1 :.5f} "
127
144
)
128
145
info = {
129
146
"test_loss" : loss_a ,
130
147
"best_loss" : self .loss ,
131
148
"test_auc" : roc_auc ,
132
149
"best_auc" : self .roc_auc ,
133
150
"test_precision_recall" : precision_recall ,
134
- "best_precision_recall" : self .precision_recall ,
151
+ 'best_precision_recall' : self .precision_recall ,
152
+ "test_f1" : f1 ,
153
+ 'best_f1' : self .f1
135
154
}
136
155
self .metrics .append (info )
137
156
if roc_auc > self .roc_auc :
@@ -140,8 +159,12 @@ def performance_report_binary_classification(
140
159
best = "ROC-AUC"
141
160
if precision_recall > self .precision_recall :
142
161
self .precision_recall = precision_recall
143
- # self.save_model(model, "Precision-Recall", "best", value=precision_recall)
162
+ self .save_model (model , "Precision-Recall" , "best" , value = precision_recall )
144
163
best = "Precision-Recall"
164
+ if f1 > self .f1 :
165
+ self .f1 = f1
166
+ self .save_model (model , 'F1' , 'best' , value = f1 )
167
+ best = 'F1'
145
168
if loss_a < self .loss :
146
169
self .loss = loss_a
147
170
self .save_model (model , "loss" , "best" , value = loss_a )
@@ -219,7 +242,7 @@ def inference_report_binary_classification(
219
242
positive_recall = report ["1" ]["recall" ]
220
243
accuracy = accuracy_score (labels , bin_preds )
221
244
bal_accuracy = balanced_accuracy_score (labels , bin_preds )
222
- f1 = f1_score (labels , bin_preds )
245
+ f1 = fbeta_score (labels , bin_preds , beta = 0.5 , pos_label = 1 , average = 'binary' )
223
246
224
247
info = {
225
248
"roc_auc" : roc_auc ,
@@ -258,7 +281,7 @@ def inference_report_binarized_regression(
258
281
positive_recall = report .get ("1" , {"recall" : 0.0 })["recall" ]
259
282
accuracy = accuracy_score (bin_labels , bin_preds )
260
283
bal_accuracy = balanced_accuracy_score (bin_labels , bin_preds )
261
- f1 = f1_score (bin_labels , bin_preds )
284
+ f1 = fbeta_score (bin_labels , bin_preds , beta = 0.5 , pos_label = 1 , average = 'binary' )
262
285
263
286
info = {
264
287
"f1" : f1 ,
@@ -329,6 +352,9 @@ def final_report_binary_classification(self):
329
352
logger .info (
330
353
f"Precision-Recall = { self .precision_recall :.4f} in epoch { self .metric_df ['test_precision_recall' ].idxmax ()} "
331
354
)
355
+ logger .info (
356
+ f"F1 ({ self .beta } )= { self .f1 :.4f} in epoch { self .metric_df ['test_f1' ].idxmax ()} "
357
+ )
332
358
333
359
def final_report_regression (self ):
334
360
logger .info (
0 commit comments