Beispiel #1
0
    def _parameter_control_dependencies(self, is_init):
        assertions = []

        # In init, we can always build shape and dtype checks because
        # we assume shape doesn't change for Variable backed args.
        if is_init and self._is_vector:
            msg = "Argument `loc` must be at least rank 1."
            if tensorshape_util.rank(self.loc.shape) is not None:
                if tensorshape_util.rank(self.loc.shape) < 1:
                    raise ValueError(msg)
            elif self.validate_args:
                assertions.append(
                    assert_util.assert_rank_at_least(self.loc, 1, message=msg))

        if not self.validate_args:
            assert not assertions  # Should never happen
            return []

        if is_init != tensor_util.is_ref(self.atol):
            assertions.append(
                assert_util.assert_non_negative(
                    self.atol, message="Argument 'atol' must be non-negative"))
        if is_init != tensor_util.is_ref(self.rtol):
            assertions.append(
                assert_util.assert_non_negative(
                    self.rtol, message="Argument 'rtol' must be non-negative"))
        return assertions
Beispiel #2
0
 def _parameter_control_dependencies(self, is_init):
     if not self.validate_args:
         return []
     assertions = []
     if (is_init != tensor_util.is_ref(self.low)
             and is_init != tensor_util.is_ref(self.high)):
         assertions.append(
             assert_util.assert_less(
                 self.low,
                 self.high,
                 message='uniform not defined when low >= high.'))
     return assertions
Beispiel #3
0
 def _parameter_control_dependencies(self, is_init):
   if not self.validate_args:
     return []
   assertions = []
   if is_init != tensor_util.is_ref(self.scale):
     assertions.append(assert_util.assert_positive(
         self.scale,
         message='Argument `scale` must be positive.'))
   if is_init != tensor_util.is_ref(self.concentration):
     assertions.append(assert_util.assert_positive(
         self.concentration,
         message='Argument `concentration` must be positive.'))
   return assertions
def maybe_assert_categorical_param_correctness(is_init, validate_args, probs,
                                               logits):
    """Return assertions for `Categorical`-type distributions."""
    assertions = []

    # In init, we can always build shape and dtype checks because
    # we assume shape doesn't change for Variable backed args.
    if is_init:
        x, name = (probs, 'probs') if logits is None else (logits, 'logits')

        if not dtype_util.is_floating(x.dtype):
            raise TypeError(
                'Argument `{}` must having floating type.'.format(name))

        msg = 'Argument `{}` must have rank at least 1.'.format(name)
        ndims = tensorshape_util.rank(x.shape)
        if ndims is not None:
            if ndims < 1:
                raise ValueError(msg)
        elif validate_args:
            x = tf.convert_to_tensor(x)
            probs = x if logits is None else None  # Retain tensor conversion.
            logits = x if probs is None else None
            assertions.append(
                assert_util.assert_rank_at_least(x, 1, message=msg))

    if not validate_args:
        assert not assertions  # Should never happen.
        return []

    if logits is not None:
        if is_init != tensor_util.is_ref(logits):
            logits = tf.convert_to_tensor(logits)
            assertions.extend(
                distribution_util.assert_categorical_event_shape(logits))

    if probs is not None:
        if is_init != tensor_util.is_ref(probs):
            probs = tf.convert_to_tensor(probs)
            assertions.extend([
                assert_util.assert_non_negative(probs),
                assert_util.assert_near(
                    tf.reduce_sum(probs, axis=-1),
                    np.array(1, dtype=dtype_util.as_numpy_dtype(probs.dtype)),
                    message='Argument `probs` must sum to 1.')
            ])
            assertions.extend(
                distribution_util.assert_categorical_event_shape(probs))

    return assertions
Beispiel #5
0
 def _parameter_control_dependencies(self, is_init):
     if not self.validate_args:
         return []
     assertions = []
     if is_init != tensor_util.is_ref(self.concentration0):
         assertions.append(
             assert_util.assert_positive(
                 self.concentration0,
                 message="Argument `concentration0` must be positive."))
     if is_init != tensor_util.is_ref(self.concentration1):
         assertions.append(
             assert_util.assert_positive(
                 self.concentration1,
                 message="Argument `concentration1` must be positive."))
     return assertions
