コード例 #1
0
def maybe_check_quadrature_param(param, name, validate_args):
    """Helper which checks validity of `loc` and `scale` init args."""
    with tf.name_scope("check_" + name):
        assertions = []
        if tensorshape_util.rank(param.shape) is not None:
            if tensorshape_util.rank(param.shape) == 0:
                raise ValueError("Mixing params must be a (batch of) vector; "
                                 "{}.rank={} is not at least one.".format(
                                     name, tensorshape_util.rank(param.shape)))
        elif validate_args:
            assertions.append(
                assert_util.assert_rank_at_least(
                    param,
                    1,
                    message=("Mixing params must be a (batch of) vector; "
                             "{}.rank is not at least one.".format(name))))

        # TODO(jvdillon): Remove once we support k-mixtures.
        if tensorshape_util.with_rank_at_least(param.shape, 1)[-1] is not None:
            if tf.compat.dimension_value(param.shape[-1]) != 1:
                raise NotImplementedError(
                    "Currently only bimixtures are supported; "
                    "{}.shape[-1]={} is not 1.".format(
                        name, tf.compat.dimension_value(param.shape[-1])))
        elif validate_args:
            assertions.append(
                assert_util.assert_equal(
                    tf.shape(param)[-1],
                    1,
                    message=("Currently only bimixtures are supported; "
                             "{}.shape[-1] is not 1.".format(name))))

        if assertions:
            return distribution_util.with_dependencies(assertions, param)
        return param
コード例 #2
0
def calculate_reshape(original_shape, new_shape, validate=False, name=None):
    """Calculates the reshaped dimensions (replacing up to one -1 in reshape)."""
    batch_shape_static = tensorshape_util.constant_value_as_shape(new_shape)
    if tensorshape_util.is_fully_defined(batch_shape_static):
        return np.int32(batch_shape_static), batch_shape_static, []
    with tf.name_scope(name or 'calculate_reshape'):
        original_size = tf.reduce_prod(original_shape)
        implicit_dim = tf.equal(new_shape, -1)
        size_implicit_dim = (original_size //
                             tf.maximum(1, -tf.reduce_prod(new_shape)))
        expanded_new_shape = tf.where(  # Assumes exactly one `-1`.
            implicit_dim, size_implicit_dim, new_shape)
        validations = [] if not validate else [  # pylint: disable=g-long-ternary
            assert_util.assert_rank(
                original_shape, 1, message='Original shape must be a vector.'),
            assert_util.assert_rank(
                new_shape, 1, message='New shape must be a vector.'),
            assert_util.assert_less_equal(
                tf.math.count_nonzero(implicit_dim, dtype=tf.int32),
                1,
                message='At most one dimension can be unknown.'),
            assert_util.assert_positive(
                expanded_new_shape, message='Shape elements must be >=-1.'),
            assert_util.assert_equal(tf.reduce_prod(expanded_new_shape),
                                     original_size,
                                     message='Shape sizes do not match.'),
        ]
        return expanded_new_shape, batch_shape_static, validations
コード例 #3
0
def _maybe_check_valid_map_values(map_values, validate_args):
    """Validate `map_values` if `validate_args`==True."""
    assertions = []

    message = 'Rank of map_values must be 1.'
    if tensorshape_util.rank(map_values.shape) is not None:
        if tensorshape_util.rank(map_values.shape) != 1:
            raise ValueError(message)
    elif validate_args:
        assertions.append(
            assert_util.assert_rank(map_values, 1, message=message))

    message = 'Size of map_values must be greater than 0.'
    if tensorshape_util.num_elements(map_values.shape) is not None:
        if tensorshape_util.num_elements(map_values.shape) == 0:
            raise ValueError(message)
    elif validate_args:
        assertions.append(
            assert_util.assert_greater(tf.size(map_values), 0,
                                       message=message))

    if validate_args:
        assertions.append(
            assert_util.assert_equal(
                tf.math.is_strictly_increasing(map_values),
                True,
                message='map_values is not strictly increasing.'))

    return assertions
コード例 #4
0
ファイル: linalg.py プロジェクト: HackerShohag/SuggestBot-bn
def lu_reconstruct_assertions(lower_upper, perm, validate_args):
    """Returns list of assertions related to `lu_reconstruct` assumptions."""
    assertions = []

    message = 'Input `lower_upper` must have at least 2 dimensions.'
    if tensorshape_util.rank(lower_upper.shape) is not None:
        if tensorshape_util.rank(lower_upper.shape) < 2:
            raise ValueError(message)
    elif validate_args:
        assertions.append(
            assert_util.assert_rank_at_least(lower_upper,
                                             rank=2,
                                             message=message))

    message = '`rank(lower_upper)` must equal `rank(perm) + 1`'
    if (tensorshape_util.rank(lower_upper.shape) is not None
            and tensorshape_util.rank(perm.shape) is not None):
        if (tensorshape_util.rank(lower_upper.shape) !=
                tensorshape_util.rank(perm.shape) + 1):
            raise ValueError(message)
    elif validate_args:
        assertions.append(
            assert_util.assert_rank(lower_upper,
                                    rank=tf.rank(perm) + 1,
                                    message=message))

    message = '`lower_upper` must be square.'
    if tensorshape_util.is_fully_defined(lower_upper.shape[:-2]):
        if lower_upper.shape[-2] != lower_upper.shape[-1]:
            raise ValueError(message)
    elif validate_args:
        m, n = tf.split(tf.shape(lower_upper)[-2:], num_or_size_splits=2)
        assertions.append(assert_util.assert_equal(m, n, message=message))

    return assertions
