Beispiel #1
0
 def f(t):
     t = tf.convert_to_tensor(t, dtype=tf.float32)
     if batches is not None and not tuple(t.shape):
         # Broadcast a scalar through all batches.
         t = tf.tile(t[..., tf.newaxis], [batches])
     f, df = value_and_gradient(lambda t_: tf.math.polyval(poly, t_), t)
     return ValueAndGradient(x=t, f=tf.squeeze(f), df=tf.squeeze(df))
Beispiel #2
0
def right_mult_by_jacobian_mat(jacobian_fn_mat, ode_fn_vec, time, state_vec,
                               vec):
    """Right multiplies a vector by the Jacobian.

  The Jacobian is constructed by calling `jacobian_fn_mat(time, state_vec)` if
  doing so does not require automatic differentiation. Otherwise, chain rule
  automatic differentiation is applied to `ode_fn_vec` to obtain the Jacobian.

  Args:
    jacobian_fn_mat: Result of `get_jacobian_fn_mat`.
    ode_fn_vec: Result of `get_ode_fn_vec`.
    time: Scalar float `Tensor` time at which to evalute the Jacobian.
    state_vec: `Tensor` state at which to evaluate the Jacobian.
    vec: `Tensor` with shape is compatible with the Jacobian.

  Returns:
    `Tensor` representing the dot product.
  """
    if isinstance(jacobian_fn_mat, _AutomaticJacobian):
        # Compute the dot product by using chain rule automatic differentiation.
        _, dot_product = value_and_gradient(lambda x: ode_fn_vec(time, x),
                                            state_vec,
                                            output_gradients=vec)
    else:
        # Compute the dot product by explicitly constructing the Jacobian matrix.
        jacobian_mat = jacobian_fn_mat(time, state_vec)
        dot_product = tf.reshape(tf.matmul(vec[tf.newaxis, :], jacobian_mat),
                                 [-1])
    return dot_product
Beispiel #3
0
def _von_mises_cdf_normal(x, concentration, dtype):
    """Computes the von Mises CDF and its derivative via Normal approximation."""
    def cdf_func(concentration):
        """A helper function that is passed to value_and_gradient."""
        # z is an "almost Normally distributed" random variable.
        z = ((np.sqrt(2. / np.pi) / tf.math.bessel_i0e(concentration)) *
             tf.sin(.5 * x))

        # This is the correction described in [1] which reduces the error
        # of the Normal approximation.
        z2 = z**2
        z3 = z2 * z
        z4 = z2**2
        c = 24. * concentration
        c1 = 56.

        xi = z - z3 / ((c - 2. * z2 - 16.) / 3. -
                       (z4 +
                        (7. / 4.) * z2 + 167. / 2.) / (c - c1 - z2 + 3.))**2

        distrib = normal.Normal(tf.cast(0., dtype), tf.cast(1., dtype))

        return distrib.cdf(xi)

    return value_and_gradient(cdf_func, concentration)
Beispiel #4
0
    def test_linear_ode_dense(self, solver):
        initial_time = 0.
        jacobian = -np.float64([[1., 2., 3.], [4., 5., 6.], [7., 8., 9.]])
        num_odes = jacobian.shape[0]
        initial_state_value = np.float64([1.] * num_odes)
        initial_state = tf.constant(initial_state_value, dtype=tf.float64)

        def ode_fn(_, state):
            return tf.squeeze(tf.matmul(jacobian, state[:, tf.newaxis]))

        intermediate_time = 1.
        final_time = 2.
        solver_instance = solver(rtol=_RTOL, atol=_ATOL)

        def grad_fn(initial_state):
            results = solver_instance.solve(
                ode_fn,
                initial_time,
                initial_state,
                solution_times=[intermediate_time, final_time])
            intermediate_state = results.states[0]
            return intermediate_state

        grad = self.evaluate(
            tfp_gradient.value_and_gradient(grad_fn, initial_state)[1])
        matrix_exponential_of_jacobian = np.float64(
            [[+2.3703878775011322, +0.2645063729368097, -0.8413751316275110],
             [-0.0900545996427410, +0.7326649140674798, -0.4446155722222950],
             [-1.5504970767866180, -0.7991765448018465, +0.9521439871829201]])
        grad_exact = np.dot(np.ones([num_odes]),
                            matrix_exponential_of_jacobian)
        self.assertAllClose(grad, grad_exact)
Beispiel #5
0
 def _fn(j, result):
   res = value_and_gradient(fn_slice(i, j), xs)[1][i]
   if res is None:
     res = tf.zeros(tf.concat([sample_shape, [1]], -1), dtype=x.dtype)
   else:
     res = tf.reshape(res, tf.concat([sample_shape, [-1]], -1))
     res = res[..., j]
   return j + 1, result.write(j, res)
Beispiel #6
0
 def test_clip_by_value_preserve_grad(self, x, lo, hi, expected_y):
     expected_dydx = np.ones_like(x)
     x = tf.convert_to_tensor(value=x, dtype=self.dtype)
     y, dydx = tfp_math_gradient.value_and_gradient(
         lambda x_: numeric.clip_by_value_preserve_gradient(x_, lo, hi), x)
     y_, dydx_ = self.evaluate([y, dydx])
     self.assertAllClose(expected_y, y_)
     self.assertAllClose(expected_dydx, dydx_)
