def _parameter_control_dependencies(self, is_init):
        assertions = []

        # In init, we can always build shape and dtype checks because
        # we assume shape doesn't change for Variable backed args.
        if is_init and self._is_vector:
            msg = "Argument `loc` must be at least rank 1."
            if tensorshape_util.rank(self.loc.shape) is not None:
                if tensorshape_util.rank(self.loc.shape) < 1:
                    raise ValueError(msg)
            elif self.validate_args:
                assertions.append(
                    assert_util.assert_rank_at_least(self.loc, 1, message=msg))

        if not self.validate_args:
            assert not assertions  # Should never happen
            return []

        if is_init != tensor_util.is_ref(self.atol):
            assertions.append(
                assert_util.assert_non_negative(
                    self.atol, message="Argument 'atol' must be non-negative"))
        if is_init != tensor_util.is_ref(self.rtol):
            assertions.append(
                assert_util.assert_non_negative(
                    self.rtol, message="Argument 'rtol' must be non-negative"))
        return assertions
Example #2
0
 def _parameter_control_dependencies(self, is_init):
     if not self.validate_args:
         return []
     assertions = []
     if (is_init != tensor_util.is_ref(self.low)
             and is_init != tensor_util.is_ref(self.high)):
         assertions.append(
             assert_util.assert_less(
                 self.low,
                 self.high,
                 message='uniform not defined when low >= high.'))
     return assertions
Example #3
0
 def _parameter_control_dependencies(self, is_init):
   if not self.validate_args:
     return []
   assertions = []
   if is_init != tensor_util.is_ref(self.concentration):
     assertions.append(
         assert_util.assert_positive(
             self.concentration, message='`concentration` must be positive.'))
   if is_init != tensor_util.is_ref(self.scale):
     assertions.append(
         assert_util.assert_positive(
             self.scale, message='`scale` must be positive.'))
   return assertions
def maybe_assert_categorical_param_correctness(is_init, validate_args, probs,
                                               logits):
    """Return assertions for `Categorical`-type distributions."""
    assertions = []

    # In init, we can always build shape and dtype checks because
    # we assume shape doesn't change for Variable backed args.
    if is_init:
        x, name = (probs, 'probs') if logits is None else (logits, 'logits')

        if not dtype_util.is_floating(x.dtype):
            raise TypeError(
                'Argument `{}` must having floating type.'.format(name))

        msg = 'Argument `{}` must have rank at least 1.'.format(name)
        ndims = tensorshape_util.rank(x.shape)
        if ndims is not None:
            if ndims < 1:
                raise ValueError(msg)
        elif validate_args:
            x = tf.convert_to_tensor(x)
            probs = x if logits is None else None  # Retain tensor conversion.
            logits = x if probs is None else None
            assertions.append(
                assert_util.assert_rank_at_least(x, 1, message=msg))

    if not validate_args:
        assert not assertions  # Should never happen.
        return []

    if logits is not None:
        if is_init != tensor_util.is_ref(logits):
            logits = tf.convert_to_tensor(logits)
            assertions.extend(
                distribution_util.assert_categorical_event_shape(logits))

    if probs is not None:
        if is_init != tensor_util.is_ref(probs):
            probs = tf.convert_to_tensor(probs)
            assertions.extend([
                assert_util.assert_non_negative(probs),
                assert_util.assert_near(
                    tf.reduce_sum(probs, axis=-1),
                    np.array(1, dtype=dtype_util.as_numpy_dtype(probs.dtype)),
                    message='Argument `probs` must sum to 1.')
            ])
            assertions.extend(
                distribution_util.assert_categorical_event_shape(probs))

    return assertions
 def _parameter_control_dependencies(self, is_init):
     if not self.validate_args:
         return []
     assertions = []
     if is_init != tensor_util.is_ref(self.concentration0):
         assertions.append(
             assert_util.assert_positive(
                 self.concentration0,
                 message="Argument `concentration0` must be positive."))
     if is_init != tensor_util.is_ref(self.concentration1):
         assertions.append(
             assert_util.assert_positive(
                 self.concentration1,
                 message="Argument `concentration1` must be positive."))
     return assertions
