Beispiel #1
0
def test_meter():
    """
    Tests:
        * .reset()
        * .add()
        * .value()
    """
    meter = meters.PrecisionRecallF1ScoreMeter()
    # tests the .reset() method, which happens to be called in initialization
    for key in ["tp", "fp", "fn"]:
        assert meter.tp_fp_fn_counts[key] == 0, \
            "Counts should be initialized to 0."

    # testing .add() and .value() with tensors w/no batch size dim
    binary_y, binary_pred = create_dummy_tensors_single()
    meter.add(binary_pred, binary_y)
    runs_tests_on_meter_counts_and_value(meter, num_tp_check=1)

    # testing .add() and .value() with tensors w/the batch size dim
    meter.reset()
    batch_size = 16
    binary_y, binary_pred = create_dummy_tensors_batched(batch_size)
    meter.add(binary_pred, binary_y)
    runs_tests_on_meter_counts_and_value(meter, num_tp_check=batch_size)

    # testing with seg; shape (batch_size, n_channels, h, w)
    meter.reset()
    batch_size = 16
    binary_y, binary_pred = create_dummy_tensors_seg(batch_size)
    meter.add(binary_pred, binary_y)
    runs_tests_on_meter_counts_and_value(meter,
                                         num_tp_check=batch_size * 15 * 15)
Beispiel #2
0
    def __init__(
        self,
        input_key: str = "targets",
        output_key: str = "logits",
        class_names: List[str] = None,
        num_classes: int = 2,
        threshold: float = 0.5,
        activation: str = "Sigmoid",
    ):
        """
        Args:
            input_key (str): input key to use for metric calculation
                specifies our ``y_true``.
            output_key (str): output key to use for metric calculation;
                specifies our ``y_pred``
            class_names (List[str]): class names to display in the logs.
                If None, defaults to indices for each class, starting from 0.
            num_classes (int): Number of classes; must be > 1
            threshold (float): threshold for outputs binarization
            activation (str): An torch.nn activation applied to the outputs.
                Must be one of ['none', 'Sigmoid', 'Softmax2d']
        """
        # adjusting num_classes automatically if class_names is not None
        num_classes = num_classes if class_names is None else len(class_names)

        meter_list = [
            meters.PrecisionRecallF1ScoreMeter(threshold)
            for _ in range(num_classes)
        ]

        super().__init__(
            metric_names=["ppv", "tpr", "f1"],
            meter_list=meter_list,
            input_key=input_key,
            output_key=output_key,
            class_names=class_names,
            num_classes=num_classes,
            activation=activation,
        )