Exemple #1
0
def maybe_assert_categorical_param_correctness(is_init, validate_args, probs,
                                               logits):
    """Return assertions for `Categorical`-type distributions."""
    assertions = []

    # In init, we can always build shape and dtype checks because
    # we assume shape doesn't change for Variable backed args.
    if is_init:
        x, name = (probs, 'probs') if logits is None else (logits, 'logits')

        if not dtype_util.is_floating(x.dtype):
            raise TypeError(
                'Argument `{}` must having floating type.'.format(name))

        msg = 'Argument `{}` must have rank at least 1.'.format(name)
        ndims = tensorshape_util.rank(x.shape)
        if ndims is not None:
            if ndims < 1:
                raise ValueError(msg)
        elif validate_args:
            x = tf.convert_to_tensor(x)
            probs = x if logits is None else None  # Retain tensor conversion.
            logits = x if probs is None else None
            assertions.append(
                assert_util.assert_rank_at_least(x, 1, message=msg))

    if not validate_args:
        assert not assertions  # Should never happen.
        return []

    if logits is not None:
        if is_init != tensor_util.is_mutable(logits):
            logits = tf.convert_to_tensor(logits)
            assertions.extend(
                distribution_util.assert_categorical_event_shape(logits))

    if probs is not None:
        if is_init != tensor_util.is_mutable(probs):
            probs = tf.convert_to_tensor(probs)
            assertions.extend([
                assert_util.assert_non_negative(probs),
                assert_util.assert_near(
                    tf.reduce_sum(probs, axis=-1),
                    np.array(1, dtype=dtype_util.as_numpy_dtype(probs.dtype)),
                    message='Argument `probs` must sum to 1.')
            ])
            assertions.extend(
                distribution_util.assert_categorical_event_shape(probs))

    return assertions
Exemple #2
0
    def _parameter_control_dependencies(self, is_init):
        assertions = []

        if is_init and self.validate_args:
            # assert_categorical_event_shape handles both the static and dynamic case.
            assertions.extend(
                distribution_util.assert_categorical_event_shape(
                    self._concentration))

        if is_init != tensor_util.is_ref(self._total_count):
            if self.validate_args:
                total_count = tf.convert_to_tensor(self._total_count)
                assertions.append(
                    distribution_util.assert_casting_closed(
                        total_count,
                        target_dtype=tf.int32,
                        message=
                        'total_count cannot contain fractional components.'))
                assertions.append(
                    assert_util.assert_non_negative(
                        total_count,
                        message='total_count must be non-negative'))

        if is_init != tensor_util.is_ref(self._concentration):
            if self.validate_args:
                assertions.append(
                    assert_util.assert_positive(
                        self._concentration,
                        message='Concentration parameter must be positive.'))
        return assertions
Exemple #3
0
  def _parameter_control_dependencies(self, is_init):
    assertions = []

    if is_init and self.validate_args:
      # assert_categorical_event_shape handles both the static and dynamic case.
      assertions.extend(
          distribution_util.assert_categorical_event_shape(self._concentration))

    if is_init != tensor_util.is_ref(self._total_count):
      if self.validate_args:
        assertions.extend(
            distribution_util.assert_nonnegative_integer_form(
                self._total_count))

    if is_init != tensor_util.is_ref(self._concentration):
      if self.validate_args:
        assertions.append(
            assert_util.assert_positive(
                self._concentration,
                message='Concentration parameter must be positive.'))
    return assertions