def _test_partial_shape_correctness(self,
                                        input,
                                        rank,
                                        batch_size,
                                        grid,
                                        interpolation,
                                        boundary,
                                        expected_value=None):

        resampler = ResamplerOptionalNiftyRegLayer(interpolation=interpolation,
                                                   boundary=boundary)
        input_default = tf.random_uniform(input.shape)
        if batch_size > 0 and rank > 0:
            input_placeholder = tf.placeholder_with_default(
                input_default, shape=[batch_size] + [None] * (rank + 1))
        elif batch_size <= 0 and rank > 0:
            input_placeholder = tf.placeholder_with_default(input_default,
                                                            shape=[None] *
                                                            (rank + 2))
        elif batch_size <= 0 and rank <= 0:
            input_placeholder = tf.placeholder_with_default(input_default,
                                                            shape=None)

        out = resampler(input_placeholder, grid)
        with self.cached_session() as sess:
            out_value = sess.run(out, feed_dict={input_placeholder: input})
            if expected_value is not None:
                self.assertAllClose(expected_value, out_value)
    def _test_correctness(self, input, grid, interpolation, boundary,
                          expected_value):
        resampler = ResamplerOptionalNiftyRegLayer(interpolation=interpolation,
                                                   boundary=boundary)
        out = resampler(input, grid)

        for use_gpu in self._get_devs():
            with self.cached_session(use_gpu=use_gpu) as sess:
                out_value = sess.run(out)
                self.assertAllClose(expected_value, out_value)
    def test_gradient_correctness(self):
        if not resampler_module.HAS_NIFTYREG_RESAMPLING:
            self.skipTest('Using native NiftyNet resampler; skipping test')
            return

        for inter in ('LINEAR', 'BSPLINE'):
            for b in ('ZERO', 'REPLICATE', 'SYMMETRIC'):
                for use_gpu in self._get_devs():
                    inputs = (
                        (self.get_3d_input1(as_tensor=False),
                         [[[-5.2, .25, .25], [.25, .95, .25]],
                          [[.75, .25, .25], [.25, .25, .75]]]),
                        (self.get_2d_input(as_tensor=False), [[[.25, .25],
                                                               [.25, .78]],
                                                              [[.62, .25],
                                                               [.25, .28]]]),
                    )

                    for np_img, np_u in inputs:
                        with self.session(use_gpu=use_gpu):
                            np_u = np.array(np_u)

                            while len(np_u.shape) < len(np_img.shape):
                                np_u = np.expand_dims(np_u, axis=2)

                            img = tf.constant(np_img, dtype=tf.float32)
                            disp = tf.constant(np_u, dtype=tf.float32)

                            # multimodal needs addressing
                            if img.shape.as_list()[-1] > 1:
                                img = tf.reshape(
                                    img[..., 0],
                                    img.shape.as_list()[:-1] + [1])

                            warped = ResamplerOptionalNiftyRegLayer(
                                interpolation=inter, boundary=b)
                            warped = warped(img, disp)
                            #warped = tf.reduce_sum(warped)

                            tgrad, refgrad = tft.compute_gradient(
                                disp, disp.shape, warped, warped.shape)

                            error = np.power(tgrad - refgrad, 2).sum()
                            refmag = np.power(refgrad, 2).sum()

                            self.assertLessEqual(error, 1e-2 * refmag)