Beispiel #7
0
 def _call(self, r):
     if self._inverse_link_fn is None:
         # Interpret `None` as the identity function.
         mean, grad_mean = r, tf.ones_like(r)
     else:
         mean, grad_mean = value_and_gradient(self._inverse_link_fn, r)
     variance = self._distribution_fn(mean).variance()
     return mean, variance, grad_mean
 def _test_grad_finite(self, dtype):
     x = tf.constant([-100., 0., 100.], dtype=dtype)
     fn = special_math.log_ndtr if self._use_log else special_math.ndtr
     # Not having the lambda sanitzer means we'd get an `IndexError` whenever
     # the user supplied function has default args.
     output, grad_output = value_and_gradient(fn, x)
     # isfinite checks for NaN and Inf.
     output_, grad_output_ = self.evaluate([output, grad_output])
     self.assert_all_true(np.isfinite(output_))
     self.assert_all_true(np.isfinite(grad_output_[0]))
Beispiel #9
0
    def test_riccati_custom_adjoint_solver(self, solver, solution_times_fn):
        ode_fn = lambda time, state: (state - time)**2 + 1.
        initial_time = 0.
        initial_state_value = 0.5
        initial_state = tf.constant(initial_state_value, dtype=tf.float64)
        final_time = 1.
        solution_times = solution_times_fn(final_time)
        jacobian_fn = lambda time, state: 2. * (state - time)

        if not isinstance(solution_times, tfp.math.ode.ChosenBySolver):
            self.skipTest('b/194468619')

        # Instrument the adjoint solver for testing. We have to do this because the
        # API doesn't provide access to the adjoint solver's diagnostics.
        first_step_size = np.float64(1.)
        last_initial_step_size = tf.Variable(0., dtype=tf.float64)
        self.evaluate(last_initial_step_size.initializer)

        class _InstrumentedSolver(StepSizeHeuristicAdjointSolver):
            def solve(self, **kwargs):
                with tf.control_dependencies([
                        last_initial_step_size.assign(
                            kwargs['previous_solver_internal_state'])
                ]):
                    return super(_InstrumentedSolver, self).solve(**kwargs)

        adjoint_solver = _InstrumentedSolver(
            make_solver_fn=lambda step_size: solver(  # pylint: disable=g-long-lambda
                rtol=_RTOL,
                atol=_ATOL,
                first_step_size=step_size),
            first_step_size=first_step_size)

        solver_instance = solver(rtol=_RTOL,
                                 atol=_ATOL,
                                 make_adjoint_solver_fn=lambda: adjoint_solver)

        def grad_fn(initial_state):
            results = solver_instance.solve(ode_fn,
                                            initial_time,
                                            initial_state,
                                            solution_times=solution_times,
                                            jacobian_fn=jacobian_fn)
            final_state = results.states[-1]
            return final_state

        _, grad = tfp_gradient.value_and_gradient(grad_fn, initial_state)
        grad, last_initial_step_size = self.evaluate(
            (grad, last_initial_step_size))
        grad_exact = 1. / (1. - initial_state_value * final_time)**2
        self.assertAllClose(grad, grad_exact, rtol=1e-3, atol=1e-3)
        # This indicates that the adaptation carried over to the final solve. We
        # expect the step size to decrease because we purposefully made the initial
        # step size way too large.
        self.assertLess(last_initial_step_size, first_step_size)
Beispiel #10
0
    def testSqrtWithFiniteGradsBackpropsCorrectly(self):
        # Part of implementing a tf.custom_gradient is correctly handling the
        # `grad_ys` value that is propagating back from downstream ops. This test
        # checks that we got this right, in a particular case where our sqrt
        # function is squashed between a couple of other functions.
        def f(x):
            return x**2

        def g(x):
            return util.sqrt_with_finite_grads(x)

        def h(x):
            return tf.sin(x)**2

        # We only test away from zero, since we know the values don't match there.
        xs = tf.constant(np.linspace(1e-10, 10., 100))
        _, grad_tf_sqrt = value_and_gradient(lambda xs_: f(tf.sqrt(h(xs_))),
                                             xs)
        _, grad_safe_sqrt = value_and_gradient(lambda xs_: f(g(h(xs_))), xs)
        self.assertAllClose(*self.evaluate([grad_tf_sqrt, grad_safe_sqrt]),
                            rtol=1e-10)
