@@ -56,9 +56,8 @@ def __init__(
56
56
train_batches : int ,
57
57
test_batches : int ,
58
58
task_names : List [str ],
59
- beta : float = 1
59
+ beta : float = 1 ,
60
60
):
61
-
62
61
if task == "binary_classification" :
63
62
self .report = self .performance_report_binary_classification
64
63
self .inference_report = self .inference_report_binary_classification
@@ -70,14 +69,14 @@ def __init__(
70
69
elif task == "regression" :
71
70
self .report = self .performance_report_regression
72
71
self .inference_report = self .inference_report_regression
73
- self .metric_initializer ("rmse" , 10 ** 9 )
74
- self .metric_initializer ("mae" , 10 ** 9 )
72
+ self .metric_initializer ("rmse" , 10 ** 9 )
73
+ self .metric_initializer ("mae" , 10 ** 9 )
75
74
self .metric_initializer ("pearson" , - 1 )
76
75
self .metric_initializer ("spearman" , - 1 )
77
76
self .task_final_report = self .final_report_regression
78
77
else :
79
78
raise ValueError (f"Unknown task { task } " )
80
- self .metric_initializer ("loss" , 10 ** 9 )
79
+ self .metric_initializer ("loss" , 10 ** 9 )
81
80
82
81
self .task = task
83
82
self .task_names = task_names
@@ -130,7 +129,9 @@ def performance_report_binary_classification(
130
129
negative_recall = report ["0.0" ]["recall" ]
131
130
positive_precision = report ["1.0" ]["precision" ]
132
131
positive_recall = report ["1.0" ]["recall" ]
133
- f1 = fbeta_score (labels , bin_preds , beta = self .beta ,pos_label = 1 , average = 'binary' )
132
+ f1 = fbeta_score (
133
+ labels , bin_preds , beta = self .beta , pos_label = 1 , average = "binary"
134
+ )
134
135
135
136
logger .info (
136
137
f"\t **** TEST **** Epoch [{ self .epoch + 1 } /{ self .epochs } ], "
@@ -148,9 +149,9 @@ def performance_report_binary_classification(
148
149
"test_auc" : roc_auc ,
149
150
"best_auc" : self .roc_auc ,
150
151
"test_precision_recall" : precision_recall ,
151
- ' best_precision_recall' : self .precision_recall ,
152
+ " best_precision_recall" : self .precision_recall ,
152
153
"test_f1" : f1 ,
153
- ' best_f1' : self .f1
154
+ " best_f1" : self .f1 ,
154
155
}
155
156
self .metrics .append (info )
156
157
if roc_auc > self .roc_auc :
@@ -163,8 +164,8 @@ def performance_report_binary_classification(
163
164
best = "Precision-Recall"
164
165
if f1 > self .f1 :
165
166
self .f1 = f1
166
- self .save_model (model , 'F1' , ' best' , value = f1 )
167
- best = 'F1'
167
+ self .save_model (model , "F1" , " best" , value = f1 )
168
+ best = "F1"
168
169
if loss_a < self .loss :
169
170
self .loss = loss_a
170
171
self .save_model (model , "loss" , "best" , value = loss_a )
@@ -236,13 +237,13 @@ def inference_report_binary_classification(
236
237
precision , recall , _ = precision_recall_curve (labels , preds )
237
238
precision_recall = average_precision_score (labels , preds )
238
239
report = classification_report (labels , bin_preds , output_dict = True )
239
- negative_precision = report ["0" ]["precision" ]
240
- negative_recall = report ["0" ]["recall" ]
241
- positive_precision = report ["1" ]["precision" ]
242
- positive_recall = report ["1" ]["recall" ]
240
+ negative_precision = report ["0.0 " ]["precision" ]
241
+ negative_recall = report ["0.0 " ]["recall" ]
242
+ positive_precision = report ["1.0 " ]["precision" ]
243
+ positive_recall = report ["1.0 " ]["recall" ]
243
244
accuracy = accuracy_score (labels , bin_preds )
244
245
bal_accuracy = balanced_accuracy_score (labels , bin_preds )
245
- f1 = fbeta_score (labels , bin_preds , beta = 0.5 ,pos_label = 1 , average = ' binary' )
246
+ f1 = fbeta_score (labels , bin_preds , beta = 0.5 , pos_label = 1 , average = " binary" )
246
247
247
248
info = {
248
249
"roc_auc" : roc_auc ,
@@ -281,7 +282,7 @@ def inference_report_binarized_regression(
281
282
positive_recall = report .get ("1" , {"recall" : 0.0 })["recall" ]
282
283
accuracy = accuracy_score (bin_labels , bin_preds )
283
284
bal_accuracy = balanced_accuracy_score (bin_labels , bin_preds )
284
- f1 = fbeta_score (bin_labels , bin_preds , beta = 0.5 , pos_label = 1 , average = ' binary' )
285
+ f1 = fbeta_score (bin_labels , bin_preds , beta = 0.5 , pos_label = 1 , average = " binary" )
285
286
286
287
info = {
287
288
"f1" : f1 ,
0 commit comments