Example #1
0
  def test_flip(self):
    with tf.Session() as sess:
      image_3d = self.constant_3d_image()
      image_3d_np = sess.run(image_3d)

      for flip_axis in [0, 1, 2]:
        image_3d_flip, _ = data_aug_lib.maybe_flip(
            image_3d, tf.zeros_like(image_3d), flip_axis, 0.0)
        image_3d_flip_np = sess.run(image_3d_flip)
        self.assertAllClose(image_3d_flip_np, image_3d_np)

      image_3d_flip = image_3d
      for flip_axis in [0, 1, 2]:
        if flip_axis == 0:
          image_3d_np = image_3d_np[::-1, ...]
        elif flip_axis == 1:
          image_3d_np = image_3d_np[:, ::-1, :]
        else:
          image_3d_np = image_3d_np[..., ::-1]
        image_3d_flip, _ = data_aug_lib.maybe_flip(
            image_3d_flip, tf.zeros_like(image_3d_flip), flip_axis, 1.0)
        image_3d_flip_np = sess.run(image_3d_flip)
        self.assertAllClose(image_3d_flip_np, image_3d_np)
Example #2
0
        def _parser_fn(serialized_example):
            """Parses a single tf.Example into image and label tensors."""
            features = {}
            features['image/ct_image'] = tf.FixedLenFeature([], tf.string)
            features['image/label'] = tf.FixedLenFeature([], tf.string)
            parsed = tf.parse_single_example(serialized_example,
                                             features=features)

            spatial_dims = [FLAGS.ct_resolution] * 3
            if FLAGS.sampled_2d_slices:
                noise_shape = [FLAGS.ct_resolution] * 2 + [FLAGS.image_c]
            else:
                noise_shape = [FLAGS.ct_resolution] * 3

            image = tf.decode_raw(parsed['image/ct_image'], tf.float32)
            label = tf.decode_raw(parsed['image/label'], tf.float32)

            if dataset_str != 'train':
                # Preprocess intensity, clip to 0 ~ 1.
                # The training set is already preprocessed.
                image = tf.clip_by_value(image / 1024.0 + 0.5, 0, 1)

            image = tf.reshape(image, spatial_dims)
            label = tf.reshape(label, spatial_dims)

            if dataset_str == 'eval' and FLAGS.sampled_2d_slices:
                return _get_stacked_2d_slices(image, label)

            if FLAGS.sampled_2d_slices:
                # Take random slices of images and label
                begin_idx = tf.random_uniform(shape=[],
                                              minval=0,
                                              maxval=FLAGS.ct_resolution -
                                              FLAGS.image_c + 1,
                                              dtype=tf.int32)
                slice_begin = [0, 0, begin_idx]
                slice_size = [
                    FLAGS.ct_resolution, FLAGS.ct_resolution, FLAGS.image_c
                ]

                image = tf.slice(image, slice_begin, slice_size)
                label = tf.slice(label, slice_begin, slice_size)

            if dataset_str == 'train':
                for flip_axis in [0, 1, 2]:
                    image, label = data_aug_lib.maybe_flip(
                        image, label, flip_axis)
                image, label = data_aug_lib.maybe_rot180(image,
                                                         label,
                                                         static_axis=2)
                image = data_aug_lib.intensity_shift(
                    image, label, FLAGS.per_class_intensity_scale,
                    FLAGS.per_class_intensity_shift)
                image = data_aug_lib.image_corruption(
                    image, label, FLAGS.ct_resolution,
                    FLAGS.image_corrupt_ratio_mean,
                    FLAGS.image_corrupt_ratio_stddev)
                image = data_aug_lib.maybe_add_noise(
                    image, noise_shape, 1, 4, FLAGS.image_noise_probability,
                    FLAGS.image_noise_ratio)
                image, label = data_aug_lib.projective_transform(
                    image, label, FLAGS.ct_resolution,
                    FLAGS.image_translate_ratio, FLAGS.image_transform_ratio,
                    FLAGS.sampled_2d_slices)

            if FLAGS.sampled_2d_slices:
                # Only get the center slice of label.
                label = tf.slice(label, [0, 0, FLAGS.image_c // 2],
                                 [FLAGS.ct_resolution, FLAGS.ct_resolution, 1])

            spatial_dims_w_blocks = [
                FLAGS.image_nx_block,
                FLAGS.ct_resolution // FLAGS.image_nx_block,
                FLAGS.image_ny_block,
                FLAGS.ct_resolution // FLAGS.image_ny_block
            ]
            if not FLAGS.sampled_2d_slices:
                spatial_dims_w_blocks += [FLAGS.ct_resolution]

            image = tf.reshape(image, spatial_dims_w_blocks + [FLAGS.image_c])
            label = tf.reshape(label, spatial_dims_w_blocks)

            label = tf.cast(label, tf.int32)
            label = tf.one_hot(label, FLAGS.label_c)

            data_dtype = tf.as_dtype(FLAGS.mtf_dtype)
            image = tf.cast(image, data_dtype)
            label = tf.cast(label, data_dtype)
            return image, label