def inputwise_condition_numbers(f, *args):
  """Computes the condition numbers of `f(*args)` at each arg independently.

  The function `f(*args)` must produce a scalar result; computing
  batches of condition numbers or computing condition numbers of
  vector-valued functions is not yet supported.

  This function assumes that numerical error when computing `f` in
  float64 is negligible.  For this to work correctly, `f` needs to be
  _dtype-polymorphic_: the dtype in which computations internal to `f`
  are performed should match the dtype of the arguments of `f`.

  Args:
    f: Function whose accuracy to evaluate.  Must be differentiable
      and dtype-polymorphic.
    *args: Arguments at which to test the accuracy of `f`.

  Returns:
    condition_numbers: The condition number of `f` with respect to each input.
      The returned structure is parallel to `*args`.

  Raises:
    ValueError: If `f` is found not to be dtype-polymorphic.

  """
  # TODO(b/181967692): Compute multivariate condition numbers.
  # TODO(b/181967437): To support batch condition numbers, need batch gradients.
  # Then can infer the "event shape" of the arguments by subtracting
  # off the number of dimensions in f(*args).
  # To also support vector outputs, need to know the "event_ndims" in
  # the output f(*args), and need full Jacobians of f underneath.
  args_32 = tf.nest.map_structure(floating_tensor_to_f32, args)
  logging.vlog(1, '32-bit arguments: %s', args_32)
  args_64 = tf.nest.map_structure(floating_tensor_to_f64, args_32)
  truth, derivatives = gradient.value_and_gradient(f, args_64)
  logging.vlog(1, 'Correct answer: %s', truth)
  logging.vlog(1, 'Argument gradient: %s', derivatives)
  def check_numerics(x):
    if x is None:
      return None
    msg = 'Cannot check accuracy if ground truth or derivatives are not finite'
    return tf.debugging.check_numerics(x, message=msg)
  truth = check_numerics(truth)
  derivatives = tf.nest.map_structure(check_numerics, derivatives)
  if truth.dtype != tf.float64:
    raise ValueError('Evaluating on {} produced non-64-bit result {}'.format(
        args_64, truth))
  return tf.nest.map_structure(
      functools.partial(condition_number_one_input, truth),
      # For some reason, value_and_gradient casts the outer structure to list in
      # jax.  Is that an oversight?
      tuple(args_64), tuple(derivatives))
  def test_vimco_helper_gradient_using_finite_difference_1(self):
    """Tests that gradient calculation correctly handles batches."""

    logu_ = np.linspace(-100., 100., 100).reshape([10, 2, 5])
    logu = tf.constant(logu_)

    def log_avg_u(logu):
      return tfp.vi.csiszar_vimco_helper(logu)[0]
    _, grad_log_avg_u = self.evaluate(value_and_gradient(log_avg_u, logu))

    def log_sooavg_u(logu):
      return tfp.vi.csiszar_vimco_helper(logu)[1]
    _, grad_log_sooavg_u = self.evaluate(value_and_gradient(log_sooavg_u, logu))

    # We skip checking against finite-difference approximation since it
    # doesn't support batches.

    # Verify claim in docstring.
    self.assertAllClose(
        np.ones_like(grad_log_avg_u.sum(axis=0)),
        grad_log_avg_u.sum(axis=0))
    self.assertAllClose(
        np.ones_like(grad_log_sooavg_u.mean(axis=0)),
        grad_log_sooavg_u.mean(axis=0))
 def grad(dy):
   """The gradient of the von Mises samples w.r.t. concentration."""
   broadcast_concentration = concentration + tf.zeros_like(x)
   _, dcdf_dconcentration = value_and_gradient(
       lambda conc: von_mises_cdf(x, conc), broadcast_concentration)
   inv_prob = tf.exp(-broadcast_concentration * (tf.cos(x) - 1.)) * (
       (2. * np.pi) * tf.math.bessel_i0e(broadcast_concentration))
   # Compute the implicit reparameterization gradient [2],
   # dz/dconc = -(dF(z; conc) / dconc) / p(z; conc)
   ret = dy * (-inv_prob * dcdf_dconcentration)
   # Sum over the sample dimensions. Assume that they are always the first
   # ones.
   num_sample_dimensions = (tf.rank(broadcast_concentration) -
                            tf.rank(concentration))
   return tf.reduce_sum(ret, axis=tf.range(num_sample_dimensions))
Beispiel #14
0
    def test_vimco_helper_gradient_using_finite_difference_3(self):
        """Tests that gradient calculation correctly handles underlow."""

        delta = 1e-3
        logu_ = np.float32([0., -1000, -1, 1])
        logu = tf.constant(logu_)

        [
            np_grad_log_avg_u,
            np_grad_log_sooavg_u,
        ] = self._csiszar_vimco_helper_grad(logu_, delta)

        def log_avg_u(logu):
            return tfp.vi.csiszar_vimco_helper(logu)[0]

        _, grad_log_avg_u = self.evaluate(value_and_gradient(log_avg_u, logu))

        def log_sooavg_u(logu):
            return tfp.vi.csiszar_vimco_helper(logu)[1]

        _, grad_log_sooavg_u = self.evaluate(
            value_and_gradient(log_sooavg_u, logu))

        self.assertAllClose(np_grad_log_avg_u,
                            grad_log_avg_u,
                            rtol=delta,
                            atol=0.)
        self.assertAllClose(np_grad_log_sooavg_u,
                            grad_log_sooavg_u,
                            rtol=delta,
                            atol=0.)
        # Verify claim in docstring.
        self.assertAllClose(np.ones_like(grad_log_avg_u.sum(axis=0)),
                            grad_log_avg_u.sum(axis=0))
        self.assertAllClose(np.ones_like(grad_log_sooavg_u.mean(axis=0)),
                            grad_log_sooavg_u.mean(axis=0))
Beispiel #15
0
  def test_float64_extreme_values_result_and_gradient_finite_and_nonzero(self):
    # On the lower branch, log_cdf_laplace(x) = x, so we know this will be
    # fine, but test to -200 anyways.
    grid = _make_grid(
        np.float64, GridSpec(min=-200, max=700, shape=[20, 100]))
    grid = tf.convert_to_tensor(value=grid)

    actual, grad = value_and_gradient(special_math.log_cdf_laplace, grid)
    actual_, grad_ = self.evaluate([actual, grad])

    # isfinite checks for NaN and Inf.
    self.assertAllTrue(np.isfinite(actual_))
    self.assertAllTrue(np.isfinite(grad_))
    self.assertFalse(np.any(actual_ == 0))
    self.assertFalse(np.any(grad_ == 0))