def maybe_assert_bernoulli_param_correctness(
    is_init, validate_args, probs, logits):
  """Return assertions for `Bernoulli`-type distributions."""
  if is_init:
    x, name = (probs, 'probs') if logits is None else (logits, 'logits')
    if not dtype_util.is_floating(x.dtype):
      raise TypeError(
          'Argument `{}` must having floating type.'.format(name))

  if not validate_args:
    return []

  assertions = []

  if probs is not None:
    if is_init != tensor_util.is_ref(probs):
      probs = tf.convert_to_tensor(probs)
      one = tf.constant(1., probs.dtype)
      assertions += [
          assert_util.assert_non_negative(
              probs, message='probs has components less than 0.'),
          assert_util.assert_less_equal(
              probs, one, message='probs has components greater than 1.')
      ]

  return assertions
Beispiel #7
0
 def _parameter_control_dependencies(self, is_init):
   if not self.validate_args:
     return []
   assertions = []
   if self._rate is not None:
     if is_init != tensor_util.is_ref(self._rate):
       assertions.append(assert_util.assert_positive(
           self._rate,
           message='Argument `rate` must be positive.'))
   return assertions
Beispiel #8
0
 def _parameter_control_dependencies(self, is_init):
     assertions = categorical_lib.maybe_assert_categorical_param_correctness(
         is_init, self.validate_args, self._probs, self._logits)
     if not self.validate_args:
         return assertions
     if is_init != tensor_util.is_ref(self.total_count):
         assertions.extend(
             distribution_util.assert_nonnegative_integer_form(
                 self.total_count))
     return assertions
Beispiel #9
0
 def _parameter_control_dependencies(self, is_init):
   if is_init:
     dtype_util.assert_same_float_dtype([self.loc, self.scale])
   if not self.validate_args:
     return []
   assertions = []
   if is_init != tensor_util.is_ref(self._scale):
     assertions.append(assert_util.assert_positive(
         self._scale, message='Argument `scale` must be positive.'))
   return assertions
Beispiel #10
0
 def _parameter_control_dependencies(self, is_init):
     if not self.validate_args:
         return []
     assertions = []
     if is_init != tensor_util.is_ref(self.tailweight):
         assertions.append(
             assert_util.assert_positive(
                 self.tailweight,
                 message="Argument `tailweight` must be positive."))
     return assertions
Beispiel #11
0
 def _parameter_control_dependencies(self, is_init):
     if not self.validate_args:
         return []
     assertions = []
     if is_init != tensor_util.is_ref(self.power):
         assertions.append(
             assert_util.assert_greater(
                 self.power,
                 np.ones([], self.power.dtype.as_numpy_dtype),
                 message='`power` must be greater than 1.'))
     return assertions
Beispiel #12
0
 def _parameter_control_dependencies(self, is_init):
     if not self.validate_args:
         return []
     assertions = []
     for concentration in [self.concentration0, self.concentration1]:
         if is_init != tensor_util.is_ref(concentration):
             assertions.append(
                 assert_util.assert_positive(
                     concentration,
                     message="Concentration parameter must be positive."))
     return assertions
Beispiel #13
0
 def _parameter_control_dependencies(self, is_init):
     if not self.validate_args:
         return []
     assertions = []
     if is_init != tensor_util.is_ref(self.concentration):
         # concentration >= 1
         # TODO(b/111451422, b/115950951) Generalize to concentration > 0.
         assertions.append(
             assert_util.assert_non_negative(
                 self.concentration - 1,
                 message='Argument `concentration` must be >= 1.'))
     return assertions
Beispiel #14
0
 def _parameter_control_dependencies(self, is_init):
     if not self.validate_args:
         return []
     if is_init == tensor_util.is_ref(self.total_count):
         return []
     total_count = tf.convert_to_tensor(self.total_count)
     msg1 = 'Argument `total_count` must be non-negative.'
     msg2 = 'Argument `total_count` cannot contain fractional components.'
     return [
         assert_util.assert_non_negative(total_count, message=msg1),
         distribution_util.assert_integer_form(total_count, message=msg2),
     ]
