コード例 #1
0
    def _parameter_control_dependencies(self, is_init):
        assertions = []

        # In init, we can always build shape and dtype checks because
        # we assume shape doesn't change for Variable backed args.
        if is_init and self._is_vector:
            msg = "Argument `loc` must be at least rank 1."
            if tensorshape_util.rank(self.loc.shape) is not None:
                if tensorshape_util.rank(self.loc.shape) < 1:
                    raise ValueError(msg)
            elif self.validate_args:
                assertions.append(
                    assert_util.assert_rank_at_least(self.loc, 1, message=msg))

        if not self.validate_args:
            assert not assertions  # Should never happen
            return []

        if is_init != tensor_util.is_ref(self.atol):
            assertions.append(
                assert_util.assert_non_negative(
                    self.atol, message="Argument 'atol' must be non-negative"))
        if is_init != tensor_util.is_ref(self.rtol):
            assertions.append(
                assert_util.assert_non_negative(
                    self.rtol, message="Argument 'rtol' must be non-negative"))
        return assertions
コード例 #2
0
 def _assertions(self, t):
     if not self.validate_args:
         return []
     return [
         assert_util.assert_non_negative(t,
                                         message="Argument y was negative")
     ]
コード例 #3
0
 def _maybe_assert_valid_x(self, x):
     if not self.validate_args:
         return []
     return [
         assert_util.assert_non_negative(
             x, message='Forward transformation input must be at least 0.')
     ]
コード例 #4
0
def maybe_assert_bernoulli_param_correctness(is_init, validate_args, probs,
                                             probits):
    """Return assertions for `ProbitBernoulli`-type distributions."""
    if is_init:
        x, name = (probs, 'probs') if probits is None else (probits, 'probits')
        if not dtype_util.is_floating(x.dtype):
            raise TypeError(
                'Argument `{}` must having floating type.'.format(name))

    if not validate_args:
        return []

    assertions = []

    if probs is not None:
        if is_init != tensor_util.is_ref(probs):
            probs = tf.convert_to_tensor(probs)
            one = tf.constant(1., probs.dtype)
            assertions += [
                assert_util.assert_non_negative(
                    probs, message='probs has components less than 0.'),
                assert_util.assert_less_equal(
                    probs, one, message='probs has components greater than 1.')
            ]

    return assertions
コード例 #5
0
    def _maybe_validate_shape_override(self, override_shape, base_is_scalar,
                                       validate_args, name):
        """Helper to __init__ which ensures override batch/event_shape are valid."""
        if override_shape is None:
            override_shape = []

        override_shape = tf.convert_to_tensor(override_shape,
                                              dtype=tf.int32,
                                              name=name)

        if not dtype_util.is_integer(override_shape.dtype):
            raise TypeError("shape override must be an integer")

        override_is_scalar = _is_scalar_from_shape_tensor(override_shape)
        if tf.get_static_value(override_is_scalar):
            return self._empty

        dynamic_assertions = []

        if tensorshape_util.rank(override_shape.shape) is not None:
            if tensorshape_util.rank(override_shape.shape) != 1:
                raise ValueError("shape override must be a vector")
        elif validate_args:
            dynamic_assertions += [
                assert_util.assert_rank(
                    override_shape,
                    1,
                    message="shape override must be a vector")
            ]

        if tf.get_static_value(override_shape) is not None:
            if any(s < 0 for s in tf.get_static_value(override_shape)):
                raise ValueError(
                    "shape override must have non-negative elements")
        elif validate_args:
            dynamic_assertions += [
                assert_util.assert_non_negative(
                    override_shape,
                    message="shape override must have non-negative elements")
            ]

        is_both_nonscalar = prefer_static.logical_and(
            prefer_static.logical_not(base_is_scalar),
            prefer_static.logical_not(override_is_scalar))
        if tf.get_static_value(is_both_nonscalar) is not None:
            if tf.get_static_value(is_both_nonscalar):
                raise ValueError("base distribution not scalar")
        elif validate_args:
            dynamic_assertions += [
                assert_util.assert_equal(
                    is_both_nonscalar,
                    False,
                    message="base distribution not scalar")
            ]

        if not dynamic_assertions:
            return override_shape
        return distribution_util.with_dependencies(dynamic_assertions,
                                                   override_shape)
コード例 #6
0
 def _maybe_assert_valid_x(self, x):
     if not self.validate_args or self.power == 0.:
         return []
     return [
         assert_util.assert_non_negative(
             1. + self.power * x,
             message='Forward transformation input must be at least {}.'.
             format(-1. / self.power))
     ]