Beispiel #16
0
 def _baseNdtriFiniteGradientTest(self, dtype):
   """Verifies that ndtri has finite gradients at interesting points."""
   # Tests gradients at 0, 1, and piece-wise boundaries.
   p = tf.constant(
       np.array([
           0.,
           np.exp(-32.),
           np.exp(-2.),
           1. - np.exp(-2.),
           1. - np.exp(-32.),
           1.,
       ]).astype(dtype))
   # Not having the lambda sanitzer means we'd get an `IndexError` whenever
   # the user supplied function has default args.
   _, grads = value_and_gradient(special_math.ndtri, p)
   self.assertAllFinite(self.evaluate(grads[0]))
Beispiel #17
0
def _von_mises_sample_bwd(_, aux, dy):
    """The gradient of the von Mises samples w.r.t. concentration."""
    concentration, samples = aux
    broadcast_concentration = tf.broadcast_to(concentration, ps.shape(samples))
    _, dcdf_dconcentration = value_and_gradient(
        lambda conc: von_mises_cdf(samples, conc), broadcast_concentration)
    inv_prob = tf.exp(-broadcast_concentration * (tf.cos(samples) - 1.)) * (
        (2. * np.pi) * tf.math.bessel_i0e(broadcast_concentration))
    # Compute the implicit reparameterization gradient [2],
    # dz/dconc = -(dF(z; conc) / dconc) / p(z; conc)
    ret = dy * (-dcdf_dconcentration * inv_prob)
    # Sum over the sample dimensions. Assume that they are always the first
    # ones.
    num_sample_dimensions = (tf.rank(broadcast_concentration) -
                             tf.rank(concentration))

    # None gradients for seed
    return tf.reduce_sum(ret, axis=tf.range(num_sample_dimensions)), None
Beispiel #18
0
    def testJVP(self):
        if not JAX_MODE:
            self.skipTest('Custom JVPs are JAX-only.')

        def f_vjp_fwd(x, y):
            # When a JVP is specified, this function is ignored.
            raise NotImplementedError()

        def f_vjp_bwd(x_y, dz):
            # When a JVP is specified, this function is ignored.
            raise NotImplementedError()

        def f_jvp(x_y, dx_dy):
            x, y = x_y
            dx, dy = dx_dy
            return f(x, y), 7. * (dx * x + dy * y)

        @custom_gradient.custom_gradient(
            vjp_fwd=f_vjp_fwd,
            vjp_bwd=f_vjp_bwd,
            jvp_fn=f_jvp,
        )
        def f(x, y):
            return x**2 + y**2

        x = tf.constant(2.)
        y = tf.constant(3.)
        dz = tf.constant(5.)

        z1 = f(x, y)
        z2, (dx, dy) = tfp_gradient.value_and_gradient(f, (x, y),
                                                       output_gradients=dz)

        self.assertAllClose(x**2 + y**2, z1)
        self.assertAllClose(x**2 + y**2, z2)
        self.assertAllClose(7. * dz * x, dx)
        self.assertAllClose(7. * dz * y, dy)

        import jax  # pylint: disable=g-import-not-at-top

        z3, dz2 = jax.jvp(f, (x, y), (dx, dy))
        self.assertAllClose(x**2 + y**2, z3)
        self.assertAllClose(7. * (dx * x + dy * y), dz2)
Beispiel #19
0
    def testBijectorConditionKwargs(self, dtype):
        if not tf2.enabled():
            self.skipTest('b/152464477')

        tf_dtype = tf.as_dtype(dtype)

        def conditional_ode_fn(t, z, c):
            del t  # unused.
            return tf.ones_like(z) * c**2

        trace_augmentation_fn = tfb.ffjord.trace_jacobian_exact
        bijector = tfb.FFJORD(trace_augmentation_fn=trace_augmentation_fn,
                              state_time_derivative_fn=conditional_ode_fn,
                              dtype=tf_dtype)
        x = tf.zeros((2, 5), dtype=tf_dtype)
        y = tf.ones((2, 5), dtype=tf_dtype) * 4
        c = tf.ones((2, 5), dtype=tf_dtype) * 2
        expected_log_det_jacobian = np.zeros(2, dtype=dtype)
        expected_dy_dc = np.ones((2, 5), dtype=dtype) * 4

        def grad_fn(c):
            y = bijector.forward(x, c=c)
            return y

        dy_dc = self.evaluate(tfp_gradient.value_and_gradient(grad_fn, c)[1])

        self.assertStartsWith(bijector.name, 'ffjord')
        self.assertAllClose(self.evaluate(y),
                            self.evaluate(bijector.forward(x, c=c)),
                            atol=1e-5)
        self.assertAllClose(self.evaluate(x),
                            self.evaluate(bijector.inverse(y, c=c)),
                            atol=1e-5)
        self.assertAllClose(
            expected_log_det_jacobian,
            self.evaluate(
                bijector.inverse_log_det_jacobian(y, event_ndims=1, c=c)))
        self.assertAllClose(
            expected_log_det_jacobian,
            self.evaluate(
                bijector.forward_log_det_jacobian(x, event_ndims=1, c=c)))
        self.assertAllClose(expected_dy_dc, dy_dc)
Beispiel #20
0
def _von_mises_sample_jvp(shape, primals, tangents):
    """Compute primals and tangents using implicit derivative."""
    concentration, seed = primals
    dconcentration, dseed = tangents
    del dseed

    dconcentration = tf.broadcast_to(dconcentration, shape)
    broadcast_concentration = tf.broadcast_to(concentration, shape)

    samples = _von_mises_sample_no_gradient(shape, concentration, seed)

    _, dcdf_dconcentration = value_and_gradient(
        lambda conc: von_mises_cdf(samples, conc), broadcast_concentration)
    inv_prob = tf.exp(-concentration * (tf.cos(samples) - 1.)) * (
        (2. * np.pi) * tf.math.bessel_i0e(concentration))
    # Compute the implicit derivative,
    #   dz = dconc * -(dF(z; conc) / dconc) / p(z; conc)
    dsamples = dconcentration * (-dcdf_dconcentration * inv_prob)

    return samples, dsamples
