Example #1
0
def prepare_tuple_argument(arg, n, arg_name, validate_args=False):
    """Helper which processes `Tensor`s to tuples in standard form."""
    # Short-circuiting incoming lists and tuples here avoids both
    # Tensor packing / unpacking and numpy 1.20.+ pickiness about
    # np.array(tuple of Tensor).
    if isinstance(arg, (tuple, list)):
        if len(arg) == n:
            return tuple(arg)
        if len(arg) == 1:
            return (arg[0], ) * n

    arg_size = ps.size(arg)
    arg_size_ = tf.get_static_value(arg_size)
    assertions = []
    if arg_size_ is not None:
        if arg_size_ not in (1, n):
            raise ValueError(
                'The size of `{}` must be equal to `1` or to the rank '
                'of the convolution (={}). Saw size = {}'.format(
                    arg_name, n, arg_size_))
    elif validate_args:
        assertions.append(
            assert_util.assert_equal(
                ps.logical_or(arg_size == 1, arg_size == n),
                True,
                message=
                ('The size of `{}` must be equal to `1` or to the rank of the '
                 'convolution (={})'.format(arg_name, n))))

    with tf.control_dependencies(assertions):
        arg = ps.broadcast_to(arg, shape=[n])
        arg = ps.unstack(arg, num=n)
        return arg
  def _scan(level, elems):
    """Perform scan on `elems`."""
    elem_length = prefer_static.shape(elems[0])[0]

    # Apply `fn` to reduce adjacent pairs to a single entry.
    a = [elem[0:-1:2] for elem in elems]
    b = [elem[1::2] for elem in elems]
    reduced_elems = lowered_fn(a, b)

    def handle_base_case_elem_length_two():
      return [tf.concat([elem[0:1], reduced_elem], axis=0)
              for (reduced_elem, elem) in zip(reduced_elems, elems)]

    def handle_base_case_elem_length_three():
      reduced_reduced_elems = lowered_fn(
          reduced_elems, [elem[2:3] for elem in elems])
      return [
          tf.concat([elem[0:1], reduced_elem, reduced_reduced_elem], axis=0)
          for (reduced_reduced_elem, reduced_elem, elem)
          in zip(reduced_reduced_elems, reduced_elems, elems)]

    # Base case of recursion: assumes `elem_length` is 2 or 3.
    at_base_case = prefer_static.logical_or(
        prefer_static.equal(elem_length, 2),
        prefer_static.equal(elem_length, 3))
    base_value = lambda: prefer_static.cond(  # pylint: disable=g-long-lambda
        prefer_static.equal(elem_length, 2),
        handle_base_case_elem_length_two,
        handle_base_case_elem_length_three)

    if level <= 0:
      return base_value()

    def recursive_case():
      """Evaluate the next step of the recursion."""
      odd_elems = _scan(level - 1, reduced_elems)

      def even_length_case():
        return lowered_fn([odd_elem[:-1] for odd_elem in odd_elems],
                          [elem[2::2] for elem in elems])

      def odd_length_case():
        return lowered_fn([odd_elem for odd_elem in odd_elems],
                          [elem[2::2] for elem in elems])

      results = prefer_static.cond(
          prefer_static.equal(elem_length % 2, 0),
          even_length_case,
          odd_length_case)

      # The first element of a scan is the same as the first element
      # of the original `elems`.
      even_elems = [tf.concat([elem[0:1], result], axis=0)
                    for (elem, result) in zip(elems, results)]
      return list(map(_interleave, even_elems, odd_elems))

    return prefer_static.cond(at_base_case, base_value, recursive_case)
