def _assertions(self, x):
   if not self.validate_args:
     return []
   shape = tf.shape(x)
   is_matrix = assert_util.assert_rank_at_least(
       x, 2, message="Input must have rank at least 2.")
   is_square = assert_util.assert_equal(
       shape[-2], shape[-1], message="Input must be a square matrix.")
   above_diagonal = tf.linalg.band_part(
       tf.linalg.set_diag(x, tf.zeros(shape[:-1], dtype=tf.float32)), 0, -1)
   is_lower_triangular = assert_util.assert_equal(
       above_diagonal,
       tf.zeros_like(above_diagonal),
       message="Input must be lower triangular.")
   # A lower triangular matrix is nonsingular iff all its diagonal entries are
   # nonzero.
   diag_part = tf.linalg.diag_part(x)
   is_nonsingular = assert_util.assert_none_equal(
       diag_part,
       tf.zeros_like(diag_part),
       message="Input must have all diagonal entries nonzero.")
   return [is_matrix, is_square, is_lower_triangular, is_nonsingular]
Пример #2
0
    def _parameter_control_dependencies(self, is_init):
        if not self.validate_args:
            return []

        lu, perm = None, None
        assertions = []
        if (is_init != tensor_util.is_ref(self.lower_upper)
                or is_init != tensor_util.is_ref(self.permutation)):
            lu, perm = self._broadcast_params()
            assertions.extend(
                lu_reconstruct_assertions(lu, perm, self.validate_args))

        if is_init != tensor_util.is_ref(self.lower_upper):
            lu = tf.convert_to_tensor(self.lower_upper) if lu is None else lu
            assertions.append(
                assert_util.assert_none_equal(
                    tf.linalg.diag_part(lu),
                    tf.zeros([], dtype=lu.dtype),
                    message=
                    'Invertible `lower_upper` must have nonzero diagonal.'))

        return assertions
Пример #3
0
    def __init__(self,
                 hinge_softness=None,
                 validate_args=False,
                 name="softplus"):
        with tf.name_scope(name):
            if hinge_softness is None:
                self._hinge_softness = None
            else:
                self._hinge_softness = tf.convert_to_tensor(
                    value=hinge_softness, name="hinge_softness")
                if validate_args:
                    nonzero_check = assert_util.assert_none_equal(
                        dtype_util.as_numpy_dtype(
                            self._hinge_softness.dtype)(0),
                        self.hinge_softness,
                        message="hinge_softness must be non-zero")
                    self._hinge_softness = distribution_util.with_dependencies(
                        [nonzero_check], self.hinge_softness)

        super(Softplus, self).__init__(forward_min_event_ndims=0,
                                       validate_args=validate_args,
                                       name=name)
Пример #4
0
 def _assertions(self, t):
   if not self.validate_args:
     return []
   return [assert_util.assert_none_equal(
       t, dtype_util.as_numpy_dtype(t.dtype)(0.),
       message="All elements must be non-zero.")]
Пример #5
0
    def _create_scale_operator(self, identity_multiplier, diag, tril,
                               perturb_diag, perturb_factor, shift,
                               validate_args, dtype):
        """Construct `scale` from various components.

    Args:
      identity_multiplier: floating point rank 0 `Tensor` representing a scaling
        done to the identity matrix.
      diag: Floating-point `Tensor` representing the diagonal matrix.`diag` has
        shape `[N1, N2, ...  k]`, which represents a k x k diagonal matrix.
      tril: Floating-point `Tensor` representing the lower triangular matrix.
       `tril` has shape `[N1, N2, ...  k, k]`, which represents a k x k lower
       triangular matrix.
      perturb_diag: Floating-point `Tensor` representing the diagonal matrix of
        the low rank update.
      perturb_factor: Floating-point `Tensor` representing factor matrix.
      shift: Floating-point `Tensor` representing `shift in `scale @ X + shift`.
      validate_args: Python `bool` indicating whether arguments should be
        checked for correctness.
      dtype: `DType` for arg `Tensor` conversions.

    Returns:
      scale. In the case of scaling by a constant, scale is a
      floating point `Tensor`. Otherwise, scale is a `LinearOperator`.

    Raises:
      ValueError: if all of `tril`, `diag` and `identity_multiplier` are `None`.
    """
        identity_multiplier = _as_tensor(identity_multiplier,
                                         "identity_multiplier", dtype)
        diag = _as_tensor(diag, "diag", dtype)
        tril = _as_tensor(tril, "tril", dtype)
        perturb_diag = _as_tensor(perturb_diag, "perturb_diag", dtype)
        perturb_factor = _as_tensor(perturb_factor, "perturb_factor", dtype)

        # If possible, use the low rank update to infer the shape of
        # the identity matrix, when scale represents a scaled identity matrix
        # with a low rank update.
        shape_hint = None
        if perturb_factor is not None:
            shape_hint = distribution_util.dimension_size(perturb_factor,
                                                          axis=-2)

        if self._is_only_identity_multiplier:
            if validate_args:
                return distribution_util.with_dependencies([
                    assert_util.assert_none_equal(
                        identity_multiplier,
                        tf.zeros([], identity_multiplier.dtype),
                        ["identity_multiplier should be non-zero."])
                ], identity_multiplier)
            return identity_multiplier

        scale = _make_tril_scale(loc=shift,
                                 scale_tril=tril,
                                 scale_diag=diag,
                                 scale_identity_multiplier=identity_multiplier,
                                 validate_args=validate_args,
                                 assert_positive=False,
                                 shape_hint=shape_hint)

        if perturb_factor is not None:
            return tf.linalg.LinearOperatorLowRankUpdate(
                scale,
                u=perturb_factor,
                diag_update=perturb_diag,
                is_diag_update_positive=perturb_diag is None,
                is_non_singular=True,  # Implied by is_positive_definite=True.
                is_self_adjoint=True,
                is_positive_definite=True,
                is_square=True)

        return scale
