def _assert_valid_sample(self, x): if not self.validate_args: return x return control_flow_ops.with_dependencies([ check_ops.assert_non_positive(x), check_ops.assert_near(array_ops.zeros([], dtype=self.dtype), math_ops.reduce_logsumexp(x, axis=[-1])), ], x)
def _log_unnormalized_prob(self, x): with ops.control_dependencies( [check_ops.assert_near(linalg_ops.norm(x, axis=-1), 1, atol=1e-3 )] if self.validate_args else []): output = self.scale * math_ops.reduce_sum( self._loc * x, axis=-1, keepdims=True) return array_ops.reshape( output, ops.convert_to_tensor(array_ops.shape(output)[:-1]))
def _assert_valid_sample(self, x): if not self.validate_args: return x return control_flow_ops.with_dependencies([ check_ops.assert_non_positive(x), check_ops.assert_near( array_ops.zeros([], dtype=self.dtype), math_ops.reduce_logsumexp(x, axis=[-1])), ], x)
def _maybe_assert_valid_sample(self, x): """Checks the validity of a sample.""" if not self.validate_args: return x return control_flow_ops.with_dependencies([ check_ops.assert_positive(x, message="samples must be positive"), check_ops.assert_near( array_ops.ones([], dtype=self.dtype), math_ops.reduce_sum(x, -1), message="sample last-dimension must sum to `1`"), ], x)
def __init__(self, loc, scale, validate_args=False, allow_nan_stats=True, name="von-Mises-Fisher"): """Construct von-Mises-Fisher distributions with mean and concentration `loc` and `scale`. Args: loc: Floating point tensor; the mean of the distribution(s). scale: Floating point tensor; the concentration of the distribution(s). Must contain only non-negative values. validate_args: Python `bool`, default `False`. When `True` distribution parameters are checked for validity despite possibly degrading runtime performance. When `False` invalid inputs may silently render incorrect outputs. allow_nan_stats: Python `bool`, default `True`. When `True`, statistics (e.g., mean, mode, variance) use the value "`NaN`" to indicate the result is undefined. When `False`, an exception is raised if one or more of the statistic's batch members are undefined. name: Python `str` name prefixed to Ops created by this class. Raises: TypeError: if `loc` and `scale` have different `dtype`. """ parameters = locals() with ops.name_scope(name, values=[loc, scale]): with ops.control_dependencies( [check_ops.assert_positive(scale), check_ops.assert_near(linalg_ops.norm(loc, axis=-1), 1, atol=1e-5)] if validate_args else []): self._loc = array_ops.identity(loc, name="loc") self._scale = array_ops.identity(scale, name="scale") check_ops.assert_same_float_dtype([self._loc, self._scale]) super(VonMisesFisher, self).__init__( dtype=self._scale.dtype, reparameterization_type=distributions.FULLY_REPARAMETERIZED, validate_args=validate_args, allow_nan_stats=allow_nan_stats, parameters=parameters, graph_parents=[self._loc, self._scale], name=name) self.__m = math_ops.cast(self._loc.shape[-1], dtypes.int32) self.__mf = math_ops.cast(self.__m, dtype=self.dtype) self.__e1 = array_ops.one_hot([0], self.__m, dtype=self.dtype, name='standard-direction')
def __init__(self, loc=None, covariance_matrix=None, validate_args=False, allow_nan_stats=True, name="MultivariateNormalFullCovariance"): """Construct Multivariate Normal distribution on `R^k`. The `batch_shape` is the broadcast shape between `loc` and `covariance_matrix` arguments. The `event_shape` is given by last dimension of the matrix implied by `covariance_matrix`. The last dimension of `loc` (if provided) must broadcast with this. A non-batch `covariance_matrix` matrix is a `k x k` symmetric positive definite matrix. In other words it is (real) symmetric with all eigenvalues strictly positive. Additional leading dimensions (if any) will index batches. Args: loc: Floating-point `Tensor`. If this is set to `None`, `loc` is implicitly `0`. When specified, may have shape `[B1, ..., Bb, k]` where `b >= 0` and `k` is the event size. covariance_matrix: Floating-point, symmetric positive definite `Tensor` of same `dtype` as `loc`. The strict upper triangle of `covariance_matrix` is ignored, so if `covariance_matrix` is not symmetric no error will be raised (unless `validate_args is True`). `covariance_matrix` has shape `[B1, ..., Bb, k, k]` where `b >= 0` and `k` is the event size. validate_args: Python `bool`, default `False`. When `True` distribution parameters are checked for validity despite possibly degrading runtime performance. When `False` invalid inputs may silently render incorrect outputs. allow_nan_stats: Python `bool`, default `True`. When `True`, statistics (e.g., mean, mode, variance) use the value "`NaN`" to indicate the result is undefined. When `False`, an exception is raised if one or more of the statistic's batch members are undefined. name: Python `str` name prefixed to Ops created by this class. Raises: ValueError: if neither `loc` nor `covariance_matrix` are specified. """ parameters = locals() # Convert the covariance_matrix up to a scale_tril and call MVNTriL. with ops.name_scope(name) as name: with ops.name_scope("init", values=[loc, covariance_matrix]): if covariance_matrix is None: scale_tril = None else: covariance_matrix = ops.convert_to_tensor( covariance_matrix, name="covariance_matrix") if validate_args: covariance_matrix = control_flow_ops.with_dependencies([ check_ops.assert_near( covariance_matrix, array_ops.matrix_transpose(covariance_matrix), message="Matrix was not symmetric")], covariance_matrix) # No need to validate that covariance_matrix is non-singular. # LinearOperatorLowerTriangular has an assert_non_singular method that # is called by the Bijector. # However, cholesky() ignores the upper triangular part, so we do need # to separately assert symmetric. scale_tril = linalg_ops.cholesky(covariance_matrix) super(MultivariateNormalFullCovariance, self).__init__( loc=loc, scale_tril=scale_tril, validate_args=validate_args, allow_nan_stats=allow_nan_stats, name=name) self._parameters = parameters
def _train_op_fn(loss): with ops.control_dependencies((check_ops.assert_near( math_ops.to_float(expected_loss), math_ops.to_float(loss), atol=atol, name='assert_loss'),)): return constant_op.constant(expected_train_result)
def _train_op_fn(loss): with ops.control_dependencies( (check_ops.assert_near(math_ops.to_float(self._default_loss), math_ops.to_float(loss), name='assert_loss'), )): return constant_op.constant(expected_train_result)