Beispiel #1
0
 def test_idw_no_broadcasting(self):
     # 3D
     test_input = np.zeros((2, 3, 3, 3, 2))
     test_input[:, 0, 0, 0, 0] = 1.0
     test_input = tf.constant(test_input)
     test_coords = tf.concat(
         [tf.ones((1, 5, 5, 5, 3)) * 0.2,
          tf.ones((1, 5, 5, 5, 3)) * 1.2],
         axis=0)
     out = ResamplerLayer("IDW")(test_input, test_coords)
     with self.cached_session() as sess:
         out_value = sess.run(out)
         self.assertTrue(
             np.all(
                 np.isclose(out_value[0, ..., 0],
                            1.0 / (1. + 1. / 2. + 36. / 132. + 12. / 192.),
                            atol=1e-5)))
         self.assertTrue(
             np.all(np.isclose(out_value[1, ..., 0], 0.0, atol=1e-5)))
         self.assertEqual(out_value.shape, (2, 5, 5, 5, 2))
     # 2D
     test_input = np.zeros((2, 3, 3, 2))
     test_input[:, 0, 0, 0] = 1.0
     test_input = tf.constant(test_input)
     test_coords = tf.concat(
         [tf.ones((1, 5, 5, 2)) * 0.2,
          tf.ones((1, 5, 5, 2)) * 1.2], axis=0)
     out = ResamplerLayer("IDW")(test_input, test_coords)
     with self.cached_session() as sess:
         out_value = sess.run(out)
         self.assertTrue(
             np.all(
                 np.isclose(out_value[0, ..., 0],
                            1.0 / (1.0 + 1.0 / 16.0 + 16. / 68.),
                            atol=1e-5)))
         self.assertTrue(
             np.all(np.isclose(out_value[1, ..., 0], 0.0, atol=1e-5)))
         self.assertEqual(out_value.shape, (2, 5, 5, 2))
     # 1D
     test_input = np.zeros((2, 3, 2))
     test_input[:, 0, 0] = 1.0
     test_input = tf.constant(test_input)
     test_coords = tf.concat(
         [tf.ones((1, 5, 1)) * 0.2,
          tf.ones((1, 5, 1)) * 1.2], axis=0)
     out = ResamplerLayer("IDW")(test_input, test_coords)
     with self.cached_session() as sess:
         out_value = sess.run(out)
         self.assertTrue(
             np.all(
                 np.isclose(out_value[0, ..., 0],
                            1.0 / (1.0 + 1 / 16.0),
                            atol=1e-5)))
         self.assertTrue(
             np.all(np.isclose(out_value[1, ..., 0], 0.0, atol=1e-5)))
         self.assertEqual(out_value.shape, (2, 5, 2))
Beispiel #2
0
    def test_idw_shape(self):
        # 3D
        test_input = np.zeros((2, 8, 8, 8, 2))
        test_input[0, 0, 0, 0, 0] = 1.0
        test_input = tf.constant(test_input)
        test_coords = tf.ones((1, 5, 5, 5, 3)) * 0.1
        out = ResamplerLayer("IDW")(test_input, test_coords)
        with self.cached_session() as sess:
            out_value = sess.run(out)
            self.assertTrue(
                np.all(
                    np.isclose(out_value[0, ..., 0],
                               1.0 / (1. + 9. / 83 + 9. / 163 + 3. / 243),
                               atol=1e-5)))
            self.assertTrue(np.all(out_value[1, ...] == 0))
            self.assertEqual(out_value.shape, (2, 5, 5, 5, 2))

        # 2D
        test_input = np.zeros((2, 8, 8, 2))
        test_input[0, 0, 0, 0] = 1.0
        test_input = tf.constant(test_input)
        test_coords = tf.ones((1, 5, 5, 2)) * 0.1
        out = ResamplerLayer("IDW")(test_input, test_coords)
        with self.cached_session() as sess:
            out_value = sess.run(out)
            self.assertTrue(np.all(out_value[1, ...] == 0))
            self.assertTrue(
                np.all(
                    np.isclose(out_value[0, ..., 0],
                               1. / (2. / 41. + 1. / 81.0 + 1.0),
                               atol=1e-5)))
            self.assertEqual(out_value.shape, (2, 5, 5, 2))

        # 1D
        test_input = np.zeros((2, 8, 2))
        test_input[0, 0, 0] = 1.0
        test_input = tf.constant(test_input)
        test_coords = tf.ones((1, 5, 1)) * 0.1
        out = ResamplerLayer("IDW")(test_input, test_coords)
        with self.cached_session() as sess:
            out_value = sess.run(out)
            self.assertTrue(np.all(out_value[1, ...] == 0))
            self.assertTrue(
                np.all(
                    np.isclose(out_value[0, ..., 0],
                               100.0 / (100.0 + 1 / 0.81),
                               atol=1e-5)))
            self.assertEqual(out_value.shape, (2, 5, 2))
