Example #1
0
    def _body(i, state, num_steps_traced, trace_arrays):
      elem = elems_array.read(i)
      state = loop_fn(state, elem)

      trace_arrays, num_steps_traced = ps.cond(
          trace_criterion_fn(state) if trace_criterion_fn else True,
          lambda: (trace_one_step(num_steps_traced, trace_arrays, state),  # pylint: disable=g-long-lambda
                   num_steps_traced + 1),
          lambda: (trace_arrays, num_steps_traced))

      return i + 1, state, num_steps_traced, trace_arrays
Example #2
0
def _get_transpose_conv_dilated_padding(filter_dim, stride, dilation, padding):
    """Zero-padding for inputs dilated by strides."""
    tot_filter_dim = filter_dim + (filter_dim - 1) * (dilation - 1)
    if padding == 'VALID':
        tot_pad = tot_filter_dim + stride - 2 + ps.maximum(
            tot_filter_dim - stride, 0)
    elif padding == 'SAME':
        tot_pad = tot_filter_dim + stride - 2
    return ps.cond(filter_dim >= stride, lambda:
                   (tot_pad - tot_pad // 2 - stride + 1, tot_pad // 2), lambda:
                   (filter_dim - stride, tot_pad - filter_dim + 1))
def _secant2_inner(value_and_gradients_function, initial_args, val_0, val_c,
                   f_lim, sufficient_decrease_param, curvature_param):
    """Helper function for secant square."""
    # Apply the `update` function on active branch members to squeeze their
    # bracketing interval.
    update_result = update(value_and_gradients_function,
                           initial_args.left,
                           initial_args.right,
                           val_c,
                           f_lim,
                           active=initial_args.active)

    # Update active and failed flags, update left/right on non-failed entries.
    active = initial_args.active & ~update_result.failed
    failed = initial_args.failed | update_result.failed
    val_left = _val_where(active, update_result.left, initial_args.left)
    val_right = _val_where(active, update_result.right, initial_args.right)

    # Check if new `c` points should be generated.
    updated_left = active & tf.equal(val_left.x, val_c.x)
    updated_right = active & tf.equal(val_right.x, val_c.x)
    is_new = updated_left | updated_right

    next_c = tf.where(updated_left, _secant(initial_args.left, val_left),
                      val_c.x)
    next_c = tf.where(updated_right, _secant(initial_args.right, val_right),
                      next_c)
    in_range = (val_left.x <= next_c) & (next_c <= val_right.x)

    # Figure out if an extra function evaluation is needed for new `c` points.
    needs_extra_eval = tf.reduce_any(input_tensor=in_range & is_new)
    num_evals = initial_args.num_evals + update_result.num_evals
    num_evals = num_evals + tf.cast(needs_extra_eval, num_evals.dtype)

    next_args = _Secant2Result(
        active=active & in_range,  # No longer active if `c` is out of range.
        converged=initial_args.converged,
        failed=failed,
        num_evals=num_evals,
        left=val_left,
        right=val_right)

    def _apply_inner_update():
        next_val_c = prefer_static.cond(
            needs_extra_eval,
            (lambda: _apply(value_and_gradients_function, next_c)),
            (lambda: val_c))
        return _secant2_inner_update(value_and_gradients_function, next_args,
                                     val_0, next_val_c, f_lim,
                                     sufficient_decrease_param,
                                     curvature_param)

    return prefer_static.cond(tf.reduce_any(input_tensor=next_args.active),
                              _apply_inner_update, lambda: next_args)
Example #4
0
def pad_batch_dimension_for_multiple_chains(
    observed_time_series, model, chain_batch_shape):
  """"Expand the observed time series with extra batch dimension(s)."""
  # Running with multiple chains introduces an extra batch dimension. In
  # general we also need to pad the observed time series with a matching batch
  # dimension.
  #
  # For example, suppose our model has batch shape [3, 4] and
  # the observed time series has shape `concat([[5], [3, 4], [100])`,
  # corresponding to `sample_shape`, `batch_shape`, and `num_timesteps`
  # respectively. The model will produce distributions with batch shape
  # `concat([chain_batch_shape, [3, 4]])`, so we pad `observed_time_series` to
  # have matching shape `[5, 1, 3, 4, 100]`, where the added `1` dimension
  # between the sample and batch shapes will broadcast to `chain_batch_shape`.

  observed_time_series = maybe_expand_trailing_dim(
      observed_time_series)  # Guarantee `event_ndims=2`

  event_ndims = 2  # event_shape = [num_timesteps, observation_size=1]

  model_batch_ndims = (
      model.batch_shape.ndims if model.batch_shape.ndims is not None else
      tf.shape(input=model.batch_shape_tensor())[0])

  # Compute ndims from chain_batch_shape.
  chain_batch_shape = tf.convert_to_tensor(
      value=chain_batch_shape, name='chain_batch_shape', dtype=tf.int32)
  if not chain_batch_shape.shape.is_fully_defined():
    raise ValueError('Batch shape must have static rank. (given: {})'.format(
        chain_batch_shape))
  if chain_batch_shape.shape.ndims == 0:  # expand int `k` to `[k]`.
    chain_batch_shape = chain_batch_shape[tf.newaxis]
  chain_batch_ndims = tf.compat.dimension_value(chain_batch_shape.shape[0])

  def do_padding(observed_time_series_tensor):
    current_sample_shape = tf.shape(
        input=observed_time_series_tensor)[:-(model_batch_ndims + event_ndims)]
    current_batch_and_event_shape = tf.shape(
        input=observed_time_series_tensor)[-(model_batch_ndims + event_ndims):]
    return tf.reshape(
        tensor=observed_time_series_tensor,
        shape=tf.concat([
            current_sample_shape,
            tf.ones([chain_batch_ndims], dtype=tf.int32),
            current_batch_and_event_shape], axis=0))

  # Padding is only needed if the observed time series has sample shape.
  observed_time_series = prefer_static.cond(
      (dist_util.prefer_static_rank(observed_time_series) >
       model_batch_ndims + event_ndims),
      lambda: do_padding(observed_time_series),
      lambda: observed_time_series)

  return observed_time_series
Example #5
0
    def _loop_body(curr_interval):
        """The loop body."""
        secant2_raw_result = hzl.secant2(value_and_gradients_function, val_0,
                                         curr_interval, f_lim,
                                         sufficient_decrease_param,
                                         curvature_param)
        secant2_result = HagerZhangLineSearchResult(
            converged=secant2_raw_result.converged,
            failed=secant2_raw_result.failed,
            iterations=curr_interval.iterations + 1,
            func_evals=secant2_raw_result.num_evals,
            left=secant2_raw_result.left,
            right=secant2_raw_result.right)

        should_check_shrinkage = ~(secant2_result.converged
                                   | secant2_result.failed)

        def _do_check_shrinkage():
            """Check if interval has shrinked enough."""
            old_width = curr_interval.right.x - curr_interval.left.x
            new_width = secant2_result.right.x - secant2_result.left.x
            sufficient_shrinkage = new_width < old_width * shrinkage_param
            func_is_flat = (
                _very_close(curr_interval.left.f, curr_interval.right.f)
                & _very_close(secant2_result.left.f, secant2_result.right.f))

            new_converged = (should_check_shrinkage & sufficient_shrinkage
                             & func_is_flat)
            needs_inner_bisect = should_check_shrinkage & ~sufficient_shrinkage

            inner_bisect_args = secant2_result._replace(
                converged=secant2_result.converged | new_converged)

            def _apply_inner_bisect():
                return _line_search_inner_bisection(
                    value_and_gradients_function, inner_bisect_args,
                    needs_inner_bisect, f_lim)

            return prefer_static.cond(
                tf.reduce_any(input_tensor=needs_inner_bisect),
                _apply_inner_bisect, lambda: inner_bisect_args)

        next_args = prefer_static.cond(
            tf.reduce_any(input_tensor=should_check_shrinkage),
            _do_check_shrinkage, lambda: secant2_result)

        interval_shrunk = (~next_args.failed
                           & _very_close(next_args.left.x, next_args.right.x))
        return [
            next_args._replace(converged=next_args.converged | interval_shrunk)
        ]
def _interleave(a, b):
  """Interleaves two `Tensor`s along their first axis."""
  # [a b c ...] [d e f ...] -> [a d b e c f ...]
  num_elems_a = prefer_static.shape(a)[0]
  num_elems_b = prefer_static.shape(b)[0]

  def _interleave_with_b(a):
    return tf.reshape(
        tf.stack([a, b], axis=1),
        prefer_static.concat([[2 * num_elems_b],
                              prefer_static.shape(a)[1:]], axis=0))
  return prefer_static.cond(
      prefer_static.equal(num_elems_a, num_elems_b + 1),
      lambda: tf.concat([_interleave_with_b(a[:-1]), a[-1:]], axis=0),
      lambda: _interleave_with_b(a))
Example #7
0
 def _expand_and_maybe_replace():
     """Performs the expansion step."""
     expanded = face_centroid + expansion * (reflected - face_centroid)
     expanded_objective_value = objective_function(expanded)
     expanded_is_better = (expanded_objective_value <
                           objective_at_reflected)
     accept_expanded_fn = lambda: (expanded, expanded_objective_value)
     accept_reflected_fn = lambda: (reflected, objective_at_reflected)
     next_pt, next_objective_value = prefer_static.cond(
         expanded_is_better, accept_expanded_fn, accept_reflected_fn)
     next_simplex = _replace_at_index(simplex, worst_index, next_pt)
     next_objective_at_simplex = _replace_at_index(objective_values,
                                                   worst_index,
                                                   next_objective_value)
     return False, next_simplex, next_objective_at_simplex, 1
Example #8
0
def _secant2_inner_update(value_and_gradients_function,
                          initial_args,
                          val_0,
                          val_c,
                          f_lim,
                          sufficient_decrease_param,
                          curvature_param):
  """Helper function for secant-square step."""
  # Fail if `val_c` is no longer finite.
  new_failed = initial_args.active & ~is_finite(val_c)
  active = initial_args.active & ~new_failed
  failed = initial_args.failed | new_failed

  # We converge when we find a point satisfying the Wolfe conditions, in those
  # cases we set `val_left = val_right = val_c`.
  found_wolfe = active & _satisfies_wolfe(
      val_0, val_c, f_lim, sufficient_decrease_param, curvature_param)
  val_left = val_where(found_wolfe, val_c, initial_args.left)
  val_right = val_where(found_wolfe, val_c, initial_args.right)
  converged = initial_args.converged | found_wolfe
  active = active & ~found_wolfe

  # If any active batch members remain, we apply the `update` function to
  # squeeze further their corresponding left/right bracketing interval.
  def _apply_update():
    update_result = update(
        value_and_gradients_function, val_left, val_right, val_c, f_lim,
        active=active)
    return _Secant2Result(
        active=tf.zeros_like(active),  # End of secant2, no actives anymore.
        converged=converged,
        failed=failed | update_result.failed,
        num_evals=initial_args.num_evals + update_result.num_evals,
        left=update_result.left,
        right=update_result.right)

  # Otherwise just return the current results.
  def _default():
    return _Secant2Result(
        active=active,
        converged=converged,
        failed=failed,
        num_evals=initial_args.num_evals,
        left=val_left,
        right=val_right)

  return prefer_static.cond(
      tf.reduce_any(active), _apply_update, _default)
Example #9
0
        def _default_fn():
            """Default action."""
            new_width = secant2_result.right.x - secant2_result.left.x
            old_width = val_right.x - val_left.x
            sufficient_shrinkage = new_width < shrinkage_param * old_width

            def _sufficient_shrinkage_fn():
                """Action to perform if secant2 shrank the interval sufficiently."""
                func_is_flat = (
                    (tf.math.nextafter(val_left.f, val_right.f) >= val_right.f)
                    & (tf.math.nextafter(secant2_result.left.f,
                                         secant2_result.right.f) >=
                       secant2_result.right.f))
                is_flat_retval = _LineSearchInnerResult(
                    iteration=iteration,
                    found_wolfe=True,
                    failed=False,
                    num_evals=evals,
                    left=secant2_result.left,
                    right=secant2_result.left)
                not_is_flat_retval = _LineSearchInnerResult(
                    iteration=iteration,
                    found_wolfe=False,
                    failed=False,
                    num_evals=evals,
                    left=secant2_result.left,
                    right=secant2_result.right)

                return prefer_static.cond(func_is_flat,
                                          true_fn=lambda: is_flat_retval,
                                          false_fn=lambda: not_is_flat_retval)

            def _insufficient_shrinkage_fn():
                """Action to perform if secant2 didn't shrink the interval enough."""
                update_result = _line_search_inner_bisection(
                    value_and_gradients_function, secant2_result.left,
                    secant2_result.right, f_lim)
                return _LineSearchInnerResult(iteration=iteration,
                                              found_wolfe=False,
                                              failed=update_result.failed,
                                              num_evals=evals +
                                              update_result.num_evals,
                                              left=update_result.left,
                                              right=update_result.right)

            return prefer_static.cond(sufficient_shrinkage,
                                      true_fn=_sufficient_shrinkage_fn,
                                      false_fn=_insufficient_shrinkage_fn)
  def _marginal_hidden_probs(self):
    """Compute marginal pdf for each individual observable."""

    num_states = self.transition_distribution.batch_shape_tensor()[-1]
    log_init = _extract_log_probs(num_states,
                                  self.initial_distribution)
    initial_log_probs = tf.broadcast_to(log_init,
                                        ps.concat([self.batch_shape_tensor(),
                                                   [num_states]],
                                                  axis=0))

    # initial_log_probs :: batch_shape num_states

    no_transition_result = initial_log_probs[tf.newaxis, ...]

    def _scan_multiple_steps():
      """Perform `scan` operation when `num_steps` > 1."""

      transition_log_probs = _extract_log_probs(num_states,
                                                self.transition_distribution)

      def forward_step(log_probs, _):
        result = _log_vector_matrix(log_probs, transition_log_probs)
        # We know that `forward_step` must preserve the shape of the
        # tensor of probabilities of each state. This is because
        # the transition matrix must be square. But TensorFlow might
        # not know this so we explicitly tell it that the result has the
        # same shape.
        tensorshape_util.set_shape(result, log_probs.shape)
        return result

      dummy_index = tf.zeros(self._num_steps - 1, dtype=tf.float32)

      forward_log_probs = tf.scan(forward_step, dummy_index,
                                  initializer=initial_log_probs,
                                  name='forward_log_probs')

      result = tf.concat([[initial_log_probs], forward_log_probs],
                         axis=0)
      return result
    forward_log_probs = ps.cond(
        self._num_steps > 1,
        _scan_multiple_steps,
        lambda: no_transition_result)

    return tf.exp(forward_log_probs)
Example #11
0
    def _contraction():
        """Performs a contraction."""
        contracted = face_centroid + contraction * (reflected - face_centroid)
        objective_at_contracted = objective_function(contracted)
        is_contracted_acceptable = objective_at_contracted <= objective_at_reflected

        def _accept_contraction():
            next_simplex = _replace_at_index(simplex, worst_index, contracted)
            objective_at_next_simplex = _replace_at_index(
                objective_values, worst_index, objective_at_contracted)
            return (False, next_simplex, objective_at_next_simplex, 1)

        def _reject_contraction():
            return _shrink_towards_best(objective_function, simplex,
                                        best_index, shrinkage,
                                        batch_evaluate_objective)

        return prefer_static.cond(is_contracted_acceptable,
                                  _accept_contraction, _reject_contraction)
Example #12
0
    def _contraction():
        """Performs a contraction."""
        contracted = face_centroid - contraction * (face_centroid -
                                                    simplex[worst_index])
        objective_at_contracted = objective_function(contracted)
        is_contracted_acceptable = objective_at_contracted <= worst_objective_value

        def _accept_contraction():
            next_simplex = _replace_at_index(simplex, worst_index, contracted)
            objective_at_next_simplex = _replace_at_index(
                objective_values, worst_index, objective_at_contracted)
            return (False, next_simplex, objective_at_next_simplex,
                    np.int32(1))

        def _reject_contraction():
            return _shrink_towards_best(objective_function, simplex,
                                        best_index, shrinkage,
                                        batch_evaluate_objective)

        return ps.cond(is_contracted_acceptable, _accept_contraction,
                       _reject_contraction)
    def recursive_case():
      """Evaluate the next step of the recursion."""
      odd_elems = _scan(level - 1, reduced_elems)

      def even_length_case():
        return lowered_fn([odd_elem[:-1] for odd_elem in odd_elems],
                          [elem[2::2] for elem in elems])

      def odd_length_case():
        return lowered_fn([odd_elem for odd_elem in odd_elems],
                          [elem[2::2] for elem in elems])

      results = prefer_static.cond(
          prefer_static.equal(elem_length % 2, 0),
          even_length_case,
          odd_length_case)

      # The first element of a scan is the same as the first element
      # of the original `elems`.
      even_elems = [tf.concat([elem[0:1], result], axis=0)
                    for (elem, result) in zip(elems, results)]
      return list(map(_interleave, even_elems, odd_elems))
Example #14
0
  def _maybe_warn_increased_dof(self,
                                component_name,
                                component_ldj,
                                increased_dof):
    """Warns or raises when `increased_dof` is True."""
    # Short-circuit when the component LDJ is statically zero.
    if (tf.get_static_value(tf.rank(component_ldj)) == 0
        and tf.get_static_value(component_ldj) == 0):
      return

    # Short-circuit when increased_dof is statically False.
    increased_dof_ = tf.get_static_value(increased_dof)
    if increased_dof_ is False:  # pylint: disable=g-bool-id-comparison
      return

    error_message = (
        'Nested component "{}" in composition "{}" operates on inputs '
        'with increased degrees of freedom. This may result in an '
        'incorrect log_det_jacobian.'
        ).format(component_name, self.name)

    # When validate_args is True, we raise on increased DoF.
    if self._validate_args:
      if increased_dof_:
        raise ValueError(error_message)
      return assert_util.assert_equal(False, increased_dof, error_message)

    if (not tf.executing_eagerly() and
        control_flow_util.GraphOrParentsInXlaContext(tf1.get_default_graph())):
      return  # No StringFormat or Print ops in XLA.

    # Otherwise, we print a warning and continue.
    return ps.cond(
        pred=increased_dof,
        false_fn=tf.no_op,
        true_fn=lambda: tf.print(  # pylint: disable=g-long-lambda
            'WARNING: ' + error_message, output_stream=sys.stderr))
Example #15
0
def _update_inv_hessian(prev_state, next_state):
    """Update the BGFS state by computing the next inverse hessian estimate."""
    # Only update the inverse Hessian if not already failed or converged.
    should_update = ~next_state.converged & ~next_state.failed

    # Compute the normalization term (y^T . s), should not update if is singular.
    gradient_delta = next_state.objective_gradient - prev_state.objective_gradient
    position_delta = next_state.position - prev_state.position
    normalization_factor = tf.reduce_sum(gradient_delta * position_delta,
                                         axis=-1)
    should_update = should_update & ~tf.equal(normalization_factor, 0)

    def _do_update_inv_hessian():
        next_inv_hessian = _bfgs_inv_hessian_update(
            gradient_delta, position_delta, normalization_factor,
            prev_state.inverse_hessian_estimate)
        return bfgs_utils.update_fields(
            next_state,
            inverse_hessian_estimate=tf.where(
                should_update[..., tf.newaxis, tf.newaxis], next_inv_hessian,
                prev_state.inverse_hessian_estimate))

    return ps.cond(ps.reduce_any(should_update), _do_update_inv_hessian,
                   lambda: next_state)
Example #16
0
        def _do_check_shrinkage():
            """Check if interval has shrinked enough."""
            old_width = curr_interval.right.x - curr_interval.left.x
            new_width = secant2_result.right.x - secant2_result.left.x
            sufficient_shrinkage = new_width < old_width * shrinkage_param
            func_is_flat = (
                _very_close(curr_interval.left.f, curr_interval.right.f)
                & _very_close(secant2_result.left.f, secant2_result.right.f))

            new_converged = (should_check_shrinkage & sufficient_shrinkage
                             & func_is_flat)
            needs_inner_bisect = should_check_shrinkage & ~sufficient_shrinkage

            inner_bisect_args = secant2_result._replace(
                converged=secant2_result.converged | new_converged)

            def _apply_inner_bisect():
                return _line_search_inner_bisection(
                    value_and_gradients_function, inner_bisect_args,
                    needs_inner_bisect, f_lim)

            return prefer_static.cond(
                tf.reduce_any(input_tensor=needs_inner_bisect),
                _apply_inner_bisect, lambda: inner_bisect_args)
Example #17
0
        def _loop_body(  # pylint: disable=missing-docstring
                iter_, x_update_diff_norm_sq, x_update, hess_matmul_x_update):
            # Inner loop of the minimizer.
            #
            # This loop updates a single coordinate of x_update.  Ideally, an
            # iteration of this loop would set
            #
            #   x_update[j] += argmin{ LocalLoss(x_update + z*e_j) : z in R }
            #
            # where
            #
            #   LocalLoss(x_update')
            #     = LocalLossSmoothComponent(x_update')
            #         + l1_regularizer * (||x_start + x_update'||_1 -
            #                             ||x_start + x_update||_1)
            #    := (UnregularizedLoss(x_start + x_update') -
            #        UnregularizedLoss(x_start + x_update)
            #         + l2_regularizer * (||x_start + x_update'||_2**2 -
            #                             ||x_start + x_update||_2**2)
            #         + l1_regularizer * (||x_start + x_update'||_1 -
            #                             ||x_start + x_update||_1)
            #
            # In this algorithm approximate the above argmin using (univariate)
            # proximal gradient descent:
            #
            # (*)  x_update[j] = prox_{t * l1_regularizer * L1}(
            #                 x_update[j] -
            #                 t * d/dz|z=0 UnivariateLocalLossSmoothComponent(z))
            #
            # where
            #
            #   UnivariateLocalLossSmoothComponent(z)
            #       := LocalLossSmoothComponent(x_update + z*e_j)
            #
            # and we approximate
            #
            #       d/dz UnivariateLocalLossSmoothComponent(z)
            #     = grad LocalLossSmoothComponent(x_update))[j]
            #    ~= (grad LossSmoothComponent(x_start)
            #         + x_update matmul HessianOfLossSmoothComponent(x_start))[j].
            #
            # To choose the parameter t, we squint and pretend that the inner term of
            # (*) is a Newton update as if we were using Newton's method to minimize
            # UnivariateLocalLossSmoothComponent.  That is, we choose t such that
            #
            #   -t * d/dz ULLSC = -learning_rate * (d/dz ULLSC) / (d^2/dz^2 ULLSC)
            #
            # at z=0.  Hence
            #
            #   t = learning_rate / (d^2/dz^2|z=0 ULLSC)
            #     = learning_rate / HessianOfLossSmoothComponent(
            #                           x_start + x_update)[j,j]
            #    ~= learning_rate / HessianOfLossSmoothComponent(
            #                           x_start)[j,j]
            #
            # The above approximation is equivalent to assuming that
            # HessianOfUnregularizedLoss is constant, i.e., ignoring third-order
            # effects.
            #
            # Note that because LossSmoothComponent is (assumed to be) convex, t is
            # positive.

            # In above notation, coord = j.
            coord = iter_ % dims
            # x_update_diff_norm_sq := ||x_update_end - x_update_start||_2**2,
            # computed incrementally, where x_update_end and x_update_start are as
            # defined in the convergence criteria.  Accordingly, we reset
            # x_update_diff_norm_sq to zero at the beginning of each sweep.
            x_update_diff_norm_sq = tf.where(
                tf.equal(coord, 0), tf.zeros_like(x_update_diff_norm_sq),
                x_update_diff_norm_sq)

            # Recall that x_update and hess_matmul_x_update has the rightmost
            # dimension transposed to the leftmost dimension.
            w_old = x_start[..., coord] + x_update[coord, ...]
            # This is the coordinatewise Newton update if no L1 regularization.
            # In above notation, newton_step = -t * (approximation of d/dz|z=0 ULLSC).
            second_deriv = _hessian_diag_elt_with_l2(coord)
            newton_step = -_mul_ignoring_nones(  # pylint: disable=invalid-unary-operand-type
                learning_rate, grad_loss_with_l2[..., coord] +
                hess_matmul_x_update[coord, ...]) / second_deriv

            # Applying the soft-threshold operator accounts for L1 regularization.
            # In above notation, delta =
            #     prox_{t*l1_regularizer*L1}(w_old + newton_step) - w_old.
            delta = (soft_threshold(
                w_old + newton_step,
                _mul_ignoring_nones(learning_rate, l1_regularizer) /
                second_deriv) - w_old)

            def _do_update(x_update_diff_norm_sq, x_update,
                           hess_matmul_x_update):  # pylint: disable=missing-docstring
                hessian_column_with_l2 = sparse_or_dense_matvecmul(
                    hessian_unregularized_loss_outer,
                    hessian_unregularized_loss_middle *
                    _sparse_or_dense_matmul_onehot(
                        hessian_unregularized_loss_outer, coord),
                    adjoint_a=True)

                if l2_regularizer is not None:
                    hessian_column_with_l2 += _one_hot_like(
                        hessian_column_with_l2,
                        coord,
                        on_value=2. * l2_regularizer)

                # Move the batch dimensions of `hessian_column_with_l2` to rightmost in
                # order to conform to `hess_matmul_x_update`.
                n = tf.rank(hessian_column_with_l2)
                perm = tf.roll(tf.range(n), shift=1, axis=0)
                hessian_column_with_l2 = tf.transpose(a=hessian_column_with_l2,
                                                      perm=perm)

                # Update the entire batch at `coord` even if `delta` may be 0 at some
                # batch coordinates. In those cases, adding `delta` is a no-op.
                x_update = tf.tensor_scatter_add(x_update, [[coord]], [delta])

                with tf.control_dependencies([x_update]):
                    x_update_diff_norm_sq_ = x_update_diff_norm_sq + delta**2
                    hess_matmul_x_update_ = (hess_matmul_x_update +
                                             delta * hessian_column_with_l2)

                    # Hint that loop vars retain the same shape.
                    x_update_diff_norm_sq_.set_shape(
                        x_update_diff_norm_sq_.shape.merge_with(
                            x_update_diff_norm_sq.shape))
                    hess_matmul_x_update_.set_shape(
                        hess_matmul_x_update_.shape.merge_with(
                            hess_matmul_x_update.shape))

                    return [
                        x_update_diff_norm_sq_, x_update, hess_matmul_x_update_
                    ]

            inputs_to_update = [
                x_update_diff_norm_sq, x_update, hess_matmul_x_update
            ]
            return [iter_ + 1] + prefer_static.cond(
                # Note on why checking delta (a difference of floats) for equality to
                # zero is ok:
                #
                # First of all, x - x == 0 in floating point -- see
                # https://stackoverflow.com/a/2686671
                #
                # Delta will conceptually equal zero when one of the following holds:
                # (i)   |w_old + newton_step| <= threshold and w_old == 0
                # (ii)  |w_old + newton_step| > threshold and
                #       w_old + newton_step - sign(w_old + newton_step) * threshold
                #          == w_old
                #
                # In case (i) comparing delta to zero is fine.
                #
                # In case (ii), newton_step conceptually equals
                #     sign(w_old + newton_step) * threshold.
                # Also remember
                #     threshold = -newton_step / (approximation of d/dz|z=0 ULLSC).
                # So (i) happens when
                #     (approximation of d/dz|z=0 ULLSC) == -sign(w_old + newton_step).
                # If we did not require LossSmoothComponent to be strictly convex,
                # then this could actually happen a non-negligible amount of the time,
                # e.g. if the loss function is piecewise linear and one of the pieces
                # has slope 1.  But since LossSmoothComponent is strictly convex, (i)
                # should not systematically happen.
                tf.reduce_all(input_tensor=tf.equal(delta, 0.)),
                lambda: inputs_to_update,
                lambda: _do_update(*inputs_to_update))
  def _sample_n(self, n, seed=None):
    init_seed, scan_seed, observation_seed = samplers.split_seed(
        seed, n=3, salt='HiddenMarkovModel')

    transition_batch_shape = self.transition_distribution.batch_shape_tensor()
    num_states = transition_batch_shape[-1]

    batch_shape = self.batch_shape_tensor()
    batch_size = ps.reduce_prod(batch_shape)
    # The batch sizes of the underlying initial distributions and
    # transition distributions might not match the batch size of
    # the HMM distribution.
    # As a result we need to ask for more samples from the
    # underlying distributions and then reshape the results into
    # the correct batch size for the HMM.
    init_repeat = (
        ps.reduce_prod(batch_shape) //
        ps.reduce_prod(self._initial_distribution.batch_shape_tensor()))
    init_state = self._initial_distribution.sample(n * init_repeat,
                                                   seed=init_seed)
    init_state = tf.reshape(init_state, [n, batch_size])
    # init_state :: n batch_size

    transition_repeat = (
        ps.reduce_prod(batch_shape) // ps.reduce_prod(
            transition_batch_shape[:-1]))

    init_shape = init_state.shape

    def generate_step(state_and_seed, _):
      """Take a single step in Markov chain."""
      state, seed = state_and_seed
      sample_seed, next_seed = samplers.split_seed(seed)

      gen = self._transition_distribution.sample(n * transition_repeat,
                                                 seed=sample_seed)
      # gen :: (n * transition_repeat) transition_batch

      new_states = tf.reshape(gen,
                              [n, batch_size, num_states])

      # new_states :: n batch_size num_states

      old_states_one_hot = tf.one_hot(state, num_states, dtype=tf.int32)

      # old_states :: n batch_size num_states

      result = tf.reduce_sum(old_states_one_hot * new_states, axis=-1)
      # We know that `generate_step` must preserve the shape of the
      # tensor of states of each state. This is because
      # the transition matrix must be square. But TensorFlow might
      # not know this so we explicitly tell it that the result has the
      # same shape.
      tensorshape_util.set_shape(result, init_shape)
      return result, next_seed

    def _scan_multiple_steps():
      """Take multiple steps with tf.scan."""
      dummy_index = tf.zeros(self._num_steps - 1, dtype=tf.float32)
      hidden_states, _ = tf.scan(generate_step, dummy_index,
                                 initializer=(init_state, scan_seed))

      # TODO(b/115618503): add/use prepend_initializer to tf.scan
      return tf.concat([[init_state],
                        hidden_states], axis=0)
    hidden_states = ps.cond(
        self._num_steps > 1,
        _scan_multiple_steps,
        lambda: init_state[tf.newaxis, ...])

    hidden_one_hot = tf.one_hot(hidden_states, num_states,
                                dtype=self._observation_distribution.dtype)
    # hidden_one_hot :: num_steps n batch_size num_states

    # The observation distribution batch size might not match
    # the required batch size so as with the initial and
    # transition distributions we generate more samples and
    # reshape.
    observation_repeat = tf.maximum(
        batch_size // ps.reduce_prod(
            self._observation_distribution.batch_shape_tensor()[:-1]),
        1)

    if self._time_varying_observation_distribution:
      possible_observations = self._observation_distribution.sample(
          [observation_repeat * n], seed=observation_seed)
      # possible observations needs to have num_steps moved to the beginning.
      possible_observations = distribution_util.move_dimension(
          possible_observations,
          -(tf.size(self._observation_distribution.event_shape_tensor()) + 2),
          0)
    else:
      possible_observations = self._observation_distribution.sample(
          [self._num_steps, observation_repeat * n], seed=observation_seed)

    inner_shape = self._observation_distribution.event_shape_tensor()

    # possible_observations :: num_steps (observation_repeat * n)
    #                          observation_batch[:-1] num_states inner_shape

    possible_observations = tf.reshape(
        possible_observations,
        ps.concat([[self._num_steps, n],
                   batch_shape,
                   [num_states],
                   inner_shape], axis=0))

    # possible_observations :: steps n batch_size num_states inner_shape

    hidden_one_hot = tf.reshape(hidden_one_hot,
                                ps.concat([[self._num_steps, n],
                                           batch_shape,
                                           [num_states],
                                           ps.ones_like(inner_shape)],
                                          axis=0))

    # hidden_one_hot :: steps n batch_size num_states "inner_shape"

    observations = tf.reduce_sum(
        hidden_one_hot * possible_observations,
        axis=-1 - ps.size(inner_shape))
    # observations :: steps n batch_size inner_shape

    observations = distribution_util.move_dimension(observations, 0,
                                                    1 + ps.size(batch_shape))
    # returned :: n batch_shape steps inner_shape

    return observations