コード例 #5
0
def _maybe_validate_perm(perm, validate_args, name=None):
    """Checks that `perm` is valid."""
    with tf.name_scope(name or 'maybe_validate_perm'):
        assertions = []
        if not dtype_util.is_integer(perm.dtype):
            raise TypeError('`perm` must be integer type')

        msg = '`perm` must be a vector.'
        if tensorshape_util.rank(perm.shape) is not None:
            if tensorshape_util.rank(perm.shape) != 1:
                raise ValueError(msg[:-1] + ', saw rank: {}.'.format(
                    tensorshape_util.rank(perm.shape)))
        elif validate_args:
            assertions += [assert_util.assert_rank(perm, 1, message=msg)]

        perm_ = tf.get_static_value(perm)
        msg = '`perm` must be a valid permutation vector.'
        if perm_ is not None:
            if not np.all(np.arange(np.size(perm_)) == np.sort(perm_)):
                raise ValueError(msg[:-1] + ', saw: {}.'.format(perm_))
        elif validate_args:
            assertions += [
                assert_util.assert_equal(tf.sort(perm),
                                         tf.range(tf.size(perm)),
                                         message=msg)
            ]

        return assertions
コード例 #6
0
    def _maybe_validate_shape_override(self, override_shape, base_is_scalar,
                                       validate_args, name):
        """Helper to __init__ which ensures override batch/event_shape are valid."""
        if override_shape is None:
            override_shape = []

        override_shape = tf.convert_to_tensor(override_shape,
                                              dtype=tf.int32,
                                              name=name)

        if not dtype_util.is_integer(override_shape.dtype):
            raise TypeError("shape override must be an integer")

        override_is_scalar = _is_scalar_from_shape_tensor(override_shape)
        if tf.get_static_value(override_is_scalar):
            return self._empty

        dynamic_assertions = []

        if tensorshape_util.rank(override_shape.shape) is not None:
            if tensorshape_util.rank(override_shape.shape) != 1:
                raise ValueError("shape override must be a vector")
        elif validate_args:
            dynamic_assertions += [
                assert_util.assert_rank(
                    override_shape,
                    1,
                    message="shape override must be a vector")
            ]

        if tf.get_static_value(override_shape) is not None:
            if any(s < 0 for s in tf.get_static_value(override_shape)):
                raise ValueError(
                    "shape override must have non-negative elements")
        elif validate_args:
            dynamic_assertions += [
                assert_util.assert_non_negative(
                    override_shape,
                    message="shape override must have non-negative elements")
            ]

        is_both_nonscalar = prefer_static.logical_and(
            prefer_static.logical_not(base_is_scalar),
            prefer_static.logical_not(override_is_scalar))
        if tf.get_static_value(is_both_nonscalar) is not None:
            if tf.get_static_value(is_both_nonscalar):
                raise ValueError("base distribution not scalar")
        elif validate_args:
            dynamic_assertions += [
                assert_util.assert_equal(
                    is_both_nonscalar,
                    False,
                    message="base distribution not scalar")
            ]

        if not dynamic_assertions:
            return override_shape
        return distribution_util.with_dependencies(dynamic_assertions,
                                                   override_shape)
コード例 #7
0
 def _assertions(self, x):
     if not self.validate_args:
         return []
     x_shape = tf.shape(x)
     is_matrix = assert_util.assert_rank_at_least(
         x, 2, message="Input must have rank at least 2.")
     is_square = assert_util.assert_equal(
         x_shape[-2], x_shape[-1], message="Input must be a square matrix.")
     diag_part_x = tf.linalg.diag_part(x)
     is_lower_triangular = assert_util.assert_equal(
         tf.linalg.band_part(x, 0, -1),  # Preserves triu, zeros rest.
         tf.linalg.diag(diag_part_x),
         message="Input must be lower triangular.")
     is_positive_diag = assert_util.assert_positive(
         diag_part_x,
         message="Input must have all positive diagonal entries.")
     return [is_matrix, is_square, is_lower_triangular, is_positive_diag]
コード例 #8
0
 def _assertions(self, t):
     if self.validate_args:
         return []
     is_matrix = assert_util.assert_rank_at_least(t, 2)
     is_square = assert_util.assert_equal(tf.shape(t)[-2], tf.shape(t)[-1])
     is_positive_definite = assert_util.assert_positive(
         tf.linalg.diag_part(t), message="Input must be positive definite.")
     return [is_matrix, is_square, is_positive_definite]
コード例 #9
0
 def _observation_mask_shape_preconditions(self, observation_tensor_shape,
                                           mask_tensor_shape):
     shape_condition = [
         assert_util.assert_equal(
             observation_tensor_shape[-1 - self._underlying_event_rank],
             self._num_steps,
             message="The tensor `observations` must consist of sequences"
             "of observations from `HiddenMarkovModel` of length"
             "`num_steps`.")
     ]
     if mask_tensor_shape is not None:
         shape_condition.append(
             assert_util.assert_equal(
                 mask_tensor_shape[-1],
                 self._num_steps,
                 message="The tensor `mask` must consist of sequences"
                 "of length `num_steps`."))
     return tf.control_dependencies(shape_condition)
コード例 #10
0
 def validate_equal_last_dim(tensor_a, tensor_b, message):
     event_size_a = tf.compat.dimension_value(tensor_a.shape[-1])
     event_size_b = tf.compat.dimension_value(tensor_b.shape[-1])
     if event_size_a is not None and event_size_b is not None:
         if event_size_a != event_size_b:
             raise ValueError(message)
     elif validate_args:
         return assert_util.assert_equal(tf.shape(tensor_a)[-1],
                                         tf.shape(tensor_b)[-1],
                                         message=message)
