Exemplo n.º 1
0
    def run_comparison_torch(self, shape, points, order=3, sigma=25, crop=None, mode='constant', axis=None):
        if torch is None:
            raise unittest.SkipTest("PyTorch was not loaded.")

        # generate random displacement vector
        displacement = np.random.randn(len(shape) if axis is None else len(axis), *points) * sigma
        # generate random data
        X_val = np.random.rand(*shape)

        # compute forward reference value
        X_deformed_ref = elasticdeform.deform_grid(X_val, displacement, order=order, crop=crop, mode=mode, axis=axis)

        # generate gradient
        dX_deformed_val = np.random.rand(*X_deformed_ref.shape)

        # compute backward reference value
        dX_ref = elasticdeform.deform_grid_gradient(dX_deformed_val, displacement, order=order, crop=crop, mode=mode, axis=axis, X_shape=shape)

        # compute PyTorch output
        X = torch.tensor(X_val, requires_grad=True)
        displacement = torch.tensor(displacement)
        dX_deformed = torch.tensor(dX_deformed_val)
        X_deformed = etorch.deform_grid(X, displacement, order=order, crop=crop, mode=mode, axis=axis)
        X_deformed.backward(dX_deformed)
        dX = X.grad

        # convert back to numpy
        X_deformed = X_deformed.detach().numpy()
        dX = dX.detach().numpy()

        np.testing.assert_almost_equal(X_deformed_ref, X_deformed)
        np.testing.assert_almost_equal(dX_ref, dX)
Exemplo n.º 2
0
 def grad(*dys_disp_xs):
     dys = dys_disp_xs[:len(xs)]
     displacement = dys_disp_xs[len(xs)]
     X_shape = [x.shape for x in dys_disp_xs[len(xs) + 1:]]
     dXs = elasticdeform.deform_grid_gradient(list(dys),
                                              displacement,
                                              *args,
                                              X_shape=X_shape,
                                              **kwargs)
     return [numpy.nan * displacement] + dXs
Exemplo n.º 3
0
    def run_comparison_tensorflow(self,
                                  shape,
                                  points,
                                  order=3,
                                  sigma=25,
                                  crop=None,
                                  mode='constant',
                                  axis=None):
        if tf is None or not hasattr(tf, 'py_func'):
            raise unittest.SkipTest("TensorFlow 1 was not loaded.")

        # generate random displacement vector
        displacement = np.random.randn(
            len(shape) if axis is None else len(axis), *points) * sigma
        # generate random data
        X_val = np.random.rand(*shape)

        # compute forward reference value
        X_deformed_ref = elasticdeform.deform_grid(X_val,
                                                   displacement,
                                                   order=order,
                                                   crop=crop,
                                                   mode=mode,
                                                   axis=axis)

        # generate gradient
        dY_val = np.random.rand(*X_deformed_ref.shape)

        # compute backward reference value
        dX_ref = elasticdeform.deform_grid_gradient(dY_val,
                                                    displacement,
                                                    order=order,
                                                    crop=crop,
                                                    mode=mode,
                                                    axis=axis,
                                                    X_shape=shape)

        # build tensorflow graph
        X = tf.Variable(X_val)
        dY = tf.Variable(dY_val)
        X_deformed = etf.deform_grid(X,
                                     displacement,
                                     order=order,
                                     crop=crop,
                                     mode=mode,
                                     axis=axis)
        [dX] = tf.gradients(X_deformed, X, dY)

        with tf.Session() as sess:
            sess.run(tf.global_variables_initializer())
            X_deformed_val, dX_val = sess.run([X_deformed, dX])

            np.testing.assert_almost_equal(X_deformed_ref, X_deformed_val)
            np.testing.assert_almost_equal(dX_ref, dX_val)
Exemplo n.º 4
0
    def backward(ctx, *dys):
        displacement, = ctx.saved_tensors
        deform_args = ctx.deform_args
        deform_kwargs = ctx.deform_kwargs
        x_shapes = ctx.x_shapes

        dys_numpy = [dy.detach().cpu().numpy() for dy in dys]
        displacement = displacement.detach().cpu().numpy()
        dxs = elasticdeform.deform_grid_gradient(dys_numpy, displacement,
                                                 *deform_args, X_shape=x_shapes, **deform_kwargs)
        return (None, None, None) + tuple(torch.tensor(dx, device=dy.device) for dx, dy in zip(dxs, dys))
Exemplo n.º 5
0
def deform_grid_gradient_c(X_in,
                           displacement,
                           order=3,
                           mode='constant',
                           cval=0.0,
                           crop=None,
                           prefilter=True,
                           axis=None,
                           X_shape=None):
    return elasticdeform.deform_grid_gradient(X_in, displacement, order, mode,
                                              cval, crop, prefilter, axis,
                                              X_shape)
Exemplo n.º 6
0
 def grad(*dys_disp_xs):
     dys = list(dys_disp_xs[:len(xs)])
     displacement = dys_disp_xs[len(xs)]
     X_shape = [x.shape for x in dys_disp_xs[len(xs) + 1:]]
     if not use_tf_v1:
         dys = [dy.numpy() for dy in dys]
         displacement = displacement.numpy()
     dXs = elasticdeform.deform_grid_gradient(dys,
                                              displacement,
                                              *args,
                                              X_shape=X_shape,
                                              **kwargs)
     return [numpy.nan * displacement] + dXs
Exemplo n.º 7
0
    def run_comparison_tensorflow_multi(self, shape, points, order=3, sigma=25, crop=None, mode='constant', axis=None):
        if tf is None or not hasattr(tf, 'py_function') or hasattr(tf, 'py_func'):
            raise unittest.SkipTest("TensorFlow 2 was not loaded.")

        # generate random displacement vector
        displacement = np.random.randn(len(shape) if axis is None else len(axis), *points) * sigma
        # generate random data
        X_val = np.random.rand(*shape)
        # generate more random data
        Y_val = np.random.rand(*shape)

        # compute forward reference value
        X_deformed_ref, Y_deformed_ref = elasticdeform.deform_grid([X_val, Y_val],
                displacement, order=order, crop=crop, mode=mode, axis=axis)

        # generate gradient
        dX_deformed_val = np.random.rand(*X_deformed_ref.shape)
        dY_deformed_val = np.random.rand(*Y_deformed_ref.shape)

        # compute backward reference value
        dX_ref, dY_ref = elasticdeform.deform_grid_gradient([dX_deformed_val, dY_deformed_val],
                displacement, order=order, crop=crop, mode=mode, axis=axis, X_shape=[shape, shape])

        # compute tensorflow output
        X = tf.Variable(X_val)
        Y = tf.Variable(Y_val)
        dX_deformed = tf.Variable(dX_deformed_val)
        dY_deformed = tf.Variable(dY_deformed_val)
        with tf.GradientTape(persistent=True) as g:
            g.watch(X)
            g.watch(Y)
            X_deformed, Y_deformed = etf.deform_grid([X, Y], displacement, order=order, crop=crop, mode=mode, axis=axis)
        dX = g.gradient(X_deformed, X, dX_deformed)
        dY = g.gradient(Y_deformed, Y, dY_deformed)

        np.testing.assert_almost_equal(X_deformed_ref, X_deformed)
        np.testing.assert_almost_equal(Y_deformed_ref, Y_deformed)
        np.testing.assert_almost_equal(dX_ref, dX)
        np.testing.assert_almost_equal(dY_ref, dY)