Example #19
0
    def posterior_mode(self, observations, mask=None, name=None):
        """Compute maximum likelihood sequence of hidden states.

    When this function is provided with a sequence of observations
    `x[0], ..., x[num_steps - 1]`, it returns the sequence of hidden
    states `z[0], ..., z[num_steps - 1]`, drawn from the underlying
    Markov chain, that is most likely to yield those observations.

    It uses the [Viterbi algorithm](
    https://en.wikipedia.org/wiki/Viterbi_algorithm).

    Note: the behavior of this function is undefined if the
    `observations` argument represents impossible observations
    from the model.

    Note: if there isn't a unique most likely sequence then one
    of the equally most likely sequences is chosen.

    Args:
      observations: A tensor representing a batch of observations made on the
        hidden Markov model.  The rightmost dimensions of this tensor correspond
        to the dimensions of the observation distributions of the underlying
        Markov chain.  The next dimension from the right indexes the steps in a
        sequence of observations from a single sample from the hidden Markov
        model.  The size of this dimension should match the `num_steps`
        parameter of the hidden Markov model object.  The other dimensions are
        the dimensions of the batch and these are broadcast with the hidden
        Markov model's parameters.
      mask: optional bool-type `tensor` with rightmost dimension matching
        `num_steps` indicating which observations the result of this
        function should be conditioned on. When the mask has value
        `True` the corresponding observations aren't used.
        if `mask` is `None` then all of the observations are used.
        the `mask` dimensions left of the last are broadcast with the
        hmm batch as well as with the observations.
      name: Python `str` name prefixed to Ops created by this class.
        Default value: "HiddenMarkovModel".

    Returns:
      posterior_mode: A `Tensor` representing the most likely sequence of hidden
        states. The rightmost dimension of this tensor will equal the
        `num_steps` parameter providing one hidden state for each step. The
        other dimensions are those of the batch.

    Raises:
      ValueError: if the `observations` tensor does not consist of
      sequences of `num_steps` observations.

    #### Examples

    ```python
    tfd = tfp.distributions

    # A simple weather model.

    # Represent a cold day with 0 and a hot day with 1.
    # Suppose the first day of a sequence has a 0.8 chance of being cold.

    initial_distribution = tfd.Categorical(probs=[0.8, 0.2])

    # Suppose a cold day has a 30% chance of being followed by a hot day
    # and a hot day has a 20% chance of being followed by a cold day.

    transition_distribution = tfd.Categorical(probs=[[0.7, 0.3],
                                                     [0.2, 0.8]])

    # Suppose additionally that on each day the temperature is
    # normally distributed with mean and standard deviation 0 and 5 on
    # a cold day and mean and standard deviation 15 and 10 on a hot day.

    observation_distribution = tfd.Normal(loc=[0., 15.], scale=[5., 10.])

    # This gives the hidden Markov model:

    model = tfd.HiddenMarkovModel(
        initial_distribution=initial_distribution,
        transition_distribution=transition_distribution,
        observation_distribution=observation_distribution,
        num_steps=7)

    # Suppose we observe gradually rising temperatures over a week:
    temps = [-2., 0., 2., 4., 6., 8., 10.]

    # We can now compute the most probable sequence of hidden states:

    model.posterior_mode(temps)

    # The result is [0 0 0 0 0 1 1] telling us that the transition
    # from "cold" to "hot" most likely happened between the
    # 5th and 6th days.
    ```
    """

        with tf.name_scope(name or "posterior_mode"):
            observations = tf.convert_to_tensor(observations,
                                                name="observations")
            if mask is not None:
                mask = tf.convert_to_tensor(mask,
                                            name="mask",
                                            dtype_hint=tf.bool)
            with tf.control_dependencies(self._runtime_assertions):
                observation_tensor_shape = tf.shape(observations)
                mask_tensor_shape = tf.shape(
                    mask) if mask is not None else None

                with self._observation_mask_shape_preconditions(
                        observation_tensor_shape, mask_tensor_shape):
                    observation_log_probs = self._observation_log_probs(
                        observations, mask)
                    log_prob = self._log_init + observation_log_probs[0]

                    def _reduce_multiple_steps():
                        """Perform `reduce_max` operation when `num_steps` > 1."""
                        def forward_step(previous_step_pair,
                                         log_prob_observation):
                            log_prob_previous = previous_step_pair[0]
                            log_prob = (
                                log_prob_previous[..., tf.newaxis] +
                                self._log_trans +
                                log_prob_observation[..., tf.newaxis, :])
                            most_likely_given_successor = tf.argmax(log_prob,
                                                                    axis=-2)
                            max_log_p_given_successor = tf.reduce_max(log_prob,
                                                                      axis=-2)
                            return (max_log_p_given_successor,
                                    most_likely_given_successor)

                        forward_log_probs, all_most_likely_given_successor = tf.scan(
                            forward_step,
                            observation_log_probs[1:],
                            initializer=(log_prob,
                                         tf.zeros(tf.shape(log_prob),
                                                  dtype=tf.int64)),
                            name="forward_log_probs")

                        most_likely_end = tf.argmax(forward_log_probs[-1],
                                                    axis=-1)

                        # We require the operation that gives C from A and B where
                        # C[i...j] = A[i...j, B[i...j]]
                        # and A = most_likely_given_successor
                        #     B = most_likely_successor.
                        # tf.gather requires indices of known shape so instead we use
                        # reduction with tf.one_hot(B) to pick out elements from B
                        def backward_step(most_likely_successor,
                                          most_likely_given_successor):
                            return tf.reduce_sum(
                                (most_likely_given_successor *
                                 tf.one_hot(most_likely_successor,
                                            self._num_states,
                                            dtype=tf.int64)),
                                axis=-1)

                        backward_scan = tf.scan(
                            backward_step,
                            all_most_likely_given_successor,
                            most_likely_end,
                            reverse=True)
                        most_likely_sequences = tf.concat(
                            [backward_scan, [most_likely_end]], axis=0)
                        return distribution_util.move_dimension(
                            most_likely_sequences, 0, -1)

                    return prefer_static.cond(
                        self.num_steps > 1, _reduce_multiple_steps,
                        lambda: tf.argmax(log_prob, axis=-1)[..., tf.newaxis])
