コード例 #1
0
ファイル: interface.py プロジェクト: vsaase/DeepReg
    def get_dataset_and_preprocess(
        self,
        training: bool,
        batch_size: int,
        repeat: bool,
        shuffle_buffer_num_batch: int,
        data_augmentation: Optional[Union[List, Dict]] = None,
    ) -> tf.data.Dataset:
        """
        :param training: bool, indicating if it's training or not
        :param batch_size: int, size of mini batch
        :param repeat: bool, indicating if we need to repeat the dataset
        :param shuffle_buffer_num_batch: int, when shuffling,
            the shuffle_buffer_size = batch_size * shuffle_buffer_num_batch
        :param repeat: bool, indicating if we need to repeat the dataset
        :param data_augmentation: augmentation config, can be a list of dict or dict.
        :returns dataset:
        """

        dataset = self.get_dataset()

        # resize
        dataset = dataset.map(
            lambda x: resize_inputs(
                inputs=x,
                moving_image_size=self.moving_image_shape,
                fixed_image_size=self.fixed_image_shape,
            ),
            num_parallel_calls=tf.data.experimental.AUTOTUNE,
        )

        # shuffle / repeat / batch / preprocess
        if training and shuffle_buffer_num_batch > 0:
            dataset = dataset.shuffle(
                buffer_size=batch_size * shuffle_buffer_num_batch,
                reshuffle_each_iteration=True,
            )
        if repeat:
            dataset = dataset.repeat()

        dataset = dataset.batch(batch_size=batch_size, drop_remainder=training)
        dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE)

        if training and data_augmentation is not None:
            if isinstance(data_augmentation, dict):
                data_augmentation = [data_augmentation]
            for config in data_augmentation:
                da_fn = REGISTRY.build_data_augmentation(
                    config=config,
                    default_args={
                        "moving_image_size": self.moving_image_shape,
                        "fixed_image_size": self.fixed_image_shape,
                        "batch_size": batch_size,
                    },
                )
                dataset = dataset.map(
                    da_fn, num_parallel_calls=tf.data.experimental.AUTOTUNE
                )

        return dataset
コード例 #2
0
    def get_dataset_and_preprocess(
        self,
        training: bool,
        batch_size: int,
        repeat: bool,
        shuffle_buffer_num_batch: int,
        data_augmentation: Optional[Union[List, Dict]] = None,
        num_parallel_calls: int = tf.data.experimental.AUTOTUNE,
    ) -> tf.data.Dataset:
        """
        Generate tf.data.dataset.

        Reference:

            - https://www.tensorflow.org/guide/data_performance#parallelizing_data_transformation
            - https://www.tensorflow.org/api_docs/python/tf/data/Dataset

        :param training: indicating if it's training or not
        :param batch_size: size of mini batch
        :param repeat: indicating if we need to repeat the dataset
        :param shuffle_buffer_num_batch: when shuffling,
            the shuffle_buffer_size = batch_size * shuffle_buffer_num_batch
        :param repeat: indicating if we need to repeat the dataset
        :param data_augmentation: augmentation config, can be a list of dict or dict.
        :param num_parallel_calls: number elements to process asynchronously in parallel
            during preprocessing, -1 means unlimited, heuristically it should be set to
            the number of CPU cores available. AUTOTUNE=-1 means not limited.
        :returns dataset:
        """

        dataset = self.get_dataset()

        # resize
        dataset = dataset.map(
            lambda x: resize_inputs(
                inputs=x,
                moving_image_size=self.moving_image_shape,
                fixed_image_size=self.fixed_image_shape,
            ),
            num_parallel_calls=num_parallel_calls,
        )

        # shuffle / repeat / batch / preprocess
        if training and shuffle_buffer_num_batch > 0:
            dataset = dataset.shuffle(
                buffer_size=batch_size * shuffle_buffer_num_batch,
                reshuffle_each_iteration=True,
            )
        if repeat:
            dataset = dataset.repeat()

        dataset = dataset.batch(batch_size=batch_size, drop_remainder=training)
        dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE)

        if training and data_augmentation is not None:
            if isinstance(data_augmentation, dict):
                data_augmentation = [data_augmentation]
            for config in data_augmentation:
                da_fn = REGISTRY.build_data_augmentation(
                    config=config,
                    default_args={
                        "moving_image_size": self.moving_image_shape,
                        "fixed_image_size": self.fixed_image_shape,
                        "batch_size": batch_size,
                    },
                )
                dataset = dataset.map(da_fn,
                                      num_parallel_calls=num_parallel_calls)

        return dataset