Esempio n. 1
0
    def _parameter_control_dependencies(self, is_init):
        if not self.validate_args:
            return []
        low = tf.convert_to_tensor(self.low)
        high = tf.convert_to_tensor(self.high)
        peak = tf.convert_to_tensor(self.peak)
        assertions = []
        if (is_init != tensor_util.is_ref(self.low)
                and is_init != tensor_util.is_ref(self.high)):
            assertions.append(
                assert_util.assert_less(
                    low,
                    high,
                    message='triangular not defined when low >= high.'))
        if (is_init != tensor_util.is_ref(self.low)
                and is_init != tensor_util.is_ref(self.peak)):
            assertions.append(
                assert_util.assert_less_equal(
                    low,
                    peak,
                    message='triangular not defined when low > peak.'))
        if (is_init != tensor_util.is_ref(self.high)
                and is_init != tensor_util.is_ref(self.peak)):
            assertions.append(
                assert_util.assert_less_equal(
                    peak,
                    high,
                    message='triangular not defined when peak > high.'))

        return assertions
Esempio n. 2
0
def maybe_assert_continuous_bernoulli_param_correctness(
    is_init, validate_args, probs, logits):
  """Return assertions for `Bernoulli`-type distributions."""
  if is_init:
    x, name = (probs, 'probs') if logits is None else (logits, 'logits')
    if not dtype_util.is_floating(x.dtype):
      raise TypeError('Argument `{}` must having floating type.'.format(name))

  if not validate_args:
    return []

  assertions = []

  if probs is not None:
    if is_init != tensor_util.is_ref(probs):
      probs = tf.convert_to_tensor(probs)
      one = tf.constant(1.0, probs.dtype)
      assertions += [
          assert_util.assert_non_negative(
              probs,
              message='probs has components less than 0.'),
          assert_util.assert_less_equal(
              probs,
              one,
              message='probs has components greater than 1.')]
  return assertions
  def _inverse(self, y):
    # To derive the inverse mapping note that:
    #   y[i] = exp(x[i]) / normalization
    # and
    #   y[end] = 1 / normalization.
    # Thus:
    # x[i] = log(exp(x[i])) - log(y[end]) - log(normalization)
    #      = log(exp(x[i])/normalization) - log(y[end])
    #      = log(y[i]) - log(y[end])

    # Do this first to make sure CSE catches that it'll happen again in
    # _inverse_log_det_jacobian.

    assertions = []
    if self.validate_args:
      assertions.append(assert_util.assert_near(
          tf.reduce_sum(y, axis=-1),
          tf.ones([], y.dtype),
          2. * np.finfo(dtype_util.as_numpy_dtype(y.dtype)).eps,
          message='Last dimension of `y` must sum to `1`.'))
      assertions.append(assert_util.assert_less_equal(
          y, tf.ones([], y.dtype),
          message='Elements of `y` must be less than or equal to `1`.'))
      assertions.append(assert_util.assert_non_negative(
          y, message='Elements of `y` must be non-negative.'))

    with tf.control_dependencies(assertions):
      x = tf.math.log(y)
      x, log_normalization = tf.split(x, num_or_size_splits=[-1, 1], axis=-1)
    return x - log_normalization
Esempio n. 4
0
  def _parameter_control_dependencies(self, is_init):
    if not self.validate_args:
      return []
    mean_direction = tf.convert_to_tensor(self.mean_direction)
    concentration = tf.convert_to_tensor(self.concentration)

    assertions = []
    if is_init != tensor_util.is_ref(self._mean_direction):
      assertions.append(
          assert_util.assert_greater(
              tf.shape(mean_direction)[-1],
              1,
              message='`mean_direction` may not have scalar event shape'))
      assertions.append(
          assert_util.assert_less_equal(
              tf.shape(mean_direction)[-1],
              5,
              message='von Mises-Fisher ndims > 5 is not currently supported'))
      assertions.append(
          assert_util.assert_near(
              1.,
              tf.linalg.norm(mean_direction, axis=-1),
              message='`mean_direction` must be unit-length'))
    if is_init != tensor_util.is_ref(self._concentration):
      assertions.append(
          assert_util.assert_non_negative(
              concentration, message='`concentration` must be non-negative'))
    return assertions