Example #20
0
    def _sample_n(self, n, seed=None):
        with tf.control_dependencies(self._runtime_assertions):
            strm = SeedStream(seed, salt="HiddenMarkovModel")

            num_states = self._num_states

            batch_shape = self.batch_shape_tensor()
            batch_size = tf.reduce_prod(batch_shape)

            # The batch sizes of the underlying initial distributions and
            # transition distributions might not match the batch size of
            # the HMM distribution.
            # As a result we need to ask for more samples from the
            # underlying distributions and then reshape the results into
            # the correct batch size for the HMM.
            init_repeat = (
                tf.reduce_prod(self.batch_shape_tensor()) // tf.reduce_prod(
                    self._initial_distribution.batch_shape_tensor()))
            init_state = self._initial_distribution.sample(n * init_repeat,
                                                           seed=strm())
            init_state = tf.reshape(init_state, [n, batch_size])
            # init_state :: n batch_size

            transition_repeat = (
                tf.reduce_prod(self.batch_shape_tensor()) // tf.reduce_prod(
                    self._transition_distribution.batch_shape_tensor()[:-1]))

            def generate_step(state, _):
                """Take a single step in Markov chain."""

                gen = self._transition_distribution.sample(n *
                                                           transition_repeat,
                                                           seed=strm())
                # gen :: (n * transition_repeat) transition_batch

                new_states = tf.reshape(gen, [n, batch_size, num_states])

                # new_states :: n batch_size num_states

                old_states_one_hot = tf.one_hot(state,
                                                num_states,
                                                dtype=tf.int32)

                # old_states :: n batch_size num_states

                return tf.reduce_sum(old_states_one_hot * new_states, axis=-1)

            def _scan_multiple_steps():
                """Take multiple steps with tf.scan."""
                dummy_index = tf.zeros(self._num_steps - 1, dtype=tf.float32)
                if seed is not None:
                    # Force parallel_iterations to 1 to ensure reproducibility
                    # b/139210489
                    hidden_states = tf.scan(generate_step,
                                            dummy_index,
                                            initializer=init_state,
                                            parallel_iterations=1)
                else:
                    # Invoke default parallel_iterations behavior
                    hidden_states = tf.scan(generate_step,
                                            dummy_index,
                                            initializer=init_state)

                # TODO(b/115618503): add/use prepend_initializer to tf.scan
                return tf.concat([[init_state], hidden_states], axis=0)

            hidden_states = prefer_static.cond(
                self._num_steps > 1, _scan_multiple_steps,
                lambda: init_state[tf.newaxis, ...])

            hidden_one_hot = tf.one_hot(
                hidden_states,
                num_states,
                dtype=self._observation_distribution.dtype)
            # hidden_one_hot :: num_steps n batch_size num_states

            # The observation distribution batch size might not match
            # the required batch size so as with the initial and
            # transition distributions we generate more samples and
            # reshape.
            observation_repeat = (batch_size // tf.reduce_prod(
                self._observation_distribution.batch_shape_tensor()[:-1]))

            possible_observations = self._observation_distribution.sample(
                [self._num_steps, observation_repeat * n], seed=strm())

            inner_shape = self._observation_distribution.event_shape

            # possible_observations :: num_steps (observation_repeat * n)
            #                          observation_batch[:-1] num_states inner_shape

            possible_observations = tf.reshape(
                possible_observations,
                tf.concat([[self._num_steps, n], batch_shape, [num_states],
                           inner_shape],
                          axis=0))

            # possible_observations :: steps n batch_size num_states inner_shape

            hidden_one_hot = tf.reshape(
                hidden_one_hot,
                tf.concat([[self._num_steps, n], batch_shape, [num_states],
                           tf.ones_like(inner_shape)],
                          axis=0))

            # hidden_one_hot :: steps n batch_size num_states "inner_shape"

            observations = tf.reduce_sum(hidden_one_hot *
                                         possible_observations,
                                         axis=-1 - tf.size(inner_shape))

            # observations :: steps n batch_size inner_shape

            observations = distribution_util.move_dimension(
                observations, 0, 1 + tf.size(batch_shape))

            # returned :: n batch_shape steps inner_shape

            return observations