Beispiel #21
0
    def test_riccati(self, solver, solution_times_fn):
        ode_fn = lambda time, state: (state - time)**2 + 1.
        initial_time = 0.
        initial_state_value = 0.5
        initial_state = tf.constant(initial_state_value, dtype=tf.float64)
        final_time = 1.
        jacobian_fn = lambda time, state: 2. * (state - time)
        solver_instance = solver(rtol=_RTOL, atol=_ATOL)

        def grad_fn(initial_state):
            results = solver_instance.solve(
                ode_fn,
                initial_time,
                initial_state,
                solution_times=solution_times_fn(final_time),
                jacobian_fn=jacobian_fn)
            return results.states[-1]

        grad = self.evaluate(
            tfp_gradient.value_and_gradient(grad_fn, initial_state)[1])
        grad_exact = 1. / (1. - initial_state_value * final_time)**2
        self.assertAllClose(grad, grad_exact, rtol=1e-3, atol=1e-3)
    def _test_grad_accuracy(self, dtype, grid_spec, error_spec):
        grid = _make_grid(dtype, grid_spec)
        _, actual_grad = self.evaluate(
            value_and_gradient(
                special_math.log_ndtr if self._use_log else special_math.ndtr,
                grid))

        # Check for NaN separately in order to get informative failures.
        self.assert_all_false(np.isnan(actual_grad))
        if self._use_log:
            g = np.reshape(actual_grad, [-1])
            half = np.ceil(len(g) / 2)
            self.assert_all_true(g[:int(half)] > 0.)
            self.assert_all_true(g[int(half):] >= 0.)
        else:
            # The ndtr gradient will only be non-zero in the range [-14, 14] for
            # float32 and [-38, 38] for float64.
            self.assert_all_true(actual_grad >= 0.)
        # isfinite checks for NaN and Inf.
        self.assert_all_true(np.isfinite(actual_grad))

        # Versus scipy.
        if not (sp_special and sp_stats):
            return

        expected_grad = sp_stats.norm.pdf(grid)
        if self._use_log:
            expected_grad /= sp_special.ndtr(grid)
            expected_grad[np.isnan(expected_grad)] = 0.
        # Scipy prematurely goes to zero at some places that we don't.  So don't
        # include these in the comparison.
        self.assertAllClose(expected_grad.astype(
            np.float64)[expected_grad < 0],
                            actual_grad.astype(np.float64)[expected_grad < 0],
                            rtol=error_spec.rtol,
                            atol=error_spec.atol)
  def test_score_trick(self):
    d = 5  # Dimension
    num_draws = int(4.5e5)
    seed = tfp_test_util.test_seed()

    # Variance is very high when approximating Forward KL, so we make
    # scale_diag large. This ensures q "covers" p and thus Var_q[p/q] is
    # smaller.
    s = tf.constant(1.)

    def construct_monte_carlo_csiszar_f_divergence(
        func, use_reparametrization=True):
      def _fn(s):
        p = tfd.MultivariateNormalFullCovariance(
            covariance_matrix=tridiag(d, diag_value=1, offdiag_value=0.5))
        q = tfd.MultivariateNormalDiag(scale_diag=tf.tile([s], [d]))
        return tfp.vi.monte_carlo_csiszar_f_divergence(
            f=func,
            p_log_prob=p.log_prob,
            q=q,
            num_draws=num_draws,
            use_reparametrization=use_reparametrization,
            seed=seed)
      return _fn

    approx_kl = construct_monte_carlo_csiszar_f_divergence(
        tfp.vi.kl_reverse)

    approx_kl_self_normalized = construct_monte_carlo_csiszar_f_divergence(
        lambda logu: tfp.vi.kl_reverse(logu, self_normalized=True))

    approx_kl_score_trick = construct_monte_carlo_csiszar_f_divergence(
        tfp.vi.kl_reverse, use_reparametrization=False)

    approx_kl_self_normalized_score_trick = (
        construct_monte_carlo_csiszar_f_divergence(
            lambda logu: tfp.vi.kl_reverse(logu, self_normalized=True),
            use_reparametrization=False))

    def exact_kl(s):
      p = tfd.MultivariateNormalFullCovariance(
          covariance_matrix=tridiag(d, diag_value=1, offdiag_value=0.5))
      q = tfd.MultivariateNormalDiag(scale_diag=tf.tile([s], [d]))
      return tfd.kl_divergence(q, p)

    [
        approx_kl_,
        approx_kl_grad_,
        approx_kl_self_normalized_,
        approx_kl_self_normalized_grad_,
        approx_kl_score_trick_,
        approx_kl_score_trick_grad_,
        approx_kl_self_normalized_score_trick_,
        approx_kl_self_normalized_score_trick_grad_,
        exact_kl_,
        exact_kl_grad_,
    ] = self.evaluate(
        list(value_and_gradient(approx_kl, s)) +
        list(value_and_gradient(approx_kl_self_normalized, s)) +
        list(value_and_gradient(approx_kl_score_trick, s)) +
        list(value_and_gradient(approx_kl_self_normalized_score_trick, s)) +
        list(value_and_gradient(exact_kl, s)))

    # Test average divergence.
    self.assertAllClose(approx_kl_, exact_kl_,
                        rtol=0.04, atol=0.)

    self.assertAllClose(approx_kl_self_normalized_, exact_kl_,
                        rtol=0.08, atol=0.)

    self.assertAllClose(approx_kl_score_trick_, exact_kl_,
                        rtol=0.04, atol=0.)

    self.assertAllClose(approx_kl_self_normalized_score_trick_, exact_kl_,
                        rtol=0.08, atol=0.)

    # Test average gradient-divergence.
    self.assertAllClose(approx_kl_grad_, exact_kl_grad_,
                        rtol=0.04, atol=0.)

    self.assertAllClose(approx_kl_self_normalized_grad_, exact_kl_grad_,
                        rtol=0.04, atol=0.)

    self.assertAllClose(approx_kl_score_trick_grad_, exact_kl_grad_,
                        rtol=0.05, atol=0.)

    self.assertAllClose(
        approx_kl_self_normalized_score_trick_grad_, exact_kl_grad_,
        rtol=0.04, atol=0.)