Beispiel #15
0
 def _parameter_control_dependencies(self, is_init):
     if not self.validate_args:
         return []
     assertions = []
     if (self.hinge_softness is not None
             and is_init != tensor_util.is_ref(self.hinge_softness)):
         assertions.append(
             assert_util.assert_none_equal(
                 dtype_util.as_numpy_dtype(self._hinge_softness.dtype)(0),
                 self.hinge_softness,
                 message='Argument `hinge_softness` must be non-zero.'))
     return assertions
  def _parameter_control_dependencies(self, is_init):
    assertions = []

    if is_init and self.validate_args:
      # assert_categorical_event_shape handles both the static and dynamic case.
      assertions.extend(
          distribution_util.assert_categorical_event_shape(self._concentration))

    if is_init != tensor_util.is_ref(self._total_count):
      if self.validate_args:
        assertions.extend(
            distribution_util.assert_nonnegative_integer_form(
                self._total_count))

    if is_init != tensor_util.is_ref(self._concentration):
      if self.validate_args:
        assertions.append(
            assert_util.assert_positive(
                self._concentration,
                message='Concentration parameter must be positive.'))
    return assertions
 def _parameter_control_dependencies(self, is_init):
     if not self.validate_args:
         return []
     assertions = []
     if (self.scale is not None
             and is_init != tensor_util.is_ref(self.scale)):
         assertions.append(
             assert_util.assert_none_equal(
                 self.scale,
                 tf.zeros([], dtype=self._scale.dtype),
                 message='Argument `scale` must be non-zero.'))
     return assertions
    def _parameter_control_dependencies(self, is_init):
        if not self.validate_args:
            return []
        low = tf.convert_to_tensor(self.low)
        high = tf.convert_to_tensor(self.high)
        peak = tf.convert_to_tensor(self.peak)
        assertions = []
        if (is_init != tensor_util.is_ref(self.low)
                and is_init != tensor_util.is_ref(self.high)):
            assertions.append(
                assert_util.assert_less(
                    low,
                    high,
                    message='triangular not defined when low >= high.'))
        if (is_init != tensor_util.is_ref(self.low)
                and is_init != tensor_util.is_ref(self.peak)):
            assertions.append(
                assert_util.assert_less(
                    low,
                    peak,
                    message='triangular not defined when low > peak.'))
        if (is_init != tensor_util.is_ref(self.high)
                and is_init != tensor_util.is_ref(self.peak)):
            assertions.append(
                assert_util.assert_less(
                    peak,
                    high,
                    message='triangular not defined when peak > high.'))

        return assertions
def maybe_assert_negative_binomial_param_correctness(is_init, validate_args,
                                                     total_count, probs,
                                                     logits):
    """Return assertions for `NegativeBinomial`-type distributions."""
    if is_init:
        x, name = (probs, 'probs') if logits is None else (logits, 'logits')
        if not dtype_util.is_floating(x.dtype):
            raise TypeError(
                'Argument `{}` must having floating type.'.format(name))

    if not validate_args:
        return []

    assertions = []
    if is_init != tensor_util.is_ref(total_count):
        total_count = tf.convert_to_tensor(total_count)
        assertions.extend([
            assert_util.assert_non_negative(
                total_count,
                message='`total_count` has components less than 0.'),
            distribution_util.assert_integer_form(
                total_count,
                message='`total_count` has fractional components.')
        ])
    if probs is not None:
        if is_init != tensor_util.is_ref(probs):
            probs = tf.convert_to_tensor(probs)
            one = tf.constant(1., probs.dtype)
            assertions.extend([
                assert_util.assert_non_negative(
                    probs, message='`probs` has components less than 0.'),
                assert_util.assert_less_equal(
                    probs,
                    one,
                    message='`probs` has components greater than 1.')
            ])

    return assertions
    def _parameter_control_dependencies(self, is_init):
        if not self.validate_args:
            return []
        assertions = []
        for param_name, param in dict(
                concentration=self.concentration,
                mixing_concentration=self.mixing_concentration,
                mixing_rate=self.mixing_rate).items():

            if is_init != tensor_util.is_ref(param):
                assertions.append(
                    assert_util.assert_positive(
                        param,
                        message='Argument `{}` must be positive.'.format(
                            param_name)))
        return assertions