def maybe_assert_bernoulli_param_correctness(is_init, validate_args, probs,
                                             probits):
    """Return assertions for `ProbitBernoulli`-type distributions."""
    if is_init:
        x, name = (probs, 'probs') if probits is None else (probits, 'probits')
        if not dtype_util.is_floating(x.dtype):
            raise TypeError(
                'Argument `{}` must having floating type.'.format(name))

    if not validate_args:
        return []

    assertions = []

    if probs is not None:
        if is_init != tensor_util.is_ref(probs):
            probs = tf.convert_to_tensor(probs)
            one = tf.constant(1., probs.dtype)
            assertions += [
                assert_util.assert_non_negative(
                    probs, message='probs has components less than 0.'),
                assert_util.assert_less_equal(
                    probs, one, message='probs has components greater than 1.')
            ]

    return assertions
Example #7
0
 def _parameter_control_dependencies(self, is_init):
   if not self.validate_args:
     return []
   assertions = []
   if is_init != tensor_util.is_ref(self.power):
     assertions.append(assert_util.assert_greater(
         self.power, np.ones([], self.power.dtype.as_numpy_dtype),
         message='`power` must be greater than 1.'))
   return assertions
Example #8
0
 def _parameter_control_dependencies(self, is_init):
     if not self.validate_args:
         return []
     assertions = []
     if is_init != tensor_util.is_ref(self._df):
         assertions.append(
             assert_util.assert_positive(
                 self._df, message='Argument `df` must be positive.'))
     return assertions
 def _parameter_control_dependencies(self, is_init):
     assertions = categorical_lib.maybe_assert_categorical_param_correctness(
         is_init, self.validate_args, self._probs, self._logits)
     if not self.validate_args:
         return assertions
     if is_init != tensor_util.is_ref(self.total_count):
         assertions.extend(
             distribution_util.assert_nonnegative_integer_form(
                 self.total_count))
     return assertions
Example #10
0
 def _parameter_control_dependencies(self, is_init):
     if not self.validate_args:
         return []
     assertions = []
     for concentration in [self.concentration0, self.concentration1]:
         if is_init != tensor_util.is_ref(concentration):
             assertions.append(
                 assert_util.assert_positive(
                     concentration,
                     message="Concentration parameter must be positive."))
     return assertions
Example #11
0
 def _parameter_control_dependencies(self, is_init):
     if is_init:
         dtype_util.assert_same_float_dtype([self.loc, self.scale])
     if not self.validate_args:
         return []
     assertions = []
     if is_init != tensor_util.is_ref(self._scale):
         assertions.append(
             assert_util.assert_positive(
                 self._scale, message='Argument `scale` must be positive.'))
     return assertions
Example #12
0
 def _parameter_control_dependencies(self, is_init):
   if not self.validate_args:
     return []
   assertions = []
   if (self.hinge_softness is not None and
       is_init != tensor_util.is_ref(self.hinge_softness)):
     assertions.append(assert_util.assert_none_equal(
         dtype_util.as_numpy_dtype(self._hinge_softness.dtype)(0),
         self.hinge_softness,
         message='Argument `hinge_softness` must be non-zero.'))
   return assertions
Example #13
0
 def _parameter_control_dependencies(self, is_init):
     if not self.validate_args:
         return []
     assertions = []
     if is_init != tensor_util.is_ref(self.concentration):
         # concentration >= 1
         # TODO(b/111451422, b/115950951) Generalize to concentration > 0.
         assertions.append(
             assert_util.assert_non_negative(
                 self.concentration - 1,
                 message='Argument `concentration` must be >= 1.'))
     return assertions
 def _parameter_control_dependencies(self, is_init):
     if not self.validate_args:
         return []
     assertions = []
     if (self.scale is not None
             and is_init != tensor_util.is_ref(self.scale)):
         assertions.append(
             assert_util.assert_none_equal(
                 self.scale,
                 tf.zeros([], dtype=self._scale.dtype),
                 message='Argument `scale` must be non-zero.'))
     return assertions
