Ejemplo n.º 1
0
def test_label_indices_unknown():
    """
    Assert ValueError raised if unknown string passed to sample
    label
    """
    with pytest.raises(ValueError):
        util.get_label_indices(3, "random_str")
Ejemplo n.º 2
0
def test_label_indices_all():
    """
    Assert list with all labels returned if all passed
    """
    expected = [0, 1, 2]
    actual = util.get_label_indices(3, "all")
    assert expected == actual
Ejemplo n.º 3
0
def test_label_indices_sample():
    """
    Assert random number for passed arg returned
    """
    expected = {0, 1, 2, 3}
    actual = util.get_label_indices(4, "sample")
    assert expected.intersection(set(actual))
Ejemplo n.º 4
0
def test_label_indices_first():
    """
    Assert list with 0 raised if first sample label
    """
    expected = [0]
    actual = util.get_label_indices(5, "first")
    assert expected == actual
Ejemplo n.º 5
0
    def sample_image_label(
        self,
        moving_image: np.ndarray,
        fixed_image: np.ndarray,
        moving_label: Optional[np.ndarray],
        fixed_label: Optional[np.ndarray],
        image_indices: list,
    ):
        """
        Sample the image labels, only used in data_generator.

        :param moving_image:
        :param fixed_image:
        :param moving_label:
        :param fixed_label:
        :param image_indices:
        """
        self.validate_images_and_labels(moving_image, fixed_image,
                                        moving_label, fixed_label,
                                        image_indices)
        # unlabeled
        if moving_label is None or fixed_label is None:
            label_index = -1  # means no label
            indices = np.asarray(image_indices + [label_index],
                                 dtype=np.float32)
            yield dict(moving_image=moving_image,
                       fixed_image=fixed_image,
                       indices=indices)
        else:
            # labeled
            if len(moving_label.shape) == 4:  # multiple labels
                label_indices = get_label_indices(
                    moving_label.shape[3],
                    self.sample_label  # type:ignore
                )
                for label_index in label_indices:
                    indices = np.asarray(image_indices + [label_index],
                                         dtype=np.float32)
                    yield dict(
                        moving_image=moving_image,
                        fixed_image=fixed_image,
                        indices=indices,
                        moving_label=moving_label[..., label_index],
                        fixed_label=fixed_label[..., label_index],
                    )
            else:  # only one label
                label_index = 0
                indices = np.asarray(image_indices + [label_index],
                                     dtype=np.float32)
                yield dict(
                    moving_image=moving_image,
                    fixed_image=fixed_image,
                    moving_label=moving_label,
                    fixed_label=fixed_label,
                    indices=indices,
                )