Beispiel #1
0
 def _assert_valid_sample(self, x):
     if not self.validate_args:
         return x
     return control_flow_ops.with_dependencies([
         check_ops.assert_non_positive(x),
         check_ops.assert_near(array_ops.zeros([], dtype=self.dtype),
                               math_ops.reduce_logsumexp(x, axis=[-1])),
     ], x)
Beispiel #2
0
    def _log_unnormalized_prob(self, x):
        with ops.control_dependencies(
            [check_ops.assert_near(linalg_ops.norm(x, axis=-1), 1, atol=1e-3
                                   )] if self.validate_args else []):
            output = self.scale * math_ops.reduce_sum(
                self._loc * x, axis=-1, keepdims=True)

        return array_ops.reshape(
            output, ops.convert_to_tensor(array_ops.shape(output)[:-1]))
 def _assert_valid_sample(self, x):
   if not self.validate_args:
     return x
   return control_flow_ops.with_dependencies([
       check_ops.assert_non_positive(x),
       check_ops.assert_near(
           array_ops.zeros([], dtype=self.dtype),
           math_ops.reduce_logsumexp(x, axis=[-1])),
   ], x)
Beispiel #4
0
 def _maybe_assert_valid_sample(self, x):
   """Checks the validity of a sample."""
   if not self.validate_args:
     return x
   return control_flow_ops.with_dependencies([
       check_ops.assert_positive(x, message="samples must be positive"),
       check_ops.assert_near(
           array_ops.ones([], dtype=self.dtype),
           math_ops.reduce_sum(x, -1),
           message="sample last-dimension must sum to `1`"),
   ], x)
    def __init__(self, loc, scale, validate_args=False, allow_nan_stats=True,
                 name="von-Mises-Fisher"):
        """Construct von-Mises-Fisher distributions with mean
        and concentration `loc` and `scale`.

        Args:
          loc: Floating point tensor; the mean of the distribution(s).
          scale: Floating point tensor; the concentration of the distribution(s).
            Must contain only non-negative values.
          validate_args: Python `bool`, default `False`. When `True` distribution
            parameters are checked for validity despite possibly degrading runtime
            performance. When `False` invalid inputs may silently render incorrect
            outputs.
          allow_nan_stats: Python `bool`, default `True`. When `True`,
            statistics (e.g., mean, mode, variance) use the value "`NaN`" to
            indicate the result is undefined. When `False`, an exception is raised
            if one or more of the statistic's batch members are undefined.
          name: Python `str` name prefixed to Ops created by this class.

        Raises:
          TypeError: if `loc` and `scale` have different `dtype`.
        """
        parameters = locals()
        with ops.name_scope(name, values=[loc, scale]):
            with ops.control_dependencies(
                    [check_ops.assert_positive(scale),
                     check_ops.assert_near(linalg_ops.norm(loc, axis=-1),
                                           1, atol=1e-5)] if validate_args else []):
                self._loc = array_ops.identity(loc, name="loc")
                self._scale = array_ops.identity(scale, name="scale")
                check_ops.assert_same_float_dtype([self._loc, self._scale])

        super(VonMisesFisher, self).__init__(
            dtype=self._scale.dtype,
            reparameterization_type=distributions.FULLY_REPARAMETERIZED,
            validate_args=validate_args,
            allow_nan_stats=allow_nan_stats,
            parameters=parameters,
            graph_parents=[self._loc, self._scale],
            name=name)

        self.__m = math_ops.cast(self._loc.shape[-1], dtypes.int32)
        self.__mf = math_ops.cast(self.__m, dtype=self.dtype)
        self.__e1 = array_ops.one_hot([0], self.__m, dtype=self.dtype,
                                      name='standard-direction')
Beispiel #6
0
  def __init__(self,
               loc=None,
               covariance_matrix=None,
               validate_args=False,
               allow_nan_stats=True,
               name="MultivariateNormalFullCovariance"):
    """Construct Multivariate Normal distribution on `R^k`.

    The `batch_shape` is the broadcast shape between `loc` and
    `covariance_matrix` arguments.

    The `event_shape` is given by last dimension of the matrix implied by
    `covariance_matrix`. The last dimension of `loc` (if provided) must
    broadcast with this.

    A non-batch `covariance_matrix` matrix is a `k x k` symmetric positive
    definite matrix.  In other words it is (real) symmetric with all eigenvalues
    strictly positive.

    Additional leading dimensions (if any) will index batches.

    Args:
      loc: Floating-point `Tensor`. If this is set to `None`, `loc` is
        implicitly `0`. When specified, may have shape `[B1, ..., Bb, k]` where
        `b >= 0` and `k` is the event size.
      covariance_matrix: Floating-point, symmetric positive definite `Tensor` of
        same `dtype` as `loc`.  The strict upper triangle of `covariance_matrix`
        is ignored, so if `covariance_matrix` is not symmetric no error will be
        raised (unless `validate_args is True`).  `covariance_matrix` has shape
        `[B1, ..., Bb, k, k]` where `b >= 0` and `k` is the event size.
      validate_args: Python `bool`, default `False`. When `True` distribution
        parameters are checked for validity despite possibly degrading runtime
        performance. When `False` invalid inputs may silently render incorrect
        outputs.
      allow_nan_stats: Python `bool`, default `True`. When `True`,
        statistics (e.g., mean, mode, variance) use the value "`NaN`" to
        indicate the result is undefined. When `False`, an exception is raised
        if one or more of the statistic's batch members are undefined.
      name: Python `str` name prefixed to Ops created by this class.

    Raises:
      ValueError: if neither `loc` nor `covariance_matrix` are specified.
    """
    parameters = locals()

    # Convert the covariance_matrix up to a scale_tril and call MVNTriL.
    with ops.name_scope(name) as name:
      with ops.name_scope("init", values=[loc, covariance_matrix]):
        if covariance_matrix is None:
          scale_tril = None
        else:
          covariance_matrix = ops.convert_to_tensor(
              covariance_matrix, name="covariance_matrix")
          if validate_args:
            covariance_matrix = control_flow_ops.with_dependencies([
                check_ops.assert_near(
                    covariance_matrix,
                    array_ops.matrix_transpose(covariance_matrix),
                    message="Matrix was not symmetric")], covariance_matrix)
          # No need to validate that covariance_matrix is non-singular.
          # LinearOperatorLowerTriangular has an assert_non_singular method that
          # is called by the Bijector.
          # However, cholesky() ignores the upper triangular part, so we do need
          # to separately assert symmetric.
          scale_tril = linalg_ops.cholesky(covariance_matrix)
        super(MultivariateNormalFullCovariance, self).__init__(
            loc=loc,
            scale_tril=scale_tril,
            validate_args=validate_args,
            allow_nan_stats=allow_nan_stats,
            name=name)
    self._parameters = parameters
Beispiel #7
0
 def _train_op_fn(loss):
   with ops.control_dependencies((check_ops.assert_near(
       math_ops.to_float(expected_loss), math_ops.to_float(loss),
       atol=atol, name='assert_loss'),)):
     return constant_op.constant(expected_train_result)
Beispiel #8
0
 def _train_op_fn(loss):
     with ops.control_dependencies(
         (check_ops.assert_near(math_ops.to_float(self._default_loss),
                                math_ops.to_float(loss),
                                name='assert_loss'), )):
         return constant_op.constant(expected_train_result)