Esempio n. 5
0
def calculate_reshape(original_shape, new_shape, validate=False, name=None):
    """Calculates the reshaped dimensions (replacing up to one -1 in reshape)."""
    batch_shape_static = tensorshape_util.constant_value_as_shape(new_shape)
    if tensorshape_util.is_fully_defined(batch_shape_static):
        return np.int32(batch_shape_static), batch_shape_static, []
    with tf.name_scope(name or 'calculate_reshape'):
        original_size = tf.reduce_prod(original_shape)
        implicit_dim = tf.equal(new_shape, -1)
        size_implicit_dim = (original_size //
                             tf.maximum(1, -tf.reduce_prod(new_shape)))
        expanded_new_shape = tf.where(  # Assumes exactly one `-1`.
            implicit_dim, size_implicit_dim, new_shape)
        validations = [] if not validate else [  # pylint: disable=g-long-ternary
            assert_util.assert_rank(
                original_shape, 1, message='Original shape must be a vector.'),
            assert_util.assert_rank(
                new_shape, 1, message='New shape must be a vector.'),
            assert_util.assert_less_equal(
                tf.math.count_nonzero(implicit_dim, dtype=tf.int32),
                1,
                message='At most one dimension can be unknown.'),
            assert_util.assert_positive(
                expanded_new_shape, message='Shape elements must be >=-1.'),
            assert_util.assert_equal(tf.reduce_prod(expanded_new_shape),
                                     original_size,
                                     message='Shape sizes do not match.'),
        ]
        return expanded_new_shape, batch_shape_static, validations
def maybe_assert_negative_binomial_param_correctness(
    is_init, validate_args, total_count, probs, logits):
  """Return assertions for `NegativeBinomial`-type distributions."""
  if is_init:
    x, name = (probs, 'probs') if logits is None else (logits, 'logits')
    if not dtype_util.is_floating(x.dtype):
      raise TypeError(
          'Argument `{}` must having floating type.'.format(name))

  if not validate_args:
    return []

  assertions = []
  if is_init != tensor_util.is_ref(total_count):
    total_count = tf.convert_to_tensor(total_count)
    assertions.extend([
        assert_util.assert_positive(
            total_count,
            message='`total_count` has components less than or equal to 0.'),
        distribution_util.assert_integer_form(
            total_count,
            message='`total_count` has fractional components.')
    ])
  if probs is not None:
    if is_init != tensor_util.is_ref(probs):
      probs = tf.convert_to_tensor(probs)
      one = tf.constant(1., probs.dtype)
      assertions.extend([
          assert_util.assert_non_negative(
              probs, message='`probs` has components less than 0.'),
          assert_util.assert_less_equal(
              probs, one, message='`probs` has components greater than 1.')
      ])

  return assertions
  def _parameter_control_dependencies(self, is_init):
    assertions = []

    logits = self._logits
    probs = self._probs
    param, name = (probs, 'probs') if logits is None else (logits, 'logits')

    # In init, we can always build shape and dtype checks because
    # we assume shape doesn't change for Variable backed args.
    if is_init:
      if not dtype_util.is_floating(param.dtype):
        raise TypeError('Argument `{}` must having floating type.'.format(name))

      msg = 'Argument `{}` must have rank at least 1.'.format(name)
      shape_static = tensorshape_util.dims(param.shape)
      if shape_static is not None:
        if len(shape_static) < 1:
          raise ValueError(msg)
      elif self.validate_args:
        param = tf.convert_to_tensor(param)
        assertions.append(
            assert_util.assert_rank_at_least(param, 1, message=msg))
        with tf.control_dependencies(assertions):
          param = tf.identity(param)

      msg1 = 'Argument `{}` must have final dimension >= 1.'.format(name)
      msg2 = 'Argument `{}` must have final dimension <= {}.'.format(
          name, dtype_util.max(tf.int32))
      event_size = shape_static[-1] if shape_static is not None else None
      if event_size is not None:
        if event_size < 1:
          raise ValueError(msg1)
        if event_size > dtype_util.max(tf.int32):
          raise ValueError(msg2)
      elif self.validate_args:
        param = tf.convert_to_tensor(param)
        assertions.append(assert_util.assert_greater_equal(
            tf.shape(param)[-1], 1, message=msg1))
        # NOTE: For now, we leave out a runtime assertion that
        # `tf.shape(param)[-1] <= tf.int32.max`.  An earlier `tf.shape` call
        # will fail before we get to this point.

    if not self.validate_args:
      assert not assertions  # Should never happen.
      return []

    if probs is not None:
      probs = param  # reuse tensor conversion from above
      if is_init != tensor_util.is_ref(probs):
        probs = tf.convert_to_tensor(probs)
        one = tf.ones([], dtype=probs.dtype)
        assertions.extend([
            assert_util.assert_non_negative(probs),
            assert_util.assert_less_equal(probs, one),
            assert_util.assert_near(
                tf.reduce_sum(probs, axis=-1), one,
                message='Argument `probs` must sum to 1.'),
        ])

    return assertions