Example #3
0
  def test_step_indices_to_trace(self):
    num_particles = 1024
    (particles_1_3,
     log_weights_1_3,
     parent_indices_1_3,
     incremental_log_marginal_likelihood_1_3) = self.evaluate(
         tfp.experimental.mcmc.particle_filter(
             observations=tf.convert_to_tensor([1., 3., 5., 7., 9.]),
             initial_state_prior=tfd.Normal(0., 1.),
             transition_fn=lambda _, state: tfd.Normal(state, 10.),
             observation_fn=lambda _, state: tfd.Normal(state, 0.1),
             num_particles=num_particles,
             trace_criterion_fn=lambda s, r: ps.logical_or(  # pylint: disable=g-long-lambda
                 ps.equal(r.steps, 2),
                 ps.equal(r.steps, 4)),
             static_trace_allocation_size=2,
             seed=test_util.test_seed()))
    self.assertLen(particles_1_3, 2)
    self.assertLen(log_weights_1_3, 2)
    self.assertLen(parent_indices_1_3, 2)
    self.assertLen(incremental_log_marginal_likelihood_1_3, 2)
    means = np.sum(np.exp(log_weights_1_3) * particles_1_3, axis=1)
    self.assertAllClose(means, [3., 7.], atol=1.)

    (final_particles,
     final_log_weights,
     final_cumulative_lp) = self.evaluate(
         tfp.experimental.mcmc.particle_filter(
             observations=tf.convert_to_tensor([1., 3., 5., 7., 9.]),
             initial_state_prior=tfd.Normal(0., 1.),
             transition_fn=lambda _, state: tfd.Normal(state, 10.),
             observation_fn=lambda _, state: tfd.Normal(state, 0.1),
             num_particles=num_particles,
             trace_fn=lambda s, r: (s.particles,  # pylint: disable=g-long-lambda
                                    s.log_weights,
                                    r.accumulated_log_marginal_likelihood),
             trace_criterion_fn=None,
             seed=test_util.test_seed()))
    self.assertLen(final_particles, num_particles)
    self.assertLen(final_log_weights, num_particles)
    self.assertEqual(final_cumulative_lp.shape, ())
    means = np.sum(np.exp(final_log_weights) * final_particles)
    self.assertAllClose(means, 9., atol=1.5)
Example #4
0
def prepare_tuple_argument(arg, n, arg_name, validate_args=False):
    """Helper which processes `Tensor`s to tuples in standard form."""
    arg_size = ps.size(arg)
    arg_size_ = tf.get_static_value(arg_size)
    assertions = []
    if arg_size_ is not None:
        if arg_size_ not in (1, n):
            raise ValueError(
                'The size of `{}` must be equal to `1` or to the rank '
                'of the convolution (={}). Saw size = {}'.format(
                    arg_name, n, arg_size_))
    elif validate_args:
        assertions.append(
            assert_util.assert_equal(
                ps.logical_or(arg_size == 1, arg_size == n),
                True,
                message=
                ('The size of `{}` must be equal to `1` or to the rank of the '
                 'convolution (={})'.format(arg_name, n))))

    with tf.control_dependencies(assertions):
        arg = ps.broadcast_to(arg, shape=[n])
        arg = ps.unstack(arg, num=n)
        return arg