コード例 #11
0
    def _parameter_control_dependencies(self, is_init):
        assertions = []

        message = 'Distributions must have the same `batch_shape`'

        if is_init:
            batch_shapes = tf.nest.flatten(self._cached_batch_shape)
            if all(tensorshape_util.is_fully_defined(b) for b in batch_shapes):
                if batch_shapes[1:] != batch_shapes[:-1]:
                    raise ValueError('{}; found: {}.'.format(
                        message, batch_shapes))

        if not self.validate_args:
            assert not assertions  # Should never happen.
            return []

        if self.validate_args:
            batch_shapes = self._cached_batch_shape
            if not all(
                    tensorshape_util.is_fully_defined(s)
                    for s in tf.nest.flatten(batch_shapes)):
                batch_shapes = tf.nest.map_structure(
                    lambda static_shape, shape_tensor:  # pylint: disable=g-long-lambda
                    (static_shape if tensorshape_util.is_fully_defined(
                        static_shape) else shape_tensor),
                    batch_shapes,
                    self._cached_batch_shape_tensor)
            batch_shapes = tf.nest.flatten(batch_shapes)
            assertions.extend(
                assert_util.assert_equal(  # pylint: disable=g-complex-comprehension
                    b1,
                    b2,
                    message='{}.'.format(message))
                for b1, b2 in zip(batch_shapes[1:], batch_shapes[:-1]))
            assertions.extend(
                assert_util.assert_equal(  # pylint: disable=g-complex-comprehension
                    tf.size(b1),
                    tf.size(b2),
                    message='{}.'.format(message))
                for b1, b2 in zip(batch_shapes[1:], batch_shapes[:-1]))

        return assertions
コード例 #12
0
 def _maybe_assert_valid_sample(self, counts):
   """Check counts for proper shape, values, then return tensor version."""
   if not self.validate_args:
     return counts
   counts = distribution_util.embed_check_nonnegative_integer_form(counts)
   return distribution_util.with_dependencies([
       assert_util.assert_equal(
           self.total_count,
           tf.reduce_sum(counts, axis=-1),
           message='counts last-dimension must sum to `self.total_count`'),
   ], counts)
コード例 #13
0
 def _maybe_assert_valid_sample(self, counts):
     """Check counts for proper shape, values, then return tensor version."""
     if not self.validate_args:
         return []
     assertions = distribution_util.assert_nonnegative_integer_form(counts)
     assertions.append(
         assert_util.assert_equal(
             self.total_count,
             tf.reduce_sum(counts, axis=-1),
             message='counts must sum to `self.total_count`'))
     return assertions
コード例 #14
0
 def _assertions(self, x):
   if not self.validate_args:
     return []
   shape = tf.shape(x)
   is_matrix = assert_util.assert_rank_at_least(
       x, 2, message="Input must have rank at least 2.")
   is_square = assert_util.assert_equal(
       shape[-2], shape[-1], message="Input must be a square matrix.")
   above_diagonal = tf.linalg.band_part(
       tf.linalg.set_diag(x, tf.zeros(shape[:-1], dtype=tf.float32)), 0, -1)
   is_lower_triangular = assert_util.assert_equal(
       above_diagonal,
       tf.zeros_like(above_diagonal),
       message="Input must be lower triangular.")
   # A lower triangular matrix is nonsingular iff all its diagonal entries are
   # nonzero.
   diag_part = tf.linalg.diag_part(x)
   is_nonsingular = assert_util.assert_none_equal(
       diag_part,
       tf.zeros_like(diag_part),
       message="Input must have all diagonal entries nonzero.")
   return [is_matrix, is_square, is_lower_triangular, is_nonsingular]
コード例 #15
0
def _validate_block_sizes(block_sizes, bijectors, validate_args):
  """Helper to validate block sizes."""
  block_sizes_shape = block_sizes.shape
  if tensorshape_util.is_fully_defined(block_sizes_shape):
    if (tensorshape_util.rank(block_sizes_shape) != 1 or
        (tensorshape_util.num_elements(block_sizes_shape) != len(bijectors))):
      raise ValueError(
          '`block_sizes` must be `None`, or a vector of the same length as '
          '`bijectors`. Got a `Tensor` with shape {} and `bijectors` of '
          'length {}'.format(block_sizes_shape, len(bijectors)))
    return block_sizes
  elif validate_args:
    message = ('`block_sizes` must be `None`, or a vector of the same length '
               'as `bijectors`.')
    with tf.control_dependencies([
        assert_util.assert_equal(
            tf.size(block_sizes), len(bijectors), message=message),
        assert_util.assert_equal(tf.rank(block_sizes), 1)
    ]):
      return tf.identity(block_sizes)
  else:
    return block_sizes
コード例 #16
0
 def _z(self, x, scale, concentration):
     loc = tf.convert_to_tensor(self.loc)
     if self.validate_args:
         valid = (x >= loc) & ((concentration >= 0) |
                               (x <= loc - scale / concentration))
         with tf.control_dependencies([
                 assert_util.assert_equal(
                     valid,
                     True,
                     message='`x` outside distribution\'s support.')
         ]):
             x = tf.identity(x)
     return (x - loc) / scale
コード例 #17
0
 def _forward(self, x):
     map_values = tf.convert_to_tensor(self.map_values)
     if self.validate_args:
         with tf.control_dependencies([
                 assert_util.assert_equal(
                     (0 <= x) & (x < tf.size(map_values)),
                     True,
                     message='indices out of bound')
         ]):
             x = tf.identity(x)
     # If we want batch dims in self.map_values, we can (after broadcasting),
     # use:
     # tf.gather(self.map_values, x, batch_dims=-1, axis=-1)
     return tf.gather(map_values, indices=x)