Esempio n. 8
0
  def _prob(self, x):
    if self.validate_args:
      with tf.control_dependencies([
          assert_util.assert_greater_equal(x, self.low),
          assert_util.assert_less_equal(x, self.high)
      ]):
        x = tf.identity(x)

    broadcast_x_to_high = _broadcast_to(x, [self.high])
    left_of_peak = tf.logical_and(
        broadcast_x_to_high > self.low, broadcast_x_to_high <= self.peak)

    interval_length = self.high - self.low
    # This is the pdf function when a low <= high <= x. This looks like
    # a triangle, so we have to treat each line segment separately.
    result_inside_interval = tf.where(
        left_of_peak,
        # Line segment from (self.low, 0) to (self.peak, 2 / (self.high -
        # self.low).
        2. * (x - self.low) / (interval_length * (self.peak - self.low)),
        # Line segment from (self.peak, 2 / (self.high - self.low)) to
        # (self.high, 0).
        2. * (self.high - x) / (interval_length * (self.high - self.peak)))

    broadcast_x_to_peak = _broadcast_to(x, [self.peak])
    outside_interval = tf.logical_or(
        broadcast_x_to_peak < self.low, broadcast_x_to_peak > self.high)

    broadcast_shape = tf.broadcast_dynamic_shape(
        tf.shape(input=x), self.batch_shape_tensor())

    return tf.where(
        outside_interval,
        tf.zeros(broadcast_shape, dtype=self.dtype),
        result_inside_interval)
Esempio n. 9
0
  def _parameter_control_dependencies(self, is_init):
    assertions = []
    if is_init:
      if not dtype_util.is_floating(self._scale.dtype):
        raise TypeError(
            'scale.dtype={} is not a floating-point type.'.format(
                self._scale.dtype))
      if not self._scale.is_square:
        raise ValueError('scale must be square.')
      dtype_util.assert_same_float_dtype([self._df, self._scale])

    df_val = tf.get_static_value(self._df)
    dim_val = tf.compat.dimension_value(self._scale.shape[-1])
    msg = ('Degrees of freedom (`df = {}`) cannot be less than dimension of '
           'scale matrix (`scale.dimension = {}`).')
    if is_init and df_val is not None and dim_val is not None:
      df_val = np.asarray(df_val)
      dim_val = np.asarray(dim_val)
      if not dim_val.shape:
        dim_val = dim_val[np.newaxis, ...]
      if not df_val.shape:
        df_val = df_val[np.newaxis, ...]
      if np.any(df_val < dim_val):
        raise ValueError(msg.format(df_val, dim_val))

    elif self.validate_args:
      if (is_init != tensor_util.is_ref(self._df) or
          is_init != tensor_util.is_ref(self._scale)):
        df = tf.convert_to_tensor(self._df)
        dimension = self._dimension()
        assertions.append(assert_util.assert_less_equal(
            dimension, df, message=(msg.format(df, dimension))))

    return assertions
Esempio n. 10
0
  def _parameter_control_dependencies(self, is_init):
    if not self.validate_args:
      return []

    assertions = []

    if is_init != tensor_util.is_ref(self.total_count):
      total_count = tf.convert_to_tensor(self.total_count)
      msg1 = 'Argument `total_count` must be non-negative.'
      msg2 = 'Argument `total_count` cannot contain fractional components.'
      assertions += [
          assert_util.assert_non_negative(total_count, message=msg1),
          distribution_util.assert_integer_form(total_count, message=msg2),
      ]

    if self._probs is not None:
      if is_init != tensor_util.is_ref(self._probs):
        probs = tf.convert_to_tensor(self._probs)
        one = tf.constant(1., probs.dtype)
        assertions += [
            assert_util.assert_non_negative(
                probs, message='probs has components less than 0.'),
            assert_util.assert_less_equal(
                probs, one, message='probs has components greater than 1.')
        ]

    return assertions
