def test_resize_inputs(): """ Test resize_inputs by confirming that it generates appropriate output sizes on a simple test case. """ input_size = (2, 2, 2) moving_image_size = (1, 3, 5) fixed_image_size = (2, 4, 6) # labeled data - Pass moving_image = tf.ones(input_size) fixed_image = tf.ones(input_size) moving_label = tf.ones(input_size) fixed_label = tf.ones(input_size) indices = tf.ones((2, )) inputs = dict( moving_image=moving_image, fixed_image=fixed_image, moving_label=moving_label, fixed_label=fixed_label, indices=indices, ) outputs = preprocess.resize_inputs( inputs=inputs, moving_image_size=moving_image_size, fixed_image_size=fixed_image_size, ) assert outputs["moving_image"].shape == moving_image_size assert outputs["fixed_image"].shape == fixed_image_size assert outputs["moving_label"].shape == moving_image_size assert outputs["fixed_label"].shape == fixed_image_size # unlabeled data - Pass moving_image = tf.ones(input_size) fixed_image = tf.ones(input_size) indices = tf.ones((2, )) inputs = dict(moving_image=moving_image, fixed_image=fixed_image, indices=indices) outputs = preprocess.resize_inputs( inputs=inputs, moving_image_size=moving_image_size, fixed_image_size=fixed_image_size, ) assert outputs["moving_image"].shape == moving_image_size assert outputs["fixed_image"].shape == fixed_image_size
def get_dataset_and_preprocess( self, training: bool, batch_size: int, repeat: bool, shuffle_buffer_num_batch: int, data_augmentation: Optional[Union[List, Dict]] = None, ) -> tf.data.Dataset: """ :param training: bool, indicating if it's training or not :param batch_size: int, size of mini batch :param repeat: bool, indicating if we need to repeat the dataset :param shuffle_buffer_num_batch: int, when shuffling, the shuffle_buffer_size = batch_size * shuffle_buffer_num_batch :param repeat: bool, indicating if we need to repeat the dataset :param data_augmentation: augmentation config, can be a list of dict or dict. :returns dataset: """ dataset = self.get_dataset() # resize dataset = dataset.map( lambda x: resize_inputs( inputs=x, moving_image_size=self.moving_image_shape, fixed_image_size=self.fixed_image_shape, ), num_parallel_calls=tf.data.experimental.AUTOTUNE, ) # shuffle / repeat / batch / preprocess if training and shuffle_buffer_num_batch > 0: dataset = dataset.shuffle( buffer_size=batch_size * shuffle_buffer_num_batch, reshuffle_each_iteration=True, ) if repeat: dataset = dataset.repeat() dataset = dataset.batch(batch_size=batch_size, drop_remainder=training) dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE) if training and data_augmentation is not None: if isinstance(data_augmentation, dict): data_augmentation = [data_augmentation] for config in data_augmentation: da_fn = REGISTRY.build_data_augmentation( config=config, default_args={ "moving_image_size": self.moving_image_shape, "fixed_image_size": self.fixed_image_shape, "batch_size": batch_size, }, ) dataset = dataset.map( da_fn, num_parallel_calls=tf.data.experimental.AUTOTUNE ) return dataset
def get_dataset_and_preprocess( self, training: bool, batch_size: int, repeat: bool, shuffle_buffer_num_batch: int, ) -> tf.data.Dataset: """ :param training: bool, indicating if it's training or not :param batch_size: int, size of mini batch :param repeat: bool, indicating if we need to repeat the dataset :param shuffle_buffer_num_batch: int, when shuffling, the shuffle_buffer_size = batch_size * shuffle_buffer_num_batch :returns dataset: """ dataset = self.get_dataset() # resize dataset = dataset.map( lambda x: resize_inputs( inputs=x, moving_image_size=self.moving_image_shape, fixed_image_size=self.fixed_image_shape, ), num_parallel_calls=tf.data.experimental.AUTOTUNE, ) # shuffle / repeat / batch / preprocess if training and shuffle_buffer_num_batch > 0: dataset = dataset.shuffle( buffer_size=batch_size * shuffle_buffer_num_batch, reshuffle_each_iteration=True, ) if repeat: dataset = dataset.repeat() dataset = dataset.batch(batch_size=batch_size, drop_remainder=training) dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE) if training: # TODO add cropping, but crop first or rotation first? affine_transform = AffineTransformation3D( moving_image_size=self.moving_image_shape, fixed_image_size=self.fixed_image_shape, batch_size=batch_size, ) dataset = dataset.map( affine_transform.transform, num_parallel_calls=tf.data.experimental.AUTOTUNE, ) return dataset
def test_resize_inputs( moving_input_size: tuple, fixed_input_size: tuple, moving_image_size: tuple, fixed_image_size: tuple, labeled: bool, ): """ Check return shapes. :param moving_input_size: input moving image/label shape :param fixed_input_size: input fixed image/label shape :param moving_image_size: output moving image/label shape :param fixed_image_size: output fixed image/label shape :param labeled: if data is labeled """ num_indices = 2 moving_image = tf.random.uniform(moving_input_size) fixed_image = tf.random.uniform(fixed_input_size) indices = tf.ones((num_indices, )) inputs = dict(moving_image=moving_image, fixed_image=fixed_image, indices=indices) if labeled: moving_label = tf.random.uniform(moving_input_size) fixed_label = tf.random.uniform(fixed_input_size) inputs["moving_label"] = moving_label inputs["fixed_label"] = fixed_label outputs = preprocess.resize_inputs(inputs, moving_image_size, fixed_image_size) assert inputs["indices"].shape == outputs["indices"].shape for k in inputs: if k == "indices": assert outputs[k].shape == inputs[k].shape continue expected_shape = moving_image_size if "moving" in k else fixed_image_size assert outputs[k].shape == expected_shape
def get_dataset_and_preprocess( self, training: bool, batch_size: int, repeat: bool, shuffle_buffer_num_batch: int, data_augmentation: Optional[Union[List, Dict]] = None, num_parallel_calls: int = tf.data.experimental.AUTOTUNE, ) -> tf.data.Dataset: """ Generate tf.data.dataset. Reference: - https://www.tensorflow.org/guide/data_performance#parallelizing_data_transformation - https://www.tensorflow.org/api_docs/python/tf/data/Dataset :param training: indicating if it's training or not :param batch_size: size of mini batch :param repeat: indicating if we need to repeat the dataset :param shuffle_buffer_num_batch: when shuffling, the shuffle_buffer_size = batch_size * shuffle_buffer_num_batch :param repeat: indicating if we need to repeat the dataset :param data_augmentation: augmentation config, can be a list of dict or dict. :param num_parallel_calls: number elements to process asynchronously in parallel during preprocessing, -1 means unlimited, heuristically it should be set to the number of CPU cores available. AUTOTUNE=-1 means not limited. :returns dataset: """ dataset = self.get_dataset() # resize dataset = dataset.map( lambda x: resize_inputs( inputs=x, moving_image_size=self.moving_image_shape, fixed_image_size=self.fixed_image_shape, ), num_parallel_calls=num_parallel_calls, ) # shuffle / repeat / batch / preprocess if training and shuffle_buffer_num_batch > 0: dataset = dataset.shuffle( buffer_size=batch_size * shuffle_buffer_num_batch, reshuffle_each_iteration=True, ) if repeat: dataset = dataset.repeat() dataset = dataset.batch(batch_size=batch_size, drop_remainder=training) dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE) if training and data_augmentation is not None: if isinstance(data_augmentation, dict): data_augmentation = [data_augmentation] for config in data_augmentation: da_fn = REGISTRY.build_data_augmentation( config=config, default_args={ "moving_image_size": self.moving_image_shape, "fixed_image_size": self.fixed_image_shape, "batch_size": batch_size, }, ) dataset = dataset.map(da_fn, num_parallel_calls=num_parallel_calls) return dataset