コード例 #18
0
  def _assert_compatible_shape(self, index, sample_shape, samples):
    requested_shape, _ = self._expand_sample_shape_to_vector(
        tf.convert_to_tensor(sample_shape, dtype=tf.int32),
        name='requested_shape')
    actual_shape = prefer_static.shape(samples)
    actual_rank = prefer_static.rank_from_shape(actual_shape)
    requested_rank = prefer_static.rank_from_shape(requested_shape)

    # We test for two properties we expect of yielded distributions:
    # (1) The rank of the tensor of generated samples must be at least
    #     as large as the rank requested.
    # (2) The requested shape must be a prefix of the shape of the
    #     generated tensor of samples.
    # We attempt to perform test (1) statically first.
    # We don't need to do this explicitly for test (2) because
    # `assert_equal` evaluates statically if it can.
    static_actual_rank = tf.get_static_value(actual_rank)
    static_requested_rank = tf.get_static_value(requested_rank)

    assertion_message = ('Samples yielded by distribution #{} are not '
                         'consistent with `sample_shape` passed to '
                         '`JointDistributionCoroutine` '
                         'distribution.'.format(index))

    # TODO Remove this static check (b/138738650)
    if (static_actual_rank is not None and
        static_requested_rank is not None):
      # We're able to statically check the rank
      if static_actual_rank < static_requested_rank:
        raise ValueError(assertion_message)
      else:
        control_dependencies = []
    else:
      # We're not able to statically check the rank
      control_dependencies = [
          assert_util.assert_greater_equal(
              actual_rank, requested_rank,
              message=assertion_message)
          ]

    with tf.control_dependencies(control_dependencies):
      trimmed_actual_shape = actual_shape[:requested_rank]

    control_dependencies = [
        assert_util.assert_equal(
            requested_shape, trimmed_actual_shape,
            message=assertion_message)
    ]

    return control_dependencies
コード例 #19
0
ファイル: lkj.py プロジェクト: HackerShohag/SuggestBot-bn
 def _validate_dimension(self, x):
     x = tf.convert_to_tensor(x, name='x')
     if tensorshape_util.is_fully_defined(x.shape[-2:]):
         if (tensorshape_util.dims(x.shape)[-2] == tensorshape_util.dims(
                 x.shape)[-1] == self.dimension):
             pass
         else:
             raise ValueError(
                 'Input dimension mismatch: expected [..., {}, {}], got {}'.
                 format(self.dimension, self.dimension,
                        tensorshape_util.dims(x.shape)))
     elif self.validate_args:
         msg = 'Input dimension mismatch: expected [..., {}, {}], got {}'.format(
             self.dimension, self.dimension, tf.shape(x))
         with tf.control_dependencies([
                 assert_util.assert_equal(tf.shape(x)[-2],
                                          self.dimension,
                                          message=msg),
                 assert_util.assert_equal(tf.shape(x)[-1],
                                          self.dimension,
                                          message=msg)
         ]):
             x = tf.identity(x)
     return x
コード例 #20
0
 def _maybe_assert_valid_sample(self, samples):
     """Check counts for proper shape, values, then return tensor version."""
     if not self.validate_args:
         return samples
     with tf.control_dependencies([
             assert_util.assert_near(1.,
                                     tf.linalg.norm(samples, axis=-1),
                                     message='samples must be unit length'),
             assert_util.assert_equal(
                 tf.shape(samples)[-1:],
                 self.event_shape_tensor(),
                 message=
                 ('samples must have innermost dimension matching that of '
                  '`self.mean_direction`')),
     ]):
         return tf.identity(samples)
コード例 #21
0
 def _prob(self, x):
     if self.validate_args:
         is_vector_check = assert_util.assert_rank_at_least(x, 1)
         right_vec_space_check = assert_util.assert_equal(
             self.event_shape_tensor(),
             tf.gather(tf.shape(x),
                       tf.rank(x) - 1),
             message=
             "Argument 'x' not defined in the same space R^k as this distribution"
         )
         with tf.control_dependencies([is_vector_check]):
             with tf.control_dependencies([right_vec_space_check]):
                 x = tf.identity(x)
     loc = tf.convert_to_tensor(self.loc)
     return tf.cast(tf.reduce_all(tf.abs(x - loc) <= self._slack(loc),
                                  axis=-1),
                    dtype=self.dtype)
コード例 #22
0
def maybe_check_wont_broadcast(flat_xs, validate_args):
    """Verifies that `parts` don't broadcast."""
    flat_xs = tuple(flat_xs)  # So we can receive generators.
    if not validate_args:
        # Note: we don't try static validation because it is theoretically
        # possible that a user wants to take advantage of broadcasting.
        # Only when `validate_args` is `True` do we enforce the validation.
        return flat_xs
    msg = 'Broadcasting probably indicates an error in model specification.'
    s = tuple(prefer_static.shape(x) for x in flat_xs)
    if all(prefer_static.is_numpy(s_) for s_ in s):
        if not all(np.all(a == b) for a, b in zip(s[1:], s[:-1])):
            raise ValueError(msg)
        return flat_xs
    assertions = [
        assert_util.assert_equal(a, b, message=msg)
        for a, b in zip(s[1:], s[:-1])
    ]
    with tf.control_dependencies(assertions):
        return tuple(tf.identity(x) for x in flat_xs)