Beispiel #3
0
    def _test_partial_shape_correctness(self,
                                        input,
                                        rank,
                                        batch_size,
                                        grid,
                                        interpolation,
                                        boundary,
                                        expected_value=None):

        resampler = ResamplerLayer(interpolation=interpolation,
                                   boundary=boundary)
        input_default = tf.random_uniform(input.shape)
        if batch_size > 0 and rank > 0:
            input_placeholder = tf.placeholder_with_default(
                input_default, shape=[batch_size] + [None] * (rank + 1))
        elif batch_size <= 0 and rank > 0:
            input_placeholder = tf.placeholder_with_default(input_default,
                                                            shape=[None] *
                                                            (rank + 2))
        elif batch_size <= 0 and rank <= 0:
            input_placeholder = tf.placeholder_with_default(input_default,
                                                            shape=None)

        out = resampler(input_placeholder, grid)
        with self.test_session() as sess:
            out_value = sess.run(out, feed_dict={input_placeholder: input})
            # print(expected_value)
            # print(out_value)
            if expected_value is not None:
                self.assertAllClose(expected_value, out_value)
Beispiel #4
0
 def test_linear_no_broadcasting(self):
     # 3D
     test_input = np.zeros((2, 8, 8, 8, 2))
     test_input[:, 0, 0, 0, 0] = 1.0
     test_input = tf.constant(test_input)
     test_coords = tf.concat(
         [tf.ones((1, 5, 5, 5, 3)) * 0.1,
          tf.ones((1, 5, 5, 5, 3)) * 0.2],
         axis=0)
     out = ResamplerLayer("LINEAR")(test_input, test_coords)
     with self.cached_session() as sess:
         out_value = sess.run(out)
         self.assertTrue(
             np.all(np.isclose(out_value[0, ..., 0], 0.9**3, atol=1e-5)))
         self.assertTrue(
             np.all(np.isclose(out_value[1, ..., 0], 0.8**3, atol=1e-5)))
         self.assertEqual(out_value.shape, (2, 5, 5, 5, 2))
     # 2D
     test_input = np.zeros((2, 8, 8, 2))
     test_input[:, 0, 0, 0] = 1.0
     test_input = tf.constant(test_input)
     test_coords = tf.concat(
         [tf.ones((1, 5, 5, 2)) * 0.1,
          tf.ones((1, 5, 5, 2)) * 0.2], axis=0)
     out = ResamplerLayer("LINEAR")(test_input, test_coords)
     with self.cached_session() as sess:
         out_value = sess.run(out)
         self.assertTrue(
             np.all(np.isclose(out_value[0, ..., 0], 0.9**2, atol=1e-5)))
         self.assertTrue(
             np.all(np.isclose(out_value[1, ..., 0], 0.8**2, atol=1e-5)))
         self.assertEqual(out_value.shape, (2, 5, 5, 2))
     # 1D
     test_input = np.zeros((2, 8, 2))
     test_input[:, 0, 0] = 1.0
     test_input = tf.constant(test_input)
     test_coords = tf.concat(
         [tf.ones((1, 5, 1)) * 0.1,
          tf.ones((1, 5, 1)) * 0.2], axis=0)
     out = ResamplerLayer("LINEAR")(test_input, test_coords)
     with self.cached_session() as sess:
         out_value = sess.run(out)
         self.assertTrue(
             np.all(np.isclose(out_value[0, ..., 0], 0.9, atol=1e-5)))
         self.assertTrue(
             np.all(np.isclose(out_value[1, ..., 0], 0.8, atol=1e-5)))
         self.assertEqual(out_value.shape, (2, 5, 2))
 def _test_correctness(self, inputs, grid, interpolation, boundary,
                       expected_value):
     resampler = ResamplerLayer(interpolation=interpolation,
                                boundary=boundary)
     out = resampler(inputs, grid)
     with self.test_session() as sess:
         out_value = sess.run(out)
         self.assertAllClose(expected_value, out_value)