Beispiel #24
0
def _update_trajectory_grad(previous_kernel_results,
                            previous_state,
                            proposed_state,
                            proposed_velocity,
                            trajectory_jitter,
                            accept_prob,
                            step_size,
                            criterion_fn,
                            max_leapfrog_steps,
                            experimental_shard_axis_names=None,
                            experimental_chain_axis_names=None):
    """Updates the trajectory length."""

    # Compute criterion grads.
    def leapfrog_action(dt):
        # This represents the effect on the criterion value as the state follows the
        # proposed velocity. This implicitly assumes an identity mass matrix.
        def adjust_state(x, v, shard_axes=None):
            broadcasted_dt = distribute_lib.pbroadcast(
                bu.left_justified_expand_dims_like(dt, v), shard_axes)
            return x + broadcasted_dt * v

        adjusted_state = _map_structure_up_to_with_axes(
            proposed_state,
            adjust_state,
            proposed_state,
            proposed_velocity,
            experimental_shard_axis_names=experimental_shard_axis_names)
        return criterion_fn(previous_state, adjusted_state, accept_prob)

    criterion, trajectory_grad = gradient.value_and_gradient(
        leapfrog_action, tf.zeros_like(accept_prob))
    trajectory_grad *= trajectory_jitter

    # Weight by acceptance probability.
    experimental_chain_axis_names = distribute_lib.canonicalize_named_axis(
        experimental_chain_axis_names)
    trajectory_grad = tf.where(accept_prob > 1e-4, trajectory_grad, 0.)
    trajectory_grad = tf.where(tf.math.is_finite(trajectory_grad),
                               trajectory_grad, 0.)
    trajectory_grad = (_reduce_sum_with_axes(
        trajectory_grad * accept_prob, None, experimental_chain_axis_names) /
                       _reduce_sum_with_axes(accept_prob + 1e-20, None,
                                             experimental_chain_axis_names))

    # Compute Adam/RMSProp step size.
    dtype = previous_kernel_results.adaptation_rate.dtype
    iteration_f = tf.cast(previous_kernel_results.step, dtype) + 1.
    msg_adaptation_rate = 0.05
    new_averaged_sq_grad = (
        (1 - msg_adaptation_rate) * previous_kernel_results.averaged_sq_grad +
        msg_adaptation_rate * trajectory_grad**2)
    adjusted_averaged_sq_grad = new_averaged_sq_grad / (
        1. - (1 - msg_adaptation_rate)**iteration_f)
    trajectory_step_size = (previous_kernel_results.adaptation_rate /
                            tf.sqrt(adjusted_averaged_sq_grad + 1e-20))

    # Apply the gradient. Clip absolute value to ~log(2)/2.
    log_update = tf.clip_by_value(trajectory_step_size * trajectory_grad,
                                  -0.35, 0.35)
    new_max_trajectory_length = previous_kernel_results.max_trajectory_length * tf.exp(
        log_update)

    # Iterate averaging.
    average_weight = iteration_f**(-0.5)
    new_averaged_max_trajectory_length = tf.exp(
        average_weight * tf.math.log(new_max_trajectory_length) +
        (1 - average_weight) *
        tf.math.log(1e-10 +
                    previous_kernel_results.averaged_max_trajectory_length))

    # Clip the maximum trajectory length.
    new_max_trajectory_length = _clip_max_trajectory_length(
        new_max_trajectory_length, step_size,
        previous_kernel_results.adaptation_rate, max_leapfrog_steps)

    return previous_kernel_results._replace(
        criterion=criterion,
        max_trajectory_length=new_max_trajectory_length,
        averaged_sq_grad=new_averaged_sq_grad,
        averaged_max_trajectory_length=new_averaged_max_trajectory_length)