Example #21
0
def _bracket_and_search(value_and_gradients_function,
                        val_0,
                        val_c,
                        f_lim,
                        max_iterations,
                        shrinkage_param=None,
                        expansion_param=None,
                        sufficient_decrease_param=None,
                        curvature_param=None):
    """Brackets the minimum and performs a line search.

  Args:
    value_and_gradients_function: A Python callable that accepts a real scalar
      tensor and returns an object that can be converted to a namedtuple.
      The namedtuple should have fields 'f' and 'df' that correspond to scalar
      tensors of real dtype containing the value of the function and its
      derivative at that point. The other namedtuple fields, if present,
      should be tensors or sequences (possibly nested) of tensors.
      In usual optimization application, this function would be generated by
      projecting the multivariate objective function along some specific
      direction. The direction is determined by some other procedure but should
      be a descent direction (i.e. the derivative of the projected univariate
      function must be negative at 0.).
    val_0: Instance of `_FnDFn` containing the value and gradient of the
      objective at 0. The gradient must be negative (i.e. must be a descent
      direction).
    val_c: Instance of `_FnDFn` containing the initial step size and the value
      and gradient of the objective at the initial step size. The step size
      must be positive and finite.
    f_lim: Scalar `Tensor` of float dtype.
    max_iterations: Positive scalar `Tensor` of integral dtype. The maximum
      number of iterations to perform in the line search. The number of
      iterations used to bracket the minimum are also counted against this
      parameter.
    shrinkage_param: Scalar positive Tensor of real dtype. Must be less than
      `1.`. Corresponds to the parameter `gamma` in [Hager and Zhang (2006)][2].
    expansion_param: Scalar positive `Tensor` of real dtype. Must be greater
      than `1.`. Used to expand the initial interval in case it does not bracket
      a minimum. Corresponds to `rho` in [Hager and Zhang (2006)][2].
    sufficient_decrease_param: Positive scalar `Tensor` of real dtype.
      Bounded above by the curvature param. Corresponds to `delta` in the
      terminology of [Hager and Zhang (2006)][2].
    curvature_param: Positive scalar `Tensor` of real dtype. Bounded above
      by `1.`. Corresponds to 'sigma' in the terminology of
      [Hager and Zhang (2006)][2].

  Returns:
    A namedtuple containing the following fields.
      iteration: A scalar int32 `Tensor`. The number of iterations consumed.
      found_wolfe: A scalar boolean `Tensor`. Indicates whether a point
        satisfying the Wolfe conditions has been found. If this is True, the
        interval will be degenerate (i.e. left and right below
        will be identical).
      failed: A scalar boolean `Tensor`. Indicates if invalid function or
        gradient values were encountered (i.e. infinity or NaNs).
      num_evals: A scalar int32 `Tensor`. The total number of function
        evaluations made.
      left: Instance of _FnDFn. The position and the associated value and
        derivative at the updated left end point of the interval.
      right: Instance of _FnDFn. The position and the associated value and
        derivative at the updated right end point of the interval.
  """
    bracket_result = hzl.bracket(value_and_gradients_function,
                                 val_0,
                                 val_c,
                                 f_lim,
                                 max_iterations,
                                 expansion_param=expansion_param)

    # If the bracketing failed, or we have already exhausted all the allowed
    # iterations, we return an error.
    failed = (bracket_result.failed
              | tf.greater_equal(bracket_result.iteration, max_iterations))

    def _bracketing_failed_fn():
        return _LineSearchInnerResult(iteration=bracket_result.iteration,
                                      found_wolfe=False,
                                      failed=True,
                                      num_evals=bracket_result.num_evals,
                                      left=val_0,
                                      right=val_c)

    def _bracketing_success_fn():
        """Performs line search."""
        result = _line_search_after_bracketing(
            value_and_gradients_function,
            val_0,
            bracket_result.left,
            bracket_result.right,
            f_lim,
            bracket_result.iteration,
            max_iterations,
            sufficient_decrease_param=sufficient_decrease_param,
            curvature_param=curvature_param,
            shrinkage_param=shrinkage_param)

        return _LineSearchInnerResult(iteration=result.iteration,
                                      found_wolfe=result.found_wolfe,
                                      failed=result.failed,
                                      num_evals=bracket_result.num_evals +
                                      result.num_evals,
                                      left=result.left,
                                      right=result.right)

    return prefer_static.cond(failed,
                              true_fn=_bracketing_failed_fn,
                              false_fn=_bracketing_success_fn)