Esempio n. 11
0
 def _parameter_control_dependencies(self, is_init):
     if not self.validate_args:
         # Avoid computing intermediates needed to construct the assertions.
         return []
     assertions = []
     if is_init != tensor_util.is_ref(self._batch_shape_unexpanded):
         implicit_dim_mask = ps.equal(self._batch_shape_unexpanded, -1)
         assertions.append(
             assert_util.assert_rank(self._batch_shape_unexpanded,
                                     1,
                                     message='New shape must be a vector.'))
         assertions.append(
             assert_util.assert_less_equal(
                 tf.math.count_nonzero(implicit_dim_mask, dtype=tf.int32),
                 1,
                 message='At most one dimension can be unknown.'))
         assertions.append(
             assert_util.assert_non_negative(
                 self._batch_shape_unexpanded + 1,
                 message='Shape elements must be >=-1.'))
         # Check that the old and new shapes are the same size.
         expanded_new_shape, original_size = self._calculate_new_shape()
         new_size = ps.reduce_prod(expanded_new_shape)
         assertions.append(
             assert_util.assert_equal(new_size,
                                      tf.cast(original_size,
                                              new_size.dtype),
                                      message='Shape sizes do not match.'))
     return assertions
Esempio n. 12
0
    def _parameter_control_dependencies(self, is_init):
        if not self.validate_args:
            return []

        assertions = []
        if is_init != tensor_util.is_ref(self._temperature):
            msg1 = 'Argument `temperature` must be positive.'
            temperature = tf.convert_to_tensor(self._temperature)
            assertions.append(
                assert_util.assert_positive(temperature, message=msg1))

        if self._probs is not None:
            if is_init != tensor_util.is_ref(self._probs):
                probs = tf.convert_to_tensor(self._probs)
                one = tf.constant(1., probs.dtype)
                assertions.extend([
                    assert_util.assert_non_negative(
                        probs,
                        message='Argument `probs` has components less than 0.'
                    ),
                    assert_util.assert_less_equal(
                        probs,
                        one,
                        message=
                        'Argument `probs` has components greater than 1.')
                ])

        return assertions
Esempio n. 13
0
    def _sample_control_dependencies(self, x):
        assertions = []
        if tensorshape_util.is_fully_defined(x.shape[-2:]):
            if not (tensorshape_util.dims(x.shape)[-2] ==
                    tensorshape_util.dims(x.shape)[-1] == self.dimension):
                raise ValueError(
                    'Input dimension mismatch: expected [..., {}, {}], got {}'.
                    format(self.dimension, self.dimension,
                           tensorshape_util.dims(x.shape)))
        elif self.validate_args:
            msg = 'Input dimension mismatch: expected [..., {}, {}], got {}'.format(
                self.dimension, self.dimension, tf.shape(x))
            assertions.append(
                assert_util.assert_equal(tf.shape(x)[-2],
                                         self.dimension,
                                         message=msg))
            assertions.append(
                assert_util.assert_equal(tf.shape(x)[-1],
                                         self.dimension,
                                         message=msg))

        if self.validate_args and not self.input_output_cholesky:
            assertions.append(
                assert_util.assert_less_equal(
                    dtype_util.as_numpy_dtype(x.dtype)(-1),
                    x,
                    message='Correlations must be >= -1.',
                    summarize=30))
            assertions.append(
                assert_util.assert_less_equal(
                    x,
                    dtype_util.as_numpy_dtype(x.dtype)(1),
                    message='Correlations must be <= 1.',
                    summarize=30))
            assertions.append(
                assert_util.assert_near(
                    tf.linalg.diag_part(x),
                    dtype_util.as_numpy_dtype(x.dtype)(1),
                    message='Self-correlations must be = 1.',
                    summarize=30))
            assertions.append(
                assert_util.assert_near(
                    x,
                    tf.linalg.matrix_transpose(x),
                    message='Correlation matrices must be symmetric.',
                    summarize=30))
        return assertions
