def run_comparison_torch(self, shape, points, order=3, sigma=25, crop=None, mode='constant', axis=None): if torch is None: raise unittest.SkipTest("PyTorch was not loaded.") # generate random displacement vector displacement = np.random.randn(len(shape) if axis is None else len(axis), *points) * sigma # generate random data X_val = np.random.rand(*shape) # compute forward reference value X_deformed_ref = elasticdeform.deform_grid(X_val, displacement, order=order, crop=crop, mode=mode, axis=axis) # generate gradient dX_deformed_val = np.random.rand(*X_deformed_ref.shape) # compute backward reference value dX_ref = elasticdeform.deform_grid_gradient(dX_deformed_val, displacement, order=order, crop=crop, mode=mode, axis=axis, X_shape=shape) # compute PyTorch output X = torch.tensor(X_val, requires_grad=True) displacement = torch.tensor(displacement) dX_deformed = torch.tensor(dX_deformed_val) X_deformed = etorch.deform_grid(X, displacement, order=order, crop=crop, mode=mode, axis=axis) X_deformed.backward(dX_deformed) dX = X.grad # convert back to numpy X_deformed = X_deformed.detach().numpy() dX = dX.detach().numpy() np.testing.assert_almost_equal(X_deformed_ref, X_deformed) np.testing.assert_almost_equal(dX_ref, dX)
def grad(*dys_disp_xs): dys = dys_disp_xs[:len(xs)] displacement = dys_disp_xs[len(xs)] X_shape = [x.shape for x in dys_disp_xs[len(xs) + 1:]] dXs = elasticdeform.deform_grid_gradient(list(dys), displacement, *args, X_shape=X_shape, **kwargs) return [numpy.nan * displacement] + dXs
def run_comparison_tensorflow(self, shape, points, order=3, sigma=25, crop=None, mode='constant', axis=None): if tf is None or not hasattr(tf, 'py_func'): raise unittest.SkipTest("TensorFlow 1 was not loaded.") # generate random displacement vector displacement = np.random.randn( len(shape) if axis is None else len(axis), *points) * sigma # generate random data X_val = np.random.rand(*shape) # compute forward reference value X_deformed_ref = elasticdeform.deform_grid(X_val, displacement, order=order, crop=crop, mode=mode, axis=axis) # generate gradient dY_val = np.random.rand(*X_deformed_ref.shape) # compute backward reference value dX_ref = elasticdeform.deform_grid_gradient(dY_val, displacement, order=order, crop=crop, mode=mode, axis=axis, X_shape=shape) # build tensorflow graph X = tf.Variable(X_val) dY = tf.Variable(dY_val) X_deformed = etf.deform_grid(X, displacement, order=order, crop=crop, mode=mode, axis=axis) [dX] = tf.gradients(X_deformed, X, dY) with tf.Session() as sess: sess.run(tf.global_variables_initializer()) X_deformed_val, dX_val = sess.run([X_deformed, dX]) np.testing.assert_almost_equal(X_deformed_ref, X_deformed_val) np.testing.assert_almost_equal(dX_ref, dX_val)
def backward(ctx, *dys): displacement, = ctx.saved_tensors deform_args = ctx.deform_args deform_kwargs = ctx.deform_kwargs x_shapes = ctx.x_shapes dys_numpy = [dy.detach().cpu().numpy() for dy in dys] displacement = displacement.detach().cpu().numpy() dxs = elasticdeform.deform_grid_gradient(dys_numpy, displacement, *deform_args, X_shape=x_shapes, **deform_kwargs) return (None, None, None) + tuple(torch.tensor(dx, device=dy.device) for dx, dy in zip(dxs, dys))
def deform_grid_gradient_c(X_in, displacement, order=3, mode='constant', cval=0.0, crop=None, prefilter=True, axis=None, X_shape=None): return elasticdeform.deform_grid_gradient(X_in, displacement, order, mode, cval, crop, prefilter, axis, X_shape)
def grad(*dys_disp_xs): dys = list(dys_disp_xs[:len(xs)]) displacement = dys_disp_xs[len(xs)] X_shape = [x.shape for x in dys_disp_xs[len(xs) + 1:]] if not use_tf_v1: dys = [dy.numpy() for dy in dys] displacement = displacement.numpy() dXs = elasticdeform.deform_grid_gradient(dys, displacement, *args, X_shape=X_shape, **kwargs) return [numpy.nan * displacement] + dXs
def run_comparison_tensorflow_multi(self, shape, points, order=3, sigma=25, crop=None, mode='constant', axis=None): if tf is None or not hasattr(tf, 'py_function') or hasattr(tf, 'py_func'): raise unittest.SkipTest("TensorFlow 2 was not loaded.") # generate random displacement vector displacement = np.random.randn(len(shape) if axis is None else len(axis), *points) * sigma # generate random data X_val = np.random.rand(*shape) # generate more random data Y_val = np.random.rand(*shape) # compute forward reference value X_deformed_ref, Y_deformed_ref = elasticdeform.deform_grid([X_val, Y_val], displacement, order=order, crop=crop, mode=mode, axis=axis) # generate gradient dX_deformed_val = np.random.rand(*X_deformed_ref.shape) dY_deformed_val = np.random.rand(*Y_deformed_ref.shape) # compute backward reference value dX_ref, dY_ref = elasticdeform.deform_grid_gradient([dX_deformed_val, dY_deformed_val], displacement, order=order, crop=crop, mode=mode, axis=axis, X_shape=[shape, shape]) # compute tensorflow output X = tf.Variable(X_val) Y = tf.Variable(Y_val) dX_deformed = tf.Variable(dX_deformed_val) dY_deformed = tf.Variable(dY_deformed_val) with tf.GradientTape(persistent=True) as g: g.watch(X) g.watch(Y) X_deformed, Y_deformed = etf.deform_grid([X, Y], displacement, order=order, crop=crop, mode=mode, axis=axis) dX = g.gradient(X_deformed, X, dX_deformed) dY = g.gradient(Y_deformed, Y, dY_deformed) np.testing.assert_almost_equal(X_deformed_ref, X_deformed) np.testing.assert_almost_equal(Y_deformed_ref, Y_deformed) np.testing.assert_almost_equal(dX_ref, dX) np.testing.assert_almost_equal(dY_ref, dY)