Beispiel #6
0
 def _test_correctness(self, input, grid, interpolation, boundary,
                       expected_value):
     resampler = ResamplerLayer(interpolation=interpolation,
                                boundary=boundary)
     out = resampler(input, grid)
     with self.test_session() as sess:
         sess.run(tf.global_variables_initializer())
         out_value = sess.run(out)
         self.assertAllClose(expected_value, out_value)
Beispiel #7
0
    def test_shape_interface(self):
        test_input = tf.zeros((2, 10, 10, 10, 3))
        test_coords = tf.zeros((3, 5, 5, 5, 3))
        # bad batch sizes
        with self.assertRaisesRegexp(ValueError, ''):
            out = ResamplerLayer()(test_input, test_coords)

        test_input = tf.zeros((2, 10, 10, 10, 3))
        test_coords = tf.zeros((5, 5, 5, 3))
        # bad batch sizes
        with self.assertRaisesRegexp(ValueError, ''):
            out = ResamplerLayer()(test_input, test_coords)

        test_input = tf.zeros((1, 10, 10, 3))
        test_coords = tf.zeros((1, 5, 5, 3))
        # bad n coordinates
        with self.assertRaisesRegexp(ValueError, ''):
            out = ResamplerLayer()(test_input, test_coords)
Beispiel #8
0
    def _test_simple_2d_images(self,
                               interpolation='linear',
                               boundary='replicate'):
        # rotating around the center (8, 8) by 15 degree
        expected = [[0.96592583, -0.25881905, 2.34314575],
                    [0.25881905, 0.96592583, -1.79795897]]
        expected = np.asarray(expected).flatten()
        test_image, input_shape = get_multiple_2d_images()
        test_target, target_shape = get_multiple_2d_rotated_targets()

        identity_affine = [[1., 0., 0., 0., 1., 0.], [1., 0., 0., 0., 1., 0.],
                           [1., 0., 0., 0., 1., 0.], [1., 0., 0., 0., 1., 0.]]
        affine_var = tf.get_variable('affine', initializer=identity_affine)
        grid = AffineGridWarperLayer(source_shape=input_shape[1:-1],
                                     output_shape=target_shape[1:-1],
                                     constraints=None)
        warp_coords = grid(affine_var)
        resampler = ResamplerLayer(interpolation, boundary=boundary)
        new_image = resampler(tf.constant(test_image, dtype=tf.float32),
                              warp_coords)

        diff = tf.reduce_mean(
            tf.squared_difference(new_image,
                                  tf.constant(test_target, dtype=tf.float32)))
        learning_rate = 0.05
        if (interpolation == 'linear') and (boundary == 'zero'):
            learning_rate = 0.0003
        optimiser = tf.train.AdagradOptimizer(learning_rate)
        grads = optimiser.compute_gradients(diff)
        opt = optimiser.apply_gradients(grads)
        with self.cached_session() as sess:
            sess.run(tf.global_variables_initializer())
            init_val, affine_val = sess.run([diff, affine_var])
            # compute the MAE between the initial estimated parameters and the expected parameters
            init_var_diff = np.sum(np.abs(affine_val[0] - expected))
            for it in range(500):
                _, diff_val, affine_val = sess.run([opt, diff, affine_var])
                # print('{} diff: {}, {}'.format(it, diff_val, affine_val[0]))
            # import matplotlib.pyplot as plt
            # plt.figure()
            # plt.imshow(test_target[0])
            # plt.draw()

            # plt.figure()
            # plt.imshow(sess.run(new_image).astype(np.uint8)[0])
            # plt.draw()

            # plt.show()
            self.assertGreater(init_val, diff_val)
            # compute the MAE between the final estimated parameters and the expected parameters
            var_diff = np.sum(np.abs(affine_val[0] - expected))
            self.assertGreater(init_var_diff, var_diff)
            print('{} {} -- diff {}'.format(interpolation, boundary, var_diff))
            print('{}'.format(affine_val[0]))
