def _parameter_control_dependencies(self, is_init): assertions = [] # In init, we can always build shape and dtype checks because # we assume shape doesn't change for Variable backed args. if is_init and self._is_vector: msg = "Argument `loc` must be at least rank 1." if tensorshape_util.rank(self.loc.shape) is not None: if tensorshape_util.rank(self.loc.shape) < 1: raise ValueError(msg) elif self.validate_args: assertions.append( assert_util.assert_rank_at_least(self.loc, 1, message=msg)) if not self.validate_args: assert not assertions # Should never happen return [] if is_init != tensor_util.is_ref(self.atol): assertions.append( assert_util.assert_non_negative( self.atol, message="Argument 'atol' must be non-negative")) if is_init != tensor_util.is_ref(self.rtol): assertions.append( assert_util.assert_non_negative( self.rtol, message="Argument 'rtol' must be non-negative")) return assertions
def _assertions(self, t): if not self.validate_args: return [] return [ assert_util.assert_non_negative(t, message="Argument y was negative") ]
def maybe_assert_bernoulli_param_correctness( is_init, validate_args, probs, logits): """Return assertions for `Bernoulli`-type distributions.""" if is_init: x, name = (probs, 'probs') if logits is None else (logits, 'logits') if not dtype_util.is_floating(x.dtype): raise TypeError( 'Argument `{}` must having floating type.'.format(name)) if not validate_args: return [] assertions = [] if probs is not None: if is_init != tensor_util.is_ref(probs): probs = tf.convert_to_tensor(probs) one = tf.constant(1., probs.dtype) assertions += [ assert_util.assert_non_negative( probs, message='probs has components less than 0.'), assert_util.assert_less_equal( probs, one, message='probs has components greater than 1.') ] return assertions
def _assertions(self, t): if not self.validate_args: return [] return [ assert_util.assert_non_negative( t, message="All elements must be non-negative.") ]
def _maybe_validate_shape_override(self, override_shape, base_is_scalar, validate_args, name): """Helper to __init__ which ensures override batch/event_shape are valid.""" if override_shape is None: override_shape = [] override_shape = tf.convert_to_tensor(override_shape, dtype=tf.int32, name=name) if not dtype_util.is_integer(override_shape.dtype): raise TypeError("shape override must be an integer") override_is_scalar = _is_scalar_from_shape_tensor(override_shape) if tf.get_static_value(override_is_scalar): return self._empty dynamic_assertions = [] if tensorshape_util.rank(override_shape.shape) is not None: if tensorshape_util.rank(override_shape.shape) != 1: raise ValueError("shape override must be a vector") elif validate_args: dynamic_assertions += [ assert_util.assert_rank( override_shape, 1, message="shape override must be a vector") ] if tf.get_static_value(override_shape) is not None: if any(s < 0 for s in tf.get_static_value(override_shape)): raise ValueError( "shape override must have non-negative elements") elif validate_args: dynamic_assertions += [ assert_util.assert_non_negative( override_shape, message="shape override must have non-negative elements") ] is_both_nonscalar = prefer_static.logical_and( prefer_static.logical_not(base_is_scalar), prefer_static.logical_not(override_is_scalar)) if tf.get_static_value(is_both_nonscalar) is not None: if tf.get_static_value(is_both_nonscalar): raise ValueError("base distribution not scalar") elif validate_args: dynamic_assertions += [ assert_util.assert_equal( is_both_nonscalar, False, message="base distribution not scalar") ] if not dynamic_assertions: return override_shape return distribution_util.with_dependencies(dynamic_assertions, override_shape)
def _maybe_assert_valid_x(self, x): if not self.validate_args or self.power == 0.: return [] return [ assert_util.assert_non_negative( 1. + self.power * x, message='Forward transformation input must be at least {}.'. format(-1. / self.power)) ]
def _maybe_assert_valid_sample(self, x, dtype): if not self.validate_args: return x one = tf.ones([], dtype=dtype) return distribution_util.with_dependencies([ assert_util.assert_non_negative(x), assert_util.assert_less_equal(x, one), assert_util.assert_near(one, tf.reduce_sum(x, axis=[-1])), ], x)
def _maybe_assert_valid_y(self, y): if not self.validate_args: return [] is_positive = assert_util.assert_non_negative( y, message='Inverse transformation input must be greater than 0.') less_than_one = assert_util.assert_less_equal( y, tf.constant(1., y.dtype), message='Inverse transformation input must be less than or equal to 1.') return [is_positive, less_than_one]
def _maybe_assert_valid(self, x): if not self.validate_args: return x return distribution_util.with_dependencies([ assert_util.assert_non_negative( x, message="sample must be non-negative"), assert_util.assert_less_equal( x, tf.ones([], self.concentration0.dtype), message="sample must be no larger than `1`."), ], x)
def _parameter_control_dependencies(self, is_init): if not self.validate_args: return [] if is_init == tensor_util.is_ref(self.total_count): return [] total_count = tf.convert_to_tensor(self.total_count) msg1 = 'Argument `total_count` must be non-negative.' msg2 = 'Argument `total_count` cannot contain fractional components.' return [ assert_util.assert_non_negative(total_count, message=msg1), distribution_util.assert_integer_form(total_count, message=msg2), ]
def _parameter_control_dependencies(self, is_init): if not self.validate_args: return [] assertions = [] if is_init != tensor_util.is_ref(self.concentration): # concentration >= 1 # TODO(b/111451422, b/115950951) Generalize to concentration > 0. assertions.append( assert_util.assert_non_negative( self.concentration - 1, message='Argument `concentration` must be >= 1.')) return assertions
def maybe_assert_negative_binomial_param_correctness(is_init, validate_args, total_count, probs, logits): """Return assertions for `NegativeBinomial`-type distributions.""" if is_init: x, name = (probs, 'probs') if logits is None else (logits, 'logits') if not dtype_util.is_floating(x.dtype): raise TypeError( 'Argument `{}` must having floating type.'.format(name)) if not validate_args: return [] assertions = [] if is_init != tensor_util.is_ref(total_count): total_count = tf.convert_to_tensor(total_count) assertions.extend([ assert_util.assert_non_negative( total_count, message='`total_count` has components less than 0.'), distribution_util.assert_integer_form( total_count, message='`total_count` has fractional components.') ]) if probs is not None: if is_init != tensor_util.is_ref(probs): probs = tf.convert_to_tensor(probs) one = tf.constant(1., probs.dtype) assertions.extend([ assert_util.assert_non_negative( probs, message='`probs` has components less than 0.'), assert_util.assert_less_equal( probs, one, message='`probs` has components greater than 1.') ]) return assertions
def maybe_assert_categorical_param_correctness(is_init, validate_args, probs, logits): """Return assertions for `Categorical`-type distributions.""" assertions = [] # In init, we can always build shape and dtype checks because # we assume shape doesn't change for Variable backed args. if is_init: x, name = (probs, 'probs') if logits is None else (logits, 'logits') if not dtype_util.is_floating(x.dtype): raise TypeError( 'Argument `{}` must having floating type.'.format(name)) msg = 'Argument `{}` must have rank at least 1.'.format(name) ndims = tensorshape_util.rank(x.shape) if ndims is not None: if ndims < 1: raise ValueError(msg) elif validate_args: x = tf.convert_to_tensor(x) probs = x if logits is None else None # Retain tensor conversion. logits = x if probs is None else None assertions.append( assert_util.assert_rank_at_least(x, 1, message=msg)) if not validate_args: assert not assertions # Should never happen. return [] if logits is not None: if is_init != tensor_util.is_ref(logits): logits = tf.convert_to_tensor(logits) assertions.extend( distribution_util.assert_categorical_event_shape(logits)) if probs is not None: if is_init != tensor_util.is_ref(probs): probs = tf.convert_to_tensor(probs) assertions.extend([ assert_util.assert_non_negative(probs), assert_util.assert_near( tf.reduce_sum(probs, axis=-1), np.array(1, dtype=dtype_util.as_numpy_dtype(probs.dtype)), message='Argument `probs` must sum to 1.') ]) assertions.extend( distribution_util.assert_categorical_event_shape(probs)) return assertions
def _assertions(self, t): if not self.validate_args: return [] return [ assert_util.assert_non_negative( t, message="Inverse transformation input must be greater than 0." ), assert_util.assert_less_equal( t, dtype_util.as_numpy_dtype(t.dtype)(1.), message= "Inverse transformation input must be less than or equal " "to 1.") ]
def _maybe_validate_rightmost_transposed_ndims(rightmost_transposed_ndims, validate_args, name=None): """Checks that `rightmost_transposed_ndims` is valid.""" with tf.name_scope(name or 'maybe_validate_rightmost_transposed_ndims'): assertions = [] if not dtype_util.is_integer(rightmost_transposed_ndims.dtype): raise TypeError( '`rightmost_transposed_ndims` must be integer type.') if tensorshape_util.rank(rightmost_transposed_ndims.shape) is not None: if tensorshape_util.rank(rightmost_transposed_ndims.shape) != 0: raise ValueError( '`rightmost_transposed_ndims` must be a scalar, ' 'saw rank: {}.'.format( tensorshape_util.rank( rightmost_transposed_ndims.shape))) elif validate_args: assertions += [ assert_util.assert_rank(rightmost_transposed_ndims, 0) ] rightmost_transposed_ndims_ = tf.get_static_value( rightmost_transposed_ndims) msg = '`rightmost_transposed_ndims` must be non-negative.' if rightmost_transposed_ndims_ is not None: if rightmost_transposed_ndims_ < 0: raise ValueError( msg[:-1] + ', saw: {}.'.format(rightmost_transposed_ndims_)) elif validate_args: assertions += [ assert_util.assert_non_negative(rightmost_transposed_ndims, message=msg) ] return assertions
def _maybe_assert_valid_x(self, x): if not self.validate_args: return [] return [assert_util.assert_non_negative( x, message='Forward transformation input must be at least 0.')]
def _parameter_control_dependencies(self, is_init): assertions = [] logits = self._logits probs = self._probs param, name = (probs, 'probs') if logits is None else (logits, 'logits') # In init, we can always build shape and dtype checks because # we assume shape doesn't change for Variable backed args. if is_init: if not dtype_util.is_floating(param.dtype): raise TypeError( 'Argument `{}` must having floating type.'.format(name)) msg = 'Argument `{}` must have rank at least 1.'.format(name) shape_static = tensorshape_util.dims(param.shape) if shape_static is not None: if len(shape_static) < 1: raise ValueError(msg) elif self.validate_args: param = tf.convert_to_tensor(param) assertions.append( assert_util.assert_rank_at_least(param, 1, message=msg)) msg1 = 'Argument `{}` must have final dimension >= 1.'.format(name) msg2 = 'Argument `{}` must have final dimension <= {}.'.format( name, tf.int32.max) event_size = shape_static[-1] if shape_static is not None else None if event_size is not None: if event_size < 1: raise ValueError(msg1) if event_size > tf.int32.max: raise ValueError(msg2) elif self.validate_args: param = tf.convert_to_tensor(param) assertions.append( assert_util.assert_greater_equal(tf.shape(param)[-1:], 1, message=msg1)) # NOTE: For now, we leave out a runtime assertion that # `tf.shape(param)[-1] <= tf.int32.max`. An earlier `tf.shape` call # will fail before we get to this point. if not self.validate_args: assert not assertions # Should never happen. return [] if is_init != tensor_util.is_ref(self.temperature): assertions.append(assert_util.assert_positive(self.temperature)) if probs is not None: probs = param # reuse tensor conversion from above if is_init != tensor_util.is_ref(probs): probs = tf.convert_to_tensor(probs) one = tf.ones([], dtype=probs.dtype) assertions.extend([ assert_util.assert_non_negative(probs), assert_util.assert_less_equal(probs, one), assert_util.assert_near( tf.reduce_sum(probs, axis=-1), one, message='Argument `probs` must sum to 1.'), ]) return assertions
def __init__(self, mean_direction, concentration, validate_args=False, allow_nan_stats=True, name='VonMisesFisher'): """Creates a new `VonMisesFisher` instance. Args: mean_direction: Floating-point `Tensor` with shape [B1, ... Bn, D]. A unit vector indicating the mode of the distribution, or the unit-normalized direction of the mean. (This is *not* in general the mean of the distribution; the mean is not generally in the support of the distribution.) NOTE: `D` is currently restricted to <= 5. concentration: Floating-point `Tensor` having batch shape [B1, ... Bn] broadcastable with `mean_direction`. The level of concentration of samples around the `mean_direction`. `concentration=0` indicates a uniform distribution over the unit hypersphere, and `concentration=+inf` indicates a `Deterministic` distribution (delta function) at `mean_direction`. validate_args: Python `bool`, default `False`. When `True` distribution parameters are checked for validity despite possibly degrading runtime performance. When `False` invalid inputs may silently render incorrect outputs. allow_nan_stats: Python `bool`, default `True`. When `True`, statistics (e.g., mean, mode, variance) use the value "`NaN`" to indicate the result is undefined. When `False`, an exception is raised if one or more of the statistic's batch members are undefined. name: Python `str` name prefixed to Ops created by this class. Raises: ValueError: For known-bad arguments, i.e. unsupported event dimension. """ parameters = dict(locals()) with tf.name_scope(name) as name: dtype = dtype_util.common_dtype([mean_direction, concentration], tf.float32) mean_direction = tf.convert_to_tensor(mean_direction, name='mean_direction', dtype=dtype) concentration = tf.convert_to_tensor(concentration, name='concentration', dtype=dtype) assertions = [ assert_util.assert_non_negative( concentration, message='`concentration` must be non-negative'), assert_util.assert_greater( tf.shape(mean_direction)[-1], 1, message='`mean_direction` may not have scalar event shape' ), assert_util.assert_near( 1., tf.linalg.norm(mean_direction, axis=-1), message='`mean_direction` must be unit-length') ] if validate_args else [] static_event_dim = tf.compat.dimension_value( tensorshape_util.with_rank_at_least(mean_direction.shape, 1)[-1]) if static_event_dim is not None and static_event_dim > 5: raise ValueError('vMF ndims > 5 is not currently supported') elif validate_args: assertions += [ assert_util.assert_less_equal( tf.shape(mean_direction)[-1], 5, message='vMF ndims > 5 is not currently supported') ] with tf.control_dependencies(assertions): self._mean_direction = tf.identity(mean_direction) self._concentration = tf.identity(concentration) dtype_util.assert_same_float_dtype( [self._mean_direction, self._concentration]) # mean_direction is always reparameterized. # concentration is only for event_dim==3, via an inversion sampler. reparameterization_type = (reparameterization.FULLY_REPARAMETERIZED if static_event_dim == 3 else reparameterization.NOT_REPARAMETERIZED) super(VonMisesFisher, self).__init__( dtype=self._concentration.dtype, validate_args=validate_args, allow_nan_stats=allow_nan_stats, reparameterization_type=reparameterization_type, parameters=parameters, name=name)