コード例 #23
0
ファイル: sample.py プロジェクト: HackerShohag/SuggestBot-bn
def _kl_sample(a, b, name='kl_sample'):
    """Batched KL divergence `KL(a || b)` for Sample distributions.

  We can leverage the fact that:

  ```
  KL(Sample(a) || Sample(b)) = sum(KL(a || b))
  ```

  where the sum is over the `sample_shape` dims.

  Args:
    a: Instance of `Sample` distribution.
    b: Instance of `Sample` distribution.
    name: (optional) name to use for created ops.
      Default value: `"kl_sample"`'.

  Returns:
    kldiv: Batchwise `KL(a || b)`.

  Raises:
    ValueError: If the `sample_shape` of `a` and `b` don't match.
  """
    assertions = []
    a_ss = tf.get_static_value(a.sample_shape)
    b_ss = tf.get_static_value(b.sample_shape)
    msg = '`a.sample_shape` must be identical to `b.sample_shape`.'
    if a_ss is not None and b_ss is not None:
        if not np.array_equal(a_ss, b_ss):
            raise ValueError(msg)
    elif a.validate_args or b.validate_args:
        assertions.append(
            assert_util.assert_equal(a.sample_shape,
                                     b.sample_shape,
                                     message=msg))
    with tf.control_dependencies(assertions):
        kl = kullback_leibler.kl_divergence(a.distribution,
                                            b.distribution,
                                            name=name)
        n = prefer_static.reduce_prod(a.sample_shape)
        return tf.cast(x=n, dtype=kl.dtype) * kl
コード例 #24
0
ファイル: linalg.py プロジェクト: HackerShohag/SuggestBot-bn
def _lu_solve_assertions(lower_upper, perm, rhs, validate_args):
    """Returns list of assertions related to `lu_solve` assumptions."""
    assertions = lu_reconstruct_assertions(lower_upper, perm, validate_args)

    message = 'Input `rhs` must have at least 2 dimensions.'
    if rhs.shape.ndims is not None:
        if rhs.shape.ndims < 2:
            raise ValueError(message)
    elif validate_args:
        assertions.append(
            assert_util.assert_rank_at_least(rhs, rank=2, message=message))

    message = '`lower_upper.shape[-1]` must equal `rhs.shape[-1]`.'
    if (tf.compat.dimension_value(lower_upper.shape[-1]) is not None
            and tf.compat.dimension_value(rhs.shape[-2]) is not None):
        if lower_upper.shape[-1] != rhs.shape[-2]:
            raise ValueError(message)
    elif validate_args:
        assertions.append(
            assert_util.assert_equal(tf.shape(lower_upper)[-1],
                                     tf.shape(rhs)[-2],
                                     message=message))

    return assertions
コード例 #25
0
ファイル: wishart.py プロジェクト: HackerShohag/SuggestBot-bn
    def __init__(self,
                 df,
                 scale=None,
                 scale_tril=None,
                 input_output_cholesky=False,
                 validate_args=False,
                 allow_nan_stats=True,
                 name="Wishart"):
        """Construct Wishart distributions.

    Args:
      df: `float` or `double` `Tensor`. Degrees of freedom, must be greater than
        or equal to dimension of the scale matrix.
      scale: `float` or `double` `Tensor`. The symmetric positive definite
        scale matrix of the distribution. Exactly one of `scale` and
        'scale_tril` must be passed.
      scale_tril: `float` or `double` `Tensor`. The Cholesky factorization
        of the symmetric positive definite scale matrix of the distribution.
        Exactly one of `scale` and 'scale_tril` must be passed.
      input_output_cholesky: Python `bool`. If `True`, functions whose input or
        output have the semantics of samples assume inputs are in Cholesky form
        and return outputs in Cholesky form. In particular, if this flag is
        `True`, input to `log_prob` is presumed of Cholesky form and output from
        `sample`, `mean`, and `mode` are of Cholesky form.  Setting this
        argument to `True` is purely a computational optimization and does not
        change the underlying distribution; for instance, `mean` returns the
        Cholesky of the mean, not the mean of Cholesky factors. The `variance`
        and `stddev` methods are unaffected by this flag.
        Default value: `False` (i.e., input/output does not have Cholesky
        semantics).
      validate_args: Python `bool`, default `False`. When `True` distribution
        parameters are checked for validity despite possibly degrading runtime
        performance. When `False` invalid inputs may silently render incorrect
        outputs.
      allow_nan_stats: Python `bool`, default `True`. When `True`, statistics
        (e.g., mean, mode, variance) use the value "`NaN`" to indicate the
        result is undefined. When `False`, an exception is raised if one or
        more of the statistic's batch members are undefined.
      name: Python `str` name prefixed to Ops created by this class.
    Raises:
      ValueError: if zero or both of 'scale' and 'scale_tril' are passed in.
    """
        parameters = dict(locals())

        with tf.name_scope(name) as name:
            with tf.name_scope("init"):
                if (scale is None) == (scale_tril is None):
                    raise ValueError(
                        "Must pass scale or scale_tril, but not both.")

                dtype = dtype_util.common_dtype([df, scale, scale_tril],
                                                tf.float32)
                df = tf.convert_to_tensor(df, name="df", dtype=dtype)
                if scale is not None:
                    scale = tf.convert_to_tensor(scale,
                                                 name="scale",
                                                 dtype=dtype)
                    if validate_args:
                        scale = distribution_util.assert_symmetric(scale)
                    scale_tril = tf.linalg.cholesky(scale)
                else:  # scale_tril is not None
                    scale_tril = tf.convert_to_tensor(scale_tril,
                                                      name="scale_tril",
                                                      dtype=dtype)
                    if validate_args:
                        scale_tril = distribution_util.with_dependencies([
                            assert_util.assert_positive(
                                tf.linalg.diag_part(scale_tril),
                                message="scale_tril must be positive definite"
                            ),
                            assert_util.assert_equal(
                                tf.shape(scale_tril)[-1],
                                tf.shape(scale_tril)[-2],
                                message="scale_tril must be square")
                        ], scale_tril)

            super(Wishart, self).__init__(
                df=df,
                scale_operator=tf.linalg.LinearOperatorLowerTriangular(
                    tril=scale_tril,
                    is_non_singular=True,
                    is_positive_definite=True,
                    is_square=True),
                input_output_cholesky=input_output_cholesky,
                validate_args=validate_args,
                allow_nan_stats=allow_nan_stats,
                name=name)
        self._parameters = parameters