def fit_one_step(
    model_matrix,
    response,
    model,
    model_coefficients_start=None,
    predicted_linear_response_start=None,
    l2_regularizer=None,
    dispersion=None,
    offset=None,
    learning_rate=None,
    fast_unsafe_numerics=True,
    l2_regularization_penalty_factor=None,
    name=None):
  """Runs one step of Fisher scoring.

  Args:
    model_matrix: (Batch of) `float`-like, matrix-shaped `Tensor` where each row
      represents a sample's features.
    response: (Batch of) vector-shaped `Tensor` where each element represents a
      sample's observed response (to the corresponding row of features). Must
      have same `dtype` as `model_matrix`.
    model: `tfp.glm.ExponentialFamily`-like instance used to construct the
      negative log-likelihood loss, gradient, and expected Hessian (i.e., the
      Fisher information matrix).
    model_coefficients_start: Optional (batch of) vector-shaped `Tensor`
      representing the initial model coefficients, one for each column in
      `model_matrix`. Must have same `dtype` as `model_matrix`.
      Default value: Zeros.
    predicted_linear_response_start: Optional `Tensor` with `shape`, `dtype`
      matching `response`; represents `offset` shifted initial linear
      predictions based on `model_coefficients_start`.
      Default value: `offset` if `model_coefficients is None`, and
      `tf.linalg.matvec(model_matrix, model_coefficients_start) + offset`
      otherwise.
    l2_regularizer: Optional scalar `Tensor` representing L2 regularization
      penalty, i.e.,
      `loss(w) = sum{-log p(y[i]|x[i],w) : i=1..n} + l2_regularizer ||w||_2^2`.
      Default value: `None` (i.e., no L2 regularization).
    dispersion: Optional (batch of) `Tensor` representing `response` dispersion,
      i.e., as in, `p(y|theta) := exp((y theta - A(theta)) / dispersion)`.
      Must broadcast with rows of `model_matrix`.
      Default value: `None` (i.e., "no dispersion").
    offset: Optional `Tensor` representing constant shift applied to
      `predicted_linear_response`.  Must broadcast to `response`.
      Default value: `None` (i.e., `tf.zeros_like(response)`).
    learning_rate: Optional (batch of) scalar `Tensor` used to dampen iterative
      progress. Typically only needed if optimization diverges, should be no
      larger than `1` and typically very close to `1`.
      Default value: `None` (i.e., `1`).
    fast_unsafe_numerics: Optional Python `bool` indicating if solve should be
      based on Cholesky or QR decomposition.
      Default value: `True` (i.e., "prefer speed via Cholesky decomposition").
    l2_regularization_penalty_factor: Optional (batch of) vector-shaped
      `Tensor`, representing a separate penalty factor to apply to each model
      coefficient, length equal to columns in `model_matrix`. Each penalty
      factor multiplies l2_regularizer to allow differential regularization. Can
      be 0 for some coefficients, which implies no regularization. Default is 1
      for all coefficients.
      `loss(w) = sum{-log p(y[i]|x[i],w) : i=1..n} + l2_regularizer ||w *
        l2_regularization_penalty_factor||_2^2`
    name: Python `str` used as name prefix to ops created by this function.
      Default value: `"fit_one_step"`.

  Returns:
    model_coefficients: (Batch of) vector-shaped `Tensor`; represents the
      next estimate of the model coefficients, one for each column in
      `model_matrix`.
    predicted_linear_response: `response`-shaped `Tensor` representing linear
      predictions based on new `model_coefficients`, i.e.,
      `tf.linalg.matvec(model_matrix, model_coefficients_next) + offset`.
  """
  with tf.name_scope(name or 'fit_one_step'):

    [
        model_matrix,
        response,
        model_coefficients_start,
        predicted_linear_response_start,
        offset,
    ] = prepare_args(
        model_matrix,
        response,
        model_coefficients_start,
        predicted_linear_response_start,
        offset)

    # Compute: mean, grad(mean, predicted_linear_response_start), and variance.
    mean, variance, grad_mean = model(predicted_linear_response_start)

    # If either `grad_mean` or `variance is non-finite or zero, then we'll
    # replace it with a value such that the row is zeroed out. Although this
    # procedure may seem circuitous, it is necessary to ensure this algorithm is
    # itself differentiable.
    is_valid = (
        tf.math.is_finite(grad_mean) & tf.not_equal(grad_mean, 0.)
        & tf.math.is_finite(variance) & (variance > 0.))

    def mask_if_invalid(x, mask):
      return tf.where(
          is_valid, x, np.array(mask, dtype_util.as_numpy_dtype(x.dtype)))

    # Run one step of iteratively reweighted least-squares.
    # Compute "`z`", the adjusted predicted linear response.
    # z = predicted_linear_response_start
    #     + learning_rate * (response - mean) / grad_mean
    z = (response - mean) / mask_if_invalid(grad_mean, 1.)
    # TODO(jvdillon): Rather than use learning rate, we should consider using
    # backtracking line search.
    if learning_rate is not None:
      z *= learning_rate[..., tf.newaxis]
    z += predicted_linear_response_start
    if offset is not None:
      z -= offset

    # Compute "`w`", the per-sample weight.
    if dispersion is not None:
      # For convenience, we'll now scale the variance by the dispersion factor.
      variance *= dispersion
    w = (
        mask_if_invalid(grad_mean, 0.) *
        tf.math.rsqrt(mask_if_invalid(variance, np.inf)))

    a = model_matrix * w[..., tf.newaxis]
    b = z * w
    # Solve `min{ || A @ model_coefficients - b ||_2**2 : model_coefficients }`
    # where `@` denotes `matmul`.

    if l2_regularizer is None:
      l2_regularizer = np.array(0, dtype_util.as_numpy_dtype(a.dtype))
    else:
      l2_regularizer_ = distribution_util.maybe_get_static_value(
          l2_regularizer, dtype_util.as_numpy_dtype(a.dtype))
      if l2_regularizer_ is not None:
        l2_regularizer = l2_regularizer_

    def _embed_l2_regularization():
      """Adds synthetic observations to implement L2 regularization."""
      # `tf.matrix_solve_ls` does not respect the `l2_regularization` argument
      # when `fast_unsafe_numerics` is `False`. This function  adds synthetic
      # observations to the data to implement the regularization instead.
      # Adding observations `sqrt(l2_regularizer) * I` is mathematically
      # equivalent to adding the term
      # `-l2_regularizer ||coefficients||_2**2` to the log-likelihood.
      num_model_coefficients = num_cols(model_matrix)
      batch_shape = tf.shape(model_matrix)[:-2]
      if l2_regularization_penalty_factor is None:
        eye = tf.eye(
            num_model_coefficients, batch_shape=batch_shape, dtype=a.dtype)
      else:
        eye = tf.linalg.tensor_diag(
            tf.cast(l2_regularization_penalty_factor, dtype=a.dtype))
        broadcasted_shape = prefer_static.concat(
            [batch_shape, [num_model_coefficients, num_model_coefficients]],
            axis=0)
        eye = tf.broadcast_to(eye, broadcasted_shape)
      a_ = tf.concat([a, tf.sqrt(l2_regularizer) * eye], axis=-2)
      b_ = distribution_util.pad(
          b, count=num_model_coefficients, axis=-1, back=True)
      # Return l2_regularizer=0 since its now embedded.
      l2_regularizer_ = np.array(0, dtype_util.as_numpy_dtype(a.dtype))
      return a_, b_, l2_regularizer_

    a, b, l2_regularizer = prefer_static.cond(
        prefer_static.reduce_all([
            prefer_static.logical_or(
                not(fast_unsafe_numerics),
                l2_regularization_penalty_factor is not None),
            l2_regularizer > 0.
        ]),
        _embed_l2_regularization,
        lambda: (a, b, l2_regularizer))

    model_coefficients_next = tf.linalg.lstsq(
        a,
        b[..., tf.newaxis],
        fast=fast_unsafe_numerics,
        l2_regularizer=l2_regularizer,
        name='model_coefficients_next')
    model_coefficients_next = model_coefficients_next[..., 0]

    # TODO(b/79122261): The approach used in `matrix_solve_ls` could be made
    # faster by avoiding explicitly forming Q and instead keeping the
    # factorization in 'implicit' form with stacked (rescaled) Householder
    # vectors underneath the 'R' and then applying the (accumulated)
    # reflectors in the appropriate order to apply Q'. However, we don't
    # presently do this because we lack core TF functionality. For reference,
    # the vanilla QR approach is:
    #   q, r = tf.linalg.qr(a)
    #   c = tf.matmul(q, b, adjoint_a=True)
    #   model_coefficients_next = tf.matrix_triangular_solve(
    #       r, c, lower=False, name='model_coefficients_next')

    predicted_linear_response_next = compute_predicted_linear_response(
        model_matrix,
        model_coefficients_next,
        offset,
        name='predicted_linear_response_next')

    return model_coefficients_next, predicted_linear_response_next
  def _scan(level, elems):
    """Perform scan on `elems`."""
    elem_length = ps.shape(elems[0])[axis]

    # Apply `fn` to reduce adjacent pairs to a single entry.
    a = [slice_elem(elem, 0, -1, step=2) for elem in elems]
    b = [slice_elem(elem, 1, None, step=2) for elem in elems]
    reduced_elems = lowered_fn(a, b)

    def handle_base_case_elem_length_two():
      return [tf.concat([slice_elem(elem, 0, 1), reduced_elem], axis=axis)
              for (reduced_elem, elem) in zip(reduced_elems, elems)]

    def handle_base_case_elem_length_three():
      reduced_reduced_elems = lowered_fn(
          reduced_elems,
          [slice_elem(elem, 2, 3) for elem in elems])
      return [
          tf.concat([slice_elem(elem, 0, 1),  # pylint: disable=g-complex-comprehension
                     reduced_elem,
                     reduced_reduced_elem], axis=axis)
          for (reduced_reduced_elem, reduced_elem, elem)
          in zip(reduced_reduced_elems, reduced_elems, elems)]

    # Base case of recursion: assumes `elem_length` is 2 or 3.
    at_base_case = ps.logical_or(
        ps.equal(elem_length, 2),
        ps.equal(elem_length, 3))
    base_value = lambda: ps.cond(  # pylint: disable=g-long-lambda
        ps.equal(elem_length, 2),
        handle_base_case_elem_length_two,
        handle_base_case_elem_length_three)

    if level <= 0:
      return base_value()

    def recursive_case():
      """Evaluate the next step of the recursion."""
      odd_elems = _scan(level - 1, reduced_elems)

      def even_length_case():
        return lowered_fn(
            [slice_elem(odd_elem, 0, -1) for odd_elem in odd_elems],
            [slice_elem(elem, 2, None, 2) for elem in elems])

      def odd_length_case():
        return lowered_fn([odd_elem for odd_elem in odd_elems],
                          [slice_elem(elem, 2, None, 2) for elem in elems])

      results = ps.cond(
          ps.equal(elem_length % 2, 0),
          even_length_case,
          odd_length_case)

      # The first element of a scan is the same as the first element
      # of the original `elems`.
      even_elems = [tf.concat([slice_elem(elem, 0, 1), result], axis=axis)
                    for (elem, result) in zip(elems, results)]
      return list(map(lambda a, b: _interleave(a, b, axis=axis),
                      even_elems,
                      odd_elems))

    return ps.cond(at_base_case, base_value, recursive_case)