Пример #6
0
def find_root_chandrupatla(objective_fn,
                           low,
                           high,
                           position_tolerance=1e-8,
                           value_tolerance=0.,
                           max_iterations=50,
                           stopping_policy_fn=tf.reduce_all,
                           validate_args=False,
                           name='find_root_chandrupatla'):
    r"""Finds root(s) of a scalar function using Chandrupatla's method.

  Chandrupatla's method [1, 2] is a root-finding algorithm that is guaranteed
  to converge if a root lies within the given bounds. It generalizes the
  [bisection method](https://en.wikipedia.org/wiki/Bisection_method); at each
  step it chooses to perform either bisection or inverse quadratic
  interpolation. This makes it similar in spirit to [Brent's method](
  https://en.wikipedia.org/wiki/Brent%27s_method), which also considers steps
  that use the secant method, but Chandrupatla's method is simpler and often
  converges at least as quickly [3].

  Args:
    objective_fn: Python callable for which roots are searched. It must be a
      callable of a single variable. `objective_fn` must return a `Tensor` with
      shape `batch_shape` and dtype matching `lower_bound` and `upper_bound`.
    low: Float `Tensor` of shape `batch_shape` representing a lower
      bound(s) on the value of a root(s).
    high: Float `Tensor` of shape `batch_shape` representing an upper
      bound(s) on the value of a root(s).
    position_tolerance: Optional `Tensor` representing the maximum absolute
      error in the positions of the estimated roots. Shape must broadcast with
      `batch_shape`.
      Default value: `1e-8`.
    value_tolerance: Optional `Tensor` representing the absolute error allowed
      in the value of the objective function. If the absolute value of
      `objective_fn` is smaller than
      `value_tolerance` at a given position, then that position is considered a
      root for the function. Shape must broadcast with `batch_shape`.
      Default value: `1e-8`.
    max_iterations: Optional `Tensor` or Python integer specifying the maximum
      number of steps to perform. Shape must broadcast with `batch_shape`.
      Default value: `50`.
    stopping_policy_fn: Python `callable` controlling the algorithm termination.
      It must be a callable accepting a `Tensor` of booleans with the same shape
      as `lower_bound` and `upper_bound` (denoting whether each search is
      finished), and returning a scalar boolean `Tensor` indicating
      whether the overall search should stop. Typical values are
      `tf.reduce_all` (which returns only when the search is finished for all
      points), and `tf.reduce_any` (which returns as soon as the search is
      finished for any point).
      Default value: `tf.reduce_all` (returns only when the search is finished
        for all points).
    validate_args: Python `bool` indicating whether to validate arguments.
      Default value: `False`.
    name: Python `str` name prefixed to ops created by this function.
      Default value: 'find_root_chandrupatla'.

  Returns:
    root_search_results: A Python `namedtuple` containing the following items:
      estimated_root: `Tensor` containing the last position explored. If the
        search was successful within the specified tolerance, this position is
        a root of the objective function.
      objective_at_estimated_root: `Tensor` containing the value of the
        objective function at `position`. If the search was successful within
        the specified tolerance, then this is close to 0.
      num_iterations: The number of iterations performed.

  #### References

  [1] Tirupathi R. Chandrupatla. A new hybrid quadratic/bisection algorithm for
      finding the zero of a nonlinear function without using derivatives.
      _Advances in Engineering Software_, 28.3:145-149, 1997.
  [2] Philipp OJ Scherer. Computational Physics. _Springer Berlin_,
      Heidelberg, 2010.
      Section 6.1.7.3 https://books.google.com/books?id=cC-8BAAAQBAJ&pg=PA95
  [3] Jason Sachs. Ten Little Algorithms, Part 5: Quadratic Extremum
      Interpolation and Chandrupatla's Method (2015).
      https://www.embeddedrelated.com/showarticle/855.php
  """

    ################################################
    # Loop variables used by Chandrupatla's method:
    #
    #  a: endpoint of an interval `[min(a, b), max(a, b)]` containing the
    #     root. There is no guarantee as to which of `a` and `b` is larger.
    #  b: endpoint of an interval `[min(a, b), max(a, b)]` containing the
    #       root. There is no guarantee as to which of `a` and `b` is larger.
    #  f_a: value of the objective at `a`.
    #  f_b: value of the objective at `b`.
    #  t: the next position to be evaluated as the coefficient of a convex
    #    combination of `a` and `b` (i.e., a value in the unit interval).
    #  num_iterations: integer number of steps taken so far.
    #  converged: boolean indicating whether each batch element has converged.
    #
    # All variables have the same shape `batch_shape`.

    def _should_continue(a, b, f_a, f_b, t, num_iterations, converged):
        del a, b, f_a, f_b, t  # Unused.
        all_converged = stopping_policy_fn(
            tf.logical_or(converged, num_iterations >= max_iterations))
        return ~all_converged

    def _body(a, b, f_a, f_b, t, num_iterations, converged):
        """One step of Chandrupatla's method for root finding."""
        previous_loop_vars = (a, b, f_a, f_b, t, num_iterations, converged)
        finalized_elements = tf.logical_or(converged,
                                           num_iterations >= max_iterations)

        # Evaluate the new point.
        x_new = (1 - t) * a + t * b
        f_new = objective_fn(x_new)
        # Tighten the bounds.
        a, b, c, f_a, f_b, f_c = _structure_broadcasting_where(
            tf.equal(tf.math.sign(f_new), tf.math.sign(f_a)),
            (x_new, b, a, f_new, f_b, f_a), (x_new, a, b, f_new, f_a, f_b))

        # Check for convergence.
        f_best = tf.where(tf.abs(f_a) < tf.abs(f_b), f_a, f_b)
        interval_tolerance = position_tolerance / (tf.abs(b - c))
        converged = tf.logical_or(interval_tolerance > 0.5,
                                  tf.math.abs(f_best) <= value_tolerance)

        # Propose next point to evaluate.
        xi = (a - b) / (c - b)
        phi = (f_a - f_b) / (f_c - f_b)
        t = tf.where(
            # Condition for inverse quadratic interpolation.
            tf.logical_and(1 - tf.math.sqrt(1 - xi) < phi,
                           tf.math.sqrt(xi) > phi),
            # Propose a point by inverse quadratic interpolation.
            (f_a / (f_b - f_a) * f_c / (f_b - f_c) + (c - a) / (b - a) * f_a /
             (f_c - f_a) * f_b / (f_c - f_b)),
            # Otherwise, just cut the interval in half (bisection).
            0.5)
        # Constrain the proposal to the current interval (0 < t < 1).
        t = tf.minimum(tf.maximum(t, interval_tolerance),
                       1 - interval_tolerance)

        # Update elements that haven't converged.
        return _structure_broadcasting_where(
            finalized_elements, previous_loop_vars,
            (a, b, f_a, f_b, t, num_iterations + 1, converged))

    with tf.name_scope(name):
        max_iterations = tf.convert_to_tensor(max_iterations,
                                              name='max_iterations',
                                              dtype_hint=tf.int32)
        a = tf.convert_to_tensor(low, name='lower_bound')
        b = tf.convert_to_tensor(high, name='upper_bound')
        f_a, f_b = objective_fn(a), objective_fn(b)
        batch_shape = ps.broadcast_shape(ps.shape(f_a), ps.shape(f_b))

        assertions = []
        if validate_args:
            assertions += [
                assert_util.assert_none_equal(
                    tf.math.sign(f_a),
                    tf.math.sign(f_b),
                    message='Bounds must be on different sides of a root.')
            ]

        with tf.control_dependencies(assertions):
            initial_loop_vars = [
                a, b, f_a, f_b,
                tf.cast(0.5, dtype=f_a.dtype),
                tf.cast(0, dtype=max_iterations.dtype), False
            ]
            a, b, f_a, f_b, _, num_iterations, _ = tf.while_loop(
                _should_continue,
                _body,
                loop_vars=tf.nest.map_structure(
                    lambda x: tf.broadcast_to(x, batch_shape),
                    initial_loop_vars))

        x_best, f_best = _structure_broadcasting_where(
            tf.abs(f_a) < tf.abs(f_b), (a, f_a), (b, f_b))
    return RootSearchResults(estimated_root=x_best,
                             objective_at_estimated_root=f_best,
                             num_iterations=num_iterations)
Пример #7
0
 def _maybe_assert_valid(self, t):
   if not self.validate_args:
     return t
   is_valid = assert_util.assert_none_equal(
       t, 0., message="All elements must be non-zero.")
   return distribution_util.with_dependencies([is_valid], t)