def maybe_assert_bernoulli_param_correctness(is_init, validate_args, probs, probits): """Return assertions for `ProbitBernoulli`-type distributions.""" if is_init: x, name = (probs, 'probs') if probits is None else (probits, 'probits') if not dtype_util.is_floating(x.dtype): raise TypeError( 'Argument `{}` must having floating type.'.format(name)) if not validate_args: return [] assertions = [] if probs is not None: if is_init != tensor_util.is_ref(probs): probs = tf.convert_to_tensor(probs) one = tf.constant(1., probs.dtype) assertions += [ assert_util.assert_non_negative( probs, message='probs has components less than 0.'), assert_util.assert_less_equal( probs, one, message='probs has components greater than 1.') ] return assertions
def calculate_reshape(original_shape, new_shape, validate=False, name=None): """Calculates the reshaped dimensions (replacing up to one -1 in reshape).""" batch_shape_static = tensorshape_util.constant_value_as_shape(new_shape) if tensorshape_util.is_fully_defined(batch_shape_static): return np.int32(batch_shape_static), batch_shape_static, [] with tf.name_scope(name or 'calculate_reshape'): original_size = tf.reduce_prod(original_shape) implicit_dim = tf.equal(new_shape, -1) size_implicit_dim = (original_size // tf.maximum(1, -tf.reduce_prod(new_shape))) expanded_new_shape = tf.where( # Assumes exactly one `-1`. implicit_dim, size_implicit_dim, new_shape) validations = [] if not validate else [ # pylint: disable=g-long-ternary assert_util.assert_rank( original_shape, 1, message='Original shape must be a vector.'), assert_util.assert_rank( new_shape, 1, message='New shape must be a vector.'), assert_util.assert_less_equal( tf.math.count_nonzero(implicit_dim, dtype=tf.int32), 1, message='At most one dimension can be unknown.'), assert_util.assert_positive( expanded_new_shape, message='Shape elements must be >=-1.'), assert_util.assert_equal(tf.reduce_prod(expanded_new_shape), original_size, message='Shape sizes do not match.'), ] return expanded_new_shape, batch_shape_static, validations
def _maybe_assert_valid_sample(self, x, dtype): if not self.validate_args: return x one = tf.ones([], dtype=dtype) return distribution_util.with_dependencies([ assert_util.assert_non_negative(x), assert_util.assert_less_equal(x, one), assert_util.assert_near(one, tf.reduce_sum(x, axis=[-1])), ], x)
def _maybe_assert_valid_sample(self, counts): """Check counts for proper shape, values, then return tensor version.""" if not self.validate_args: return counts counts = distribution_util.embed_check_nonnegative_integer_form(counts) msg = ('Sampled counts must be itemwise less than ' 'or equal to `total_count` parameter.') return distribution_util.with_dependencies([ assert_util.assert_less_equal(counts, self.total_count, message=msg), ], counts)
def _maybe_assert_valid(self, x): if not self.validate_args: return x return distribution_util.with_dependencies([ assert_util.assert_non_negative( x, message="sample must be non-negative"), assert_util.assert_less_equal( x, tf.ones([], self.concentration0.dtype), message="sample must be no larger than `1`."), ], x)
def _maybe_assert_valid_y(self, y): if not self.validate_args: return [] is_positive = assert_util.assert_non_negative( y, message='Inverse transformation input must be greater than 0.') less_than_one = assert_util.assert_less_equal( y, tf.constant(1., y.dtype), message= 'Inverse transformation input must be less than or equal to 1.') return [is_positive, less_than_one]
def _validate_correlationness(self, x): if not self.validate_args or self.input_output_cholesky: return x checks = [ assert_util.assert_less_equal( dtype_util.as_numpy_dtype(x.dtype)(-1), x, message='Correlations must be >= -1.'), assert_util.assert_less_equal( x, dtype_util.as_numpy_dtype(x.dtype)(1), message='Correlations must be <= 1.'), assert_util.assert_near(tf.linalg.diag_part(x), dtype_util.as_numpy_dtype(x.dtype)(1), message='Self-correlations must be = 1.'), assert_util.assert_near( x, tf.linalg.matrix_transpose(x), message='Correlation matrices must be symmetric') ] with tf.control_dependencies(checks): return tf.identity(x)
def _assertions(self, t): if not self.validate_args: return [] return [ assert_util.assert_non_negative( t, message="Inverse transformation input must be greater than 0." ), assert_util.assert_less_equal( t, dtype_util.as_numpy_dtype(t.dtype)(1.), message= "Inverse transformation input must be less than or equal " "to 1.") ]
def _parameter_control_dependencies(self, is_init): if not self.validate_args: return [] assertions = [] if self._probs is not None: if is_init != tensor_util.is_ref(self._probs): probs = tf.convert_to_tensor(self._probs) assertions.append( assert_util.assert_positive( probs, message='Argument `probs` must be positive.')) assertions.append( assert_util.assert_less_equal( probs, dtype_util.as_numpy_dtype(self.dtype)(1.), message= 'Argument `probs` must be less than or equal to 1.')) return assertions
def _make_runtime_assertions(self, distribution, reinterpreted_batch_ndims, validate_args): assertions = [] static_reinterpreted_batch_ndims = tf.get_static_value( reinterpreted_batch_ndims) batch_ndims = tensorshape_util.rank(distribution.batch_shape) if batch_ndims is not None and static_reinterpreted_batch_ndims is not None: if static_reinterpreted_batch_ndims > batch_ndims: raise ValueError("reinterpreted_batch_ndims({}) cannot exceed " "distribution.batch_ndims({})".format( static_reinterpreted_batch_ndims, batch_ndims)) elif validate_args: assertions.append( assert_util.assert_less_equal( reinterpreted_batch_ndims, prefer_static.rank_from_shape( distribution.batch_shape_tensor, distribution.batch_shape), message=("reinterpreted_batch_ndims cannot exceed " "distribution.batch_ndims"))) return assertions
def maybe_assert_negative_binomial_param_correctness(is_init, validate_args, total_count, probs, logits): """Return assertions for `NegativeBinomial`-type distributions.""" if is_init: x, name = (probs, 'probs') if logits is None else (logits, 'logits') if not dtype_util.is_floating(x.dtype): raise TypeError( 'Argument `{}` must having floating type.'.format(name)) if not validate_args: return [] assertions = [] if is_init != tensor_util.is_ref(total_count): total_count = tf.convert_to_tensor(total_count) assertions.extend([ assert_util.assert_non_negative( total_count, message='`total_count` has components less than 0.'), distribution_util.assert_integer_form( total_count, message='`total_count` has fractional components.') ]) if probs is not None: if is_init != tensor_util.is_ref(probs): probs = tf.convert_to_tensor(probs) one = tf.constant(1., probs.dtype) assertions.extend([ assert_util.assert_non_negative( probs, message='`probs` has components less than 0.'), assert_util.assert_less_equal( probs, one, message='`probs` has components greater than 1.') ]) return assertions
def _prob(self, x): low = tf.convert_to_tensor(self.low) high = tf.convert_to_tensor(self.high) peak = tf.convert_to_tensor(self.peak) if self.validate_args: with tf.control_dependencies([ assert_util.assert_greater_equal(x, low), assert_util.assert_less_equal(x, high) ]): x = tf.identity(x) interval_length = high - low # This is the pdf function when a low <= high <= x. This looks like # a triangle, so we have to treat each line segment separately. result_inside_interval = tf.where( (x >= low) & (x <= peak), # Line segment from (low, 0) to (peak, 2 / (high - low)). 2. * (x - low) / (interval_length * (peak - low)), # Line segment from (peak, 2 / (high - low)) to (high, 0). 2. * (high - x) / (interval_length * (high - peak))) return tf.where((x < low) | (x > high), tf.zeros_like(x), result_inside_interval)
def _parameter_control_dependencies(self, is_init): assertions = [] logits = self._logits probs = self._probs param, name = (probs, 'probs') if logits is None else (logits, 'logits') # In init, we can always build shape and dtype checks because # we assume shape doesn't change for Variable backed args. if is_init: if not dtype_util.is_floating(param.dtype): raise TypeError( 'Argument `{}` must having floating type.'.format(name)) msg = 'Argument `{}` must have rank at least 1.'.format(name) shape_static = tensorshape_util.dims(param.shape) if shape_static is not None: if len(shape_static) < 1: raise ValueError(msg) elif self.validate_args: param = tf.convert_to_tensor(param) assertions.append( assert_util.assert_rank_at_least(param, 1, message=msg)) with tf.control_dependencies(assertions): param = tf.identity(param) msg1 = 'Argument `{}` must have final dimension >= 1.'.format(name) msg2 = 'Argument `{}` must have final dimension <= {}.'.format( name, tf.int32.max) event_size = shape_static[-1] if shape_static is not None else None if event_size is not None: if event_size < 1: raise ValueError(msg1) if event_size > tf.int32.max: raise ValueError(msg2) elif self.validate_args: param = tf.convert_to_tensor(param) assertions.append( assert_util.assert_greater_equal(tf.shape(param)[-1], 1, message=msg1)) # NOTE: For now, we leave out a runtime assertion that # `tf.shape(param)[-1] <= tf.int32.max`. An earlier `tf.shape` call # will fail before we get to this point. if not self.validate_args: assert not assertions # Should never happen. return [] if probs is not None: probs = param # reuse tensor conversion from above if is_init != tensor_util.is_ref(probs): probs = tf.convert_to_tensor(probs) one = tf.ones([], dtype=probs.dtype) assertions.extend([ assert_util.assert_non_negative(probs), assert_util.assert_less_equal(probs, one), assert_util.assert_near( tf.reduce_sum(probs, axis=-1), one, message='Argument `probs` must sum to 1.'), ]) return assertions
def _sample_n(self, n, seed=None): seed = SeedStream(seed, salt='vom_mises_fisher') # The sampling strategy relies on the fact that vMF variates are symmetric # about the mean direction. Accordingly, if we have a sampling strategy for # the away-from-mean angle, then we can uniformly sample the remaining # dimensions on the S^{dim-2} sphere for , and rotate these samples from a # (1, 0, 0, ..., 0)-mode distribution into the target orientation. # # This is easy to imagine on the 1-sphere (S^1; in 2-D space): sample a # von-Mises distributed `x` value in [-1, 1], then uniformly select what # amounts to a "up" or "down" additional degree of freedom after unit # normalizing, followed by a final rotation to the desired mean direction # from a basis of (1, 0). # # On S^2 (in 3-D), selecting a vMF `x` identifies a circle in `yz` on the # unit sphere over which the distribution is uniform, in particular the # circle where x = \hat{x} intersects the unit sphere. We pick a point on # that circle, then rotate to the desired mean direction from a basis of # (1, 0, 0). event_dim = ( tf.compat.dimension_value(self.event_shape[0]) or self._event_shape_tensor()[0]) sample_batch_shape = tf.concat([[n], self._batch_shape_tensor()], axis=0) dim = tf.cast(event_dim - 1, self.dtype) if event_dim == 3: samples_dim0 = self._sample_3d(n, seed=seed) else: # Wood'94 provides a rejection algorithm to sample the x coordinate. # Wood'94 definition of b: # b = (-2 * kappa + tf.sqrt(4 * kappa**2 + dim**2)) / dim # https://stats.stackexchange.com/questions/156729 suggests: b = dim / (2 * self.concentration + tf.sqrt(4 * self.concentration**2 + dim**2)) # TODO(bjp): Integrate any useful numerical tricks from hyperspherical VAE # https://github.com/nicola-decao/s-vae-tf/ x = (1 - b) / (1 + b) c = self.concentration * x + dim * tf.math.log1p(-x**2) beta = beta_lib.Beta(dim / 2, dim / 2) def cond_fn(w, should_continue): del w return tf.reduce_any(should_continue) def body_fn(w, should_continue): z = beta.sample(sample_shape=sample_batch_shape, seed=seed()) # set_shape needed here because of b/139013403 z.set_shape(w.shape) w = tf.where(should_continue, (1 - (1 + b) * z) / (1 - (1 - b) * z), w) w = tf.debugging.check_numerics(w, 'w') unif = tf.random.uniform( sample_batch_shape, seed=seed(), dtype=self.dtype) # set_shape needed here because of b/139013403 unif.set_shape(w.shape) should_continue = tf.logical_and( should_continue, self.concentration * w + dim * tf.math.log1p(-x * w) - c < tf.math.log(unif)) return w, should_continue w = tf.zeros(sample_batch_shape, dtype=self.dtype) should_continue = tf.ones(sample_batch_shape, dtype=tf.bool) samples_dim0 = tf.while_loop( cond=cond_fn, body=body_fn, loop_vars=(w, should_continue))[0] samples_dim0 = samples_dim0[..., tf.newaxis] if not self._allow_nan_stats: # Verify samples are w/in -1, 1, with useful error output tensors (top # value rather than all values). with tf.control_dependencies([ assert_util.assert_less_equal( samples_dim0, dtype_util.as_numpy_dtype(self.dtype)(1.01), data=[tf.math.top_k(tf.reshape(samples_dim0, [-1]))[0]]), assert_util.assert_greater_equal( samples_dim0, dtype_util.as_numpy_dtype(self.dtype)(-1.01), data=[-tf.math.top_k(tf.reshape(-samples_dim0, [-1]))[0]]) ]): samples_dim0 = tf.identity(samples_dim0) samples_otherdims_shape = tf.concat([sample_batch_shape, [event_dim - 1]], axis=0) unit_otherdims = tf.math.l2_normalize( tf.random.normal( samples_otherdims_shape, seed=seed(), dtype=self.dtype), axis=-1) samples = tf.concat([ samples_dim0, # we must avoid sqrt(1 - (>1)**2) tf.sqrt(tf.maximum(1 - samples_dim0**2, 0.)) * unit_otherdims ], axis=-1) samples = tf.math.l2_normalize(samples, axis=-1) if not self._allow_nan_stats: samples = tf.debugging.check_numerics(samples, 'samples') # Runtime assert that samples are unit length. if not self._allow_nan_stats: worst, idx = tf.math.top_k( tf.reshape(tf.abs(1 - tf.linalg.norm(samples, axis=-1)), [-1])) with tf.control_dependencies([ assert_util.assert_near( dtype_util.as_numpy_dtype(self.dtype)(0), worst, data=[ worst, idx, tf.gather(tf.reshape(samples, [-1, event_dim]), idx) ], atol=1e-4, summarize=100) ]): samples = tf.identity(samples) # The samples generated are symmetric around a mode at (1, 0, 0, ...., 0). # Now, we move the mode to `self.mean_direction` using a rotation matrix. if not self._allow_nan_stats: # Assert that the basis vector rotates to the mean direction, as expected. basis = tf.cast(tf.concat([[1.], tf.zeros([event_dim - 1])], axis=0), self.dtype) with tf.control_dependencies([ assert_util.assert_less( tf.linalg.norm( self._rotate(basis) - self.mean_direction, axis=-1), dtype_util.as_numpy_dtype(self.dtype)(1e-5)) ]): return self._rotate(samples) return self._rotate(samples)
def __init__(self, mean_direction, concentration, validate_args=False, allow_nan_stats=True, name='VonMisesFisher'): """Creates a new `VonMisesFisher` instance. Args: mean_direction: Floating-point `Tensor` with shape [B1, ... Bn, D]. A unit vector indicating the mode of the distribution, or the unit-normalized direction of the mean. (This is *not* in general the mean of the distribution; the mean is not generally in the support of the distribution.) NOTE: `D` is currently restricted to <= 5. concentration: Floating-point `Tensor` having batch shape [B1, ... Bn] broadcastable with `mean_direction`. The level of concentration of samples around the `mean_direction`. `concentration=0` indicates a uniform distribution over the unit hypersphere, and `concentration=+inf` indicates a `Deterministic` distribution (delta function) at `mean_direction`. validate_args: Python `bool`, default `False`. When `True` distribution parameters are checked for validity despite possibly degrading runtime performance. When `False` invalid inputs may silently render incorrect outputs. allow_nan_stats: Python `bool`, default `True`. When `True`, statistics (e.g., mean, mode, variance) use the value "`NaN`" to indicate the result is undefined. When `False`, an exception is raised if one or more of the statistic's batch members are undefined. name: Python `str` name prefixed to Ops created by this class. Raises: ValueError: For known-bad arguments, i.e. unsupported event dimension. """ parameters = dict(locals()) with tf.name_scope(name) as name: dtype = dtype_util.common_dtype([mean_direction, concentration], tf.float32) mean_direction = tf.convert_to_tensor( mean_direction, name='mean_direction', dtype=dtype) concentration = tf.convert_to_tensor( concentration, name='concentration', dtype=dtype) assertions = [ assert_util.assert_non_negative( concentration, message='`concentration` must be non-negative'), assert_util.assert_greater( tf.shape(mean_direction)[-1], 1, message='`mean_direction` may not have scalar event shape'), assert_util.assert_near( 1., tf.linalg.norm(mean_direction, axis=-1), message='`mean_direction` must be unit-length') ] if validate_args else [] static_event_dim = tf.compat.dimension_value( tensorshape_util.with_rank_at_least(mean_direction.shape, 1)[-1]) if static_event_dim is not None and static_event_dim > 5: raise ValueError('vMF ndims > 5 is not currently supported') elif validate_args: assertions += [ assert_util.assert_less_equal( tf.shape(mean_direction)[-1], 5, message='vMF ndims > 5 is not currently supported') ] with tf.control_dependencies(assertions): self._mean_direction = tf.identity(mean_direction) self._concentration = tf.identity(concentration) dtype_util.assert_same_float_dtype( [self._mean_direction, self._concentration]) # mean_direction is always reparameterized. # concentration is only for event_dim==3, via an inversion sampler. reparameterization_type = ( reparameterization.FULLY_REPARAMETERIZED if static_event_dim == 3 else reparameterization.NOT_REPARAMETERIZED) super(VonMisesFisher, self).__init__( dtype=self._concentration.dtype, validate_args=validate_args, allow_nan_stats=allow_nan_stats, reparameterization_type=reparameterization_type, parameters=parameters, name=name)
def __init__(self, df, scale_operator, input_output_cholesky=False, validate_args=False, allow_nan_stats=True, name=None): """Construct Wishart distributions. Args: df: `float` or `double` tensor, the degrees of freedom of the distribution(s). `df` must be greater than or equal to `k`. scale_operator: `float` or `double` instance of `LinearOperator`. input_output_cholesky: Python `bool`. If `True`, functions whose input or output have the semantics of samples assume inputs are in Cholesky form and return outputs in Cholesky form. In particular, if this flag is `True`, input to `log_prob` is presumed of Cholesky form and output from `sample`, `mean`, and `mode` are of Cholesky form. Setting this argument to `True` is purely a computational optimization and does not change the underlying distribution; for instance, `mean` returns the Cholesky of the mean, not the mean of Cholesky factors. The `variance` and `stddev` methods are unaffected by this flag. Default value: `False` (i.e., input/output does not have Cholesky semantics). validate_args: Python `bool`, default `False`. When `True` distribution parameters are checked for validity despite possibly degrading runtime performance. When `False` invalid inputs may silently render incorrect outputs. allow_nan_stats: Python `bool`, default `True`. When `True`, statistics (e.g., mean, mode, variance) use the value "`NaN`" to indicate the result is undefined. When `False`, an exception is raised if one or more of the statistic's batch members are undefined. name: Python `str` name prefixed to Ops created by this class. Raises: TypeError: if scale is not floating-type TypeError: if scale.dtype != df.dtype ValueError: if df < k, where scale operator event shape is `(k, k)` """ parameters = dict(locals()) self._input_output_cholesky = input_output_cholesky with tf.name_scope(name) as name: with tf.name_scope("init"): if not dtype_util.is_floating(scale_operator.dtype): raise TypeError( "scale_operator.dtype=%s is not a floating-point type" % scale_operator.dtype) if not scale_operator.is_square: print(scale_operator.to_dense().eval()) raise ValueError("scale_operator must be square.") self._scale_operator = scale_operator self._df = tf.convert_to_tensor(df, dtype=scale_operator.dtype, name="df") dtype_util.assert_same_float_dtype( [self._df, self._scale_operator]) if tf.compat.dimension_value( self._scale_operator.shape[-1]) is None: self._dimension = tf.cast( self._scale_operator.domain_dimension_tensor(), dtype=self._scale_operator.dtype, name="dimension") else: self._dimension = tf.convert_to_tensor( tf.compat.dimension_value( self._scale_operator.shape[-1]), dtype=self._scale_operator.dtype, name="dimension") df_val = tf.get_static_value(self._df) dim_val = tf.get_static_value(self._dimension) if df_val is not None and dim_val is not None: df_val = np.asarray(df_val) if not df_val.shape: df_val = [df_val] if np.any(df_val < dim_val): raise ValueError( "Degrees of freedom (df = %s) cannot be less than " "dimension of scale matrix (scale.dimension = %s)" % (df_val, dim_val)) elif validate_args: assertions = assert_util.assert_less_equal( self._dimension, self._df, message=("Degrees of freedom (df = %s) cannot be " "less than dimension of scale matrix " "(scale.dimension = %s)" % (self._dimension, self._df))) self._df = distribution_util.with_dependencies( [assertions], self._df) super(_WishartLinearOperator, self).__init__( dtype=self._scale_operator.dtype, validate_args=validate_args, allow_nan_stats=allow_nan_stats, reparameterization_type=reparameterization.FULLY_REPARAMETERIZED, parameters=parameters, name=name)