Esempio n. 14
0
 def _maybe_assert_valid_sample(self, x, dtype):
     if not self.validate_args:
         return x
     one = tf.ones([], dtype=dtype)
     return distribution_util.with_dependencies([
         assert_util.assert_non_negative(x),
         assert_util.assert_less_equal(x, one),
         assert_util.assert_near(one, tf.reduce_sum(x, axis=[-1])),
     ], x)
 def _sample_control_dependencies(self, x):
   assertions = []
   if not self.validate_args:
     return assertions
   assertions.append(assert_util.assert_greater_equal(
       x, self.low, message='Sample must be greater than or equal to `low`.'))
   assertions.append(assert_util.assert_less_equal(
       x, self.high, message='Sample must be less than or equal to `high`.'))
   return assertions
Esempio n. 16
0
 def _sample_control_dependencies(self, x):
   assertions = []
   if not self.validate_args:
     return assertions
   assertions.extend(distribution_util.assert_nonnegative_integer_form(x))
   assertions.append(
       assert_util.assert_less_equal(x, tf.ones([], dtype=x.dtype),
                                     message='Elements cannot exceed 1.'))
   return assertions
Esempio n. 17
0
 def _maybe_assert_valid_y(self, y):
   if not self.validate_args:
     return []
   is_positive = assert_util.assert_non_negative(
       y, message='Inverse transformation input must be greater than 0.')
   less_than_one = assert_util.assert_less_equal(
       y,
       tf.constant(1., y.dtype),
       message='Inverse transformation input must be less than or equal to 1.')
   return [is_positive, less_than_one]
Esempio n. 18
0
 def _maybe_assert_valid_y(self, y):
   if not self.validate_args:
     return y
   is_positive = assert_util.assert_non_negative(
       y, message="Inverse transformation input must be greater than 0.")
   less_than_one = assert_util.assert_less_equal(
       y,
       tf.constant(1., y.dtype),
       message="Inverse transformation input must be less than or equal to 1.")
   return distribution_util.with_dependencies([is_positive, less_than_one], y)
Esempio n. 19
0
 def _sample_control_dependencies(self, x):
   assertions = []
   if not self.validate_args:
     return assertions
   assertions.append(assert_util.assert_non_negative(
       x, message='Sample must be non-negative.'))
   assertions.append(assert_util.assert_less_equal(
       x, tf.ones([], x.dtype),
       message='Sample must be less than or equal to `1`.'))
   return assertions
Esempio n. 20
0
 def _maybe_assert_valid_sample(self, counts):
   """Check counts for proper shape, values, then return tensor version."""
   if not self.validate_args:
     return counts
   counts = distribution_util.embed_check_nonnegative_integer_form(counts)
   msg = ('Sampled counts must be itemwise less than '
          'or equal to `total_count` parameter.')
   return distribution_util.with_dependencies([
       assert_util.assert_less_equal(counts, self.total_count, message=msg),
   ], counts)
Esempio n. 21
0
 def _sample_control_dependencies(self, x):
   assertions = []
   if not self.validate_args:
     return assertions
   assertions.extend(distribution_util.assert_nonnegative_integer_form(x))
   assertions.append(
       assert_util.assert_less_equal(
           x, tf.cast(self._num_categories(), x.dtype),
           message=('StoppingRatioLogistic samples must be `>= 0` and `<= K` '
                    'where `K` is the number of cutpoints.')))
   return assertions
Esempio n. 22
0
 def _is_valid_correlation_matrix(self, x):
     if not self.validate_args or self.input_output_cholesky:
         return []
     return [
         assert_util.assert_less_equal(
             dtype_util.as_numpy_dtype(x.dtype)(-1),
             x,
             message='Correlations must be >= -1.'),
         assert_util.assert_less_equal(
             x,
             dtype_util.as_numpy_dtype(x.dtype)(1),
             message='Correlations must be <= 1.'),
         assert_util.assert_near(tf.linalg.diag_part(x),
                                 dtype_util.as_numpy_dtype(x.dtype)(1),
                                 message='Self-correlations must be = 1.'),
         assert_util.assert_near(
             x,
             tf.linalg.matrix_transpose(x),
             message='Correlation matrices must be symmetric')
     ]