Beispiel #21
0
 def _parameter_control_dependencies(self, is_init):
     if not self.validate_args:
         return []
     assertions = []
     if self._probs is not None:
         if is_init != tensor_util.is_ref(self._probs):
             probs = tf.convert_to_tensor(self._probs)
             assertions.append(
                 assert_util.assert_positive(
                     probs, message='Argument `probs` must be positive.'))
             assertions.append(
                 assert_util.assert_less_equal(
                     probs,
                     dtype_util.as_numpy_dtype(self.dtype)(1.),
                     message=
                     'Argument `probs` must be less than or equal to 1.'))
     return assertions
Beispiel #22
0
  def _parameter_control_dependencies(self, is_init):
    """Checks the validity of the concentration parameter."""
    assertions = []

    # In init, we can always build shape and dtype checks because
    # we assume shape doesn't change for Variable backed args.
    if is_init:
      if not dtype_util.is_floating(self.concentration.dtype):
        raise TypeError('Argument `concentration` must be float type.')

      msg = 'Argument `concentration` must have rank at least 1.'
      ndims = tensorshape_util.rank(self.concentration.shape)
      if ndims is not None:
        if ndims < 1:
          raise ValueError(msg)
      elif self.validate_args:
        assertions.append(assert_util.assert_rank_at_least(
            self.concentration, 1, message=msg))

      msg = 'Argument `concentration` must have `event_size` at least 2.'
      event_size = tf.compat.dimension_value(self.concentration.shape[-1])
      if event_size is not None:
        if event_size < 2:
          raise ValueError(msg)
      elif self.validate_args:
        assertions.append(assert_util.assert_less(
            1,
            tf.shape(self.concentration)[-1],
            message=msg))

    if not self.validate_args:
      assert not assertions  # Should never happen.
      return []

    if is_init != tensor_util.is_ref(self.concentration):
      assertions.append(assert_util.assert_positive(
          self.concentration,
          message='Argument `concentration` must be positive.'))

    return assertions
Beispiel #23
0
  def _parameter_control_dependencies(self, is_init):
    assertions = []

    if is_init:
      try:
        self._batch_shape()
      except ValueError:
        raise ValueError(
            'Arguments `loc` and `scale` must have compatible shapes; '
            'loc.shape={}, scale.shape={}.'.format(
                self.loc.shape, self.scale.shape))
      # We don't bother checking the shapes in the dynamic case because
      # all member functions access both arguments anyway.

    if not self.validate_args:
      assert not assertions  # Should never happen.
      return []

    if is_init != tensor_util.is_ref(self.scale):
      assertions.append(assert_util.assert_positive(
          self.scale, message='Argument `scale` must be positive.'))

    return assertions
