def test_label_indices_unknown(): """ Assert ValueError raised if unknown string passed to sample label """ with pytest.raises(ValueError): util.get_label_indices(3, "random_str")
def test_label_indices_all(): """ Assert list with all labels returned if all passed """ expected = [0, 1, 2] actual = util.get_label_indices(3, "all") assert expected == actual
def test_label_indices_sample(): """ Assert random number for passed arg returned """ expected = {0, 1, 2, 3} actual = util.get_label_indices(4, "sample") assert expected.intersection(set(actual))
def test_label_indices_first(): """ Assert list with 0 raised if first sample label """ expected = [0] actual = util.get_label_indices(5, "first") assert expected == actual
def sample_image_label( self, moving_image: np.ndarray, fixed_image: np.ndarray, moving_label: Optional[np.ndarray], fixed_label: Optional[np.ndarray], image_indices: list, ): """ Sample the image labels, only used in data_generator. :param moving_image: :param fixed_image: :param moving_label: :param fixed_label: :param image_indices: """ self.validate_images_and_labels(moving_image, fixed_image, moving_label, fixed_label, image_indices) # unlabeled if moving_label is None or fixed_label is None: label_index = -1 # means no label indices = np.asarray(image_indices + [label_index], dtype=np.float32) yield dict(moving_image=moving_image, fixed_image=fixed_image, indices=indices) else: # labeled if len(moving_label.shape) == 4: # multiple labels label_indices = get_label_indices( moving_label.shape[3], self.sample_label # type:ignore ) for label_index in label_indices: indices = np.asarray(image_indices + [label_index], dtype=np.float32) yield dict( moving_image=moving_image, fixed_image=fixed_image, indices=indices, moving_label=moving_label[..., label_index], fixed_label=fixed_label[..., label_index], ) else: # only one label label_index = 0 indices = np.asarray(image_indices + [label_index], dtype=np.float32) yield dict( moving_image=moving_image, fixed_image=fixed_image, moving_label=moving_label, fixed_label=fixed_label, indices=indices, )