@@ -106,7 +106,7 @@ def process_data(self, labels: np.array, preds: np.array):
106
106
preds = preds [~ np .isnan (labels )]
107
107
labels = labels [~ np .isnan (labels )]
108
108
109
- return labels , preds
109
+ return labels . astype ( float ) , preds . astype ( float )
110
110
111
111
def performance_report_binary_classification (
112
112
self , labels : np .array , preds : np .array , loss : float , model : Callable
@@ -125,10 +125,10 @@ def performance_report_binary_classification(
125
125
126
126
bin_preds , youden = binarize_predictions (preds , labels , return_youden = True )
127
127
report = classification_report (labels , bin_preds , output_dict = True )
128
- negative_precision = report [ "0.0" ][ "precision" ]
129
- negative_recall = report [ "0.0" ][ "recall" ]
130
- positive_precision = report [ "1.0" ][ "precision" ]
131
- positive_recall = report [ "1.0" ][ "recall" ]
128
+ negative_precision = report . get ( "0.0" , {}). get ( "precision" , - 1 )
129
+ negative_recall = report . get ( "0.0" , {}). get ( "recall" , - 1 )
130
+ positive_precision = report . get ( "1.0" , {}). get ( "precision" , - 1 )
131
+ positive_recall = report . get ( "1.0" , {}). get ( "recall" , - 1 )
132
132
f1 = fbeta_score (
133
133
labels , bin_preds , beta = self .beta , pos_label = 1 , average = "binary"
134
134
)
@@ -237,10 +237,10 @@ def inference_report_binary_classification(
237
237
precision , recall , _ = precision_recall_curve (labels , preds )
238
238
precision_recall = average_precision_score (labels , preds )
239
239
report = classification_report (labels , bin_preds , output_dict = True )
240
- negative_precision = report [ "0.0" ][ "precision" ]
241
- negative_recall = report [ "0.0" ][ "recall" ]
242
- positive_precision = report [ "1.0" ][ "precision" ]
243
- positive_recall = report [ "1.0" ][ "recall" ]
240
+ negative_precision = report . get ( "0.0" , {}). get ( "precision" , - 1 )
241
+ negative_recall = report . get ( "0.0" , {}). get ( "recall" , - 1 )
242
+ positive_precision = report . get ( "1.0" , {}). get ( "precision" , - 1 )
243
+ positive_recall = report . get ( "1.0" , {}). get ( "recall" , - 1 )
244
244
accuracy = accuracy_score (labels , bin_preds )
245
245
bal_accuracy = balanced_accuracy_score (labels , bin_preds )
246
246
f1 = fbeta_score (labels , bin_preds , beta = 0.5 , pos_label = 1 , average = "binary" )
0 commit comments