Example #22
0
        def _valid_inputs_fn():
            """Performs bracketing and line search if inputs are valid."""
            # If the value or the gradient at the supplied step is not finite,
            # we attempt to repair it.
            step_size_too_large = ~(tf.math.is_finite(val_c_input.df)
                                    & tf.math.is_finite(val_c_input.f))

            def _is_too_large_fn():
                return _fix_step_size(value_and_gradients_function,
                                      val_c_input, step_size_shrink_param)

            val_c, fix_evals = prefer_static.cond(step_size_too_large,
                                                  _is_too_large_fn, lambda:
                                                  (val_c_input, 0))

            # Check if c is fixed now.
            valid_at_c = hzl.is_finite(val_c) & (val_c.x > 0)

            def _failure_fn():
                # If c is still not good, just return 0.
                return HagerZhangLineSearchResult(
                    converged=tf.convert_to_tensor(value=True,
                                                   name='converged'),
                    failed=tf.convert_to_tensor(value=False, name='failed'),
                    func_evals=prepare_evals + fix_evals,
                    iterations=tf.convert_to_tensor(value=0),
                    left_pt=val_0.x,
                    objective_at_left_pt=val_0.f,
                    grad_objective_at_left_pt=val_0.df,
                    right_pt=val_0.x,
                    objective_at_right_pt=val_0.f,
                    grad_objective_at_right_pt=val_0.df,
                    full_result=val_0.full_result)

            def success_fn():
                """Bracketing and searching to do if all inputs are valid."""
                result = _bracket_and_search(
                    value_and_gradients_function,
                    val_0,
                    val_c,
                    f_lim,
                    max_iterations,
                    shrinkage_param=shrinkage_param,
                    expansion_param=expansion_param,
                    sufficient_decrease_param=sufficient_decrease_param,
                    curvature_param=curvature_param)
                converged = tf.convert_to_tensor(value=result.found_wolfe,
                                                 name='converged')
                return HagerZhangLineSearchResult(
                    converged=converged,
                    failed=tf.convert_to_tensor(value=result.failed,
                                                name='failed'),
                    func_evals=result.num_evals + prepare_evals + fix_evals,
                    iterations=result.iteration,
                    left_pt=result.left.x,
                    objective_at_left_pt=result.left.f,
                    grad_objective_at_left_pt=result.left.df,
                    right_pt=result.right.x,
                    objective_at_right_pt=result.right.f,
                    grad_objective_at_right_pt=result.right.df,
                    full_result=result.left.full_result)

            return prefer_static.cond(valid_at_c,
                                      true_fn=success_fn,
                                      false_fn=_failure_fn)
def scan_associative(fn,
                     elems,
                     max_num_levels=48,
                     validate_args=False,
                     name=None):
    """Perform a scan with an associative binary operation, in parallel.

  The associative scan operation computes the cumulative sum, or
  [all-prefix sum](https://en.wikipedia.org/wiki/Prefix_sum), of a set of
  elements under an associative binary operation [1]. For example, using the
  ordinary addition operator `fn = lambda a, b: a + b`, this is equivalent to
  the ordinary cumulative sum `tf.math.cumsum` along axis 0. This method
  supports the general case of arbitrary associative binary operations operating
  on `Tensor`s or structures of `Tensor`s:

  ```python
  associative_scan(fn, elems) = tf.stack([
    elems[0],
    fn(elems[0], elems[1]),
    fn(elems[0], fn(elems[1], elems[2])),
    ...
    fn(elems[0], fn(elems[1], fn(..., fn(elems[-2], elems[-1]))),
  ], axis=0)
  ```

  The associative structure allows the computation to be decomposed
  and executed by parallel reduction. Where a naive sequential
  implementation would loop over all `N` elements, this method requires
  only a logarithmic number (`2 * ceil(log_2 N)`) of sequential steps, and
  can thus yield substantial performance speedups from hardware-accelerated
  vectorization. The total number of invocations of the binary operation
  (including those performed in parallel) is
  `2 * (N / 2 + N / 4 + ... + 1) = 2N - 2`
  --- i.e., approximately twice as many as a naive approach.

  [1] Blelloch, Guy E.
      [Prefix sums and their applications](
      https://www.cs.cmu.edu/~guyb/papers/Ble93.pdf)
      Technical Report CMU-CS-90-190,
      School of Computer Science,
      Carnegie Mellon University, 1990.

  Args:
    fn: Python callable implementing an associative binary operation with
      signature `r = fn(a, b)`. This must satisfy associativity:
      `fn(a, fn(b, c)) == fn(fn(a, b), c)`. The inputs and result are
      (possibly nested structures of) `Tensor`(s), matching `elems`. Each
      `Tensor` has a leading batch dimension in place of `elem_length`; the `fn`
      is expected to map over this dimension. The result `r` has the same shape
      (and structure) as the two inputs `a` and `b`.
    elems: A (possibly nested structure of) `Tensor`(s), each with leading
      dimension `elem_length`. Note that `elem_length` determines the number
      of recursive steps required to perform the scan: if, in graph mode,
      this is not statically available, then ops will be created to
      handle any `elem_length` up to the maximum dimension of a `Tensor`.
    max_num_levels: Python `int`. The size
      of the first dimension of the tensors in `elems` must be less than
      `2**(max_num_levels + 1)`. The default value is sufficiently large
      for most needs. Lowering this value can reduce graph-building time when
      `scan_associative` is used with inputs of unknown shape.
      Default value: `48`.
    validate_args: Python `bool`. When `True`, runtime checks
      for invalid inputs are performed. This may carry a performance cost.
      Default value: `False`.
    name: Python `str` name prefixed to ops created by this function.
  Returns:
    result: A (possibly nested structure of) `Tensor`(s) of the same shape
      and structure as `elems`, in which the `k`th element is the result of
      recursively applying `fn` to combine the first `k` elements of
      `elems`. For example, given `elems = [a, b, c, ...]`, the result
      would be `[a, fn(a, b), fn(fn(a, b), c), ...]`.

  #### Examples

  ```python
  import tensorflow as tf
  import tensorflow_probability as tfp
  import operator

  # Example 1: Partials sums of numbers.

  tfp.math.associative_scan(operator.add, tf.range(0, 4))
  # ==> [ 0, 1, 3, 6]

  # Example 2: Partial products of random matrices.

  dist = tfp.distributions.Normal(loc=0., scale=1.)
  matrices = dist.sample(sample_shape=[100, 2, 2])
  tfp.math.associative_scan(tf.matmul, matrices)
  ```
  """
    def lowered_fn(a, b):
        # Lower `fn` to operate on flattened sequences of elems.
        with tf.name_scope('fn'):
            return tf.nest.flatten(
                fn(tf.nest.pack_sequence_as(elems, a),
                   tf.nest.pack_sequence_as(elems, b)))

    elems_flat = [
        tf.convert_to_tensor(elem) for elem in tf.nest.flatten(elems)
    ]

    # Summary of algorithm:
    #
    # Consider elements of `_scan(elems)` at odd indices. That's the same as first
    # summing successive pairs of elements of `elems` and performing a scan on
    # that half sized tensor. We perform the latter scan by recursion.
    #
    # Now consider the even elements of `_scan(elems)`. These can be computed
    # from the odd elements of `_scan(elems)` by adding each odd element of
    # `_scan(elems)` to the matching even element in the original `elems`.
    #
    # We return the odd and even elements interleaved.
    #
    # For the base case of the recursion we return the first element
    # of `elems` followed by the sum of the first two elements computed as
    # a (small two-down-to-one) reduction step.

    # The following is a pictorial representation of the algorithm using the
    # variables in the code below. The operator '+' is used to represent
    # the binary operation.
    # Note how the recursive call to `_scan` operates on a reduced form of the
    # input array in which successive pairs have already been summed.

    # elems     x0         x1   x2         x3    x4         x5    ...
    #           |\         /    | \        /     | \        /
    #           | \       /     |  \      /      |  \      /
    #           |  \     /      |   \    /       |   \    /
    #           |   \   /       |    \  /        |    \  /
    # reduced   |  x0+x1        |   x2+x3        |    x4+x5       ...
    # _elems    |    |          |     |          |       |
    #           |    |          |     |          |       |
    #           |    |          |     |          |       |
    # _scan(..) |    |          |     |          |       |
    #        +--|----+----------|-----+----------|-------+----    ...
    #        |  |               |                |
    #        |  |               |                |
    #        +--|----+----------|-----+----------|-------+----    ...
    #           |    |          |     |          |       |
    # odd       |  x0+x1        |   x0+...+x3    |     x0+..+x5   ...
    # _elems    |    | \        |     |      \   |       |
    #           |    |  \       |     |       \  |       |
    # even      |    |   \      |     |        \ |       |
    # _elems    x0   |   x0+...+x2    |       x0+...+x4  |        ...
    #           |    |          |     |          |       |
    # inter     |    |          |     |          |       |
    # leave(..) |    |          |     |          |       |
    #           x0 x0+x1 x0+...+x2  x0+...+x3 x0+...+x4 x0+...+x5 ...

    # TODO(b/150374456): if the sizes of all of the tensors can be determined
    # statically then we don't need a `level` parameter.
    def _scan(level, elems):
        """Perform scan on `elems`."""
        elem_length = prefer_static.shape(elems[0])[0]

        # Apply `fn` to reduce adjacent pairs to a single entry.
        a = [elem[0:-1:2] for elem in elems]
        b = [elem[1::2] for elem in elems]
        reduced_elems = lowered_fn(a, b)

        def handle_base_case_elem_length_two():
            return [
                tf.concat([elem[0:1], reduced_elem], axis=0)
                for (reduced_elem, elem) in zip(reduced_elems, elems)
            ]

        def handle_base_case_elem_length_three():
            reduced_reduced_elems = lowered_fn(reduced_elems,
                                               [elem[2:3] for elem in elems])
            return [
                tf.concat([elem[0:1], reduced_elem, reduced_reduced_elem],
                          axis=0)
                for (reduced_reduced_elem, reduced_elem,
                     elem) in zip(reduced_reduced_elems, reduced_elems, elems)
            ]

        # Base case of recursion: assumes `elem_length` is 2 or 3.
        at_base_case = prefer_static.logical_or(
            prefer_static.equal(elem_length, 2),
            prefer_static.equal(elem_length, 3))
        base_value = lambda: prefer_static.cond(  # pylint: disable=g-long-lambda
            prefer_static.equal(elem_length, 2
                                ), handle_base_case_elem_length_two,
            handle_base_case_elem_length_three)

        if level <= 0:
            return base_value()

        def recursive_case():
            """Evaluate the next step of the recursion."""
            odd_elems = _scan(level - 1, reduced_elems)

            def even_length_case():
                return lowered_fn([odd_elem[:-1] for odd_elem in odd_elems],
                                  [elem[2::2] for elem in elems])

            def odd_length_case():
                return lowered_fn([odd_elem for odd_elem in odd_elems],
                                  [elem[2::2] for elem in elems])

            results = prefer_static.cond(
                prefer_static.equal(elem_length % 2, 0), even_length_case,
                odd_length_case)

            # The first element of a scan is the same as the first element
            # of the original `elems`.
            even_elems = [
                tf.concat([elem[0:1], result], axis=0)
                for (elem, result) in zip(elems, results)
            ]
            return list(map(_interleave, even_elems, odd_elems))

        return prefer_static.cond(at_base_case, base_value, recursive_case)

    with tf.name_scope(name if name else 'scan_associative'):
        elem_length, assertions = _validate_elem_length(
            max_num_levels, elems_flat)

    with tf.control_dependencies(assertions if validate_args else []):
        return prefer_static.cond(
            elem_length < 2,
            lambda: elems,
            lambda: (
                tf.nest.pack_sequence_as(  # pylint: disable=g-long-lambda
                    elems, _scan(max_num_levels - 1, elems_flat))))
    def _scan(level, elems):
        """Perform scan on `elems`."""
        elem_length = prefer_static.shape(elems[0])[0]

        # Apply `fn` to reduce adjacent pairs to a single entry.
        a = [elem[0:-1:2] for elem in elems]
        b = [elem[1::2] for elem in elems]
        reduced_elems = lowered_fn(a, b)

        def handle_base_case_elem_length_two():
            return [
                tf.concat([elem[0:1], reduced_elem], axis=0)
                for (reduced_elem, elem) in zip(reduced_elems, elems)
            ]

        def handle_base_case_elem_length_three():
            reduced_reduced_elems = lowered_fn(reduced_elems,
                                               [elem[2:3] for elem in elems])
            return [
                tf.concat([elem[0:1], reduced_elem, reduced_reduced_elem],
                          axis=0)
                for (reduced_reduced_elem, reduced_elem,
                     elem) in zip(reduced_reduced_elems, reduced_elems, elems)
            ]

        # Base case of recursion: assumes `elem_length` is 2 or 3.
        at_base_case = prefer_static.logical_or(
            prefer_static.equal(elem_length, 2),
            prefer_static.equal(elem_length, 3))
        base_value = lambda: prefer_static.cond(  # pylint: disable=g-long-lambda
            prefer_static.equal(elem_length, 2
                                ), handle_base_case_elem_length_two,
            handle_base_case_elem_length_three)

        if level <= 0:
            return base_value()

        def recursive_case():
            """Evaluate the next step of the recursion."""
            odd_elems = _scan(level - 1, reduced_elems)

            def even_length_case():
                return lowered_fn([odd_elem[:-1] for odd_elem in odd_elems],
                                  [elem[2::2] for elem in elems])

            def odd_length_case():
                return lowered_fn([odd_elem for odd_elem in odd_elems],
                                  [elem[2::2] for elem in elems])

            results = prefer_static.cond(
                prefer_static.equal(elem_length % 2, 0), even_length_case,
                odd_length_case)

            # The first element of a scan is the same as the first element
            # of the original `elems`.
            even_elems = [
                tf.concat([elem[0:1], result], axis=0)
                for (elem, result) in zip(elems, results)
            ]
            return list(map(_interleave, even_elems, odd_elems))

        return prefer_static.cond(at_base_case, base_value, recursive_case)