Beispiel #9
0
    def test_nearest_shape(self):
        # 3D
        test_input = np.zeros((2, 8, 8, 8, 2))
        test_input[0, 0, 0, 0, 0] = 1.0
        test_input = tf.constant(test_input)
        test_coords = tf.ones((1, 5, 5, 5, 3)) * 0.1
        out = ResamplerLayer("NEAREST")(test_input, test_coords)
        with self.cached_session() as sess:
            out_value = sess.run(out)
            self.assertTrue(
                np.all(np.isclose(out_value[0, ..., 0], 1.0, atol=1e-5)))
            self.assertTrue(np.all(out_value[1, ...] == 0))
            self.assertEqual(out_value.shape, (2, 5, 5, 5, 2))

        # 2D
        test_input = np.zeros((2, 8, 8, 2))
        test_input[0, 0, 0, 0] = 1.0
        test_input = tf.constant(test_input)
        test_coords = tf.ones((1, 5, 5, 2)) * 0.1
        out = ResamplerLayer("NEAREST")(test_input, test_coords)
        with self.cached_session() as sess:
            out_value = sess.run(out)
            self.assertTrue(np.all(out_value[1, ...] == 0))
            self.assertTrue(
                np.all(np.isclose(out_value[0, ..., 0], 1.0, atol=1e-5)))
            self.assertEqual(out_value.shape, (2, 5, 5, 2))

        # 1D
        test_input = np.zeros((2, 8, 2))
        test_input[0, 0, 0] = 1.0
        test_input = tf.constant(test_input)
        test_coords = tf.ones((1, 5, 1)) * 0.1
        out = ResamplerLayer("NEAREST")(test_input, test_coords)
        with self.cached_session() as sess:
            out_value = sess.run(out)
            self.assertTrue(np.all(out_value[1, ...] == 0))
            self.assertTrue(
                np.all(np.isclose(out_value[0, ..., 0], 1.0, atol=1e-5)))
            self.assertEqual(out_value.shape, (2, 5, 2))
    def __init__(self,
                 source_shape,
                 output_shape,
                 coeff_shape,
                 field_transform=None,
                 resampler=None,
                 name='resampling_interpolated_spline_grid_warper'):
        """Constructs an ResampledFieldingGridWarperLayer.
    Args:
      source_shape: Iterable of integers determining the size of the source
        signal domain.
      output_shape: Iterable of integers determining the size of the destination
        resampled signal domain.
      coeff_shape: Shape of displacement field.
      interpolation: type_str of interpolation as used by tf.image.resize_images
      name: Name of Module.
      field_transform: an object defining the spatial relationship between the
        output_grid and the field.
        batch_size x4x4 tensor: per-image transform matrix from output coords to field coords
        None (default):         corners of output map to corners of field with an allowance for
                                  interpolation (1 for bspline, 0 for linear)
      resampler: a ResamplerLayer used to interpolate the
        deformation field
      name: Name of module.

    Raises:
      TypeError: If output_shape and source_shape are not both iterable.
    """
        if resampler == None:
            self._resampler = ResamplerLayer(interpolation='LINEAR',
                                             boundary='REPLICATE')
            self._interpolation = 'LINEAR'
        else:
            self._resampler = resampler
            self._interpolation = self._resampler.interpolation

        self._field_transform = field_transform

        super(ResampledFieldGridWarperLayer,
              self).__init__(source_shape=source_shape,
                             output_shape=output_shape,
                             coeff_shape=coeff_shape,
                             name=name)
    def layer_op(self, input_tensor):
        input_shape = input_tensor.shape.as_list()
        batch_size = input_shape[0]
        spatial_shape = input_shape[1:-1]
        spatial_rank = infer_spatial_rank(input_tensor)

        if self._transform is None:
            relative_transform = self._random_transform(
                batch_size, spatial_rank)
            self._transform = relative_transform
        else:
            relative_transform = self._transform

        grid_warper = AffineGridWarperLayer(spatial_shape, spatial_shape)
        resampler = ResamplerLayer(interpolation=self.interpolation,
                                   boundary=self.boundary)
        warp_parameters = tf.reshape(relative_transform[:, :spatial_rank, :],
                                     [batch_size, -1])
        grid = grid_warper(warp_parameters)
        resampled = resampler(input_tensor, grid)
        return resampled