コード例 #7
0
 def _maybe_assert_valid_sample(self, x, dtype):
     if not self.validate_args:
         return x
     one = tf.ones([], dtype=dtype)
     return distribution_util.with_dependencies([
         assert_util.assert_non_negative(x),
         assert_util.assert_less_equal(x, one),
         assert_util.assert_near(one, tf.reduce_sum(x, axis=[-1])),
     ], x)
コード例 #8
0
ファイル: gumbel.py プロジェクト: HackerShohag/SuggestBot-bn
 def _maybe_assert_valid_y(self, y):
     if not self.validate_args:
         return []
     is_positive = assert_util.assert_non_negative(
         y, message='Inverse transformation input must be greater than 0.')
     less_than_one = assert_util.assert_less_equal(
         y,
         tf.constant(1., y.dtype),
         message=
         'Inverse transformation input must be less than or equal to 1.')
     return [is_positive, less_than_one]
コード例 #9
0
 def _maybe_assert_valid(self, x):
     if not self.validate_args:
         return x
     return distribution_util.with_dependencies([
         assert_util.assert_non_negative(
             x, message="sample must be non-negative"),
         assert_util.assert_less_equal(
             x,
             tf.ones([], self.concentration0.dtype),
             message="sample must be no larger than `1`."),
     ], x)
コード例 #10
0
 def _parameter_control_dependencies(self, is_init):
     if not self.validate_args:
         return []
     assertions = []
     if is_init != tensor_util.is_ref(self.concentration):
         # concentration >= 1
         # TODO(b/111451422, b/115950951) Generalize to concentration > 0.
         assertions.append(
             assert_util.assert_non_negative(
                 self.concentration - 1,
                 message='Argument `concentration` must be >= 1.'))
     return assertions
コード例 #11
0
 def _parameter_control_dependencies(self, is_init):
   if not self.validate_args:
     return []
   if is_init == tensor_util.is_ref(self.total_count):
     return []
   total_count = tf.convert_to_tensor(self.total_count)
   msg1 = 'Argument `total_count` must be non-negative.'
   msg2 = 'Argument `total_count` cannot contain fractional components.'
   return [
       assert_util.assert_non_negative(total_count, message=msg1),
       distribution_util.assert_integer_form(total_count, message=msg2),
   ]
コード例 #12
0
def maybe_assert_negative_binomial_param_correctness(is_init, validate_args,
                                                     total_count, probs,
                                                     logits):
    """Return assertions for `NegativeBinomial`-type distributions."""
    if is_init:
        x, name = (probs, 'probs') if logits is None else (logits, 'logits')
        if not dtype_util.is_floating(x.dtype):
            raise TypeError(
                'Argument `{}` must having floating type.'.format(name))

    if not validate_args:
        return []

    assertions = []
    if is_init != tensor_util.is_ref(total_count):
        total_count = tf.convert_to_tensor(total_count)
        assertions.extend([
            assert_util.assert_non_negative(
                total_count,
                message='`total_count` has components less than 0.'),
            distribution_util.assert_integer_form(
                total_count,
                message='`total_count` has fractional components.')
        ])
    if probs is not None:
        if is_init != tensor_util.is_ref(probs):
            probs = tf.convert_to_tensor(probs)
            one = tf.constant(1., probs.dtype)
            assertions.extend([
                assert_util.assert_non_negative(
                    probs, message='`probs` has components less than 0.'),
                assert_util.assert_less_equal(
                    probs,
                    one,
                    message='`probs` has components greater than 1.')
            ])

    return assertions
コード例 #13
0
def maybe_assert_categorical_param_correctness(is_init, validate_args, probs,
                                               logits):
    """Return assertions for `Categorical`-type distributions."""
    assertions = []

    # In init, we can always build shape and dtype checks because
    # we assume shape doesn't change for Variable backed args.
    if is_init:
        x, name = (probs, 'probs') if logits is None else (logits, 'logits')

        if not dtype_util.is_floating(x.dtype):
            raise TypeError(
                'Argument `{}` must having floating type.'.format(name))

        msg = 'Argument `{}` must have rank at least 1.'.format(name)
        ndims = tensorshape_util.rank(x.shape)
        if ndims is not None:
            if ndims < 1:
                raise ValueError(msg)
        elif validate_args:
            x = tf.convert_to_tensor(x)
            probs = x if logits is None else None  # Retain tensor conversion.
            logits = x if probs is None else None
            assertions.append(
                assert_util.assert_rank_at_least(x, 1, message=msg))

    if not validate_args:
        assert not assertions  # Should never happen.
        return []

    if logits is not None:
        if is_init != tensor_util.is_ref(logits):
            logits = tf.convert_to_tensor(logits)
            assertions.extend(
                distribution_util.assert_categorical_event_shape(logits))

    if probs is not None:
        if is_init != tensor_util.is_ref(probs):
            probs = tf.convert_to_tensor(probs)
            assertions.extend([
                assert_util.assert_non_negative(probs),
                assert_util.assert_near(
                    tf.reduce_sum(probs, axis=-1),
                    np.array(1, dtype=dtype_util.as_numpy_dtype(probs.dtype)),
                    message='Argument `probs` must sum to 1.')
            ])
            assertions.extend(
                distribution_util.assert_categorical_event_shape(probs))

    return assertions