Example #25
0
def _filter_one_step(step,
                     observation,
                     previous_particles,
                     log_weights,
                     transition_fn,
                     observation_fn,
                     proposal_fn,
                     resample_criterion_fn,
                     has_observation=True,
                     seed=None):
    """Advances the particle filter by a single time step."""
    with tf.name_scope('filter_one_step'):
        seed = SeedStream(seed, 'filter_one_step')
        num_particles = prefer_static.shape(log_weights)[0]

        proposed_particles, proposal_log_weights = _propose_with_log_weights(
            step=step - 1,
            particles=previous_particles,
            transition_fn=transition_fn,
            proposal_fn=proposal_fn,
            seed=seed)
        log_weights = tf.nn.log_softmax(proposal_log_weights + log_weights,
                                        axis=-1)

        # If this step has an observation, compute its weights and marginal
        # likelihood (and otherwise, leave weights unchanged).
        observation_log_weights = prefer_static.cond(
            has_observation,
            lambda: prefer_static.broadcast_to(  # pylint: disable=g-long-lambda
                _compute_observation_log_weights(step, proposed_particles,
                                                 observation, observation_fn),
                prefer_static.shape(log_weights)),
            lambda: tf.zeros_like(log_weights))

        unnormalized_log_weights = log_weights + observation_log_weights
        step_log_marginal_likelihood = tf.math.reduce_logsumexp(
            unnormalized_log_weights, axis=0)
        log_weights = (unnormalized_log_weights - step_log_marginal_likelihood)

        # Adaptive resampling: resample particles iff the specified criterion.
        do_resample = resample_criterion_fn(unnormalized_log_weights)

        # Some batch elements may require resampling and others not, so
        # we first do the resampling for all elements, then select whether to use
        # the resampled values for each batch element according to
        # `do_resample`. If there were no batching, we might prefer to use
        # `tf.cond` to avoid the resampling computation on steps where it's not
        # needed---but we're ultimately interested in adaptive resampling
        # for statistical (not computational) purposes, so this isn't a dealbreaker.
        resampled_particles, resample_indices = _resample(proposed_particles,
                                                          log_weights,
                                                          resample_independent,
                                                          seed=seed)

        uniform_weights = (prefer_static.zeros_like(log_weights) -
                           prefer_static.log(num_particles))
        (resampled_particles, resample_indices,
         log_weights) = tf.nest.map_structure(
             lambda r, p: prefer_static.where(do_resample, r, p),
             (resampled_particles, resample_indices, uniform_weights),
             (proposed_particles, _dummy_indices_like(resample_indices),
              log_weights))

    return ParticleFilterStepResults(
        particles=resampled_particles,
        log_weights=log_weights,
        parent_indices=resample_indices,
        step_log_marginal_likelihood=step_log_marginal_likelihood)