コード例 #26
0
ファイル: mixture.py プロジェクト: HackerShohag/SuggestBot-bn
    def __init__(self,
                 cat,
                 components,
                 validate_args=False,
                 allow_nan_stats=True,
                 use_static_graph=False,
                 name="Mixture"):
        """Initialize a Mixture distribution.

    A `Mixture` is defined by a `Categorical` (`cat`, representing the
    mixture probabilities) and a list of `Distribution` objects
    all having matching dtype, batch shape, event shape, and continuity
    properties (the components).

    The `num_classes` of `cat` must be possible to infer at graph construction
    time and match `len(components)`.

    Args:
      cat: A `Categorical` distribution instance, representing the probabilities
          of `distributions`.
      components: A list or tuple of `Distribution` instances.
        Each instance must have the same type, be defined on the same domain,
        and have matching `event_shape` and `batch_shape`.
      validate_args: Python `bool`, default `False`. If `True`, raise a runtime
        error if batch or event ranks are inconsistent between cat and any of
        the distributions. This is only checked if the ranks cannot be
        determined statically at graph construction time.
      allow_nan_stats: Boolean, default `True`. If `False`, raise an
       exception if a statistic (e.g. mean/mode/etc...) is undefined for any
        batch member. If `True`, batch members with valid parameters leading to
        undefined statistics will return NaN for this statistic.
      use_static_graph: Calls to `sample` will not rely on dynamic tensor
        indexing, allowing for some static graph compilation optimizations, but
        at the expense of sampling all underlying distributions in the mixture.
        (Possibly useful when running on TPUs).
        Default value: `False` (i.e., use dynamic indexing).
      name: A name for this distribution (optional).

    Raises:
      TypeError: If cat is not a `Categorical`, or `components` is not
        a list or tuple, or the elements of `components` are not
        instances of `Distribution`, or do not have matching `dtype`.
      ValueError: If `components` is an empty list or tuple, or its
        elements do not have a statically known event rank.
        If `cat.num_classes` cannot be inferred at graph creation time,
        or the constant value of `cat.num_classes` is not equal to
        `len(components)`, or all `components` and `cat` do not have
        matching static batch shapes, or all components do not
        have matching static event shapes.
    """
        parameters = dict(locals())
        if not isinstance(cat, categorical.Categorical):
            raise TypeError(
                "cat must be a Categorical distribution, but saw: %s" % cat)
        if not components:
            raise ValueError("components must be a non-empty list or tuple")
        if not isinstance(components, (list, tuple)):
            raise TypeError("components must be a list or tuple, but saw: %s" %
                            components)
        if not all(
                isinstance(c, distribution.Distribution) for c in components):
            raise TypeError(
                "all entries in components must be Distribution instances"
                " but saw: %s" % components)

        dtype = components[0].dtype
        if not all(d.dtype == dtype for d in components):
            raise TypeError("All components must have the same dtype, but saw "
                            "dtypes: %s" % [(d.name, d.dtype)
                                            for d in components])
        static_event_shape = components[0].event_shape
        static_batch_shape = cat.batch_shape
        for di, d in enumerate(components):
            if not tensorshape_util.is_compatible_with(static_batch_shape,
                                                       d.batch_shape):
                raise ValueError(
                    "components[{}] batch shape must be compatible with cat "
                    "shape and other component batch shapes".format(di))
            static_event_shape = tensorshape_util.merge_with(
                static_event_shape, d.event_shape)
            static_batch_shape = tensorshape_util.merge_with(
                static_batch_shape, d.batch_shape)
        if tensorshape_util.rank(static_event_shape) is None:
            raise ValueError(
                "Expected to know rank(event_shape) from components, but "
                "none of the components provide a static number of ndims")

        # Ensure that all batch and event ndims are consistent.
        with tf.name_scope(name) as name:
            num_components = cat._num_categories()
            static_num_components = tf.get_static_value(num_components)
            if static_num_components is None:
                raise ValueError(
                    "Could not infer number of classes from cat and unable "
                    "to compare this value to the number of components passed in."
                )
            # Possibly convert from numpy 0-D array.
            static_num_components = int(static_num_components)
            if static_num_components != len(components):
                raise ValueError(
                    "cat.num_classes != len(components): %d vs. %d" %
                    (static_num_components, len(components)))

            cat_batch_shape = cat.batch_shape_tensor()
            cat_batch_rank = tf.size(cat_batch_shape)
            if validate_args:
                batch_shapes = [d.batch_shape_tensor() for d in components]
                batch_ranks = [tf.size(bs) for bs in batch_shapes]
                check_message = ("components[%d] batch shape must match cat "
                                 "batch shape")
                self._assertions = [
                    assert_util.assert_equal(cat_batch_rank,
                                             batch_ranks[di],
                                             message=check_message % di)
                    for di in range(len(components))
                ]
                self._assertions += [
                    assert_util.assert_equal(cat_batch_shape,
                                             batch_shapes[di],
                                             message=check_message % di)
                    for di in range(len(components))
                ]
            else:
                self._assertions = []

            self._cat = cat
            self._components = list(components)
            self._num_components = static_num_components
            self._static_event_shape = static_event_shape
            self._static_batch_shape = static_batch_shape

            self._use_static_graph = use_static_graph
            if use_static_graph and static_num_components is None:
                raise ValueError(
                    "Number of categories must be known statically when "
                    "`static_sample=True`.")

        super(Mixture, self).__init__(
            dtype=dtype,
            reparameterization_type=reparameterization.NOT_REPARAMETERIZED,
            validate_args=validate_args,
            allow_nan_stats=allow_nan_stats,
            parameters=parameters,
            name=name)