コード例 #14
0
 def _assertions(self, t):
     if not self.validate_args:
         return []
     return [
         assert_util.assert_non_negative(
             t,
             message="Inverse transformation input must be greater than 0."
         ),
         assert_util.assert_less_equal(
             t,
             dtype_util.as_numpy_dtype(t.dtype)(1.),
             message=
             "Inverse transformation input must be less than or equal "
             "to 1.")
     ]
コード例 #15
0
def _maybe_validate_rightmost_transposed_ndims(rightmost_transposed_ndims,
                                               validate_args,
                                               name=None):
    """Checks that `rightmost_transposed_ndims` is valid."""
    with tf.name_scope(name or 'maybe_validate_rightmost_transposed_ndims'):
        assertions = []
        if not dtype_util.is_integer(rightmost_transposed_ndims.dtype):
            raise TypeError(
                '`rightmost_transposed_ndims` must be integer type.')

        if tensorshape_util.rank(rightmost_transposed_ndims.shape) is not None:
            if tensorshape_util.rank(rightmost_transposed_ndims.shape) != 0:
                raise ValueError(
                    '`rightmost_transposed_ndims` must be a scalar, '
                    'saw rank: {}.'.format(
                        tensorshape_util.rank(
                            rightmost_transposed_ndims.shape)))
        elif validate_args:
            assertions += [
                assert_util.assert_rank(rightmost_transposed_ndims, 0)
            ]

        rightmost_transposed_ndims_ = tf.get_static_value(
            rightmost_transposed_ndims)
        msg = '`rightmost_transposed_ndims` must be non-negative.'
        if rightmost_transposed_ndims_ is not None:
            if rightmost_transposed_ndims_ < 0:
                raise ValueError(
                    msg[:-1] +
                    ', saw: {}.'.format(rightmost_transposed_ndims_))
        elif validate_args:
            assertions += [
                assert_util.assert_non_negative(rightmost_transposed_ndims,
                                                message=msg)
            ]

        return assertions
コード例 #16
0
ファイル: square.py プロジェクト: HackerShohag/SuggestBot-bn
 def _assertions(self, t):
   if not self.validate_args:
     return []
   return [assert_util.assert_non_negative(
       t, message="All elements must be non-negative.")]
コード例 #17
0
    def _parameter_control_dependencies(self, is_init):
        assertions = []

        logits = self._logits
        probs = self._probs
        param, name = (probs, 'probs') if logits is None else (logits,
                                                               'logits')

        # In init, we can always build shape and dtype checks because
        # we assume shape doesn't change for Variable backed args.
        if is_init:
            if not dtype_util.is_floating(param.dtype):
                raise TypeError(
                    'Argument `{}` must having floating type.'.format(name))

            msg = 'Argument `{}` must have rank at least 1.'.format(name)
            shape_static = tensorshape_util.dims(param.shape)
            if shape_static is not None:
                if len(shape_static) < 1:
                    raise ValueError(msg)
            elif self.validate_args:
                param = tf.convert_to_tensor(param)
                assertions.append(
                    assert_util.assert_rank_at_least(param, 1, message=msg))
                with tf.control_dependencies(assertions):
                    param = tf.identity(param)

            msg1 = 'Argument `{}` must have final dimension >= 1.'.format(name)
            msg2 = 'Argument `{}` must have final dimension <= {}.'.format(
                name, tf.int32.max)
            event_size = shape_static[-1] if shape_static is not None else None
            if event_size is not None:
                if event_size < 1:
                    raise ValueError(msg1)
                if event_size > tf.int32.max:
                    raise ValueError(msg2)
            elif self.validate_args:
                param = tf.convert_to_tensor(param)
                assertions.append(
                    assert_util.assert_greater_equal(tf.shape(param)[-1],
                                                     1,
                                                     message=msg1))
                # NOTE: For now, we leave out a runtime assertion that
                # `tf.shape(param)[-1] <= tf.int32.max`.  An earlier `tf.shape` call
                # will fail before we get to this point.

        if not self.validate_args:
            assert not assertions  # Should never happen.
            return []

        if probs is not None:
            probs = param  # reuse tensor conversion from above
            if is_init != tensor_util.is_ref(probs):
                probs = tf.convert_to_tensor(probs)
                one = tf.ones([], dtype=probs.dtype)
                assertions.extend([
                    assert_util.assert_non_negative(probs),
                    assert_util.assert_less_equal(probs, one),
                    assert_util.assert_near(
                        tf.reduce_sum(probs, axis=-1),
                        one,
                        message='Argument `probs` must sum to 1.'),
                ])

        return assertions
