def test_meter(): """ Tests: * .reset() * .add() * .value() """ meter = meters.PrecisionRecallF1ScoreMeter() # tests the .reset() method, which happens to be called in initialization for key in ["tp", "fp", "fn"]: assert meter.tp_fp_fn_counts[key] == 0, \ "Counts should be initialized to 0." # testing .add() and .value() with tensors w/no batch size dim binary_y, binary_pred = create_dummy_tensors_single() meter.add(binary_pred, binary_y) runs_tests_on_meter_counts_and_value(meter, num_tp_check=1) # testing .add() and .value() with tensors w/the batch size dim meter.reset() batch_size = 16 binary_y, binary_pred = create_dummy_tensors_batched(batch_size) meter.add(binary_pred, binary_y) runs_tests_on_meter_counts_and_value(meter, num_tp_check=batch_size) # testing with seg; shape (batch_size, n_channels, h, w) meter.reset() batch_size = 16 binary_y, binary_pred = create_dummy_tensors_seg(batch_size) meter.add(binary_pred, binary_y) runs_tests_on_meter_counts_and_value(meter, num_tp_check=batch_size * 15 * 15)
def __init__( self, input_key: str = "targets", output_key: str = "logits", class_names: List[str] = None, num_classes: int = 2, threshold: float = 0.5, activation: str = "Sigmoid", ): """ Args: input_key (str): input key to use for metric calculation specifies our ``y_true``. output_key (str): output key to use for metric calculation; specifies our ``y_pred`` class_names (List[str]): class names to display in the logs. If None, defaults to indices for each class, starting from 0. num_classes (int): Number of classes; must be > 1 threshold (float): threshold for outputs binarization activation (str): An torch.nn activation applied to the outputs. Must be one of ['none', 'Sigmoid', 'Softmax2d'] """ # adjusting num_classes automatically if class_names is not None num_classes = num_classes if class_names is None else len(class_names) meter_list = [ meters.PrecisionRecallF1ScoreMeter(threshold) for _ in range(num_classes) ] super().__init__( metric_names=["ppv", "tpr", "f1"], meter_list=meter_list, input_key=input_key, output_key=output_key, class_names=class_names, num_classes=num_classes, activation=activation, )