コード例 #27
0
    def _parameter_control_dependencies(self, is_init):
        assertions = []

        # For `logits` and `probs`, we only want to have an assertion on what the
        # user actually passed. For now, we access the underlying categorical's
        # _logits and _probs directly. After the 2019-10-01 deprecation, it would
        # also work to use .logits() and .probs().
        logits = self._categorical._logits
        probs = self._categorical._probs
        outcomes = self._outcomes
        validate_args = self._validate_args

        # Build all shape and dtype checks during the `is_init` call.
        if is_init:

            def validate_equal_last_dim(tensor_a, tensor_b, message):
                event_size_a = tf.compat.dimension_value(tensor_a.shape[-1])
                event_size_b = tf.compat.dimension_value(tensor_b.shape[-1])
                if event_size_a is not None and event_size_b is not None:
                    if event_size_a != event_size_b:
                        raise ValueError(message)
                elif validate_args:
                    return assert_util.assert_equal(tf.shape(tensor_a)[-1],
                                                    tf.shape(tensor_b)[-1],
                                                    message=message)

            message = 'Size of outcomes must be greater than 0.'
            if tensorshape_util.num_elements(outcomes.shape) is not None:
                if tensorshape_util.num_elements(outcomes.shape) == 0:
                    raise ValueError(message)
            elif validate_args:
                assertions.append(
                    tf.assert_greater(tf.size(outcomes), 0, message=message))

            if logits is not None:
                maybe_assert = validate_equal_last_dim(
                    outcomes,
                    # pylint: disable=protected-access
                    self._categorical._logits,
                    # pylint: enable=protected-access
                    message=
                    'Last dimension of outcomes and logits must be equal size.'
                )
                if maybe_assert:
                    assertions.append(maybe_assert)

            if probs is not None:
                maybe_assert = validate_equal_last_dim(
                    outcomes,
                    probs,
                    message=
                    'Last dimension of outcomes and probs must be equal size.')
                if maybe_assert:
                    assertions.append(maybe_assert)

            message = 'Rank of outcomes must be 1.'
            ndims = tensorshape_util.rank(outcomes.shape)
            if ndims is not None:
                if ndims != 1:
                    raise ValueError(message)
            elif validate_args:
                assertions.append(
                    assert_util.assert_rank(outcomes, 1, message=message))

        if not validate_args:
            assert not assertions  # Should never happen.
            return []

        if is_init != tensor_util.is_ref(outcomes):
            assertions.append(
                assert_util.assert_equal(
                    tf.math.is_strictly_increasing(outcomes),
                    True,
                    message='outcomes is not strictly increasing.'))

        return assertions
コード例 #28
0
def _kl_independent(a, b, name="kl_independent"):
    """Batched KL divergence `KL(a || b)` for Independent distributions.

  We can leverage the fact that
  ```
  KL(Independent(a) || Independent(b)) = sum(KL(a || b))
  ```
  where the sum is over the `reinterpreted_batch_ndims`.

  Args:
    a: Instance of `Independent`.
    b: Instance of `Independent`.
    name: (optional) name to use for created ops. Default "kl_independent".

  Returns:
    Batchwise `KL(a || b)`.

  Raises:
    ValueError: If the event space for `a` and `b`, or their underlying
      distributions don't match.
  """
    p = a.distribution
    q = b.distribution

    # The KL between any two (non)-batched distributions is a scalar.
    # Given that the KL between two factored distributions is the sum, i.e.
    # KL(p1(x)p2(y) || q1(x)q2(y)) = KL(p1 || q1) + KL(q1 || q2), we compute
    # KL(p || q) and do a `reduce_sum` on the reinterpreted batch dimensions.
    if (tensorshape_util.is_fully_defined(a.event_shape)
            and tensorshape_util.is_fully_defined(b.event_shape)):
        if a.event_shape == b.event_shape:
            if p.event_shape == q.event_shape:
                num_reduce_dims = (tensorshape_util.rank(a.event_shape) -
                                   tensorshape_util.rank(p.event_shape))
                reduce_dims = [-i - 1 for i in range(0, num_reduce_dims)]

                return tf.reduce_sum(kullback_leibler.kl_divergence(p,
                                                                    q,
                                                                    name=name),
                                     axis=reduce_dims)
            else:
                raise NotImplementedError(
                    "KL between Independents with different "
                    "event shapes not supported.")
        else:
            raise ValueError("Event shapes do not match.")
    else:
        with tf.control_dependencies([
                assert_util.assert_equal(a.event_shape_tensor(),
                                         b.event_shape_tensor()),
                assert_util.assert_equal(p.event_shape_tensor(),
                                         q.event_shape_tensor())
        ]):
            num_reduce_dims = (prefer_static.rank_from_shape(
                a.event_shape_tensor, a.event_shape) -
                               prefer_static.rank_from_shape(
                                   p.event_shape_tensor, a.event_shape))
            reduce_dims = prefer_static.range(-num_reduce_dims - 1, -1, 1)
            return tf.reduce_sum(kullback_leibler.kl_divergence(p,
                                                                q,
                                                                name=name),
                                 axis=reduce_dims)
