def maybe_assert_categorical_param_correctness(is_init, validate_args, probs, logits): """Return assertions for `Categorical`-type distributions.""" assertions = [] # In init, we can always build shape and dtype checks because # we assume shape doesn't change for Variable backed args. if is_init: x, name = (probs, 'probs') if logits is None else (logits, 'logits') if not dtype_util.is_floating(x.dtype): raise TypeError( 'Argument `{}` must having floating type.'.format(name)) msg = 'Argument `{}` must have rank at least 1.'.format(name) ndims = tensorshape_util.rank(x.shape) if ndims is not None: if ndims < 1: raise ValueError(msg) elif validate_args: x = tf.convert_to_tensor(x) probs = x if logits is None else None # Retain tensor conversion. logits = x if probs is None else None assertions.append( assert_util.assert_rank_at_least(x, 1, message=msg)) if not validate_args: assert not assertions # Should never happen. return [] if logits is not None: if is_init != tensor_util.is_mutable(logits): logits = tf.convert_to_tensor(logits) assertions.extend( distribution_util.assert_categorical_event_shape(logits)) if probs is not None: if is_init != tensor_util.is_mutable(probs): probs = tf.convert_to_tensor(probs) assertions.extend([ assert_util.assert_non_negative(probs), assert_util.assert_near( tf.reduce_sum(probs, axis=-1), np.array(1, dtype=dtype_util.as_numpy_dtype(probs.dtype)), message='Argument `probs` must sum to 1.') ]) assertions.extend( distribution_util.assert_categorical_event_shape(probs)) return assertions
def _parameter_control_dependencies(self, is_init): assertions = [] if is_init and self.validate_args: # assert_categorical_event_shape handles both the static and dynamic case. assertions.extend( distribution_util.assert_categorical_event_shape( self._concentration)) if is_init != tensor_util.is_ref(self._total_count): if self.validate_args: total_count = tf.convert_to_tensor(self._total_count) assertions.append( distribution_util.assert_casting_closed( total_count, target_dtype=tf.int32, message= 'total_count cannot contain fractional components.')) assertions.append( assert_util.assert_non_negative( total_count, message='total_count must be non-negative')) if is_init != tensor_util.is_ref(self._concentration): if self.validate_args: assertions.append( assert_util.assert_positive( self._concentration, message='Concentration parameter must be positive.')) return assertions
def _parameter_control_dependencies(self, is_init): assertions = [] if is_init and self.validate_args: # assert_categorical_event_shape handles both the static and dynamic case. assertions.extend( distribution_util.assert_categorical_event_shape(self._concentration)) if is_init != tensor_util.is_ref(self._total_count): if self.validate_args: assertions.extend( distribution_util.assert_nonnegative_integer_form( self._total_count)) if is_init != tensor_util.is_ref(self._concentration): if self.validate_args: assertions.append( assert_util.assert_positive( self._concentration, message='Concentration parameter must be positive.')) return assertions