Example #15
0
 def _parameter_control_dependencies(self, is_init):
   if not self.validate_args:
     return []
   if is_init == tensor_util.is_ref(self.total_count):
     return []
   total_count = tf.convert_to_tensor(self.total_count)
   msg1 = 'Argument `total_count` must be non-negative.'
   msg2 = 'Argument `total_count` cannot contain fractional components.'
   return [
       assert_util.assert_non_negative(total_count, message=msg1),
       distribution_util.assert_integer_form(total_count, message=msg2),
   ]
Example #16
0
    def _parameter_control_dependencies(self, is_init):
        if not self.validate_args:
            return []
        low = tf.convert_to_tensor(self.low)
        high = tf.convert_to_tensor(self.high)
        peak = tf.convert_to_tensor(self.peak)
        assertions = []
        if (is_init != tensor_util.is_ref(self.low)
                and is_init != tensor_util.is_ref(self.high)):
            assertions.append(
                assert_util.assert_less(
                    low,
                    high,
                    message='triangular not defined when low >= high.'))
        if (is_init != tensor_util.is_ref(self.low)
                and is_init != tensor_util.is_ref(self.peak)):
            assertions.append(
                assert_util.assert_less(
                    low,
                    peak,
                    message='triangular not defined when low > peak.'))
        if (is_init != tensor_util.is_ref(self.high)
                and is_init != tensor_util.is_ref(self.peak)):
            assertions.append(
                assert_util.assert_less(
                    peak,
                    high,
                    message='triangular not defined when peak > high.'))

        return assertions
Example #17
0
    def _parameter_control_dependencies(self, is_init):
        assertions = []

        if is_init and self.validate_args:
            # assert_categorical_event_shape handles both the static and dynamic case.
            assertions.extend(
                distribution_util.assert_categorical_event_shape(
                    self._concentration))

        if is_init != tensor_util.is_ref(self._total_count):
            if self.validate_args:
                assertions.extend(
                    distribution_util.assert_nonnegative_integer_form(
                        self._total_count))

        if is_init != tensor_util.is_ref(self._concentration):
            if self.validate_args:
                assertions.append(
                    assert_util.assert_positive(
                        self._concentration,
                        message='Concentration parameter must be positive.'))
        return assertions
Example #18
0
def maybe_assert_negative_binomial_param_correctness(is_init, validate_args,
                                                     total_count, probs,
                                                     logits):
    """Return assertions for `NegativeBinomial`-type distributions."""
    if is_init:
        x, name = (probs, 'probs') if logits is None else (logits, 'logits')
        if not dtype_util.is_floating(x.dtype):
            raise TypeError(
                'Argument `{}` must having floating type.'.format(name))

    if not validate_args:
        return []

    assertions = []
    if is_init != tensor_util.is_ref(total_count):
        total_count = tf.convert_to_tensor(total_count)
        assertions.extend([
            assert_util.assert_non_negative(
                total_count,
                message='`total_count` has components less than 0.'),
            distribution_util.assert_integer_form(
                total_count,
                message='`total_count` has fractional components.')
        ])
    if probs is not None:
        if is_init != tensor_util.is_ref(probs):
            probs = tf.convert_to_tensor(probs)
            one = tf.constant(1., probs.dtype)
            assertions.extend([
                assert_util.assert_non_negative(
                    probs, message='`probs` has components less than 0.'),
                assert_util.assert_less_equal(
                    probs,
                    one,
                    message='`probs` has components greater than 1.')
            ])

    return assertions