Example #26
0
def hager_zhang(value_and_gradients_function,
                initial_step_size=None,
                value_at_initial_step=None,
                value_at_zero=None,
                threshold_use_approximate_wolfe_condition=1e-6,
                shrinkage_param=0.66,
                expansion_param=5.0,
                sufficient_decrease_param=0.1,
                curvature_param=0.9,
                step_size_shrink_param=0.1,
                max_iterations=50,
                name=None):
    """The Hager Zhang line search algorithm.

  Performs an inexact line search based on the algorithm of
  [Hager and Zhang (2006)][2].
  The univariate objective function `value_and_gradients_function` is typically
  generated by projecting a multivariate objective function along a search
  direction. Suppose the multivariate function to be minimized is
  `g(x1,x2, .. xn)`. Let (d1, d2, ..., dn) be the direction along which we wish
  to perform a line search. Then the projected univariate function to be used
  for line search is

  ```None
    f(a) = g(x1 + d1 * a, x2 + d2 * a, ..., xn + dn * a)
  ```

  The directional derivative along (d1, d2, ..., dn) is needed for this
  procedure. This also corresponds to the derivative of the projected function
  `f(a)` with respect to `a`. Note that this derivative must be negative for
  `a = 0` if the direction is a descent direction.

  The usual stopping criteria for the line search is the satisfaction of the
  (weak) Wolfe conditions. For details of the Wolfe conditions, see
  ref. [3]. On a finite precision machine, the exact Wolfe conditions can
  be difficult to satisfy when one is very close to the minimum and as argued
  by [Hager and Zhang (2005)][1], one can only expect the minimum to be
  determined within square root of machine precision. To improve the situation,
  they propose to replace the Wolfe conditions with an approximate version
  depending on the derivative of the function which is applied only when one
  is very close to the minimum. The following algorithm implements this
  enhanced scheme.

  ### Usage:

  Primary use of line search methods is as an internal component of a class of
  optimization algorithms (called line search based methods as opposed to
  trust region methods). Hence, the end user will typically not want to access
  line search directly. In particular, inexact line search should not be
  confused with a univariate minimization method. The stopping criteria of line
  search is the satisfaction of Wolfe conditions and not the discovery of the
  minimum of the function.

  With this caveat in mind, the following example illustrates the standalone
  usage of the line search.

  ```python
    # Define value and gradient namedtuple
    ValueAndGradient = namedtuple('ValueAndGradient', ['x', 'f', 'df'])
    # Define a quadratic target with minimum at 1.3.
    def value_and_gradients_function(x):
      return ValueAndGradient(x=x, f=(x - 1.3) ** 2, df=2 * (x-1.3))
    # Set initial step size.
    step_size = tf.constant(0.1)
    ls_result = tfp.optimizer.linesearch.hager_zhang(
        value_and_gradients_function, initial_step_size=step_size)
    # Evaluate the results.
    with tf.Session() as session:
      results = session.run(ls_result)
      # Ensure convergence.
      assert results.converged
      # If the line search converged, the left and the right ends of the
      # bracketing interval are identical.
      assert results.left.x == result.right.x
      # Print the number of evaluations and the final step size.
      print ("Final Step Size: %f, Evaluations: %d" % (results.left.x,
                                                       results.func_evals))
  ```

  ### References:
  [1]: William Hager, Hongchao Zhang. A new conjugate gradient method with
    guaranteed descent and an efficient line search. SIAM J. Optim., Vol 16. 1,
    pp. 170-172. 2005.
    https://www.math.lsu.edu/~hozhang/papers/cg_descent.pdf

  [2]: William Hager, Hongchao Zhang. Algorithm 851: CG_DESCENT, a conjugate
    gradient method with guaranteed descent. ACM Transactions on Mathematical
    Software, Vol 32., 1, pp. 113-137. 2006.
    http://users.clas.ufl.edu/hager/papers/CG/cg_compare.pdf

  [3]: Jorge Nocedal, Stephen Wright. Numerical Optimization. Springer Series in
    Operations Research. pp 33-36. 2006

  Args:
    value_and_gradients_function: A Python callable that accepts a real scalar
      tensor and returns a namedtuple with the fields 'x', 'f', and 'df' that
      correspond to scalar tensors of real dtype containing the point at which
      the function was evaluated, the value of the function, and its
      derivative at that point. The other namedtuple fields, if present,
      should be tensors or sequences (possibly nested) of tensors.
      In usual optimization application, this function would be generated by
      projecting the multivariate objective function along some specific
      direction. The direction is determined by some other procedure but should
      be a descent direction (i.e. the derivative of the projected univariate
      function must be negative at 0.).
      Alternatively, the function may represent the batching of `n` such line
      functions (e.g. projecting a single multivariate objective function along
      `n` distinct directions at once) accepting n points as input, i.e. a
      tensor of shape [n], and the fields 'x', 'f' and 'df' in the returned
      namedtuple should each be a tensor of shape [n], with the corresponding
      input points, function values, and derivatives at those input points.
    initial_step_size: (Optional) Scalar positive `Tensor` of real dtype, or
      a tensor of shape [n] in batching mode. The initial value (or values) to
      try to bracket the minimum. Default is `1.` as a float32.
      Note that this point need not necessarily bracket the minimum for the line
      search to work correctly but the supplied value must be greater than 0.
      A good initial value will make the search converge faster.
    value_at_initial_step: (Optional) The full return value of evaluating
      value_and_gradients_function at initial_step_size, i.e. a namedtuple with
      'x', 'f', 'df', if already known by the caller. If supplied the value of
      `initial_step_size` will be ignored, otherwise the tuple will be computed
      by evaluating value_and_gradients_function.
    value_at_zero: (Optional) The full return value of
      value_and_gradients_function at `0.`, i.e. a namedtuple with
      'x', 'f', 'df', if already known by the caller. If not supplied the tuple
      will be computed by evaluating value_and_gradients_function.
    threshold_use_approximate_wolfe_condition: Scalar positive `Tensor`
      of real dtype. Corresponds to the parameter 'epsilon' in
      [Hager and Zhang (2006)][2]. Used to estimate the
      threshold at which the line search switches to approximate Wolfe
      conditions.
    shrinkage_param: Scalar positive Tensor of real dtype. Must be less than
      `1.`. Corresponds to the parameter `gamma` in
      [Hager and Zhang (2006)][2].
      If the secant**2 step does not shrink the bracketing interval by this
      proportion, a bisection step is performed to reduce the interval width.
    expansion_param: Scalar positive `Tensor` of real dtype. Must be greater
      than `1.`. Used to expand the initial interval in case it does not bracket
      a minimum. Corresponds to `rho` in [Hager and Zhang (2006)][2].
    sufficient_decrease_param: Positive scalar `Tensor` of real dtype.
      Bounded above by the curvature param. Corresponds to `delta` in the
      terminology of [Hager and Zhang (2006)][2].
    curvature_param: Positive scalar `Tensor` of real dtype. Bounded above
      by `1.`. Corresponds to 'sigma' in the terminology of
      [Hager and Zhang (2006)][2].
    step_size_shrink_param: Positive scalar `Tensor` of real dtype. Bounded
      above by `1`. If the supplied step size is too big (i.e. either the
      objective value or the gradient at that point is infinite), this factor
      is used to shrink the step size until it is finite.
    max_iterations: Positive scalar `Tensor` of integral dtype or None. The
      maximum number of iterations to perform in the line search. The number of
      iterations used to bracket the minimum are also counted against this
      parameter.
    name: (Optional) Python str. The name prefixed to the ops created by this
      function. If not supplied, the default name 'hager_zhang' is used.

  Returns:
    results: A namedtuple containing the following attributes.
      converged: Boolean `Tensor` of shape [n]. Whether a point satisfying
        Wolfe/Approx wolfe was found.
      failed: Boolean `Tensor` of shape [n]. Whether line search failed e.g.
        if either the objective function or the gradient are not finite at
        an evaluation point.
      iterations: Scalar int32 `Tensor`. Number of line search iterations made.
      func_evals: Scalar int32 `Tensor`. Number of function evaluations made.
      left: A namedtuple, as returned by value_and_gradients_function,
        of the left end point of the final bracketing interval. Values are
        equal to those of `right` on batch members where converged is True.
        Otherwise, it corresponds to the last interval computed.
      right: A namedtuple, as returned by value_and_gradients_function,
        of the right end point of the final bracketing interval. Values are
        equal to those of `left` on batch members where converged is True.
        Otherwise, it corresponds to the last interval computed.
  """
    with tf.compat.v1.name_scope(name, 'hager_zhang', [
            initial_step_size, value_at_initial_step, value_at_zero,
            threshold_use_approximate_wolfe_condition, shrinkage_param,
            expansion_param, sufficient_decrease_param, curvature_param
    ]):
        val_0, val_initial, f_lim, prepare_evals = _prepare_args(
            value_and_gradients_function, initial_step_size,
            value_at_initial_step, value_at_zero,
            threshold_use_approximate_wolfe_condition)

        valid_inputs = (hzl.is_finite(val_0) & (val_0.df < 0)
                        & tf.math.is_finite(val_initial.x) &
                        (val_initial.x > 0))

        # Note: _fix_step_size returns immediately if either all inputs are invalid
        # or none need fixing.
        fix_step_evals, val_c, fix_failed = _fix_step_size(
            value_and_gradients_function, val_initial, valid_inputs,
            step_size_shrink_param)

        failed = ~valid_inputs | fix_failed
        init_interval = HagerZhangLineSearchResult(
            converged=tf.zeros_like(failed),  # i.e. all False.
            failed=failed,
            func_evals=prepare_evals + fix_step_evals,
            iterations=tf.convert_to_tensor(value=0),
            left=val_0,
            right=val_c)

        def _apply_bracket_and_search():
            """Bracketing and searching to do for valid inputs."""
            return _bracket_and_search(value_and_gradients_function,
                                       init_interval, f_lim, max_iterations,
                                       shrinkage_param, expansion_param,
                                       sufficient_decrease_param,
                                       curvature_param)

        return prefer_static.cond(tf.reduce_any(input_tensor=~failed),
                                  _apply_bracket_and_search,
                                  lambda: init_interval)
Example #27
0
    def posterior_marginals(self, observations, mask=None, name=None):
        """Compute marginal posterior distribution for each state.

    This function computes, for each time step, the marginal
    conditional probability that the hidden Markov model was in
    each possible state given the observations that were made
    at each time step.
    So if the hidden states are `z[0],...,z[num_steps - 1]` and
    the observations are `x[0], ..., x[num_steps - 1]`, then
    this function computes `P(z[i] | x[0], ..., x[num_steps - 1])`
    for all `i` from `0` to `num_steps - 1`.

    This operation is sometimes called smoothing. It uses a form
    of the forward-backward algorithm.

    Note: the behavior of this function is undefined if the
    `observations` argument represents impossible observations
    from the model.

    Args:
      observations: A tensor representing a batch of observations
        made on the hidden Markov model.  The rightmost dimension of this tensor
        gives the steps in a sequence of observations from a single sample from
        the hidden Markov model. The size of this dimension should match the
        `num_steps` parameter of the hidden Markov model object. The other
        dimensions are the dimensions of the batch and these are broadcast with
        the hidden Markov model's parameters.
      mask: optional bool-type `tensor` with rightmost dimension matching
        `num_steps` indicating which observations the result of this
        function should be conditioned on. When the mask has value
        `True` the corresponding observations aren't used.
        if `mask` is `None` then all of the observations are used.
        the `mask` dimensions left of the last are broadcast with the
        hmm batch as well as with the observations.
      name: Python `str` name prefixed to Ops created by this class.
        Default value: "HiddenMarkovModel".

    Returns:
      posterior_marginal: A `Categorical` distribution object representing the
        marginal probability of the hidden Markov model being in each state at
        each step. The rightmost dimension of the `Categorical` distributions
        batch will equal the `num_steps` parameter providing one marginal
        distribution for each step. The other dimensions are the dimensions
        corresponding to the batch of observations.

    Raises:
      ValueError: if rightmost dimension of `observations` does not
      have size `num_steps`.
    """

        with tf.name_scope(name or "posterior_marginals"):
            with tf.control_dependencies(self._runtime_assertions):
                observation_tensor_shape = tf.shape(observations)
                mask_tensor_shape = tf.shape(
                    mask) if mask is not None else None

                with self._observation_mask_shape_preconditions(
                        observation_tensor_shape, mask_tensor_shape):
                    observation_log_probs = self._observation_log_probs(
                        observations, mask)
                    log_prob = self._log_init + observation_log_probs[0]
                    log_transition = self._log_trans
                    log_adjoint_prob = tf.zeros_like(log_prob)

                    def _scan_multiple_steps_forwards():
                        def forward_step(log_previous_step,
                                         log_prob_observation):
                            return _log_vector_matrix(
                                log_previous_step,
                                log_transition) + log_prob_observation

                        forward_log_probs = tf.scan(forward_step,
                                                    observation_log_probs[1:],
                                                    initializer=log_prob,
                                                    name="forward_log_probs")
                        return tf.concat([[log_prob], forward_log_probs],
                                         axis=0)

                    forward_log_probs = prefer_static.cond(
                        self._num_steps > 1, _scan_multiple_steps_forwards,
                        lambda: tf.convert_to_tensor([log_prob]))

                    total_log_prob = tf.reduce_logsumexp(forward_log_probs[-1],
                                                         axis=-1)

                    def _scan_multiple_steps_backwards():
                        """Perform `scan` operation when `num_steps` > 1."""
                        def backward_step(log_previous_step,
                                          log_prob_observation):
                            return _log_matrix_vector(
                                log_transition,
                                log_prob_observation + log_previous_step)

                        backward_log_adjoint_probs = tf.scan(
                            backward_step,
                            observation_log_probs[1:],
                            initializer=log_adjoint_prob,
                            reverse=True,
                            name="backward_log_adjoint_probs")

                        return tf.concat(
                            [backward_log_adjoint_probs, [log_adjoint_prob]],
                            axis=0)

                    backward_log_adjoint_probs = prefer_static.cond(
                        self._num_steps > 1, _scan_multiple_steps_backwards,
                        lambda: tf.convert_to_tensor([log_adjoint_prob]))

                    log_likelihoods = forward_log_probs + backward_log_adjoint_probs

                    marginal_log_probs = distribution_util.move_dimension(
                        log_likelihoods - total_log_prob[..., tf.newaxis], 0,
                        -2)

                    return categorical.Categorical(logits=marginal_log_probs)