def _update_trajectory_grad(previous_kernel_results, previous_state,
                            proposed_state, proposed_velocity,
                            trajectory_jitter, accept_prob, step_size,
                            criterion_fn, max_leapfrog_steps):
  """Updates the trajectory length."""
  # Compute criterion grads.
  def leapfrog_action(dt):
    # This represents the effect on the criterion value as the state follows the
    # proposed velocity. This implicitly assumes an identity mass matrix.
    return criterion_fn(
        previous_state,
        tf.nest.map_structure(
            lambda x, v:  # pylint: disable=g-long-lambda
            (x + mcmc_util.left_justified_expand_dims_like(dt, v) * v),
            proposed_state,
            proposed_velocity),
        accept_prob)

  criterion, trajectory_grad = gradient.value_and_gradient(
      leapfrog_action, tf.zeros_like(accept_prob))
  trajectory_grad *= trajectory_jitter

  # Weight by acceptance probability.
  trajectory_grad = tf.where(accept_prob > 1e-4, trajectory_grad, 0.)
  trajectory_grad = tf.where(
      tf.math.is_finite(trajectory_grad), trajectory_grad, 0.)
  trajectory_grad = (
      tf.reduce_sum(trajectory_grad * accept_prob) /
      tf.reduce_sum(accept_prob + 1e-20))

  # Compute Adam/RMSProp step size.
  dtype = previous_kernel_results.adaptation_rate.dtype
  iteration_f = tf.cast(previous_kernel_results.step, dtype) + 1.
  msg_adaptation_rate = 0.05
  new_averaged_sq_grad = (
      (1 - msg_adaptation_rate) * previous_kernel_results.averaged_sq_grad +
      msg_adaptation_rate * trajectory_grad**2)
  adjusted_averaged_sq_grad = new_averaged_sq_grad / (
      1. - (1 - msg_adaptation_rate)**iteration_f)
  trajectory_step_size = (
      previous_kernel_results.adaptation_rate /
      tf.sqrt(adjusted_averaged_sq_grad + 1e-20))

  # Apply the gradient. Clip absolute value to ~log(2)/2.
  log_update = tf.clip_by_value(trajectory_step_size * trajectory_grad, -0.35,
                                0.35)
  new_max_trajectory_length = previous_kernel_results.max_trajectory_length * tf.exp(
      log_update)

  # Iterate averaging.
  average_weight = iteration_f**(-0.5)
  new_averaged_max_trajectory_length = tf.exp(
      average_weight * tf.math.log(new_max_trajectory_length) +
      (1 - average_weight) *
      tf.math.log(1e-10 +
                  previous_kernel_results.averaged_max_trajectory_length))

  # Clip the maximum trajectory length.
  new_max_trajectory_length = _clip_max_trajectory_length(
      new_max_trajectory_length, step_size,
      previous_kernel_results.adaptation_rate, max_leapfrog_steps)

  return previous_kernel_results._replace(
      criterion=criterion,
      max_trajectory_length=new_max_trajectory_length,
      averaged_sq_grad=new_averaged_sq_grad,
      averaged_max_trajectory_length=new_averaged_max_trajectory_length)
Beispiel #26
0
 def test_log1psquare(self, x, expected_y, expected_dydx):
     x = tf.convert_to_tensor(value=x, dtype=self.dtype)
     y, dydx = tfp_math_gradient.value_and_gradient(numeric.log1psquare, x)
     y_, dydx_ = self.evaluate([y, dydx])
     self.assertAllClose(expected_y, y_)
     self.assertAllClose(expected_dydx, dydx_)
    def _sharded_log_prob_parts_bwd(value, gs):

        map_axes = nest.map_structure_up_to(value, canonicalize_axis_name,
                                            axis_names)

        def flat_log_prob_parts_fn(flat_args):
            args = tf.nest.pack_sequence_as(value, flat_args)
            log_prob_parts = log_prob_parts_fn(args)
            return tf.nest.flatten(log_prob_parts)

        # Operate with flattened lists, to make it easier to tease-out individual
        # outputs for the local grads.
        flat_value = tf.nest.flatten(value)
        flat_gs = tf.nest.flatten(gs)
        local_grads = [
            math_gradient.value_and_gradient(  # pylint: disable=g-complex-comprehension
                lambda *val: flat_log_prob_parts_fn(val)[out_idx],  # pylint: disable=cell-var-from-loop
                flat_value,
                output_gradients=flat_gs[out_idx])[1]
            for out_idx, value_part in enumerate(flat_value)
        ]
        # Transpose.
        local_grads = list(zip(*local_grads))
        # Repack.
        local_grads = tf.nest.pack_sequence_as(value, [
            _DummyGrads(tf.nest.pack_sequence_as(value, v))
            for v in local_grads
        ])

        def value_grad(v, value_axis_names, term_grads):
            """Computes reductions of output gradients.

      A `log_prob_parts` function takes in a list of values and outputs
      a log density for each input to the function. The vector-Jacobian
      product (VJP) of a `log_prob_parts` function thus needs to compute the
      gradient of each output term w.r.t. each input value. This function
      overrides the default VJP of an output term `j` w.r.t to an input
      value `i` to include an all-reduce-sum when:
      1) The gradient of `j` w.r.t. `i` is connected.
      2) `j` is a sharded term and `i` is an unsharded value.

      If these conditions do not hold, the gradient remains the same and
      either corresponds to:
      1) The gradient of a sharded term w.r.t to a sharded value
      2) The gradient of an unsharded term w.r.t. to an unsharded value.
      3) The gradient of an unsharded term w.r.t. to an sharded value.
      In any of these cases, no all-reduce-sum is necessary.
      Args:
        v: The output term of a `log_prob_part` function.
        value_axis_names: A list of axis names indicating whether or not the
          output term is sharded or not, `None` if no sharding.
        term_grads: The gradient of the output term w.r.t. to each of the input
          values to the `log_prob_part` function.

      Returns:
        The vector Jacobian product of `v` w.r.t. the input parts of the
        `log_prob_parts` function.
      """
            term_grads = term_grads.grads

            def psum_grads(term_grad, term_axis_names):
                if term_grad is not None:
                    if not value_axis_names and term_axis_names:
                        # TODO(https://github.com/google/jax/issues/6022): This cast
                        # shouldn't be here.
                        term_grad = tf.cast(
                            psum(term_grad, axis_name=term_axis_names),
                            term_grad.dtype)
                return term_grad

            total_grad = nest.map_structure_up_to(term_grads, psum_grads,
                                                  term_grads, map_axes)
            if all([grad is None for grad in tf.nest.flatten(total_grad)]):
                return None
            return tf.add_n([
                v for v in tf.nest.flatten(total_grad)
                if tfp_custom_gradient.is_valid_gradient(v)
            ])

        out = nest.map_structure_up_to(value, value_grad, value, map_axes,
                                       local_grads)
        return (out, )