コード例 #18
0
  def __init__(self,
               mean_direction,
               concentration,
               validate_args=False,
               allow_nan_stats=True,
               name='VonMisesFisher'):
    """Creates a new `VonMisesFisher` instance.

    Args:
      mean_direction: Floating-point `Tensor` with shape [B1, ... Bn, D].
        A unit vector indicating the mode of the distribution, or the
        unit-normalized direction of the mean. (This is *not* in general the
        mean of the distribution; the mean is not generally in the support of
        the distribution.) NOTE: `D` is currently restricted to <= 5.
      concentration: Floating-point `Tensor` having batch shape [B1, ... Bn]
        broadcastable with `mean_direction`. The level of concentration of
        samples around the `mean_direction`. `concentration=0` indicates a
        uniform distribution over the unit hypersphere, and `concentration=+inf`
        indicates a `Deterministic` distribution (delta function) at
        `mean_direction`.
      validate_args: Python `bool`, default `False`. When `True` distribution
        parameters are checked for validity despite possibly degrading runtime
        performance. When `False` invalid inputs may silently render incorrect
        outputs.
      allow_nan_stats: Python `bool`, default `True`. When `True`,
        statistics (e.g., mean, mode, variance) use the value "`NaN`" to
        indicate the result is undefined. When `False`, an exception is raised
        if one or more of the statistic's batch members are undefined.
      name: Python `str` name prefixed to Ops created by this class.

    Raises:
      ValueError: For known-bad arguments, i.e. unsupported event dimension.
    """
    parameters = dict(locals())
    with tf.name_scope(name) as name:
      dtype = dtype_util.common_dtype([mean_direction, concentration],
                                      tf.float32)
      mean_direction = tf.convert_to_tensor(
          mean_direction, name='mean_direction', dtype=dtype)
      concentration = tf.convert_to_tensor(
          concentration, name='concentration', dtype=dtype)
      assertions = [
          assert_util.assert_non_negative(
              concentration, message='`concentration` must be non-negative'),
          assert_util.assert_greater(
              tf.shape(mean_direction)[-1],
              1,
              message='`mean_direction` may not have scalar event shape'),
          assert_util.assert_near(
              1.,
              tf.linalg.norm(mean_direction, axis=-1),
              message='`mean_direction` must be unit-length')
      ] if validate_args else []
      static_event_dim = tf.compat.dimension_value(
          tensorshape_util.with_rank_at_least(mean_direction.shape, 1)[-1])
      if static_event_dim is not None and static_event_dim > 5:
        raise ValueError('vMF ndims > 5 is not currently supported')
      elif validate_args:
        assertions += [
            assert_util.assert_less_equal(
                tf.shape(mean_direction)[-1],
                5,
                message='vMF ndims > 5 is not currently supported')
        ]
      with tf.control_dependencies(assertions):
        self._mean_direction = tf.identity(mean_direction)
        self._concentration = tf.identity(concentration)
      dtype_util.assert_same_float_dtype(
          [self._mean_direction, self._concentration])
      # mean_direction is always reparameterized.
      # concentration is only for event_dim==3, via an inversion sampler.
      reparameterization_type = (
          reparameterization.FULLY_REPARAMETERIZED
          if static_event_dim == 3 else
          reparameterization.NOT_REPARAMETERIZED)
      super(VonMisesFisher, self).__init__(
          dtype=self._concentration.dtype,
          validate_args=validate_args,
          allow_nan_stats=allow_nan_stats,
          reparameterization_type=reparameterization_type,
          parameters=parameters,
          name=name)