Example #28
0
def _get_search_direction(state):
    """Computes the search direction to follow at the current state.

  On the `k`-th iteration of the main L-BFGS algorithm, the state has collected
  the most recent `m` correction pairs in position_deltas and gradient_deltas,
  where `k = state.num_iterations` and `m = min(k, num_correction_pairs)`.

  Assuming these, the code below is an implementation of the L-BFGS two-loop
  recursion algorithm given by [Nocedal and Wright(2006)][1]:

  ```None
    q_direction = objective_gradient
    for i in reversed(range(m)):  # First loop.
      inv_rho[i] = gradient_deltas[i]^T * position_deltas[i]
      alpha[i] = position_deltas[i]^T * q_direction / inv_rho[i]
      q_direction = q_direction - alpha[i] * gradient_deltas[i]

    kth_inv_hessian_factor = (gradient_deltas[-1]^T * position_deltas[-1] /
                              gradient_deltas[-1]^T * gradient_deltas[-1])
    r_direction = kth_inv_hessian_factor * I * q_direction

    for i in range(m):  # Second loop.
      beta = gradient_deltas[i]^T * r_direction / inv_rho[i]
      r_direction = r_direction + position_deltas[i] * (alpha[i] - beta)

    return -r_direction  # Approximates - H_k * objective_gradient.
  ```

  Args:
    state: A `LBfgsOptimizerResults` tuple with the current state of the
      search procedure.

  Returns:
    A real `Tensor` of the same shape as the `state.position`. The direction
    along which to perform line search.
  """
    # The number of correction pairs that have been collected so far.
    num_elements = tf.minimum(
        state.num_iterations,
        distribution_util.prefer_static_shape(state.position_deltas)[0])

    def _two_loop_algorithm():
        """L-BFGS two-loop algorithm."""
        # Correction pairs are always appended to the end, so only the latest
        # `num_elements` vectors have valid position/gradient deltas.
        position_deltas = state.position_deltas[-num_elements:]
        gradient_deltas = state.gradient_deltas[-num_elements:]

        # Pre-compute all `inv_rho[i]`s.
        inv_rhos = tf.reduce_sum(gradient_deltas * position_deltas, axis=-1)

        def first_loop(acc, args):
            _, q_direction = acc
            position_delta, gradient_delta, inv_rho = args
            alpha = tf.reduce_sum(position_delta * q_direction,
                                  axis=-1) / inv_rho
            direction_delta = alpha[..., tf.newaxis] * gradient_delta
            return (alpha, q_direction - direction_delta)

        # Run first loop body computing and collecting `alpha[i]`s, while also
        # computing the updated `q_direction` at each step.
        zero = tf.zeros_like(inv_rhos[0])
        alphas, q_directions = tf.scan(
            first_loop, [position_deltas, gradient_deltas, inv_rhos],
            initializer=(zero, state.objective_gradient),
            reverse=True)

        # We use `H^0_k = gamma_k * I` as an estimate for the initial inverse
        # hessian for the k-th iteration; then `r_direction = H^0_k * q_direction`.
        gamma_k = inv_rhos[-1] / tf.reduce_sum(
            gradient_deltas[-1] * gradient_deltas[-1], axis=-1)
        r_direction = gamma_k[..., tf.newaxis] * q_directions[0]

        def second_loop(r_direction, args):
            alpha, position_delta, gradient_delta, inv_rho = args
            beta = tf.reduce_sum(gradient_delta * r_direction,
                                 axis=-1) / inv_rho
            direction_delta = (alpha - beta)[..., tf.newaxis] * position_delta
            return r_direction + direction_delta

        # Finally, run second loop body computing the updated `r_direction` at each
        # step.
        r_directions = tf.scan(
            second_loop, [alphas, position_deltas, gradient_deltas, inv_rhos],
            initializer=r_direction)
        return -r_directions[-1]

    return prefer_static.cond(tf.equal(num_elements,
                                       0), (lambda: -state.objective_gradient),
                              _two_loop_algorithm)
def fit_one_step(
    model_matrix,
    response,
    model,
    model_coefficients_start=None,
    predicted_linear_response_start=None,
    l2_regularizer=None,
    dispersion=None,
    offset=None,
    learning_rate=None,
    fast_unsafe_numerics=True,
    l2_regularization_penalty_factor=None,
    name=None):
  """Runs one step of Fisher scoring.

  Args:
    model_matrix: (Batch of) `float`-like, matrix-shaped `Tensor` where each row
      represents a sample's features.
    response: (Batch of) vector-shaped `Tensor` where each element represents a
      sample's observed response (to the corresponding row of features). Must
      have same `dtype` as `model_matrix`.
    model: `tfp.glm.ExponentialFamily`-like instance used to construct the
      negative log-likelihood loss, gradient, and expected Hessian (i.e., the
      Fisher information matrix).
    model_coefficients_start: Optional (batch of) vector-shaped `Tensor`
      representing the initial model coefficients, one for each column in
      `model_matrix`. Must have same `dtype` as `model_matrix`.
      Default value: Zeros.
    predicted_linear_response_start: Optional `Tensor` with `shape`, `dtype`
      matching `response`; represents `offset` shifted initial linear
      predictions based on `model_coefficients_start`.
      Default value: `offset` if `model_coefficients is None`, and
      `tf.linalg.matvec(model_matrix, model_coefficients_start) + offset`
      otherwise.
    l2_regularizer: Optional scalar `Tensor` representing L2 regularization
      penalty, i.e.,
      `loss(w) = sum{-log p(y[i]|x[i],w) : i=1..n} + l2_regularizer ||w||_2^2`.
      Default value: `None` (i.e., no L2 regularization).
    dispersion: Optional (batch of) `Tensor` representing `response` dispersion,
      i.e., as in, `p(y|theta) := exp((y theta - A(theta)) / dispersion)`.
      Must broadcast with rows of `model_matrix`.
      Default value: `None` (i.e., "no dispersion").
    offset: Optional `Tensor` representing constant shift applied to
      `predicted_linear_response`.  Must broadcast to `response`.
      Default value: `None` (i.e., `tf.zeros_like(response)`).
    learning_rate: Optional (batch of) scalar `Tensor` used to dampen iterative
      progress. Typically only needed if optimization diverges, should be no
      larger than `1` and typically very close to `1`.
      Default value: `None` (i.e., `1`).
    fast_unsafe_numerics: Optional Python `bool` indicating if solve should be
      based on Cholesky or QR decomposition.
      Default value: `True` (i.e., "prefer speed via Cholesky decomposition").
    l2_regularization_penalty_factor: Optional (batch of) vector-shaped
      `Tensor`, representing a separate penalty factor to apply to each model
      coefficient, length equal to columns in `model_matrix`. Each penalty
      factor multiplies l2_regularizer to allow differential regularization. Can
      be 0 for some coefficients, which implies no regularization. Default is 1
      for all coefficients.
      `loss(w) = sum{-log p(y[i]|x[i],w) : i=1..n} + l2_regularizer ||w *
        l2_regularization_penalty_factor||_2^2`
    name: Python `str` used as name prefix to ops created by this function.
      Default value: `"fit_one_step"`.

  Returns:
    model_coefficients: (Batch of) vector-shaped `Tensor`; represents the
      next estimate of the model coefficients, one for each column in
      `model_matrix`.
    predicted_linear_response: `response`-shaped `Tensor` representing linear
      predictions based on new `model_coefficients`, i.e.,
      `tf.linalg.matvec(model_matrix, model_coefficients_next) + offset`.
  """
  with tf.name_scope(name or 'fit_one_step'):

    [
        model_matrix,
        response,
        model_coefficients_start,
        predicted_linear_response_start,
        offset,
    ] = prepare_args(
        model_matrix,
        response,
        model_coefficients_start,
        predicted_linear_response_start,
        offset)

    # Compute: mean, grad(mean, predicted_linear_response_start), and variance.
    mean, variance, grad_mean = model(predicted_linear_response_start)

    # If either `grad_mean` or `variance is non-finite or zero, then we'll
    # replace it with a value such that the row is zeroed out. Although this
    # procedure may seem circuitous, it is necessary to ensure this algorithm is
    # itself differentiable.
    is_valid = (
        tf.math.is_finite(grad_mean) & tf.not_equal(grad_mean, 0.)
        & tf.math.is_finite(variance) & (variance > 0.))

    def mask_if_invalid(x, mask):
      return tf.where(
          is_valid, x, np.array(mask, dtype_util.as_numpy_dtype(x.dtype)))

    # Run one step of iteratively reweighted least-squares.
    # Compute "`z`", the adjusted predicted linear response.
    # z = predicted_linear_response_start
    #     + learning_rate * (response - mean) / grad_mean
    z = (response - mean) / mask_if_invalid(grad_mean, 1.)
    # TODO(jvdillon): Rather than use learning rate, we should consider using
    # backtracking line search.
    if learning_rate is not None:
      z *= learning_rate[..., tf.newaxis]
    z += predicted_linear_response_start
    if offset is not None:
      z -= offset

    # Compute "`w`", the per-sample weight.
    if dispersion is not None:
      # For convenience, we'll now scale the variance by the dispersion factor.
      variance *= dispersion
    w = (
        mask_if_invalid(grad_mean, 0.) *
        tf.math.rsqrt(mask_if_invalid(variance, np.inf)))

    a = model_matrix * w[..., tf.newaxis]
    b = z * w
    # Solve `min{ || A @ model_coefficients - b ||_2**2 : model_coefficients }`
    # where `@` denotes `matmul`.

    if l2_regularizer is None:
      l2_regularizer = np.array(0, dtype_util.as_numpy_dtype(a.dtype))
    else:
      l2_regularizer_ = distribution_util.maybe_get_static_value(
          l2_regularizer, dtype_util.as_numpy_dtype(a.dtype))
      if l2_regularizer_ is not None:
        l2_regularizer = l2_regularizer_

    def _embed_l2_regularization():
      """Adds synthetic observations to implement L2 regularization."""
      # `tf.matrix_solve_ls` does not respect the `l2_regularization` argument
      # when `fast_unsafe_numerics` is `False`. This function  adds synthetic
      # observations to the data to implement the regularization instead.
      # Adding observations `sqrt(l2_regularizer) * I` is mathematically
      # equivalent to adding the term
      # `-l2_regularizer ||coefficients||_2**2` to the log-likelihood.
      num_model_coefficients = num_cols(model_matrix)
      batch_shape = tf.shape(model_matrix)[:-2]
      if l2_regularization_penalty_factor is None:
        eye = tf.eye(
            num_model_coefficients, batch_shape=batch_shape, dtype=a.dtype)
      else:
        eye = tf.linalg.tensor_diag(
            tf.cast(l2_regularization_penalty_factor, dtype=a.dtype))
        broadcasted_shape = prefer_static.concat(
            [batch_shape, [num_model_coefficients, num_model_coefficients]],
            axis=0)
        eye = tf.broadcast_to(eye, broadcasted_shape)
      a_ = tf.concat([a, tf.sqrt(l2_regularizer) * eye], axis=-2)
      b_ = distribution_util.pad(
          b, count=num_model_coefficients, axis=-1, back=True)
      # Return l2_regularizer=0 since its now embedded.
      l2_regularizer_ = np.array(0, dtype_util.as_numpy_dtype(a.dtype))
      return a_, b_, l2_regularizer_

    a, b, l2_regularizer = prefer_static.cond(
        prefer_static.reduce_all([
            prefer_static.logical_or(
                not(fast_unsafe_numerics),
                l2_regularization_penalty_factor is not None),
            l2_regularizer > 0.
        ]),
        _embed_l2_regularization,
        lambda: (a, b, l2_regularizer))

    model_coefficients_next = tf.linalg.lstsq(
        a,
        b[..., tf.newaxis],
        fast=fast_unsafe_numerics,
        l2_regularizer=l2_regularizer,
        name='model_coefficients_next')
    model_coefficients_next = model_coefficients_next[..., 0]

    # TODO(b/79122261): The approach used in `matrix_solve_ls` could be made
    # faster by avoiding explicitly forming Q and instead keeping the
    # factorization in 'implicit' form with stacked (rescaled) Householder
    # vectors underneath the 'R' and then applying the (accumulated)
    # reflectors in the appropriate order to apply Q'. However, we don't
    # presently do this because we lack core TF functionality. For reference,
    # the vanilla QR approach is:
    #   q, r = tf.linalg.qr(a)
    #   c = tf.matmul(q, b, adjoint_a=True)
    #   model_coefficients_next = tf.matrix_triangular_solve(
    #       r, c, lower=False, name='model_coefficients_next')

    predicted_linear_response_next = compute_predicted_linear_response(
        model_matrix,
        model_coefficients_next,
        offset,
        name='predicted_linear_response_next')

    return model_coefficients_next, predicted_linear_response_next
Example #30
0
 def test_missing_arg2(self):
     x = tf.constant(1)
     with self.assertRaises(TypeError):
         ps.cond(True, lambda: x)