def _parameter_control_dependencies(self, is_init): if not self.validate_args: return [] assertions = [] if is_init != tensor_util.is_ref(self.total_count): total_count = tf.convert_to_tensor(self.total_count) msg1 = 'Argument `total_count` must be non-negative.' msg2 = 'Argument `total_count` cannot contain fractional components.' assertions += [ assert_util.assert_non_negative(total_count, message=msg1), distribution_util.assert_integer_form(total_count, message=msg2), ] if self._probs is not None: if is_init != tensor_util.is_ref(self._probs): probs = tf.convert_to_tensor(self._probs) one = tf.constant(1., probs.dtype) assertions += [ assert_util.assert_non_negative( probs, message='probs has components less than 0.'), assert_util.assert_less_equal( probs, one, message='probs has components greater than 1.') ] return assertions
def _parameter_control_dependencies(self, is_init): if not self.validate_args: return [] assertions = [] if self._rate1 is not None: if is_init != tensor_util.is_ref(self._rate1): assertions.append(assert_util.assert_non_negative( self._rate1, message='Argument `rate1` must be non-negative.')) if self._rate2 is not None: if is_init != tensor_util.is_ref(self._rate2): assertions.append(assert_util.assert_non_negative( self._rate2, message='Argument `rate2` must be non-negative.')) return assertions
def _inverse(self, y): # To derive the inverse mapping note that: # y[i] = exp(x[i]) / normalization # and # y[end] = 1 / normalization. # Thus: # x[i] = log(exp(x[i])) - log(y[end]) - log(normalization) # = log(exp(x[i])/normalization) - log(y[end]) # = log(y[i]) - log(y[end]) # Do this first to make sure CSE catches that it'll happen again in # _inverse_log_det_jacobian. assertions = [] if self.validate_args: assertions.append(assert_util.assert_near( tf.reduce_sum(y, axis=-1), tf.ones([], y.dtype), 2. * np.finfo(dtype_util.as_numpy_dtype(y.dtype)).eps, message='Last dimension of `y` must sum to `1`.')) assertions.append(assert_util.assert_less_equal( y, tf.ones([], y.dtype), message='Elements of `y` must be less than or equal to `1`.')) assertions.append(assert_util.assert_non_negative( y, message='Elements of `y` must be non-negative.')) with tf.control_dependencies(assertions): x = tf.math.log(y) x, log_normalization = tf.split(x, num_or_size_splits=[-1, 1], axis=-1) return x - log_normalization
def _maybe_assert_valid_x(self, x): if not self.validate_args or self.power == 0.: return [] return [assert_util.assert_non_negative( 1. + self.power * x, message='Forward transformation input must be at least {}.'.format( -1. / self.power))]
def _inverse(self, y): if self.validate_args: y = distribution_util.with_dependencies([ assert_util.assert_non_negative( y, message="Argument y was negative") ], y) return -y, y
def _parameter_control_dependencies(self, is_init): assertions = [] if is_init and self.validate_args: # assert_categorical_event_shape handles both the static and dynamic case. assertions.extend( distribution_util.assert_categorical_event_shape( self._concentration)) if is_init != tensor_util.is_ref(self._total_count): if self.validate_args: total_count = tf.convert_to_tensor(self._total_count) assertions.append( distribution_util.assert_casting_closed( total_count, target_dtype=tf.int32, message= 'total_count cannot contain fractional components.')) assertions.append( assert_util.assert_non_negative( total_count, message='total_count must be non-negative')) if is_init != tensor_util.is_ref(self._concentration): if self.validate_args: assertions.append( assert_util.assert_positive( self._concentration, message='Concentration parameter must be positive.')) return assertions
def _parameter_control_dependencies(self, is_init): assertions = [] logits = self._logits probs = self._probs param, name = (probs, 'probs') if logits is None else (logits, 'logits') # In init, we can always build shape and dtype checks because # we assume shape doesn't change for Variable backed args. if is_init: if not dtype_util.is_floating(param.dtype): raise TypeError('Argument `{}` must having floating type.'.format(name)) msg = 'Argument `{}` must have rank at least 1.'.format(name) shape_static = tensorshape_util.dims(param.shape) if shape_static is not None: if len(shape_static) < 1: raise ValueError(msg) elif self.validate_args: param = tf.convert_to_tensor(param) assertions.append( assert_util.assert_rank_at_least(param, 1, message=msg)) with tf.control_dependencies(assertions): param = tf.identity(param) msg1 = 'Argument `{}` must have final dimension >= 1.'.format(name) msg2 = 'Argument `{}` must have final dimension <= {}.'.format( name, dtype_util.max(tf.int32)) event_size = shape_static[-1] if shape_static is not None else None if event_size is not None: if event_size < 1: raise ValueError(msg1) if event_size > dtype_util.max(tf.int32): raise ValueError(msg2) elif self.validate_args: param = tf.convert_to_tensor(param) assertions.append(assert_util.assert_greater_equal( tf.shape(param)[-1], 1, message=msg1)) # NOTE: For now, we leave out a runtime assertion that # `tf.shape(param)[-1] <= tf.int32.max`. An earlier `tf.shape` call # will fail before we get to this point. if not self.validate_args: assert not assertions # Should never happen. return [] if probs is not None: probs = param # reuse tensor conversion from above if is_init != tensor_util.is_ref(probs): probs = tf.convert_to_tensor(probs) one = tf.ones([], dtype=probs.dtype) assertions.extend([ assert_util.assert_non_negative(probs), assert_util.assert_less_equal(probs, one), assert_util.assert_near( tf.reduce_sum(probs, axis=-1), one, message='Argument `probs` must sum to 1.'), ]) return assertions
def _parameter_control_dependencies(self, is_init): if not self.validate_args: # Avoid computing intermediates needed to construct the assertions. return [] assertions = [] if is_init != tensor_util.is_ref(self._batch_shape_unexpanded): implicit_dim_mask = ps.equal(self._batch_shape_unexpanded, -1) assertions.append( assert_util.assert_rank(self._batch_shape_unexpanded, 1, message='New shape must be a vector.')) assertions.append( assert_util.assert_less_equal( tf.math.count_nonzero(implicit_dim_mask, dtype=tf.int32), 1, message='At most one dimension can be unknown.')) assertions.append( assert_util.assert_non_negative( self._batch_shape_unexpanded + 1, message='Shape elements must be >=-1.')) # Check that the old and new shapes are the same size. expanded_new_shape, original_size = self._calculate_new_shape() new_size = ps.reduce_prod(expanded_new_shape) assertions.append( assert_util.assert_equal(new_size, tf.cast(original_size, new_size.dtype), message='Shape sizes do not match.')) return assertions
def _parameter_control_dependencies(self, is_init): if not self.validate_args: return [] assertions = [] if is_init != tensor_util.is_ref(self._temperature): msg1 = 'Argument `temperature` must be positive.' temperature = tf.convert_to_tensor(self._temperature) assertions.append( assert_util.assert_positive(temperature, message=msg1)) if self._probs is not None: if is_init != tensor_util.is_ref(self._probs): probs = tf.convert_to_tensor(self._probs) one = tf.constant(1., probs.dtype) assertions.extend([ assert_util.assert_non_negative( probs, message='Argument `probs` has components less than 0.' ), assert_util.assert_less_equal( probs, one, message= 'Argument `probs` has components greater than 1.') ]) return assertions
def maybe_assert_negative_binomial_param_correctness( is_init, validate_args, total_count, probs, logits): """Return assertions for `NegativeBinomial`-type distributions.""" if is_init: x, name = (probs, 'probs') if logits is None else (logits, 'logits') if not dtype_util.is_floating(x.dtype): raise TypeError( 'Argument `{}` must having floating type.'.format(name)) if not validate_args: return [] assertions = [] if is_init != tensor_util.is_ref(total_count): total_count = tf.convert_to_tensor(total_count) assertions.extend([ assert_util.assert_positive( total_count, message='`total_count` has components less than or equal to 0.'), distribution_util.assert_integer_form( total_count, message='`total_count` has fractional components.') ]) if probs is not None: if is_init != tensor_util.is_ref(probs): probs = tf.convert_to_tensor(probs) one = tf.constant(1., probs.dtype) assertions.extend([ assert_util.assert_non_negative( probs, message='`probs` has components less than 0.'), assert_util.assert_less_equal( probs, one, message='`probs` has components greater than 1.') ]) return assertions
def _sample_control_dependencies(self, x): assertions = [] if not self.validate_args: return assertions assertions.append(assert_util.assert_non_negative( x, message='Sample must be non-negative.')) return assertions
def _maybe_validate_rightmost_transposed_ndims( rightmost_transposed_ndims, validate_args, name=None): """Checks that `rightmost_transposed_ndims` is valid.""" with tf.name_scope(name or 'maybe_validate_rightmost_transposed_ndims'): assertions = [] if not dtype_util.is_integer(rightmost_transposed_ndims.dtype): raise TypeError('`rightmost_transposed_ndims` must be integer type.') if tensorshape_util.rank(rightmost_transposed_ndims.shape) is not None: if tensorshape_util.rank(rightmost_transposed_ndims.shape) != 0: raise ValueError('`rightmost_transposed_ndims` must be a scalar, ' 'saw rank: {}.'.format( tensorshape_util.rank( rightmost_transposed_ndims.shape))) elif validate_args: assertions += [assert_util.assert_rank(rightmost_transposed_ndims, 0)] rightmost_transposed_ndims_ = tf.get_static_value( rightmost_transposed_ndims) msg = '`rightmost_transposed_ndims` must be non-negative.' if rightmost_transposed_ndims_ is not None: if rightmost_transposed_ndims_ < 0: raise ValueError(msg[:-1] + ', saw: {}.'.format( rightmost_transposed_ndims_)) elif validate_args: assertions += [ assert_util.assert_non_negative( rightmost_transposed_ndims, message=msg) ] return assertions
def _assertions(self, t): if not self.validate_args: return [] return [ assert_util.assert_non_negative(t, message="Argument y was negative") ]
def maybe_assert_continuous_bernoulli_param_correctness( is_init, validate_args, probs, logits): """Return assertions for `Bernoulli`-type distributions.""" if is_init: x, name = (probs, 'probs') if logits is None else (logits, 'logits') if not dtype_util.is_floating(x.dtype): raise TypeError('Argument `{}` must having floating type.'.format(name)) if not validate_args: return [] assertions = [] if probs is not None: if is_init != tensor_util.is_ref(probs): probs = tf.convert_to_tensor(probs) one = tf.constant(1.0, probs.dtype) assertions += [ assert_util.assert_non_negative( probs, message='probs has components less than 0.'), assert_util.assert_less_equal( probs, one, message='probs has components greater than 1.')] return assertions
def _parameter_control_dependencies(self, is_init): if not self.validate_args: return [] mean_direction = tf.convert_to_tensor(self.mean_direction) concentration = tf.convert_to_tensor(self.concentration) assertions = [] if is_init != tensor_util.is_ref(self._mean_direction): assertions.append( assert_util.assert_greater( tf.shape(mean_direction)[-1], 1, message= '`mean_direction` must be a vector of at least size 2.')) assertions.append( assert_util.assert_near( tf.cast(1., self.dtype), tf.linalg.norm(mean_direction, axis=-1), message='`mean_direction` must be unit-length')) if is_init != tensor_util.is_ref(self._concentration): assertions.append( assert_util.assert_non_negative( concentration, message='`concentration` must be non-negative')) return assertions
def _assertions(self, t): if not self.validate_args: return [] return [ assert_util.assert_non_negative( t, message='All elements must be non-negative.') ]
def _parameter_control_dependencies(self, is_init): if not self.validate_args: return [] mean_direction = tf.convert_to_tensor(self.mean_direction) concentration = tf.convert_to_tensor(self.concentration) assertions = [] if is_init != tensor_util.is_ref(self._mean_direction): assertions.append( assert_util.assert_greater( tf.shape(mean_direction)[-1], 1, message='`mean_direction` may not have scalar event shape')) assertions.append( assert_util.assert_less_equal( tf.shape(mean_direction)[-1], 5, message='von Mises-Fisher ndims > 5 is not currently supported')) assertions.append( assert_util.assert_near( 1., tf.linalg.norm(mean_direction, axis=-1), message='`mean_direction` must be unit-length')) if is_init != tensor_util.is_ref(self._concentration): assertions.append( assert_util.assert_non_negative( concentration, message='`concentration` must be non-negative')) return assertions
def _maybe_assert_valid_x(self, x): if not self.validate_args: return [] return [ assert_util.assert_non_negative( x, message='Forward transformation input must be at least 0.') ]
def _maybe_validate_shape_override(self, override_shape, base_is_scalar, validate_args, name): """Helper to __init__ which ensures override batch/event_shape are valid.""" if override_shape is None: override_shape = [] override_shape = tf.convert_to_tensor(value=override_shape, dtype=tf.int32, name=name) if not dtype_util.is_integer(override_shape.dtype): raise TypeError("shape override must be an integer") override_is_scalar = _is_scalar_from_shape_tensor(override_shape) if tf.get_static_value(override_is_scalar): return self._empty dynamic_assertions = [] if tensorshape_util.rank(override_shape.shape) is not None: if tensorshape_util.rank(override_shape.shape) != 1: raise ValueError("shape override must be a vector") elif validate_args: dynamic_assertions += [ assert_util.assert_rank( override_shape, 1, message="shape override must be a vector") ] if tf.get_static_value(override_shape) is not None: if any(s < 0 for s in tf.get_static_value(override_shape)): raise ValueError( "shape override must have non-negative elements") elif validate_args: dynamic_assertions += [ assert_util.assert_non_negative( override_shape, message="shape override must have non-negative elements") ] is_both_nonscalar = prefer_static.logical_and( prefer_static.logical_not(base_is_scalar), prefer_static.logical_not(override_is_scalar)) if tf.get_static_value(is_both_nonscalar) is not None: if tf.get_static_value(is_both_nonscalar): raise ValueError("base distribution not scalar") elif validate_args: dynamic_assertions += [ assert_util.assert_equal( is_both_nonscalar, False, message="base distribution not scalar") ] if not dynamic_assertions: return override_shape return distribution_util.with_dependencies(dynamic_assertions, override_shape)
def _maybe_assert_valid_x(self, x): if not self.validate_args or self.power == 0.: return x is_valid = assert_util.assert_non_negative( 1. + self.power * x, message="Forward transformation input must be at least {}.".format( -1. / self.power)) return distribution_util.with_dependencies([is_valid], x)
def _sample_control_dependencies(self, x): dtype_util.assert_same_float_dtype(tensors=[x], dtype=self.dtype) assertions = [] if not self.validate_args: return assertions assertions.append(assert_util.assert_non_negative( x, message='Sample must be non-negative.')) return assertions
def _parameter_control_dependencies(self, is_init): if not self.validate_args: return [] assertions = [] if is_init != tensor_util.is_ref(self.concentration): assertions.append(assert_util.assert_non_negative( self.concentration, message='Argument `concentration` must be non-negative.')) return assertions
def _maybe_assert_valid_sample(self, x, dtype): if not self.validate_args: return x one = tf.ones([], dtype=dtype) return distribution_util.with_dependencies([ assert_util.assert_non_negative(x), assert_util.assert_less_equal(x, one), assert_util.assert_near(one, tf.reduce_sum(x, axis=[-1])), ], x)
def _sample_control_dependencies(self, x): assertions = [] if not self.validate_args: return assertions loc = 0. if self.loc is None else tf.convert_to_tensor(self.loc) y = self.scale.solvevec(x - loc) assertions.append(assert_util.assert_non_negative( y, message='Sample is not contained in the support.')) return assertions
def _sample_control_dependencies(self, x): assertions = [] if not self.validate_args: return assertions assertions.append(assert_util.assert_non_negative( x, message='Sample must be non-negative.')) assertions.append(assert_util.assert_less_equal( x, tf.ones([], x.dtype), message='Sample must be less than or equal to `1`.')) return assertions
def _maybe_assert_valid_y(self, y): if not self.validate_args: return [] return [ assert_util.assert_non_negative( y, message=( 'Inverse transformation input must be greater than or ' 'equal to 0.')) ]
def _maybe_assert_valid_total_count(self, total_count, validate_args): if not validate_args: return total_count return distribution_util.with_dependencies([ assert_util.assert_non_negative( total_count, message='total_count must be non-negative.'), distribution_util.assert_integer_form( total_count, message='total_count cannot contain fractional components.'), ], total_count)
def _maybe_assert_valid_y(self, y): if not self.validate_args: return [] is_positive = assert_util.assert_non_negative( y, message='Inverse transformation input must be greater than 0.') less_than_one = assert_util.assert_less_equal( y, tf.constant(1., y.dtype), message='Inverse transformation input must be less than or equal to 1.') return [is_positive, less_than_one]
def _sample_control_dependencies(self, x): assertions = [] if not self.validate_args: return assertions assertions.append(assert_util.assert_non_negative( x, message='samples must be non-negative')) if not self.interpolate_nondiscrete: assertions.append(distribution_util.assert_integer_form( x, message='samples cannot contain fractional components.')) return assertions
def _maybe_assert_valid_y(self, y): if not self.validate_args: return y is_positive = assert_util.assert_non_negative( y, message="Inverse transformation input must be greater than 0.") less_than_one = assert_util.assert_less_equal( y, tf.constant(1., y.dtype), message="Inverse transformation input must be less than or equal to 1.") return distribution_util.with_dependencies([is_positive, less_than_one], y)