def get_batch_transform_displacement_map(displacement, interpolation_order = 3): """ Apply a displacement map using elasticdeform library. """ fn = lambda x: (etf.deform_grid(x[0], x[1], order = interpolation_order, axis = (1, 2)), x[1]) return lambda x: tf.map_fn(fn, elems = (x, displacement))[0]
def run_comparison_tensorflow(self, shape, points, order=3, sigma=25, crop=None, mode='constant', axis=None): if tf is None or not (hasattr(tf, 'py_func') or hasattr(tf, 'py_function')): raise unittest.SkipTest("TensorFlow was not loaded.") # generate random displacement vector displacement = np.random.randn(len(shape) if axis is None else len(axis), *points) * sigma # generate random data X_val = np.random.rand(*shape) # compute forward reference value X_deformed_ref = elasticdeform.deform_grid(X_val, displacement, order=order, crop=crop, mode=mode, axis=axis) # generate gradient dX_deformed_val = np.random.rand(*X_deformed_ref.shape) # compute backward reference value dX_ref = elasticdeform.deform_grid_gradient(dX_deformed_val, displacement, order=order, crop=crop, mode=mode, axis=axis, X_shape=shape) # compute tensorflow output if hasattr(tf, 'py_func'): # TensorFlow 1 # build tensorflow graph X = tf.Variable(X_val) dX_deformed = tf.Variable(dX_deformed_val) X_deformed = etf.deform_grid(X, displacement, order=order, crop=crop, mode=mode, axis=axis) [dX] = tf.gradients(X_deformed, X, dX_deformed) with tf.Session() as sess: sess.run(tf.global_variables_initializer()) X_deformed_val, dX_val = sess.run([X_deformed, dX]) X_deformed = X_deformed_val dX = dX_val else: # TensorFlow 2 X = tf.Variable(X_val) dX_deformed = tf.Variable(dX_deformed_val) with tf.GradientTape() as g: g.watch(X) X_deformed = etf.deform_grid(X, displacement, order=order, crop=crop, mode=mode, axis=axis) dX = g.gradient(X_deformed, X, dX_deformed) np.testing.assert_almost_equal(X_deformed_ref, X_deformed) np.testing.assert_almost_equal(dX_ref, dX)
def get_batch_transform_displacement_rotation_map(displacement, interpolation_order = 3): """ Apply a displacement map using elasticdeform library. """ fn = lambda x: (etf.deform_grid(x[0], x[1], order = interpolation_order,rotate=np.random.uniform(low=-45, high=45, size=(1)), axis = (1, 2)), x[1]) return lambda x: tf.map_fn(fn, elems = (x, displacement))[0]
def run_comparison_tensorflow_multi(self, shape, points, order=3, sigma=25, crop=None, mode='constant', axis=None): if tf is None or not hasattr(tf, 'py_function') or hasattr(tf, 'py_func'): raise unittest.SkipTest("TensorFlow 2 was not loaded.") # generate random displacement vector displacement = np.random.randn(len(shape) if axis is None else len(axis), *points) * sigma # generate random data X_val = np.random.rand(*shape) # generate more random data Y_val = np.random.rand(*shape) # compute forward reference value X_deformed_ref, Y_deformed_ref = elasticdeform.deform_grid([X_val, Y_val], displacement, order=order, crop=crop, mode=mode, axis=axis) # generate gradient dX_deformed_val = np.random.rand(*X_deformed_ref.shape) dY_deformed_val = np.random.rand(*Y_deformed_ref.shape) # compute backward reference value dX_ref, dY_ref = elasticdeform.deform_grid_gradient([dX_deformed_val, dY_deformed_val], displacement, order=order, crop=crop, mode=mode, axis=axis, X_shape=[shape, shape]) # compute tensorflow output X = tf.Variable(X_val) Y = tf.Variable(Y_val) dX_deformed = tf.Variable(dX_deformed_val) dY_deformed = tf.Variable(dY_deformed_val) with tf.GradientTape(persistent=True) as g: g.watch(X) g.watch(Y) X_deformed, Y_deformed = etf.deform_grid([X, Y], displacement, order=order, crop=crop, mode=mode, axis=axis) dX = g.gradient(X_deformed, X, dX_deformed) dY = g.gradient(Y_deformed, Y, dY_deformed) np.testing.assert_almost_equal(X_deformed_ref, X_deformed) np.testing.assert_almost_equal(Y_deformed_ref, Y_deformed) np.testing.assert_almost_equal(dX_ref, dX) np.testing.assert_almost_equal(dY_ref, dY)