def testGradient(self): x = tf.convert_to_tensor(np.arange(10)[np.newaxis, ...] / 10.0 - 0.5, dtype=tf.float64) jac_naive = batch_jacobian(lambda t: tf.cumsum(tf.exp(t), axis=-1), x) jac_fused = batch_jacobian( lambda t: tf.exp(tfp.math.log_cumsum_exp(t, axis=-1)), x) self.assertAllClose(jac_naive, jac_fused)
def test_batch_jacobian_larger_rank_and_dtype(self): w1 = tf.reshape(tf.range(24., dtype=tf.float64) + 1., (4, 2, 3)) w2 = tf.reshape( tf.range(24., dtype=tf.float32) * 0.5 - 6., (4, 2, 1, 3)) def f(x, y): # [4, 2, 3], [4, 2, 1, 3] -> [4, 3, 2] return tf.transpose( tf.cast(tf.math.cumsum(w1 * x, axis=-1), dtype=tf.float32) + tf.square(tf.reverse(w2 * y, axis=[-3]))[..., 0, :], perm=[0, 2, 1]) x = tf.cast(np.random.uniform(size=(4, 2, 3)), dtype=tf.float64) y = tf.cast(np.random.uniform(size=(4, 2, 1, 3)), dtype=tf.float32) jac = batch_jacobian(f, [x, y]) # Check shapes. self.assertLen(jac, 2) self.assertAllEqual([4, 3, 2, 2, 3], jac[0].shape) self.assertAllEqual([4, 3, 2, 2, 1, 3], jac[1].shape) self.assertEqual(tf.float64, jac[0].dtype) self.assertEqual(tf.float32, jac[1].dtype) # Check results against `value_and_gradient`. out_shape = f(x, y).shape[1:] for i in range(np.prod(out_shape)): idx = (slice(None), ) + np.unravel_index(i, out_shape) # pylint: disable=cell-var-from-loop _, grad = tfp.math.value_and_gradient(lambda x, y: f(x, y)[idx], [x, y]) print(grad[0].shape, jac[0].shape, jac[0][idx].shape) self.assertAllClose(grad[0], jac[0][idx]) self.assertAllClose(grad[1], jac[1][idx])
def __call__(self, time, state_vec): jacobian_mat = tfp_gradient.batch_jacobian( lambda state_vec: self._ode_fn_vec(time, state_vec[0])[tf.newaxis], state_vec[tf.newaxis]) if jacobian_mat is None: return tf.zeros([tf.size(state_vec)] * 2, dtype=state_vec.dtype) return jacobian_mat[0]
def testJacobianConsistent(self): bijector = tfb.IteratedSigmoidCentered() x = tf.constant((60 * np.random.rand(10) - 30).reshape(5, 2)) jacobian_matrix = batch_jacobian(bijector.forward, x) # In our case, y[-1] is determined by all the other y, so we can drop it # for the jacobian calculation. jacobian_matrix = jacobian_matrix[..., :-1, :] self.assertAllClose( tf.linalg.slogdet(jacobian_matrix).log_abs_determinant, bijector.forward_log_det_jacobian(x, event_ndims=1), atol=0., rtol=1e-7)
def test_batch_jacobian(self): w = tf.reshape(tf.range(12.) + 1., (4, 3)) def f(x): return tf.math.cumsum(w * x, axis=-1) self.assertAllEqual( tf.convert_to_tensor([[[1., 0., 0.], [1., 2., 0.], [1., 2., 3.]], [[4., 0., 0.], [4., 5., 0.], [4., 5., 6.]], [[7., 0., 0.], [7., 8., 0.], [7., 8., 9.]], [[10., 0., 0.], [10., 11., 0.], [10., 11., 12.]]]), batch_jacobian(f, tf.ones((4, 3))))
def test_aux(self): x = tf.constant([[2.]]) def f(x): return x**2, x (y, aux), dx = tfm.value_and_gradient(f, x, has_aux=True) self.assertAllClose(x**2, y) self.assertAllClose(2 * x, dx) self.assertAllClose(x, aux) dx, aux = batch_jacobian(f, x, has_aux=True) self.assertAllClose((2 * x)[..., tf.newaxis], dx) self.assertAllClose(x, aux)
def testInverseLogDetJacobian(self): """Test if log-det-jacobian agrees with numerical computation.""" if not tf.executing_eagerly(): self.skipTest( 'Theres a problem with numerical computation of Jacobian.' 'the bijector jacobian implementation still returns' 'roughly the same values as it does in eager mode, so I' 'think our computation works here.') x = self._make_images() x = tf.constant(x, tf.float32) bijection = self._create_glow(actnorm=True) self.evaluate([v.initializer for v in bijection.variables]) jacob_manual = self.evaluate(batch_jacobian(bijection.inverse, x)) _, ldj_manual = np.linalg.slogdet(jacob_manual.reshape([5, 768, 768])) jacob = self.evaluate(bijection.inverse_log_det_jacobian(x, 3)) self.assertAllClose(ldj_manual, jacob, rtol=1.e-5)
def test_aux_multi_arg(self): x = tf.constant([[2.]]) z = tf.constant([[3.]]) def f(x, z): return x**2 + z**2, (x, z) (y, aux), (dx, dz) = tfm.value_and_gradient(f, (x, z), has_aux=True) self.assertAllClose(x**2 + z**2, y) self.assertAllClose(2 * x, dx) self.assertAllClose(2 * z, dz) self.assertAllClose(x, aux[0]) self.assertAllClose(z, aux[1]) (dx, dz), aux = batch_jacobian(f, (x, z), has_aux=True) self.assertAllClose((2 * x)[..., tf.newaxis], dx) self.assertAllClose((2 * z)[..., tf.newaxis], dz) self.assertAllClose(x, aux[0]) self.assertAllClose(z, aux[1])
def get_fldj_theoretical(bijector, x, event_ndims, inverse_event_ndims=None, input_to_unconstrained=None, output_to_unconstrained=None): """Numerically approximate the forward log det Jacobian of a bijector. We compute the Jacobian of the chain output_to_unconst_vec(bijector(inverse(input_to_unconst_vec))) so that we're working with a full rank matrix. We then adjust the resulting Jacobian for the unconstraining bijectors. Bijectors that constrain / unconstrain their inputs/outputs may not be testable with this method, since the composition above may reduce the test to something trivial. However, bijectors that map within constrained spaces should be fine. Args: bijector: the bijector whose Jacobian we wish to approximate x: the value for which we want to approximate the Jacobian. Must have rank at least `event_ndims`. event_ndims: number of dimensions in an event inverse_event_ndims: Integer describing the number of event dimensions for the bijector codomain. If None, then the value of `event_ndims` is used. input_to_unconstrained: bijector that maps the input to the above bijector to an unconstrained 1-D vector. If unspecified, flatten the input into a 1-D vector according to its event_ndims. output_to_unconstrained: bijector that maps the output of the above bijector to an unconstrained 1-D vector. If unspecified, flatten the input into a 1-D vector according to its event_ndims. Returns: fldj: A gradient-based evaluation of the log det Jacobian of `bijector.forward` at `x`. """ if inverse_event_ndims is None: inverse_event_ndims = event_ndims if input_to_unconstrained is None: input_to_unconstrained = reshape_bijector.Reshape( event_shape_in=x.shape[tensorshape_util.rank(x.shape) - event_ndims:], event_shape_out=[-1]) if output_to_unconstrained is None: f_x_shape = bijector.forward_event_shape(x.shape) output_to_unconstrained = reshape_bijector.Reshape( event_shape_in=f_x_shape[tensorshape_util.rank(f_x_shape) - inverse_event_ndims:], event_shape_out=[-1]) x = tf.convert_to_tensor(x) x_unconstrained = 1 * input_to_unconstrained.forward(x) # Collapse any batch dimensions (including scalar) to a single axis. batch_shape = x_unconstrained.shape[:-1] x_unconstrained = tf.reshape( x_unconstrained, [int(np.prod(batch_shape)), x_unconstrained.shape[-1]]) def f(x_unconstrained, batch_shape=batch_shape): # Unflatten any batch dimensions now under the tape. unflattened_x_unconstrained = tf.reshape( x_unconstrained, tensorshape_util.concatenate(batch_shape, x_unconstrained.shape[-1:])) f_x = bijector.forward( input_to_unconstrained.inverse(unflattened_x_unconstrained)) return f_x def f_unconstrained(x_unconstrained, batch_shape=batch_shape): f_x_unconstrained = output_to_unconstrained.forward( f(x_unconstrained, batch_shape=batch_shape)) # Flatten any batch dimensions to a single axis. return tf.reshape( f_x_unconstrained, [int(np.prod(batch_shape)), f_x_unconstrained.shape[-1]]) if JAX_MODE: f_unconstrained = functools.partial(f_unconstrained, batch_shape=[]) jacobian = batch_jacobian(f_unconstrained, x_unconstrained) jacobian = tf.reshape( jacobian, tensorshape_util.concatenate(batch_shape, jacobian.shape[-2:])) logging.vlog(1, 'Jacobian: %s', jacobian) log_det_jacobian = 0.5 * tf.linalg.slogdet( tf.matmul(jacobian, jacobian, adjoint_a=True)).log_abs_determinant input_correction = input_to_unconstrained.forward_log_det_jacobian( x, event_ndims=event_ndims) output_correction = output_to_unconstrained.forward_log_det_jacobian( f(x_unconstrained), event_ndims=inverse_event_ndims) return (log_det_jacobian + tf.cast(input_correction, log_det_jacobian.dtype) - tf.cast(output_correction, log_det_jacobian.dtype))
def test_vimco_and_gradient(self): dims = 5 # Dimension num_draws = int(1e3) num_batch_draws = int(3) seed = test_util.test_seed(sampler_type='stateless') f = lambda logu: tfp.vi.kl_reverse(logu, self_normalized=False) np_f = lambda logu: -logu p = tfd.MultivariateNormalFullCovariance( covariance_matrix=tridiag(dims, diag_value=1, offdiag_value=0.5)) # Variance is very high when approximating Forward KL, so we make # scale_diag large. This ensures q "covers" p and thus Var_q[p/q] is # smaller. build_q = (lambda s: tfd.MultivariateNormalDiag(scale_diag=tf.tile( [s], [dims]))) def vimco_loss(s): return tfp.vi.monte_carlo_variational_loss( p.log_prob, surrogate_posterior=build_q(s), importance_sample_size=num_draws, sample_size=num_batch_draws, gradient_estimator=tfp.vi.GradientEstimators.VIMCO, discrepancy_fn=f, seed=seed) def logu(s): q = build_q(s) x = q.sample(sample_shape=[num_draws, num_batch_draws], seed=seed) x = tf.stop_gradient(x) return p.log_prob(x) - q.log_prob(x) def f_log_sum_u(s): return f(tfp.stats.log_soomean_exp(logu(s), axis=0)[::-1][0]) def q_log_prob_x(s): q = build_q(s) x = q.sample(sample_shape=[num_draws, num_batch_draws], seed=seed) x = tf.stop_gradient(x) return q.log_prob(x) s = tf.constant(1.) logu_ = self.evaluate(logu(s)) vimco_, grad_vimco_ = self.evaluate( tfp.math.value_and_gradient(vimco_loss, s)) f_log_sum_u_, grad_mean_f_log_sum_u_ = self.evaluate( tfp.math.value_and_gradient(f_log_sum_u, s)) grad_mean_f_log_sum_u_ /= num_batch_draws jacobian_logqx_ = self.evaluate( # Compute `jacobian(q_log_prob_x, s)` using `batch_jacobian` and messy # indexing. gradient.batch_jacobian( lambda s: q_log_prob_x(s[0, 0, ...])[None, ...], s[tf.newaxis, tf.newaxis, ...])[0, ..., 0]) np_log_avg_u, np_log_sooavg_u = self._csiszar_vimco_helper(logu_) # Test VIMCO loss is correct. self.assertAllClose(np_f(np_log_avg_u).mean(axis=0), vimco_, rtol=1e-4, atol=1e-5) # Test gradient of VIMCO loss is correct. # # To make this computation we'll inject two gradients from TF: # - grad[mean(f(log(sum(p(x)/q(x)))))] # - jacobian[log(q(x))]. # # We now justify why using these (and only these) TF values for # ground-truth does not undermine the completeness of this test. # # Regarding `grad_mean_f_log_sum_u_`, note that we validate the # correctness of the zero-th order derivative (for each batch member). # Since `tfp.vi.csiszar_vimco_helper` itself does not manipulate any # gradient information, we can safely rely on TF. self.assertAllClose(np_f(np_log_avg_u), f_log_sum_u_, rtol=1e-4, atol=1e-5) # # Regarding `jacobian_logqx_`, note that testing the gradient of # `q.log_prob` is outside the scope of this unit-test thus we may safely # use TF to find it. # The `mean` is across batches and the `sum` is across iid samples. np_grad_vimco = (grad_mean_f_log_sum_u_ + np.mean(np.sum( jacobian_logqx_ * (np_f(np_log_avg_u) - np_f(np_log_sooavg_u)), axis=0), axis=0)) self.assertAllClose(np_grad_vimco, grad_vimco_, rtol=0.03, atol=1e-3)
def get_jacobian(f, x): return tfp_gradient.batch_jacobian(lambda x: f(x[0])[tf.newaxis], x[tf.newaxis])[0]