Example #19
0
    def _parameter_control_dependencies(self, is_init):
        if not self.validate_args:
            return []
        assertions = []
        for param_name, param in dict(
                concentration=self.concentration,
                mixing_concentration=self.mixing_concentration,
                mixing_rate=self.mixing_rate).items():

            if is_init != tensor_util.is_ref(param):
                assertions.append(
                    assert_util.assert_positive(
                        param,
                        message='Argument `{}` must be positive.'.format(
                            param_name)))
        return assertions
Example #20
0
 def _parameter_control_dependencies(self, is_init):
     if not self.validate_args:
         return []
     assertions = []
     if self._probs is not None:
         if is_init != tensor_util.is_ref(self._probs):
             probs = tf.convert_to_tensor(self._probs)
             assertions.append(
                 assert_util.assert_positive(
                     probs, message='Argument `probs` must be positive.'))
             assertions.append(
                 assert_util.assert_less_equal(
                     probs,
                     dtype_util.as_numpy_dtype(self.dtype)(1.),
                     message=
                     'Argument `probs` must be less than or equal to 1.'))
     return assertions
Example #21
0
  def _parameter_control_dependencies(self, is_init):
    """Checks the validity of the concentration parameter."""
    assertions = []

    # In init, we can always build shape and dtype checks because
    # we assume shape doesn't change for Variable backed args.
    if is_init:
      if not dtype_util.is_floating(self.concentration.dtype):
        raise TypeError('Argument `concentration` must be float type.')

      msg = 'Argument `concentration` must have rank at least 1.'
      ndims = tensorshape_util.rank(self.concentration.shape)
      if ndims is not None:
        if ndims < 1:
          raise ValueError(msg)
      elif self.validate_args:
        assertions.append(assert_util.assert_rank_at_least(
            self.concentration, 1, message=msg))

      msg = 'Argument `concentration` must have `event_size` at least 2.'
      event_size = tf.compat.dimension_value(self.concentration.shape[-1])
      if event_size is not None:
        if event_size < 2:
          raise ValueError(msg)
      elif self.validate_args:
        assertions.append(assert_util.assert_less(
            1,
            tf.shape(self.concentration)[-1],
            message=msg))

    if not self.validate_args:
      assert not assertions  # Should never happen.
      return []

    if is_init != tensor_util.is_ref(self.concentration):
      assertions.append(assert_util.assert_positive(
          self.concentration,
          message='Argument `concentration` must be positive.'))

    return assertions
Example #22
0
    def _parameter_control_dependencies(self, is_init):
        assertions = []

        if is_init:
            try:
                self._batch_shape()
            except ValueError:
                raise ValueError(
                    'Arguments `loc` and `scale` must have compatible shapes; '
                    'loc.shape={}, scale.shape={}.'.format(
                        self.loc.shape, self.scale.shape))
            # We don't bother checking the shapes in the dynamic case because
            # all member functions access both arguments anyway.

        if not self.validate_args:
            assert not assertions  # Should never happen.
            return []

        if is_init != tensor_util.is_ref(self.scale):
            assertions.append(
                assert_util.assert_positive(
                    self.scale, message='Argument `scale` must be positive.'))

        return assertions
