Exemple #1
0
def test_resize_inputs():
    """
    Test resize_inputs by confirming that it generates
    appropriate output sizes on a simple test case.
    """
    input_size = (2, 2, 2)
    moving_image_size = (1, 3, 5)
    fixed_image_size = (2, 4, 6)

    # labeled data - Pass
    moving_image = tf.ones(input_size)
    fixed_image = tf.ones(input_size)
    moving_label = tf.ones(input_size)
    fixed_label = tf.ones(input_size)
    indices = tf.ones((2, ))

    inputs = dict(
        moving_image=moving_image,
        fixed_image=fixed_image,
        moving_label=moving_label,
        fixed_label=fixed_label,
        indices=indices,
    )
    outputs = preprocess.resize_inputs(
        inputs=inputs,
        moving_image_size=moving_image_size,
        fixed_image_size=fixed_image_size,
    )

    assert outputs["moving_image"].shape == moving_image_size
    assert outputs["fixed_image"].shape == fixed_image_size
    assert outputs["moving_label"].shape == moving_image_size
    assert outputs["fixed_label"].shape == fixed_image_size

    # unlabeled data - Pass
    moving_image = tf.ones(input_size)
    fixed_image = tf.ones(input_size)
    indices = tf.ones((2, ))

    inputs = dict(moving_image=moving_image,
                  fixed_image=fixed_image,
                  indices=indices)
    outputs = preprocess.resize_inputs(
        inputs=inputs,
        moving_image_size=moving_image_size,
        fixed_image_size=fixed_image_size,
    )

    assert outputs["moving_image"].shape == moving_image_size
    assert outputs["fixed_image"].shape == fixed_image_size
Exemple #2
0
    def get_dataset_and_preprocess(
        self,
        training: bool,
        batch_size: int,
        repeat: bool,
        shuffle_buffer_num_batch: int,
        data_augmentation: Optional[Union[List, Dict]] = None,
    ) -> tf.data.Dataset:
        """
        :param training: bool, indicating if it's training or not
        :param batch_size: int, size of mini batch
        :param repeat: bool, indicating if we need to repeat the dataset
        :param shuffle_buffer_num_batch: int, when shuffling,
            the shuffle_buffer_size = batch_size * shuffle_buffer_num_batch
        :param repeat: bool, indicating if we need to repeat the dataset
        :param data_augmentation: augmentation config, can be a list of dict or dict.
        :returns dataset:
        """

        dataset = self.get_dataset()

        # resize
        dataset = dataset.map(
            lambda x: resize_inputs(
                inputs=x,
                moving_image_size=self.moving_image_shape,
                fixed_image_size=self.fixed_image_shape,
            ),
            num_parallel_calls=tf.data.experimental.AUTOTUNE,
        )

        # shuffle / repeat / batch / preprocess
        if training and shuffle_buffer_num_batch > 0:
            dataset = dataset.shuffle(
                buffer_size=batch_size * shuffle_buffer_num_batch,
                reshuffle_each_iteration=True,
            )
        if repeat:
            dataset = dataset.repeat()

        dataset = dataset.batch(batch_size=batch_size, drop_remainder=training)
        dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE)

        if training and data_augmentation is not None:
            if isinstance(data_augmentation, dict):
                data_augmentation = [data_augmentation]
            for config in data_augmentation:
                da_fn = REGISTRY.build_data_augmentation(
                    config=config,
                    default_args={
                        "moving_image_size": self.moving_image_shape,
                        "fixed_image_size": self.fixed_image_shape,
                        "batch_size": batch_size,
                    },
                )
                dataset = dataset.map(
                    da_fn, num_parallel_calls=tf.data.experimental.AUTOTUNE
                )

        return dataset
Exemple #3
0
    def get_dataset_and_preprocess(
        self,
        training: bool,
        batch_size: int,
        repeat: bool,
        shuffle_buffer_num_batch: int,
    ) -> tf.data.Dataset:
        """
        :param training: bool, indicating if it's training or not
        :param batch_size: int, size of mini batch
        :param repeat: bool, indicating if we need to repeat the dataset
        :param shuffle_buffer_num_batch: int, when shuffling, the shuffle_buffer_size = batch_size * shuffle_buffer_num_batch

        :returns dataset:
        """

        dataset = self.get_dataset()

        # resize
        dataset = dataset.map(
            lambda x: resize_inputs(
                inputs=x,
                moving_image_size=self.moving_image_shape,
                fixed_image_size=self.fixed_image_shape,
            ),
            num_parallel_calls=tf.data.experimental.AUTOTUNE,
        )

        # shuffle / repeat / batch / preprocess
        if training and shuffle_buffer_num_batch > 0:
            dataset = dataset.shuffle(
                buffer_size=batch_size * shuffle_buffer_num_batch,
                reshuffle_each_iteration=True,
            )
        if repeat:
            dataset = dataset.repeat()
        dataset = dataset.batch(batch_size=batch_size, drop_remainder=training)
        dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE)
        if training:
            # TODO add cropping, but crop first or rotation first?
            affine_transform = AffineTransformation3D(
                moving_image_size=self.moving_image_shape,
                fixed_image_size=self.fixed_image_shape,
                batch_size=batch_size,
            )
            dataset = dataset.map(
                affine_transform.transform,
                num_parallel_calls=tf.data.experimental.AUTOTUNE,
            )
        return dataset