Beispiel #24
0
    def _parameter_control_dependencies(self, is_init):
        assertions = []

        logits = self._logits
        probs = self._probs
        param, name = (probs, 'probs') if logits is None else (logits,
                                                               'logits')

        # In init, we can always build shape and dtype checks because
        # we assume shape doesn't change for Variable backed args.
        if is_init:
            if not dtype_util.is_floating(param.dtype):
                raise TypeError(
                    'Argument `{}` must having floating type.'.format(name))

            msg = 'Argument `{}` must have rank at least 1.'.format(name)
            shape_static = tensorshape_util.dims(param.shape)
            if shape_static is not None:
                if len(shape_static) < 1:
                    raise ValueError(msg)
            elif self.validate_args:
                param = tf.convert_to_tensor(param)
                assertions.append(
                    assert_util.assert_rank_at_least(param, 1, message=msg))

            msg1 = 'Argument `{}` must have final dimension >= 1.'.format(name)
            msg2 = 'Argument `{}` must have final dimension <= {}.'.format(
                name, tf.int32.max)
            event_size = shape_static[-1] if shape_static is not None else None
            if event_size is not None:
                if event_size < 1:
                    raise ValueError(msg1)
                if event_size > tf.int32.max:
                    raise ValueError(msg2)
            elif self.validate_args:
                param = tf.convert_to_tensor(param)
                assertions.append(
                    assert_util.assert_greater_equal(tf.shape(param)[-1:],
                                                     1,
                                                     message=msg1))
                # NOTE: For now, we leave out a runtime assertion that
                # `tf.shape(param)[-1] <= tf.int32.max`.  An earlier `tf.shape` call
                # will fail before we get to this point.

        if not self.validate_args:
            assert not assertions  # Should never happen.
            return []

        if is_init != tensor_util.is_ref(self.temperature):
            assertions.append(assert_util.assert_positive(self.temperature))

        if probs is not None:
            probs = param  # reuse tensor conversion from above
            if is_init != tensor_util.is_ref(probs):
                probs = tf.convert_to_tensor(probs)
                one = tf.ones([], dtype=probs.dtype)
                assertions.extend([
                    assert_util.assert_non_negative(probs),
                    assert_util.assert_less_equal(probs, one),
                    assert_util.assert_near(
                        tf.reduce_sum(probs, axis=-1),
                        one,
                        message='Argument `probs` must sum to 1.'),
                ])

        return assertions
Beispiel #25
0
    def _parameter_control_dependencies(self, is_init):
        assertions = []

        # For `logits` and `probs`, we only want to have an assertion on what the
        # user actually passed. For now, we access the underlying categorical's
        # _logits and _probs directly. After the 2019-10-01 deprecation, it would
        # also work to use .logits() and .probs().
        logits = self._categorical._logits
        probs = self._categorical._probs
        outcomes = self._outcomes
        validate_args = self._validate_args

        # Build all shape and dtype checks during the `is_init` call.
        if is_init:

            def validate_equal_last_dim(tensor_a, tensor_b, message):
                event_size_a = tf.compat.dimension_value(tensor_a.shape[-1])
                event_size_b = tf.compat.dimension_value(tensor_b.shape[-1])
                if event_size_a is not None and event_size_b is not None:
                    if event_size_a != event_size_b:
                        raise ValueError(message)
                elif validate_args:
                    return assert_util.assert_equal(tf.shape(tensor_a)[-1],
                                                    tf.shape(tensor_b)[-1],
                                                    message=message)

            message = 'Size of outcomes must be greater than 0.'
            if tensorshape_util.num_elements(outcomes.shape) is not None:
                if tensorshape_util.num_elements(outcomes.shape) == 0:
                    raise ValueError(message)
            elif validate_args:
                assertions.append(
                    tf.assert_greater(tf.size(outcomes), 0, message=message))

            if logits is not None:
                maybe_assert = validate_equal_last_dim(
                    outcomes,
                    # pylint: disable=protected-access
                    self._categorical._logits,
                    # pylint: enable=protected-access
                    message=
                    'Last dimension of outcomes and logits must be equal size.'
                )
                if maybe_assert:
                    assertions.append(maybe_assert)

            if probs is not None:
                maybe_assert = validate_equal_last_dim(
                    outcomes,
                    probs,
                    message=
                    'Last dimension of outcomes and probs must be equal size.')
                if maybe_assert:
                    assertions.append(maybe_assert)

            message = 'Rank of outcomes must be 1.'
            ndims = tensorshape_util.rank(outcomes.shape)
            if ndims is not None:
                if ndims != 1:
                    raise ValueError(message)
            elif validate_args:
                assertions.append(
                    assert_util.assert_rank(outcomes, 1, message=message))

        if not validate_args:
            assert not assertions  # Should never happen.
            return []

        if is_init != tensor_util.is_ref(outcomes):
            assertions.append(
                assert_util.assert_equal(
                    tf.math.is_strictly_increasing(outcomes),
                    True,
                    message='outcomes is not strictly increasing.'))

        return assertions