Example #1
0
def get_batch_transform_displacement_map(displacement, interpolation_order = 3):
    """
    Apply a displacement map using elasticdeform library.

    """
    
    fn = lambda x: (etf.deform_grid(x[0], x[1], order = interpolation_order, axis = (1, 2)), x[1])
    
    return lambda x: tf.map_fn(fn, elems = (x, displacement))[0]
Example #2
0
    def run_comparison_tensorflow(self, shape, points, order=3, sigma=25, crop=None, mode='constant', axis=None):
        if tf is None or not (hasattr(tf, 'py_func') or hasattr(tf, 'py_function')):
            raise unittest.SkipTest("TensorFlow was not loaded.")

        # generate random displacement vector
        displacement = np.random.randn(len(shape) if axis is None else len(axis), *points) * sigma
        # generate random data
        X_val = np.random.rand(*shape)

        # compute forward reference value
        X_deformed_ref = elasticdeform.deform_grid(X_val, displacement, order=order, crop=crop, mode=mode, axis=axis)

        # generate gradient
        dX_deformed_val = np.random.rand(*X_deformed_ref.shape)

        # compute backward reference value
        dX_ref = elasticdeform.deform_grid_gradient(dX_deformed_val, displacement, order=order, crop=crop, mode=mode, axis=axis, X_shape=shape)

        # compute tensorflow output
        if hasattr(tf, 'py_func'):
            # TensorFlow 1
            # build tensorflow graph
            X = tf.Variable(X_val)
            dX_deformed = tf.Variable(dX_deformed_val)
            X_deformed = etf.deform_grid(X, displacement, order=order, crop=crop, mode=mode, axis=axis)
            [dX] = tf.gradients(X_deformed, X, dX_deformed)

            with tf.Session() as sess:
                sess.run(tf.global_variables_initializer())
                X_deformed_val, dX_val = sess.run([X_deformed, dX])

            X_deformed = X_deformed_val
            dX = dX_val
        else:
            # TensorFlow 2
            X = tf.Variable(X_val)
            dX_deformed = tf.Variable(dX_deformed_val)
            with tf.GradientTape() as g:
                g.watch(X)
                X_deformed = etf.deform_grid(X, displacement, order=order, crop=crop, mode=mode, axis=axis)
            dX = g.gradient(X_deformed, X, dX_deformed)

        np.testing.assert_almost_equal(X_deformed_ref, X_deformed)
        np.testing.assert_almost_equal(dX_ref, dX)
Example #3
0
def get_batch_transform_displacement_rotation_map(displacement, interpolation_order = 3):
    """
    Apply a displacement map using elasticdeform library.

    """
    
    fn = lambda x: (etf.deform_grid(x[0], x[1], order = interpolation_order,rotate=np.random.uniform(low=-45, high=45, size=(1)), axis = (1, 2)), x[1])

    
    return lambda x: tf.map_fn(fn, elems = (x, displacement))[0]
Example #4
0
    def run_comparison_tensorflow_multi(self, shape, points, order=3, sigma=25, crop=None, mode='constant', axis=None):
        if tf is None or not hasattr(tf, 'py_function') or hasattr(tf, 'py_func'):
            raise unittest.SkipTest("TensorFlow 2 was not loaded.")

        # generate random displacement vector
        displacement = np.random.randn(len(shape) if axis is None else len(axis), *points) * sigma
        # generate random data
        X_val = np.random.rand(*shape)
        # generate more random data
        Y_val = np.random.rand(*shape)

        # compute forward reference value
        X_deformed_ref, Y_deformed_ref = elasticdeform.deform_grid([X_val, Y_val],
                displacement, order=order, crop=crop, mode=mode, axis=axis)

        # generate gradient
        dX_deformed_val = np.random.rand(*X_deformed_ref.shape)
        dY_deformed_val = np.random.rand(*Y_deformed_ref.shape)

        # compute backward reference value
        dX_ref, dY_ref = elasticdeform.deform_grid_gradient([dX_deformed_val, dY_deformed_val],
                displacement, order=order, crop=crop, mode=mode, axis=axis, X_shape=[shape, shape])

        # compute tensorflow output
        X = tf.Variable(X_val)
        Y = tf.Variable(Y_val)
        dX_deformed = tf.Variable(dX_deformed_val)
        dY_deformed = tf.Variable(dY_deformed_val)
        with tf.GradientTape(persistent=True) as g:
            g.watch(X)
            g.watch(Y)
            X_deformed, Y_deformed = etf.deform_grid([X, Y], displacement, order=order, crop=crop, mode=mode, axis=axis)
        dX = g.gradient(X_deformed, X, dX_deformed)
        dY = g.gradient(Y_deformed, Y, dY_deformed)

        np.testing.assert_almost_equal(X_deformed_ref, X_deformed)
        np.testing.assert_almost_equal(Y_deformed_ref, Y_deformed)
        np.testing.assert_almost_equal(dX_ref, dX)
        np.testing.assert_almost_equal(dY_ref, dY)