Beispiel #28
0
 def f(t):
   t = tf.convert_to_tensor(value=t, dtype=tf.float32)
   f, df = value_and_gradient(lambda t_: tf.math.polyval(poly, t_), t)
   return ValueAndGradient(f=tf.squeeze(f), df=tf.squeeze(df))
Beispiel #29
0
                def augmented_ode_fn(backward_time, augmented_state):
                    """Dynamics function for the augmented system.

          Describes a differential equation that evolves the augmented state
          backwards in time to compute gradients using the adjoint method.
          Augmented state consists of 4 components `(state, adjoint_state,
          vars, constants)` all evaluated at time `backward_time`:

          state: represents the solution of user provided `ode_fn`. The
            structure coincides with the `initial_state`.
          adjoint_state: represents the solution of the adjoint sensitivity
            differential equation as discussed below. Has the same structure
            and shape as `state`.
          variables: represent the solution of the adjoint equation for
            variable gradients. Represented as a `Tuple(Tensor, ...)` with as
            many tensors as there are `variables` variable outside this
            function.
          constants: represent the solution of the adjoint equation for
            constant gradients. Has the same structure and shape as
            `constants` variable outside this function.

          The adjoint sensitivity equation describes the gradient of the
          solution with respect to the value of the solution at a previous
          time t. Its dynamics are given by
          d/dt[adj(t)] = -1 * adj(t) @ jacobian(ode_fn(t, z), z)
          Which is computed as:
          d/dt[adj(t)]_i = -1 * sum_j(adj(t)_j * d/dz_i[ode_fn(t, z)_j)]
          d/dt[adj(t)]_i = -1 * d/dz_i[sum_j(no_grad_adj_j * ode_fn(t, z)_j)]
          where in the last line we moved adj(t)_j under derivative by
          removing gradient from it.

          Adjoint equation for the gradient with respect to every
          `tf.Variable` and constant theta follows:
          d/dt[grad_theta(t)] = -1 * adj(t) @ jacobian(ode_fn(t, z), theta)
          = -1 * d/d theta_i[sum_j(no_grad_adj_j * ode_fn(t, z)_j)]

          Args:
            backward_time: Floating `Tensor` representing current time.
            augmented_state: `Tuple(state, adjoint_state, variable_grads)`

          Returns:
            negative_derivatives: Structure of `Tensor`s equal to backwards
              time derivative of the `state` componnent.
            adjoint_ode: Structure of `Tensor`s equal to backwards time
              derivative of the `adjoint_state` component.
            adjoint_variables_ode: Structure of `Tensor`s equal to backwards
              time derivative of the `vars` component.
            adjoint_constants_ode: Structure of `Tensor`s equal to backwards
              time derivative of the `constants` component.
          """
                    # The negative signs disappears after the change of variables.
                    # The ODE solver cannot handle the case initial_time > final_time
                    # and hence a change of variables backward_time = -time is used.
                    time = -backward_time
                    state, adjoint_state, _, _ = augmented_state

                    # TODO(b/152464477): Doesn't work reliably in TF1.
                    def grad_fn(state, variables, constants):
                        del variables  # We compute these gradients via the GradientTape
                        # capturing them.
                        derivatives = ode_fn(time, state, **constants)
                        adjoint_no_grad = tf.nest.map_structure(
                            tf.stop_gradient, adjoint_state)
                        negative_derivatives = rk_util.weighted_sum(
                            [-1.0], [derivatives])

                        def dot_prod(tensor_a, tensor_b):
                            return tf.reduce_sum(tensor_a * tensor_b)

                        # See docstring for details.
                        adjoint_dot_derivatives = tf.nest.map_structure(
                            dot_prod, adjoint_no_grad, derivatives)
                        adjoint_dot_derivatives = tf.squeeze(
                            tf.add_n(tf.nest.flatten(adjoint_dot_derivatives)))
                        return adjoint_dot_derivatives, negative_derivatives

                    values = (state, tuple(variables), constants)
                    ((_, negative_derivatives),
                     gradients) = tfp_gradient.value_and_gradient(
                         grad_fn, values, has_aux=True, use_gradient_tape=True)

                    (adjoint_ode, adjoint_variables_ode,
                     adjoint_constants_ode) = tf.nest.map_structure(
                         lambda v, g: tf.zeros_like(v)
                         if g is None else g, values, tuple(gradients))
                    return (negative_derivatives, adjoint_ode,
                            adjoint_variables_ode, adjoint_constants_ode)
Beispiel #30
0
 def testSqrtWithFiniteGradsWithDynamicShape(self):
     x = tf1.placeholder_with_default([1.], shape=[None])
     _, grad_tf_sqrt = value_and_gradient(tf.sqrt, x)
     _, grad_safe_sqrt = value_and_gradient(util.sqrt_with_finite_grads, x)
     self.assertAllEqual(*self.evaluate([grad_tf_sqrt, grad_safe_sqrt]))