Exemple #4
0
def test_resize_inputs(
    moving_input_size: tuple,
    fixed_input_size: tuple,
    moving_image_size: tuple,
    fixed_image_size: tuple,
    labeled: bool,
):
    """
    Check return shapes.

    :param moving_input_size: input moving image/label shape
    :param fixed_input_size: input fixed image/label shape
    :param moving_image_size: output moving image/label shape
    :param fixed_image_size: output fixed image/label shape
    :param labeled: if data is labeled
    """
    num_indices = 2

    moving_image = tf.random.uniform(moving_input_size)
    fixed_image = tf.random.uniform(fixed_input_size)
    indices = tf.ones((num_indices, ))
    inputs = dict(moving_image=moving_image,
                  fixed_image=fixed_image,
                  indices=indices)
    if labeled:
        moving_label = tf.random.uniform(moving_input_size)
        fixed_label = tf.random.uniform(fixed_input_size)
        inputs["moving_label"] = moving_label
        inputs["fixed_label"] = fixed_label

    outputs = preprocess.resize_inputs(inputs, moving_image_size,
                                       fixed_image_size)
    assert inputs["indices"].shape == outputs["indices"].shape
    for k in inputs:
        if k == "indices":
            assert outputs[k].shape == inputs[k].shape
            continue
        expected_shape = moving_image_size if "moving" in k else fixed_image_size
        assert outputs[k].shape == expected_shape
Exemple #5
0
    def get_dataset_and_preprocess(
        self,
        training: bool,
        batch_size: int,
        repeat: bool,
        shuffle_buffer_num_batch: int,
        data_augmentation: Optional[Union[List, Dict]] = None,
        num_parallel_calls: int = tf.data.experimental.AUTOTUNE,
    ) -> tf.data.Dataset:
        """
        Generate tf.data.dataset.

        Reference:

            - https://www.tensorflow.org/guide/data_performance#parallelizing_data_transformation
            - https://www.tensorflow.org/api_docs/python/tf/data/Dataset

        :param training: indicating if it's training or not
        :param batch_size: size of mini batch
        :param repeat: indicating if we need to repeat the dataset
        :param shuffle_buffer_num_batch: when shuffling,
            the shuffle_buffer_size = batch_size * shuffle_buffer_num_batch
        :param repeat: indicating if we need to repeat the dataset
        :param data_augmentation: augmentation config, can be a list of dict or dict.
        :param num_parallel_calls: number elements to process asynchronously in parallel
            during preprocessing, -1 means unlimited, heuristically it should be set to
            the number of CPU cores available. AUTOTUNE=-1 means not limited.
        :returns dataset:
        """

        dataset = self.get_dataset()

        # resize
        dataset = dataset.map(
            lambda x: resize_inputs(
                inputs=x,
                moving_image_size=self.moving_image_shape,
                fixed_image_size=self.fixed_image_shape,
            ),
            num_parallel_calls=num_parallel_calls,
        )

        # shuffle / repeat / batch / preprocess
        if training and shuffle_buffer_num_batch > 0:
            dataset = dataset.shuffle(
                buffer_size=batch_size * shuffle_buffer_num_batch,
                reshuffle_each_iteration=True,
            )
        if repeat:
            dataset = dataset.repeat()

        dataset = dataset.batch(batch_size=batch_size, drop_remainder=training)
        dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE)

        if training and data_augmentation is not None:
            if isinstance(data_augmentation, dict):
                data_augmentation = [data_augmentation]
            for config in data_augmentation:
                da_fn = REGISTRY.build_data_augmentation(
                    config=config,
                    default_args={
                        "moving_image_size": self.moving_image_shape,
                        "fixed_image_size": self.fixed_image_shape,
                        "batch_size": batch_size,
                    },
                )
                dataset = dataset.map(da_fn,
                                      num_parallel_calls=num_parallel_calls)

        return dataset