Beispiel #12
0
    def _test_grads_images(self,
                           interpolation='linear',
                           boundary='replicate',
                           ndim=2):
        if ndim == 2:
            test_image, input_shape = get_multiple_2d_images()
            test_target, target_shape = get_multiple_2d_targets()
            identity_affine = [[1., 0., 0., 0., 1., 0.]] * 4
        else:
            test_image, input_shape = get_multiple_3d_images()
            test_target, target_shape = get_multiple_3d_targets()
            identity_affine = [[
                1., 0., 0., 0., 1., 0., 1., 0., 0., 0., 1., 0.
            ]] * 4
        affine_var = tf.get_variable('affine', initializer=identity_affine)
        grid = AffineGridWarperLayer(source_shape=input_shape[1:-1],
                                     output_shape=target_shape[1:-1],
                                     constraints=None)
        warp_coords = grid(affine_var)
        resampler = ResamplerLayer(interpolation, boundary=boundary)
        new_image = resampler(tf.constant(test_image, dtype=tf.float32),
                              warp_coords)

        diff = tf.reduce_mean(
            tf.squared_difference(new_image,
                                  tf.constant(test_target, dtype=tf.float32)))
        optimiser = tf.train.AdagradOptimizer(0.01)
        grads = optimiser.compute_gradients(diff)
        opt = optimiser.apply_gradients(grads)
        with self.test_session() as sess:
            sess.run(tf.global_variables_initializer())
            init_val, affine_val = sess.run([diff, affine_var])
            for _ in range(5):
                _, diff_val, affine_val = sess.run([opt, diff, affine_var])
                print('{}, {}'.format(diff_val, affine_val[0]))
            self.assertGreater(init_val, diff_val)
