def test_flip(self): with tf.Session() as sess: image_3d = self.constant_3d_image() image_3d_np = sess.run(image_3d) for flip_axis in [0, 1, 2]: image_3d_flip, _ = data_aug_lib.maybe_flip( image_3d, tf.zeros_like(image_3d), flip_axis, 0.0) image_3d_flip_np = sess.run(image_3d_flip) self.assertAllClose(image_3d_flip_np, image_3d_np) image_3d_flip = image_3d for flip_axis in [0, 1, 2]: if flip_axis == 0: image_3d_np = image_3d_np[::-1, ...] elif flip_axis == 1: image_3d_np = image_3d_np[:, ::-1, :] else: image_3d_np = image_3d_np[..., ::-1] image_3d_flip, _ = data_aug_lib.maybe_flip( image_3d_flip, tf.zeros_like(image_3d_flip), flip_axis, 1.0) image_3d_flip_np = sess.run(image_3d_flip) self.assertAllClose(image_3d_flip_np, image_3d_np)
def _parser_fn(serialized_example): """Parses a single tf.Example into image and label tensors.""" features = {} features['image/ct_image'] = tf.FixedLenFeature([], tf.string) features['image/label'] = tf.FixedLenFeature([], tf.string) parsed = tf.parse_single_example(serialized_example, features=features) spatial_dims = [FLAGS.ct_resolution] * 3 if FLAGS.sampled_2d_slices: noise_shape = [FLAGS.ct_resolution] * 2 + [FLAGS.image_c] else: noise_shape = [FLAGS.ct_resolution] * 3 image = tf.decode_raw(parsed['image/ct_image'], tf.float32) label = tf.decode_raw(parsed['image/label'], tf.float32) if dataset_str != 'train': # Preprocess intensity, clip to 0 ~ 1. # The training set is already preprocessed. image = tf.clip_by_value(image / 1024.0 + 0.5, 0, 1) image = tf.reshape(image, spatial_dims) label = tf.reshape(label, spatial_dims) if dataset_str == 'eval' and FLAGS.sampled_2d_slices: return _get_stacked_2d_slices(image, label) if FLAGS.sampled_2d_slices: # Take random slices of images and label begin_idx = tf.random_uniform(shape=[], minval=0, maxval=FLAGS.ct_resolution - FLAGS.image_c + 1, dtype=tf.int32) slice_begin = [0, 0, begin_idx] slice_size = [ FLAGS.ct_resolution, FLAGS.ct_resolution, FLAGS.image_c ] image = tf.slice(image, slice_begin, slice_size) label = tf.slice(label, slice_begin, slice_size) if dataset_str == 'train': for flip_axis in [0, 1, 2]: image, label = data_aug_lib.maybe_flip( image, label, flip_axis) image, label = data_aug_lib.maybe_rot180(image, label, static_axis=2) image = data_aug_lib.intensity_shift( image, label, FLAGS.per_class_intensity_scale, FLAGS.per_class_intensity_shift) image = data_aug_lib.image_corruption( image, label, FLAGS.ct_resolution, FLAGS.image_corrupt_ratio_mean, FLAGS.image_corrupt_ratio_stddev) image = data_aug_lib.maybe_add_noise( image, noise_shape, 1, 4, FLAGS.image_noise_probability, FLAGS.image_noise_ratio) image, label = data_aug_lib.projective_transform( image, label, FLAGS.ct_resolution, FLAGS.image_translate_ratio, FLAGS.image_transform_ratio, FLAGS.sampled_2d_slices) if FLAGS.sampled_2d_slices: # Only get the center slice of label. label = tf.slice(label, [0, 0, FLAGS.image_c // 2], [FLAGS.ct_resolution, FLAGS.ct_resolution, 1]) spatial_dims_w_blocks = [ FLAGS.image_nx_block, FLAGS.ct_resolution // FLAGS.image_nx_block, FLAGS.image_ny_block, FLAGS.ct_resolution // FLAGS.image_ny_block ] if not FLAGS.sampled_2d_slices: spatial_dims_w_blocks += [FLAGS.ct_resolution] image = tf.reshape(image, spatial_dims_w_blocks + [FLAGS.image_c]) label = tf.reshape(label, spatial_dims_w_blocks) label = tf.cast(label, tf.int32) label = tf.one_hot(label, FLAGS.label_c) data_dtype = tf.as_dtype(FLAGS.mtf_dtype) image = tf.cast(image, data_dtype) label = tf.cast(label, data_dtype) return image, label