コード例 #29
0
def _replace_event_shape_in_shape_tensor(input_shape, event_shape_in,
                                         event_shape_out, validate_args):
    """Replaces the rightmost dims in a `Tensor` representing a shape.

  Args:
    input_shape: a rank-1 `Tensor` of integers
    event_shape_in: the event shape expected to be present in rightmost dims
      of `shape_in`.
    event_shape_out: the event shape with which to replace `event_shape_in` in
      the rightmost dims of `input_shape`.
    validate_args: Python `bool` indicating whether arguments should
      be checked for correctness.

  Returns:
    output_shape: A rank-1 integer `Tensor` with the same contents as
      `input_shape` except for the event dims, which are replaced with
      `event_shape_out`.
  """
    output_tensorshape, is_validated = _replace_event_shape_in_tensorshape(
        tensorshape_util.constant_value_as_shape(input_shape), event_shape_in,
        event_shape_out)

    # TODO(b/124240153): Remove map(tf.identity, deps) once tf.function
    # correctly supports control_dependencies.
    validation_dependencies = (map(tf.identity,
                                   (event_shape_in,
                                    event_shape_out)) if validate_args else ())

    if (tensorshape_util.is_fully_defined(output_tensorshape)
            and (is_validated or not validate_args)):
        with tf.control_dependencies(validation_dependencies):
            output_shape = tf.convert_to_tensor(
                tensorshape_util.as_list(output_tensorshape),
                name='output_shape',
                dtype_hint=tf.int32)
        return output_shape, output_tensorshape

    with tf.control_dependencies(validation_dependencies):
        event_shape_in_ndims = (
            tf.size(event_shape_in)
            if tensorshape_util.num_elements(event_shape_in.shape) is None else
            tensorshape_util.num_elements(event_shape_in.shape))
        input_non_event_shape, input_event_shape = tf.split(
            input_shape, num_or_size_splits=[-1, event_shape_in_ndims])

    additional_assertions = []
    if is_validated:
        pass
    elif validate_args:
        # Check that `input_event_shape` and `event_shape_in` are compatible in the
        # sense that they have equal entries in any position that isn't a `-1` in
        # `event_shape_in`. Note that our validations at construction time ensure
        # there is at most one such entry in `event_shape_in`.
        mask = event_shape_in >= 0
        explicit_input_event_shape = tf.boolean_mask(input_event_shape,
                                                     mask=mask)
        explicit_event_shape_in = tf.boolean_mask(event_shape_in, mask=mask)
        additional_assertions.append(
            assert_util.assert_equal(
                explicit_input_event_shape,
                explicit_event_shape_in,
                message='Input `event_shape` does not match `event_shape_in`.')
        )
        # We don't explicitly additionally verify
        # `tf.size(input_shape) > tf.size(event_shape_in)` since `tf.split`
        # already makes this assertion.

    with tf.control_dependencies(additional_assertions):
        output_shape = tf.concat([input_non_event_shape, event_shape_out],
                                 axis=0,
                                 name='output_shape')

    return output_shape, output_tensorshape
コード例 #30
0
    def __init__(self, permutation, axis=-1, validate_args=False, name=None):
        """Creates the `Permute` bijector.

    Args:
      permutation: An `int`-like vector-shaped `Tensor` representing the
        permutation to apply to the `axis` dimension of the transformed
        `Tensor`.
      axis: Scalar `int` `Tensor` representing the dimension over which to
        `tf.gather`. `axis` must be relative to the end (reading left to right)
        thus must be negative.
        Default value: `-1` (i.e., right-most).
      validate_args: Python `bool` indicating whether arguments should be
        checked for correctness.
      name: Python `str`, name given to ops managed by this object.

    Raises:
      TypeError: if `not dtype_util.is_integer(permutation.dtype)`.
      ValueError: if `permutation` does not contain exactly one of each of
        `{0, 1, ..., d}`.
      NotImplementedError: if `axis` is not known prior to graph execution.
      NotImplementedError: if `axis` is not negative.
    """
        with tf.name_scope(name or "permute") as name:
            axis = tf.convert_to_tensor(axis, name="axis")
            if not dtype_util.is_integer(axis.dtype):
                raise TypeError("axis.dtype ({}) should be `int`-like.".format(
                    dtype_util.name(axis.dtype)))
            permutation = tf.convert_to_tensor(permutation, name="permutation")
            if not dtype_util.is_integer(permutation.dtype):
                raise TypeError(
                    "permutation.dtype ({}) should be `int`-like.".format(
                        dtype_util.name(permutation.dtype)))
            p = tf.get_static_value(permutation)
            if p is not None:
                if set(p) != set(np.arange(p.size)):
                    raise ValueError(
                        "Permutation over `d` must contain exactly one of "
                        "each of `{0, 1, ..., d}`.")
            elif validate_args:
                p, _ = tf.math.top_k(-permutation,
                                     k=tf.shape(permutation)[-1],
                                     sorted=True)
                permutation = distribution_util.with_dependencies([
                    assert_util.assert_equal(
                        -p,
                        tf.range(tf.size(p)),
                        message=(
                            "Permutation over `d` must contain exactly one of "
                            "each of `{0, 1, ..., d}`.")),
                ], permutation)
            axis_ = tf.get_static_value(axis)
            if axis_ is None:
                raise NotImplementedError(
                    "`axis` must be known prior to graph "
                    "execution.")
            elif axis_ >= 0:
                raise NotImplementedError(
                    "`axis` must be relative the rightmost "
                    "dimension, i.e., negative.")
            else:
                forward_min_event_ndims = int(np.abs(axis_))
            self._permutation = permutation
            self._axis = axis
            super(Permute, self).__init__(
                forward_min_event_ndims=forward_min_event_ndims,
                is_constant_jacobian=True,
                validate_args=validate_args,
                name=name)