def _parameter_control_dependencies(self, is_init): assertions = [] if is_init and self.validate_args: # assert_categorical_event_shape handles both the static and dynamic case. assertions.extend( distribution_util.assert_categorical_event_shape( self._concentration)) if is_init != tensor_util.is_ref(self._total_count): if self.validate_args: total_count = tf.convert_to_tensor(self._total_count) assertions.append( distribution_util.assert_casting_closed( total_count, target_dtype=tf.int32, message= 'total_count cannot contain fractional components.')) assertions.append( assert_util.assert_non_negative( total_count, message='total_count must be non-negative')) if is_init != tensor_util.is_ref(self._concentration): if self.validate_args: assertions.append( assert_util.assert_positive( self._concentration, message='Concentration parameter must be positive.')) return assertions
def _parameter_control_dependencies(self, is_init): assertions = categorical_lib.maybe_assert_categorical_param_correctness( is_init, self.validate_args, self._probs, self._logits) if not self.validate_args: return assertions if is_init != tensor_util.is_ref(self.total_count): total_count = tf.convert_to_tensor(self.total_count) assertions.append(distribution_util.assert_casting_closed( total_count, target_dtype=tf.int32)) assertions.append(assert_util.assert_non_negative(total_count)) return assertions
def _sample_control_dependencies(self, x): assertions = [] if not self.validate_args: return assertions assertions.append(distribution_util.assert_casting_closed( x, target_dtype=tf.int32)) assertions.append(assert_util.assert_non_negative(x)) assertions.append( assert_util.assert_less_equal( x, tf.cast(self._num_categories(), x.dtype), message=('OrderedLogistic samples must be `>= 0` and `<= K` ' 'where `K` is the number of cutpoints.'))) return assertions
def _sample_control_dependencies(self, x): assertions = [] if not self.validate_args: return assertions assertions.append(distribution_util.assert_casting_closed( x, target_dtype=tf.int32)) assertions.append(assert_util.assert_non_negative( x, message='Categorical samples must be non-negative.')) assertions.append( assert_util.assert_less_equal( x, tf.cast(self._num_categories(), x.dtype), message=('Categorical samples must be between `0` and `n-1` ' 'where `n` is the number of categories.'))) return assertions
def _sample_control_dependencies(self, counts): """Check counts for proper values.""" assertions = [] if not self.validate_args: return assertions assertions.append(distribution_util.assert_casting_closed( counts, target_dtype=tf.int32, message='counts cannot contain fractional components.')) assertions.append(assert_util.assert_non_negative( counts, message='counts must be non-negative.')) assertions.append( assert_util.assert_less_equal( counts, self.total_count, message=('Sampled counts must be itemwise less than ' 'or equal to `total_count` parameter.'))) return assertions
def _parameter_control_dependencies(self, is_init): if not self.validate_args: return [] if is_init: try: self._batch_shape() except ValueError: raise ValueError( 'Arguments `total_count`, `low` and `high` must have compatible ' 'shapes; total_count.shape={}, low.shape={}, ' 'high.shape={}.'.format(tf.shape(self.total_count), tf.shape(self.low), tf.shape(self.high))) assertions = [] if is_init != tensor_util.is_ref(self.total_count): total_count = tf.convert_to_tensor(self.total_count) limit = BATES_TOTAL_COUNT_STABILITY_LIMITS[self.dtype] msg = '`total_count` must be representable as a 32-bit integer.' assertions.extend([ assert_util.assert_positive( total_count, message='`total_count` must be positive.'), distribution_util.assert_casting_closed(total_count, target_dtype=tf.int32, message=msg), assert_util.assert_less_equal( tf.cast(total_count, self.dtype), tf.cast(limit, self.dtype), message='`total_count` > {} is numerically unstable.'. format(limit)) ]) if is_init != (tensor_util.is_ref(self.low) or tensor_util.is_ref(self.high)): assertions.append( assert_util.assert_less( self.low, self.high, message='`low` must be less than `high`.')) return assertions