Beispiel #13
0
    def connect_data_and_network(self,
                                 outputs_collector=None,
                                 gradients_collector=None):
        def switch_samplers(for_training):
            with tf.name_scope('train' if for_training else 'validation'):
                sampler = self.get_sampler()[0 if for_training else -1]
                return sampler()  # returns image only

        if self.is_training:
            self.patience = self.action_param.patience
            self.mode = self.action_param.early_stopping_mode
            if self.action_param.validation_every_n > 0:
                sampler_window = \
                    tf.cond(tf.logical_not(self.is_validation),
                            lambda: switch_samplers(True),
                            lambda: switch_samplers(False))
            else:
                sampler_window = switch_samplers(True)

            image_windows, _ = sampler_window
            # image_windows, locations = sampler_window

            # decode channels for moving and fixed images
            image_windows_list = [
                tf.expand_dims(img, axis=-1)
                for img in tf.unstack(image_windows, axis=-1)
            ]
            fixed_image, fixed_label, moving_image, moving_label = \
                image_windows_list

            # estimate ddf
            dense_field = self.net(fixed_image, moving_image)
            if isinstance(dense_field, tuple):
                dense_field = dense_field[0]

            # transform the moving labels
            resampler = ResamplerLayer(interpolation='linear',
                                       boundary='replicate')
            resampled_moving_label = resampler(moving_label, dense_field)

            # compute label loss (foreground only)
            loss_func = LossFunction(n_class=1,
                                     loss_type=self.action_param.loss_type,
                                     softmax=False)
            label_loss = loss_func(prediction=resampled_moving_label,
                                   ground_truth=fixed_label)

            dice_fg = 1.0 - label_loss
            # appending regularisation loss
            total_loss = label_loss
            reg_loss = tf.get_collection('bending_energy')
            if reg_loss:
                total_loss = total_loss + \
                    self.net_param.decay * tf.reduce_mean(reg_loss)

            self.total_loss = total_loss

            # compute training gradients
            with tf.name_scope('Optimiser'):
                optimiser_class = OptimiserFactory.create(
                    name=self.action_param.optimiser)
                self.optimiser = optimiser_class.get_instance(
                    learning_rate=self.action_param.lr)
            grads = self.optimiser.compute_gradients(
                total_loss, colocate_gradients_with_ops=True)
            gradients_collector.add_to_collection(grads)

            metrics_dice = loss_func(
                prediction=tf.to_float(resampled_moving_label >= 0.5),
                ground_truth=tf.to_float(fixed_label >= 0.5))
            metrics_dice = 1.0 - metrics_dice

            # command line output
            outputs_collector.add_to_collection(var=dice_fg,
                                                name='one_minus_data_loss',
                                                collection=CONSOLE)
            outputs_collector.add_to_collection(var=tf.reduce_mean(reg_loss),
                                                name='bending_energy',
                                                collection=CONSOLE)
            outputs_collector.add_to_collection(var=total_loss,
                                                name='total_loss',
                                                collection=CONSOLE)
            outputs_collector.add_to_collection(var=metrics_dice,
                                                name='ave_fg_dice',
                                                collection=CONSOLE)

            # for tensorboard
            outputs_collector.add_to_collection(var=dice_fg,
                                                name='data_loss',
                                                average_over_devices=True,
                                                summary_type='scalar',
                                                collection=TF_SUMMARIES)
            outputs_collector.add_to_collection(var=total_loss,
                                                name='total_loss',
                                                average_over_devices=True,
                                                summary_type='scalar',
                                                collection=TF_SUMMARIES)
            outputs_collector.add_to_collection(
                var=metrics_dice,
                name='averaged_foreground_Dice',
                average_over_devices=True,
                summary_type='scalar',
                collection=TF_SUMMARIES)

            # for visualisation debugging
            # resampled_moving_image = resampler(moving_image, dense_field)
            # outputs_collector.add_to_collection(
            #     var=fixed_image, name='fixed_image',
            #     collection=NETWORK_OUTPUT)
            # outputs_collector.add_to_collection(
            #     var=fixed_label, name='fixed_label',
            #     collection=NETWORK_OUTPUT)
            # outputs_collector.add_to_collection(
            #     var=moving_image, name='moving_image',
            #     collection=NETWORK_OUTPUT)
            # outputs_collector.add_to_collection(
            #     var=moving_label, name='moving_label',
            #     collection=NETWORK_OUTPUT)
            # outputs_collector.add_to_collection(
            #     var=resampled_moving_image, name='resampled_image',
            #     collection=NETWORK_OUTPUT)
            # outputs_collector.add_to_collection(
            #     var=resampled_moving_label, name='resampled_label',
            #     collection=NETWORK_OUTPUT)
            # outputs_collector.add_to_collection(
            #     var=dense_field, name='ddf', collection=NETWORK_OUTPUT)
            # outputs_collector.add_to_collection(
            #     var=locations, name='locations', collection=NETWORK_OUTPUT)

            # outputs_collector.add_to_collection(
            #     var=shift[0], name='a', collection=CONSOLE)
            # outputs_collector.add_to_collection(
            #     var=shift[1], name='b', collection=CONSOLE)
        else:
            image_windows, locations = self.sampler()
            image_windows_list = [
                tf.expand_dims(img, axis=-1)
                for img in tf.unstack(image_windows, axis=-1)
            ]
            fixed_image, fixed_label, moving_image, moving_label = \
                image_windows_list

            dense_field = self.net(fixed_image, moving_image)
            if isinstance(dense_field, tuple):
                dense_field = dense_field[0]

            # transform the moving labels
            resampler = ResamplerLayer(interpolation='linear',
                                       boundary='replicate')
            resampled_moving_image = resampler(moving_image, dense_field)
            resampled_moving_label = resampler(moving_label, dense_field)

            outputs_collector.add_to_collection(var=fixed_image,
                                                name='fixed_image',
                                                collection=NETWORK_OUTPUT)
            outputs_collector.add_to_collection(var=moving_image,
                                                name='moving_image',
                                                collection=NETWORK_OUTPUT)
            outputs_collector.add_to_collection(var=resampled_moving_image,
                                                name='resampled_moving_image',
                                                collection=NETWORK_OUTPUT)
            outputs_collector.add_to_collection(var=resampled_moving_label,
                                                name='resampled_moving_label',
                                                collection=NETWORK_OUTPUT)

            outputs_collector.add_to_collection(var=fixed_label,
                                                name='fixed_label',
                                                collection=NETWORK_OUTPUT)
            outputs_collector.add_to_collection(var=moving_label,
                                                name='moving_label',
                                                collection=NETWORK_OUTPUT)
            #outputs_collector.add_to_collection(
            #    var=dense_field, name='field',
            #    collection=NETWORK_OUTPUT)
            outputs_collector.add_to_collection(var=locations,
                                                name='locations',
                                                collection=NETWORK_OUTPUT)

            self.output_decoder = ResizeSamplesAggregator(
                image_reader=self.readers[0],  # fixed image reader
                name='fixed_image',
                output_path=self.action_param.save_seg_dir,
                interp_order=self.action_param.output_interp_order)
    def layer_op(self):
        """
        This function concatenate image and label volumes at the last dim
        and randomly cropping the volumes (also the cropping margins)
        """
        image_id, fixed_inputs, moving_inputs, fixed_shape, moving_shape = \
            self.iterator.get_next()
        # TODO preprocessing layer modifying
        #      image shapes will not be supported
        # assuming the same shape across modalities, using the first
        image_id.set_shape((self.batch_size, ))
        image_id = tf.to_float(image_id)

        fixed_inputs.set_shape((self.batch_size, ) +
                               (None, ) * self.spatial_rank + (2, ))
        # last dim is 1 image + 1 label
        moving_inputs.set_shape((self.batch_size, ) + self.moving_image_shape +
                                (2, ))
        fixed_shape.set_shape((self.batch_size, self.spatial_rank + 1))
        moving_shape.set_shape((self.batch_size, self.spatial_rank + 1))

        # resizing the moving_inputs to match the target
        # assumes the same shape across the batch
        target_spatial_shape = \
            tf.unstack(fixed_shape[0], axis=0)[:self.spatial_rank]
        moving_inputs = Resize(new_size=target_spatial_shape)(moving_inputs)
        combined_volume = tf.concat([fixed_inputs, moving_inputs], axis=-1)

        # smoothing_layer = Smoothing(
        #     sigma=1, truncate=3.0, type_str='gaussian')
        # combined_volume = tf.unstack(combined_volume, axis=-1)
        # combined_volume[0] = tf.expand_dims(combined_volume[0], axis=-1)
        # combined_volume[1] = smoothing_layer(
        #     tf.expand_dims(combined_volume[1]), axis=-1)
        # combined_volume[2] = tf.expand_dims(combined_volume[2], axis=-1)
        # combined_volume[3] = smoothing_layer(
        #     tf.expand_dims(combined_volume[3]), axis=-1)
        # combined_volume = tf.stack(combined_volume, axis=-1)

        # TODO affine data augmentation here
        if self.spatial_rank == 3:

            window_channels = np.prod(self.window_size[self.spatial_rank:]) * 4
            # TODO if no affine augmentation:
            img_spatial_shape = target_spatial_shape
            win_spatial_shape = [
                tf.constant(dim)
                for dim in self.window_size[:self.spatial_rank]
            ]
            # when img==win make sure shift => 0.0
            # otherwise interpolation is out of bound
            batch_shift = [
                tf.random_uniform(shape=(self.batch_size, 1),
                                  minval=0,
                                  maxval=tf.maximum(tf.to_float(img - win - 1),
                                                    0.01))
                for (win, img) in zip(win_spatial_shape, img_spatial_shape)
            ]
            batch_shift = tf.concat(batch_shift, axis=1)
            affine_constraints = ((1.0, 0.0, 0.0, None), (0.0, 1.0, 0.0, None),
                                  (0.0, 0.0, 1.0, None))
            computed_grid = AffineGridWarperLayer(
                source_shape=(None, None, None),
                output_shape=self.window_size[:self.spatial_rank],
                constraints=affine_constraints)(batch_shift)
            computed_grid.set_shape((self.batch_size, ) +
                                    self.window_size[:self.spatial_rank] +
                                    (self.spatial_rank, ))
            resampler = ResamplerLayer(interpolation='linear',
                                       boundary='replicate')
            windows = resampler(combined_volume, computed_grid)
            out_shape = [self.batch_size] + \
                        list(self.window_size[:self.spatial_rank]) + \
                        [window_channels]
            windows.set_shape(out_shape)

            image_id = tf.reshape(image_id, (self.batch_size, 1))
            start_location = tf.zeros((self.batch_size, self.spatial_rank))
            locations = tf.concat([image_id, start_location, batch_shift],
                                  axis=1)
        return windows, locations