Example #23
0
    def _parameter_control_dependencies(self, is_init):
        assertions = []

        logits = self._logits
        probs = self._probs
        param, name = (probs, 'probs') if logits is None else (logits,
                                                               'logits')

        # In init, we can always build shape and dtype checks because
        # we assume shape doesn't change for Variable backed args.
        if is_init:
            if not dtype_util.is_floating(param.dtype):
                raise TypeError(
                    'Argument `{}` must having floating type.'.format(name))

            msg = 'Argument `{}` must have rank at least 1.'.format(name)
            shape_static = tensorshape_util.dims(param.shape)
            if shape_static is not None:
                if len(shape_static) < 1:
                    raise ValueError(msg)
            elif self.validate_args:
                param = tf.convert_to_tensor(param)
                assertions.append(
                    assert_util.assert_rank_at_least(param, 1, message=msg))
                with tf.control_dependencies(assertions):
                    param = tf.identity(param)

            msg1 = 'Argument `{}` must have final dimension >= 1.'.format(name)
            msg2 = 'Argument `{}` must have final dimension <= {}.'.format(
                name, tf.int32.max)
            event_size = shape_static[-1] if shape_static is not None else None
            if event_size is not None:
                if event_size < 1:
                    raise ValueError(msg1)
                if event_size > tf.int32.max:
                    raise ValueError(msg2)
            elif self.validate_args:
                param = tf.convert_to_tensor(param)
                assertions.append(
                    assert_util.assert_greater_equal(tf.shape(param)[-1],
                                                     1,
                                                     message=msg1))
                # NOTE: For now, we leave out a runtime assertion that
                # `tf.shape(param)[-1] <= tf.int32.max`.  An earlier `tf.shape` call
                # will fail before we get to this point.

        if not self.validate_args:
            assert not assertions  # Should never happen.
            return []

        if probs is not None:
            probs = param  # reuse tensor conversion from above
            if is_init != tensor_util.is_ref(probs):
                probs = tf.convert_to_tensor(probs)
                one = tf.ones([], dtype=probs.dtype)
                assertions.extend([
                    assert_util.assert_non_negative(probs),
                    assert_util.assert_less_equal(probs, one),
                    assert_util.assert_near(
                        tf.reduce_sum(probs, axis=-1),
                        one,
                        message='Argument `probs` must sum to 1.'),
                ])

        return assertions
  def _parameter_control_dependencies(self, is_init):
    assertions = []

    # For `logits` and `probs`, we only want to have an assertion on what the
    # user actually passed. For now, we access the underlying categorical's
    # _logits and _probs directly. After the 2019-10-01 deprecation, it would
    # also work to use .logits() and .probs().
    logits = self._categorical._logits
    probs = self._categorical._probs
    outcomes = self._outcomes
    validate_args = self._validate_args

    # Build all shape and dtype checks during the `is_init` call.
    if is_init:
      def validate_equal_last_dim(tensor_a, tensor_b, message):
        event_size_a = tf.compat.dimension_value(tensor_a.shape[-1])
        event_size_b = tf.compat.dimension_value(tensor_b.shape[-1])
        if event_size_a is not None and event_size_b is not None:
          if event_size_a != event_size_b:
            raise ValueError(message)
        elif validate_args:
          return assert_util.assert_equal(
              tf.shape(tensor_a)[-1], tf.shape(tensor_b)[-1], message=message)

      message = 'Size of outcomes must be greater than 0.'
      if tensorshape_util.num_elements(outcomes.shape) is not None:
        if tensorshape_util.num_elements(outcomes.shape) == 0:
          raise ValueError(message)
      elif validate_args:
        assertions.append(
            tf.assert_greater(tf.size(outcomes), 0, message=message))

      if logits is not None:
        maybe_assert = validate_equal_last_dim(
            outcomes,
            # pylint: disable=protected-access
            self._categorical._logits,
            # pylint: enable=protected-access
            message='Last dimension of outcomes and logits must be equal size.')
        if maybe_assert:
          assertions.append(maybe_assert)

      if probs is not None:
        maybe_assert = validate_equal_last_dim(
            outcomes,
            probs,
            message='Last dimension of outcomes and probs must be equal size.')
        if maybe_assert:
          assertions.append(maybe_assert)

      message = 'Rank of outcomes must be 1.'
      ndims = tensorshape_util.rank(outcomes.shape)
      if ndims is not None:
        if ndims != 1:
          raise ValueError(message)
      elif validate_args:
        assertions.append(assert_util.assert_rank(outcomes, 1, message=message))

    if not validate_args:
      assert not assertions  # Should never happen.
      return []

    if is_init != tensor_util.is_ref(outcomes):
      assertions.append(
          assert_util.assert_equal(
              tf.math.is_strictly_increasing(outcomes),
              True,
              message='outcomes is not strictly increasing.'))

    return assertions