Esempio n. 23
0
 def _maybe_assert_valid(self, x):
     if not self.validate_args:
         return x
     return distribution_util.with_dependencies([
         assert_util.assert_non_negative(
             x, message='Sample must be non-negative.'),
         assert_util.assert_less_equal(
             x,
             tf.ones([], self.concentration0.dtype),
             message='Sample must be less than or equal to `1`.'),
     ], x)
Esempio n. 24
0
 def _maybe_assert_valid_sample(self, counts):
     """Check counts for proper shape, values, then return tensor version."""
     if not self.validate_args:
         return counts
     counts = distribution_util.embed_check_nonnegative_integer_form(counts)
     return distribution_util.with_dependencies([
         assert_util.assert_less_equal(
             counts,
             self.total_count,
             message='counts are not less than or equal to n.'),
     ], counts)
Esempio n. 25
0
 def _sample_control_dependencies(self, x):
   assertions = []
   if not self.validate_args:
     return assertions
   assertions.extend(distribution_util.assert_nonnegative_integer_form(x))
   assertions.append(
       assert_util.assert_less_equal(
           x, tf.cast(self._num_categories(), x.dtype),
           message=('Categorical samples must be between `0` and `n-1` '
                    'where `n` is the number of categories.')))
   return assertions
Esempio n. 26
0
 def _assertions(self, t):
   if not self.validate_args:
     return []
   return [
       assert_util.assert_non_negative(
           t, message="Inverse transformation input must be greater than 0."),
       assert_util.assert_less_equal(
           t,
           dtype_util.as_numpy_dtype(t.dtype)(1.),
           message="Inverse transformation input must be less than or equal "
           "to 1.")]
Esempio n. 27
0
 def _sample_control_dependencies(self, counts):
   """Check counts for proper values."""
   assertions = []
   if not self.validate_args:
     return assertions
   assertions.extend(distribution_util.assert_nonnegative_integer_form(counts))
   assertions.append(
       assert_util.assert_less_equal(
           counts, self.total_count,
           message=('Sampled counts must be itemwise less than '
                    'or equal to `total_count` parameter.')))
   return assertions
 def _call_quantile(self, value, name, **kwargs):
   with self._name_and_control_scope(name):
     dtype = tf.float32 if tf.nest.is_nested(self.dtype) else self.dtype
     value = tf.convert_to_tensor(value, name='value', dtype_hint=dtype)
     if self.validate_args:
       value = distribution_util.with_dependencies([
           assert_util.assert_less_equal(value, tf.cast(1, value.dtype),
                                         message='`value` must be <= 1'),
           assert_util.assert_greater_equal(value, tf.cast(0, value.dtype),
                                            message='`value` must be >= 0')
       ], value)
     return self._quantile(value, **kwargs)
Esempio n. 29
0
 def _validate_correlationness(self, x):
     if not self.validate_args or self.input_output_cholesky:
         return x
     checks = [
         assert_util.assert_less_equal(
             dtype_util.as_numpy_dtype(x.dtype)(-1),
             x,
             message='Correlations must be >= -1.'),
         assert_util.assert_less_equal(
             x,
             dtype_util.as_numpy_dtype(x.dtype)(1),
             message='Correlations must be <= 1.'),
         assert_util.assert_near(tf.linalg.diag_part(x),
                                 dtype_util.as_numpy_dtype(x.dtype)(1),
                                 message='Self-correlations must be = 1.'),
         assert_util.assert_near(
             x,
             tf.linalg.matrix_transpose(x),
             message='Correlation matrices must be symmetric')
     ]
     with tf.control_dependencies(checks):
         return tf.identity(x)
Esempio n. 30
0
 def _assertions(self, t):
   if not self.validate_args:
     return []
   return [
       assert_util.assert_greater_equal(
           t,
           dtype_util.as_numpy_dtype(t.dtype)(-1),
           message="Inverse transformation input must be >= -1."),
       assert_util.assert_less_equal(
           t,
           dtype_util.as_numpy_dtype(t.dtype)(1),
           message="Inverse transformation input must be <= 1.")
   ]