示例#1
0
def _validate_block_sizes(block_sizes, bijectors, validate_args):
    """Helper to validate block sizes."""
    block_sizes_shape = block_sizes.shape
    if tensorshape_util.is_fully_defined(block_sizes_shape):
        if (tensorshape_util.rank(block_sizes_shape) != 1
                or (tensorshape_util.num_elements(block_sizes_shape) !=
                    len(bijectors))):
            raise ValueError(
                '`block_sizes` must be `None`, or a vector of the same length as '
                '`bijectors`. Got a `Tensor` with shape {} and `bijectors` of '
                'length {}'.format(block_sizes_shape, len(bijectors)))
        return block_sizes
    elif validate_args:
        message = (
            '`block_sizes` must be `None`, or a vector of the same length '
            'as `bijectors`.')
        with tf.control_dependencies([
                assert_util.assert_equal(tf.size(block_sizes),
                                         len(bijectors),
                                         message=message),
                assert_util.assert_equal(tf.rank(block_sizes), 1)
        ]):
            return tf.identity(block_sizes)
    else:
        return block_sizes
def calculate_reshape(original_shape, new_shape, validate=False, name=None):
    """Calculates the reshaped dimensions (replacing up to one -1 in reshape)."""
    batch_shape_static = tensorshape_util.constant_value_as_shape(new_shape)
    if tensorshape_util.is_fully_defined(batch_shape_static):
        return np.int32(batch_shape_static), batch_shape_static, []
    with tf.name_scope(name or 'calculate_reshape'):
        original_size = tf.reduce_prod(original_shape)
        implicit_dim = tf.equal(new_shape, -1)
        size_implicit_dim = (original_size //
                             tf.maximum(1, -tf.reduce_prod(new_shape)))
        expanded_new_shape = tf.where(  # Assumes exactly one `-1`.
            implicit_dim, size_implicit_dim, new_shape)
        validations = [] if not validate else [  # pylint: disable=g-long-ternary
            assert_util.assert_rank(
                original_shape, 1, message='Original shape must be a vector.'),
            assert_util.assert_rank(
                new_shape, 1, message='New shape must be a vector.'),
            assert_util.assert_less_equal(
                tf.math.count_nonzero(implicit_dim, dtype=tf.int32),
                1,
                message='At most one dimension can be unknown.'),
            assert_util.assert_positive(
                expanded_new_shape, message='Shape elements must be >=-1.'),
            assert_util.assert_equal(tf.reduce_prod(expanded_new_shape),
                                     original_size,
                                     message='Shape sizes do not match.'),
        ]
        return expanded_new_shape, batch_shape_static, validations
示例#3
0
def lu_reconstruct_assertions(lower_upper, perm, validate_args):
    """Returns list of assertions related to `lu_reconstruct` assumptions."""
    assertions = []

    message = 'Input `lower_upper` must have at least 2 dimensions.'
    if tensorshape_util.rank(lower_upper.shape) is not None:
        if tensorshape_util.rank(lower_upper.shape) < 2:
            raise ValueError(message)
    elif validate_args:
        assertions.append(
            assert_util.assert_rank_at_least(lower_upper,
                                             rank=2,
                                             message=message))

    message = '`rank(lower_upper)` must equal `rank(perm) + 1`'
    if (tensorshape_util.rank(lower_upper.shape) is not None
            and tensorshape_util.rank(perm.shape) is not None):
        if (tensorshape_util.rank(lower_upper.shape) !=
                tensorshape_util.rank(perm.shape) + 1):
            raise ValueError(message)
    elif validate_args:
        assertions.append(
            assert_util.assert_rank(lower_upper,
                                    rank=tf.rank(perm) + 1,
                                    message=message))

    message = '`lower_upper` must be square.'
    if tensorshape_util.is_fully_defined(lower_upper.shape[:-2]):
        if lower_upper.shape[-2] != lower_upper.shape[-1]:
            raise ValueError(message)
    elif validate_args:
        m, n = tf.split(tf.shape(lower_upper)[-2:], num_or_size_splits=2)
        assertions.append(assert_util.assert_equal(m, n, message=message))

    return assertions
示例#4
0
def _maybe_validate_perm(perm, validate_args, name=None):
    """Checks that `perm` is valid."""
    with tf.name_scope(name or 'maybe_validate_perm'):
        assertions = []
        if not dtype_util.is_integer(perm.dtype):
            raise TypeError('`perm` must be integer type')

        msg = '`perm` must be a vector.'
        if tensorshape_util.rank(perm.shape) is not None:
            if tensorshape_util.rank(perm.shape) != 1:
                raise ValueError(msg[:-1] + ', saw rank: {}.'.format(
                    tensorshape_util.rank(perm.shape)))
        elif validate_args:
            assertions += [assert_util.assert_rank(perm, 1, message=msg)]

        perm_ = tf.get_static_value(perm)
        msg = '`perm` must be a valid permutation vector.'
        if perm_ is not None:
            if not np.all(np.arange(np.size(perm_)) == np.sort(perm_)):
                raise ValueError(msg[:-1] + ', saw: {}.'.format(perm_))
        elif validate_args:
            assertions += [
                assert_util.assert_equal(tf.sort(perm),
                                         tf.range(tf.size(perm)),
                                         message=msg)
            ]

        return assertions
def maybe_check_quadrature_param(param, name, validate_args):
    """Helper which checks validity of `loc` and `scale` init args."""
    with tf.name_scope("check_" + name):
        assertions = []
        if tensorshape_util.rank(param.shape) is not None:
            if tensorshape_util.rank(param.shape) == 0:
                raise ValueError("Mixing params must be a (batch of) vector; "
                                 "{}.rank={} is not at least one.".format(
                                     name, tensorshape_util.rank(param.shape)))
        elif validate_args:
            assertions.append(
                assert_util.assert_rank_at_least(
                    param,
                    1,
                    message=("Mixing params must be a (batch of) vector; "
                             "{}.rank is not at least one.".format(name))))

        # TODO(jvdillon): Remove once we support k-mixtures.
        if tensorshape_util.with_rank_at_least(param.shape, 1)[-1] is not None:
            if tf.compat.dimension_value(param.shape[-1]) != 1:
                raise NotImplementedError(
                    "Currently only bimixtures are supported; "
                    "{}.shape[-1]={} is not 1.".format(
                        name, tf.compat.dimension_value(param.shape[-1])))
        elif validate_args:
            assertions.append(
                assert_util.assert_equal(
                    tf.shape(param)[-1],
                    1,
                    message=("Currently only bimixtures are supported; "
                             "{}.shape[-1] is not 1.".format(name))))

        if assertions:
            return distribution_util.with_dependencies(assertions, param)
        return param
示例#6
0
 def _assertions(self, x):
   if not self.validate_args:
     return []
   x_shape = tf.shape(x)
   is_matrix = assert_util.assert_rank_at_least(
       x, 2, message="Input must have rank at least 2.")
   is_square = assert_util.assert_equal(
       x_shape[-2], x_shape[-1], message="Input must be a square matrix.")
   diag_part_x = tf.linalg.diag_part(x)
   is_lower_triangular = assert_util.assert_equal(
       tf.linalg.band_part(x, 0, -1),  # Preserves triu, zeros rest.
       tf.linalg.diag(diag_part_x),
       message="Input must be lower triangular.")
   is_positive_diag = assert_util.assert_positive(
       diag_part_x, message="Input must have all positive diagonal entries.")
   return [is_matrix, is_square, is_lower_triangular, is_positive_diag]
def vector_size_to_square_matrix_size(d, validate_args, name=None):
    """Convert a vector size to a matrix size."""
    if isinstance(d, (float, int, np.generic, np.ndarray)):
        n = (-1 + np.sqrt(1 + 8 * d)) / 2.
        if float(int(n)) != n:
            raise ValueError(
                'Vector length {} is not a triangular number.'.format(d))
        return int(n)
    else:
        with tf.name_scope(name
                           or 'vector_size_to_square_matrix_size') as name:
            n = (-1. + tf.sqrt(1 + 8. * tf.cast(d, dtype=tf.float32))) / 2.
            if validate_args:
                with tf.control_dependencies([
                        assert_util.assert_equal(
                            tf.cast(tf.cast(n, dtype=tf.int32),
                                    dtype=tf.float32),
                            n,
                            data=[
                                'Vector length is not a triangular number: ', d
                            ],
                            message='Vector length is not a triangular number')
                ]):
                    n = tf.identity(n)
            return tf.cast(n, d.dtype)
def _maybe_check_valid_map_values(map_values, validate_args):
    """Validate `map_values` if `validate_args`==True."""
    assertions = []

    message = 'Rank of map_values must be 1.'
    if tensorshape_util.rank(map_values.shape) is not None:
        if tensorshape_util.rank(map_values.shape) != 1:
            raise ValueError(message)
    elif validate_args:
        assertions.append(
            assert_util.assert_rank(map_values, 1, message=message))

    message = 'Size of map_values must be greater than 0.'
    if tensorshape_util.num_elements(map_values.shape) is not None:
        if tensorshape_util.num_elements(map_values.shape) == 0:
            raise ValueError(message)
    elif validate_args:
        assertions.append(
            assert_util.assert_greater(tf.size(map_values), 0,
                                       message=message))

    if validate_args:
        assertions.append(
            assert_util.assert_equal(
                tf.math.is_strictly_increasing(map_values),
                True,
                message='map_values is not strictly increasing.'))

    return assertions
    def _maybe_validate_shape_override(self, override_shape, base_is_scalar,
                                       validate_args, name):
        """Helper to __init__ which ensures override batch/event_shape are valid."""
        if override_shape is None:
            override_shape = []

        override_shape = tf.convert_to_tensor(override_shape,
                                              dtype=tf.int32,
                                              name=name)

        if not dtype_util.is_integer(override_shape.dtype):
            raise TypeError("shape override must be an integer")

        override_is_scalar = _is_scalar_from_shape_tensor(override_shape)
        if tf.get_static_value(override_is_scalar):
            return self._empty

        dynamic_assertions = []

        if tensorshape_util.rank(override_shape.shape) is not None:
            if tensorshape_util.rank(override_shape.shape) != 1:
                raise ValueError("shape override must be a vector")
        elif validate_args:
            dynamic_assertions += [
                assert_util.assert_rank(
                    override_shape,
                    1,
                    message="shape override must be a vector")
            ]

        if tf.get_static_value(override_shape) is not None:
            if any(s < 0 for s in tf.get_static_value(override_shape)):
                raise ValueError(
                    "shape override must have non-negative elements")
        elif validate_args:
            dynamic_assertions += [
                assert_util.assert_non_negative(
                    override_shape,
                    message="shape override must have non-negative elements")
            ]

        is_both_nonscalar = prefer_static.logical_and(
            prefer_static.logical_not(base_is_scalar),
            prefer_static.logical_not(override_is_scalar))
        if tf.get_static_value(is_both_nonscalar) is not None:
            if tf.get_static_value(is_both_nonscalar):
                raise ValueError("base distribution not scalar")
        elif validate_args:
            dynamic_assertions += [
                assert_util.assert_equal(
                    is_both_nonscalar,
                    False,
                    message="base distribution not scalar")
            ]

        if not dynamic_assertions:
            return override_shape
        return distribution_util.with_dependencies(dynamic_assertions,
                                                   override_shape)
 def _assertions(self, t):
   if self.validate_args:
     return []
   is_matrix = assert_util.assert_rank_at_least(t, 2)
   is_square = assert_util.assert_equal(tf.shape(t)[-2], tf.shape(t)[-1])
   is_positive_definite = assert_util.assert_positive(
       tf.linalg.diag_part(t), message="Input must be positive definite.")
   return [is_matrix, is_square, is_positive_definite]
 def _observation_mask_shape_preconditions(self, observation_tensor_shape,
                                           mask_tensor_shape):
     shape_condition = [
         assert_util.assert_equal(
             observation_tensor_shape[-1 - self._underlying_event_rank],
             self._num_steps,
             message="The tensor `observations` must consist of sequences"
             "of observations from `HiddenMarkovModel` of length"
             "`num_steps`.")
     ]
     if mask_tensor_shape is not None:
         shape_condition.append(
             assert_util.assert_equal(
                 mask_tensor_shape[-1],
                 self._num_steps,
                 message="The tensor `mask` must consist of sequences"
                 "of length `num_steps`."))
     return tf.control_dependencies(shape_condition)
 def validate_equal_last_dim(tensor_a, tensor_b, message):
   event_size_a = tf.compat.dimension_value(tensor_a.shape[-1])
   event_size_b = tf.compat.dimension_value(tensor_b.shape[-1])
   if event_size_a is not None and event_size_b is not None:
     if event_size_a != event_size_b:
       raise ValueError(message)
   elif validate_args:
     return assert_util.assert_equal(
         tf.shape(tensor_a)[-1], tf.shape(tensor_b)[-1], message=message)
 def _inverse_event_shape_tensor(self, output_shape_tensor):
     batch_shape, n = output_shape_tensor[:-2], output_shape_tensor[-1]
     if self.validate_args:
         is_square_matrix = assert_util.assert_equal(
             n, output_shape_tensor[-2], message='Matrix must be square.')
         with tf.control_dependencies([is_square_matrix]):
             n = tf.identity(n)
     d = tf.cast(n * (n + 1) / 2, output_shape_tensor.dtype)
     return tf.concat([batch_shape, [d]], axis=0)
示例#14
0
    def _parameter_control_dependencies(self, is_init):
        assertions = []

        message = 'Distributions must have the same `batch_shape`'

        if is_init:
            batch_shapes = tf.nest.flatten(self._cached_batch_shape)
            if all(tensorshape_util.is_fully_defined(b) for b in batch_shapes):
                if batch_shapes[1:] != batch_shapes[:-1]:
                    raise ValueError('{}; found: {}.'.format(
                        message, batch_shapes))

        if not self.validate_args:
            assert not assertions  # Should never happen.
            return []

        if self.validate_args:
            batch_shapes = self._cached_batch_shape
            if not all(
                    tensorshape_util.is_fully_defined(s)
                    for s in tf.nest.flatten(batch_shapes)):
                batch_shapes = tf.nest.map_structure(
                    lambda static_shape, shape_tensor:  # pylint: disable=g-long-lambda
                    (static_shape if tensorshape_util.is_fully_defined(
                        static_shape) else shape_tensor),
                    batch_shapes,
                    self._cached_batch_shape_tensor)
            batch_shapes = tf.nest.flatten(batch_shapes)
            assertions.extend(
                assert_util.assert_equal(  # pylint: disable=g-complex-comprehension
                    b1,
                    b2,
                    message='{}.'.format(message))
                for b1, b2 in zip(batch_shapes[1:], batch_shapes[:-1]))
            assertions.extend(
                assert_util.assert_equal(  # pylint: disable=g-complex-comprehension
                    tf.size(b1),
                    tf.size(b2),
                    message='{}.'.format(message))
                for b1, b2 in zip(batch_shapes[1:], batch_shapes[:-1]))

        return assertions
示例#15
0
 def _maybe_assert_valid_sample(self, counts):
     """Check counts for proper shape, values, then return tensor version."""
     if not self.validate_args:
         return []
     assertions = distribution_util.assert_nonnegative_integer_form(counts)
     assertions.append(
         assert_util.assert_equal(
             self.total_count,
             tf.reduce_sum(counts, axis=-1),
             message='counts must sum to `self.total_count`'))
     return assertions
示例#16
0
 def _maybe_assert_valid_sample(self, counts):
     """Check counts for proper shape, values, then return tensor version."""
     if not self.validate_args:
         return counts
     counts = distribution_util.embed_check_nonnegative_integer_form(counts)
     return distribution_util.with_dependencies([
         assert_util.assert_equal(
             self.total_count,
             tf.reduce_sum(counts, axis=-1),
             message='counts last-dimension must sum to `self.total_count`'
         ),
     ], counts)
示例#17
0
 def _z(self, x, scale, concentration):
     loc = tf.convert_to_tensor(self.loc)
     if self.validate_args:
         valid = (x >= loc) & ((concentration >= 0) |
                               (x <= loc - scale / concentration))
         with tf.control_dependencies([
                 assert_util.assert_equal(
                     valid,
                     True,
                     message='`x` outside distribution\'s support.')
         ]):
             x = tf.identity(x)
     return (x - loc) / scale
示例#18
0
    def _assert_compatible_shape(self, index, sample_shape, samples):
        requested_shape, _ = self._expand_sample_shape_to_vector(
            tf.convert_to_tensor(sample_shape, dtype=tf.int32),
            name='requested_shape')
        actual_shape = prefer_static.shape(samples)
        actual_rank = prefer_static.rank_from_shape(actual_shape)
        requested_rank = prefer_static.rank_from_shape(requested_shape)

        # We test for two properties we expect of yielded distributions:
        # (1) The rank of the tensor of generated samples must be at least
        #     as large as the rank requested.
        # (2) The requested shape must be a prefix of the shape of the
        #     generated tensor of samples.
        # We attempt to perform test (1) statically first.
        # We don't need to do this explicitly for test (2) because
        # `assert_equal` evaluates statically if it can.
        static_actual_rank = tf.get_static_value(actual_rank)
        static_requested_rank = tf.get_static_value(requested_rank)

        assertion_message = ('Samples yielded by distribution #{} are not '
                             'consistent with `sample_shape` passed to '
                             '`JointDistributionCoroutine` '
                             'distribution.'.format(index))

        # TODO Remove this static check (b/138738650)
        if (static_actual_rank is not None
                and static_requested_rank is not None):
            # We're able to statically check the rank
            if static_actual_rank < static_requested_rank:
                raise ValueError(assertion_message)
            else:
                control_dependencies = []
        else:
            # We're not able to statically check the rank
            control_dependencies = [
                assert_util.assert_greater_equal(actual_rank,
                                                 requested_rank,
                                                 message=assertion_message)
            ]

        with tf.control_dependencies(control_dependencies):
            trimmed_actual_shape = actual_shape[:requested_rank]

        control_dependencies = [
            assert_util.assert_equal(requested_shape,
                                     trimmed_actual_shape,
                                     message=assertion_message)
        ]

        return control_dependencies
 def _forward(self, x):
     map_values = tf.convert_to_tensor(self.map_values)
     if self.validate_args:
         with tf.control_dependencies([
                 assert_util.assert_equal(
                     (0 <= x) & (x < tf.size(map_values)),
                     True,
                     message='indices out of bound')
         ]):
             x = tf.identity(x)
     # If we want batch dims in self.map_values, we can (after broadcasting),
     # use:
     # tf.gather(self.map_values, x, batch_dims=-1, axis=-1)
     return tf.gather(map_values, indices=x)
 def _assertions(self, x):
     if not self.validate_args:
         return []
     shape = tf.shape(x)
     is_matrix = assert_util.assert_rank_at_least(
         x, 2, message="Input must have rank at least 2.")
     is_square = assert_util.assert_equal(
         shape[-2], shape[-1], message="Input must be a square matrix.")
     above_diagonal = tf.linalg.band_part(
         tf.linalg.set_diag(x, tf.zeros(shape[:-1], dtype=tf.float32)), 0,
         -1)
     is_lower_triangular = assert_util.assert_equal(
         above_diagonal,
         tf.zeros_like(above_diagonal),
         message="Input must be lower triangular.")
     # A lower triangular matrix is nonsingular iff all its diagonal entries are
     # nonzero.
     diag_part = tf.linalg.diag_part(x)
     is_nonsingular = assert_util.assert_none_equal(
         diag_part,
         tf.zeros_like(diag_part),
         message="Input must have all diagonal entries nonzero.")
     return [is_matrix, is_square, is_lower_triangular, is_nonsingular]
示例#21
0
 def _validate_dimension(self, x):
     x = tf.convert_to_tensor(x, name='x')
     if tensorshape_util.is_fully_defined(x.shape[-2:]):
         if (tensorshape_util.dims(x.shape)[-2] == tensorshape_util.dims(
                 x.shape)[-1] == self.dimension):
             pass
         else:
             raise ValueError(
                 'Input dimension mismatch: expected [..., {}, {}], got {}'.
                 format(self.dimension, self.dimension,
                        tensorshape_util.dims(x.shape)))
     elif self.validate_args:
         msg = 'Input dimension mismatch: expected [..., {}, {}], got {}'.format(
             self.dimension, self.dimension, tf.shape(x))
         with tf.control_dependencies([
                 assert_util.assert_equal(tf.shape(x)[-2],
                                          self.dimension,
                                          message=msg),
                 assert_util.assert_equal(tf.shape(x)[-1],
                                          self.dimension,
                                          message=msg)
         ]):
             x = tf.identity(x)
     return x
 def _maybe_assert_valid_sample(self, samples):
   """Check counts for proper shape, values, then return tensor version."""
   if not self.validate_args:
     return samples
   with tf.control_dependencies([
       assert_util.assert_near(
           1.,
           tf.linalg.norm(samples, axis=-1),
           message='samples must be unit length'),
       assert_util.assert_equal(
           tf.shape(samples)[-1:],
           self.event_shape_tensor(),
           message=('samples must have innermost dimension matching that of '
                    '`self.mean_direction`')),
   ]):
     return tf.identity(samples)
示例#23
0
 def _prob(self, x):
     if self.validate_args:
         is_vector_check = assert_util.assert_rank_at_least(x, 1)
         right_vec_space_check = assert_util.assert_equal(
             self.event_shape_tensor(),
             tf.gather(tf.shape(x),
                       tf.rank(x) - 1),
             message=
             "Argument 'x' not defined in the same space R^k as this distribution"
         )
         with tf.control_dependencies([is_vector_check]):
             with tf.control_dependencies([right_vec_space_check]):
                 x = tf.identity(x)
     loc = tf.convert_to_tensor(self.loc)
     return tf.cast(tf.reduce_all(tf.abs(x - loc) <= self._slack(loc),
                                  axis=-1),
                    dtype=self.dtype)
def maybe_check_wont_broadcast(flat_xs, validate_args):
  """Verifies that `parts` don't broadcast."""
  flat_xs = tuple(flat_xs)  # So we can receive generators.
  if not validate_args:
    # Note: we don't try static validation because it is theoretically
    # possible that a user wants to take advantage of broadcasting.
    # Only when `validate_args` is `True` do we enforce the validation.
    return flat_xs
  msg = 'Broadcasting probably indicates an error in model specification.'
  s = tuple(prefer_static.shape(x) for x in flat_xs)
  if all(prefer_static.is_numpy(s_) for s_ in s):
    if not all(np.all(a == b) for a, b in zip(s[1:], s[:-1])):
      raise ValueError(msg)
    return flat_xs
  assertions = [assert_util.assert_equal(a, b, message=msg)
                for a, b in zip(s[1:], s[:-1])]
  with tf.control_dependencies(assertions):
    return tuple(tf.identity(x) for x in flat_xs)
示例#25
0
def _kl_sample(a, b, name='kl_sample'):
    """Batched KL divergence `KL(a || b)` for Sample distributions.

  We can leverage the fact that:

  ```
  KL(Sample(a) || Sample(b)) = sum(KL(a || b))
  ```

  where the sum is over the `sample_shape` dims.

  Args:
    a: Instance of `Sample` distribution.
    b: Instance of `Sample` distribution.
    name: (optional) name to use for created ops.
      Default value: `"kl_sample"`'.

  Returns:
    kldiv: Batchwise `KL(a || b)`.

  Raises:
    ValueError: If the `sample_shape` of `a` and `b` don't match.
  """
    assertions = []
    a_ss = tf.get_static_value(a.sample_shape)
    b_ss = tf.get_static_value(b.sample_shape)
    msg = '`a.sample_shape` must be identical to `b.sample_shape`.'
    if a_ss is not None and b_ss is not None:
        if not np.array_equal(a_ss, b_ss):
            raise ValueError(msg)
    elif a.validate_args or b.validate_args:
        assertions.append(
            assert_util.assert_equal(a.sample_shape,
                                     b.sample_shape,
                                     message=msg))
    with tf.control_dependencies(assertions):
        kl = kullback_leibler.kl_divergence(a.distribution,
                                            b.distribution,
                                            name=name)
        n = prefer_static.reduce_prod(a.sample_shape)
        return tf.cast(x=n, dtype=kl.dtype) * kl
示例#26
0
def _lu_solve_assertions(lower_upper, perm, rhs, validate_args):
    """Returns list of assertions related to `lu_solve` assumptions."""
    assertions = lu_reconstruct_assertions(lower_upper, perm, validate_args)

    message = 'Input `rhs` must have at least 2 dimensions.'
    if rhs.shape.ndims is not None:
        if rhs.shape.ndims < 2:
            raise ValueError(message)
    elif validate_args:
        assertions.append(
            assert_util.assert_rank_at_least(rhs, rank=2, message=message))

    message = '`lower_upper.shape[-1]` must equal `rhs.shape[-1]`.'
    if (tf.compat.dimension_value(lower_upper.shape[-1]) is not None
            and tf.compat.dimension_value(rhs.shape[-2]) is not None):
        if lower_upper.shape[-1] != rhs.shape[-2]:
            raise ValueError(message)
    elif validate_args:
        assertions.append(
            assert_util.assert_equal(tf.shape(lower_upper)[-1],
                                     tf.shape(rhs)[-2],
                                     message=message))

    return assertions
示例#27
0
    def _validate_sample_arg(self, x):
        """Helper which validates sample arg, e.g., input to `log_prob`."""
        with tf.name_scope('validate_sample_arg'):
            x_ndims = (tf.rank(x) if tensorshape_util.rank(x.shape) is None
                       else tensorshape_util.rank(x.shape))
            event_ndims = (tf.size(self.event_shape_tensor())
                           if tensorshape_util.rank(self.event_shape) is None
                           else tensorshape_util.rank(self.event_shape))
            batch_ndims = (tf.size(self._batch_shape_unexpanded)
                           if tensorshape_util.rank(self.batch_shape) is None
                           else tensorshape_util.rank(self.batch_shape))
            expected_batch_event_ndims = batch_ndims + event_ndims

            if (isinstance(x_ndims, int)
                    and isinstance(expected_batch_event_ndims, int)):
                if x_ndims < expected_batch_event_ndims:
                    raise NotImplementedError(
                        'Broadcasting is not supported; too few batch and event dims '
                        '(expected at least {}, saw {}).'.format(
                            expected_batch_event_ndims, x_ndims))
                ndims_assertion = []
            elif self.validate_args:
                ndims_assertion = [
                    assert_util.assert_greater_equal(
                        x_ndims,
                        expected_batch_event_ndims,
                        message=('Broadcasting is not supported; too few '
                                 'batch and event dims.'),
                        name='assert_batch_and_event_ndims_large_enough'),
                ]

            if (tensorshape_util.is_fully_defined(self.batch_shape)
                    and tensorshape_util.is_fully_defined(self.event_shape)):
                expected_batch_event_shape = np.int32(
                    tensorshape_util.concatenate(self.batch_shape,
                                                 self.event_shape))
            else:
                expected_batch_event_shape = tf.concat([
                    self.batch_shape_tensor(),
                    self.event_shape_tensor(),
                ],
                                                       axis=0)

            sample_ndims = x_ndims - expected_batch_event_ndims
            if isinstance(sample_ndims, int):
                sample_ndims = max(sample_ndims, 0)
            if (isinstance(sample_ndims, int) and
                    tensorshape_util.is_fully_defined(x.shape[sample_ndims:])):
                actual_batch_event_shape = np.int32(x.shape[sample_ndims:])
            else:
                sample_ndims = tf.maximum(sample_ndims, 0)
                actual_batch_event_shape = tf.shape(x)[sample_ndims:]

            if (isinstance(expected_batch_event_shape, np.ndarray)
                    and isinstance(actual_batch_event_shape, np.ndarray)):
                if any(expected_batch_event_shape != actual_batch_event_shape):
                    raise NotImplementedError(
                        'Broadcasting is not supported; '
                        'unexpected batch and event shape '
                        '(expected {}, saw {}).'.format(
                            expected_batch_event_shape,
                            actual_batch_event_shape))
                # We need to set the final runtime-assertions to `ndims_assertion` since
                # its possible this assertion was created. We could add a condition to
                # only do so if `self.validate_args == True`, however this is redundant
                # as `ndims_assertion` already encodes this information.
                runtime_assertions = ndims_assertion
            elif self.validate_args:
                # We need to make the `ndims_assertion` a control dep because otherwise
                # TF itself might raise an exception owing to this assertion being
                # ill-defined, ie, one cannot even compare different rank Tensors.
                with tf.control_dependencies(ndims_assertion):
                    shape_assertion = assert_util.assert_equal(
                        expected_batch_event_shape,
                        actual_batch_event_shape,
                        message=('Broadcasting is not supported; '
                                 'unexpected batch and event shape.'),
                        name='assert_batch_and_event_shape_same')
                runtime_assertions = [shape_assertion]
            else:
                runtime_assertions = []

            return runtime_assertions
示例#28
0
  def __init__(self,
               cat,
               components,
               validate_args=False,
               allow_nan_stats=True,
               use_static_graph=False,
               name="Mixture"):
    """Initialize a Mixture distribution.

    A `Mixture` is defined by a `Categorical` (`cat`, representing the
    mixture probabilities) and a list of `Distribution` objects
    all having matching dtype, batch shape, event shape, and continuity
    properties (the components).

    The `num_classes` of `cat` must be possible to infer at graph construction
    time and match `len(components)`.

    Args:
      cat: A `Categorical` distribution instance, representing the probabilities
          of `distributions`.
      components: A list or tuple of `Distribution` instances.
        Each instance must have the same type, be defined on the same domain,
        and have matching `event_shape` and `batch_shape`.
      validate_args: Python `bool`, default `False`. If `True`, raise a runtime
        error if batch or event ranks are inconsistent between cat and any of
        the distributions. This is only checked if the ranks cannot be
        determined statically at graph construction time.
      allow_nan_stats: Boolean, default `True`. If `False`, raise an
       exception if a statistic (e.g. mean/mode/etc...) is undefined for any
        batch member. If `True`, batch members with valid parameters leading to
        undefined statistics will return NaN for this statistic.
      use_static_graph: Calls to `sample` will not rely on dynamic tensor
        indexing, allowing for some static graph compilation optimizations, but
        at the expense of sampling all underlying distributions in the mixture.
        (Possibly useful when running on TPUs).
        Default value: `False` (i.e., use dynamic indexing).
      name: A name for this distribution (optional).

    Raises:
      TypeError: If cat is not a `Categorical`, or `components` is not
        a list or tuple, or the elements of `components` are not
        instances of `Distribution`, or do not have matching `dtype`.
      ValueError: If `components` is an empty list or tuple, or its
        elements do not have a statically known event rank.
        If `cat.num_classes` cannot be inferred at graph creation time,
        or the constant value of `cat.num_classes` is not equal to
        `len(components)`, or all `components` and `cat` do not have
        matching static batch shapes, or all components do not
        have matching static event shapes.
    """
    parameters = dict(locals())
    if not isinstance(cat, categorical.Categorical):
      raise TypeError("cat must be a Categorical distribution, but saw: %s" %
                      cat)
    if not components:
      raise ValueError("components must be a non-empty list or tuple")
    if not isinstance(components, (list, tuple)):
      raise TypeError("components must be a list or tuple, but saw: %s" %
                      components)
    if not all(isinstance(c, distribution.Distribution) for c in components):
      raise TypeError(
          "all entries in components must be Distribution instances"
          " but saw: %s" % components)

    dtype = components[0].dtype
    if not all(d.dtype == dtype for d in components):
      raise TypeError("All components must have the same dtype, but saw "
                      "dtypes: %s" % [(d.name, d.dtype) for d in components])
    static_event_shape = components[0].event_shape
    static_batch_shape = cat.batch_shape
    for di, d in enumerate(components):
      if not tensorshape_util.is_compatible_with(static_batch_shape,
                                                 d.batch_shape):
        raise ValueError(
            "components[{}] batch shape must be compatible with cat "
            "shape and other component batch shapes".format(di))
      static_event_shape = tensorshape_util.merge_with(
          static_event_shape, d.event_shape)
      static_batch_shape = tensorshape_util.merge_with(
          static_batch_shape, d.batch_shape)
    if tensorshape_util.rank(static_event_shape) is None:
      raise ValueError(
          "Expected to know rank(event_shape) from components, but "
          "none of the components provide a static number of ndims")

    # Ensure that all batch and event ndims are consistent.
    with tf.name_scope(name) as name:
      num_components = cat._num_categories()
      static_num_components = tf.get_static_value(num_components)
      if static_num_components is None:
        raise ValueError(
            "Could not infer number of classes from cat and unable "
            "to compare this value to the number of components passed in.")
      # Possibly convert from numpy 0-D array.
      static_num_components = int(static_num_components)
      if static_num_components != len(components):
        raise ValueError("cat.num_classes != len(components): %d vs. %d" %
                         (static_num_components, len(components)))

      cat_batch_shape = cat.batch_shape_tensor()
      cat_batch_rank = tf.size(cat_batch_shape)
      if validate_args:
        batch_shapes = [d.batch_shape_tensor() for d in components]
        batch_ranks = [tf.size(bs) for bs in batch_shapes]
        check_message = ("components[%d] batch shape must match cat "
                         "batch shape")
        self._assertions = [
            assert_util.assert_equal(
                cat_batch_rank, batch_ranks[di], message=check_message % di)
            for di in range(len(components))
        ]
        self._assertions += [
            assert_util.assert_equal(
                cat_batch_shape, batch_shapes[di], message=check_message % di)
            for di in range(len(components))
        ]
      else:
        self._assertions = []

      self._cat = cat
      self._components = list(components)
      self._num_components = static_num_components
      self._static_event_shape = static_event_shape
      self._static_batch_shape = static_batch_shape

      self._use_static_graph = use_static_graph
      if use_static_graph and static_num_components is None:
        raise ValueError("Number of categories must be known statically when "
                         "`static_sample=True`.")

    super(Mixture, self).__init__(
        dtype=dtype,
        reparameterization_type=reparameterization.NOT_REPARAMETERIZED,
        validate_args=validate_args,
        allow_nan_stats=allow_nan_stats,
        parameters=parameters,
        name=name)
示例#29
0
    def __init__(self,
                 mixture_distribution,
                 components_distribution,
                 reparameterize=False,
                 validate_args=False,
                 allow_nan_stats=True,
                 name="MixtureSameFamily"):
        """Construct a `MixtureSameFamily` distribution.

    Args:
      mixture_distribution: `tfp.distributions.Categorical`-like instance.
        Manages the probability of selecting components. The number of
        categories must match the rightmost batch dimension of the
        `components_distribution`. Must have either scalar `batch_shape` or
        `batch_shape` matching `components_distribution.batch_shape[:-1]`.
      components_distribution: `tfp.distributions.Distribution`-like instance.
        Right-most batch dimension indexes components.
      reparameterize: Python `bool`, default `False`. Whether to reparameterize
        samples of the distribution using implicit reparameterization gradients
        [(Figurnov et al., 2018)][1]. The gradients for the mixture logits are
        equivalent to the ones described by [(Graves, 2016)][2]. The gradients
        for the components parameters are also computed using implicit
        reparameterization (as opposed to ancestral sampling), meaning that
        all components are updated every step.
        Only works when:
          (1) components_distribution is fully reparameterized;
          (2) components_distribution is either a scalar distribution or
          fully factorized (tfd.Independent applied to a scalar distribution);
          (3) batch shape has a known rank.
        Experimental, may be slow and produce infs/NaNs.
      validate_args: Python `bool`, default `False`. When `True` distribution
        parameters are checked for validity despite possibly degrading runtime
        performance. When `False` invalid inputs may silently render incorrect
        outputs.
      allow_nan_stats: Python `bool`, default `True`. When `True`, statistics
        (e.g., mean, mode, variance) use the value "`NaN`" to indicate the
        result is undefined. When `False`, an exception is raised if one or
        more of the statistic's batch members are undefined.
      name: Python `str` name prefixed to Ops created by this class.

    Raises:
      ValueError: `if not dtype_util.is_integer(mixture_distribution.dtype)`.
      ValueError: if mixture_distribution does not have scalar `event_shape`.
      ValueError: if `mixture_distribution.batch_shape` and
        `components_distribution.batch_shape[:-1]` are both fully defined and
        the former is neither scalar nor equal to the latter.
      ValueError: if `mixture_distribution` categories does not equal
        `components_distribution` rightmost batch shape.

    #### References

    [1]: Michael Figurnov, Shakir Mohamed and Andriy Mnih. Implicit
         reparameterization gradients. In _Neural Information Processing
         Systems_, 2018. https://arxiv.org/abs/1805.08498

    [2]: Alex Graves. Stochastic Backpropagation through Mixture Density
         Distributions. _arXiv_, 2016. https://arxiv.org/abs/1607.05690
    """
        parameters = dict(locals())
        with tf.name_scope(name) as name:
            self._mixture_distribution = mixture_distribution
            self._components_distribution = components_distribution
            self._runtime_assertions = []

            s = components_distribution.event_shape_tensor()
            self._event_ndims = tf.compat.dimension_value(s.shape[0])
            if self._event_ndims is None:
                self._event_ndims = tf.size(s)
            self._event_size = tf.reduce_prod(s)

            if not dtype_util.is_integer(mixture_distribution.dtype):
                raise ValueError(
                    "`mixture_distribution.dtype` ({}) is not over integers".
                    format(dtype_util.name(mixture_distribution.dtype)))

            if (tensorshape_util.rank(mixture_distribution.event_shape)
                    is not None and tensorshape_util.rank(
                        mixture_distribution.event_shape) != 0):
                raise ValueError(
                    "`mixture_distribution` must have scalar `event_dim`s")
            elif validate_args:
                self._runtime_assertions += [
                    assert_util.assert_equal(
                        tf.size(mixture_distribution.event_shape_tensor()),
                        0,
                        message=
                        "`mixture_distribution` must have scalar `event_dim`s"
                    ),
                ]

            mdbs = mixture_distribution.batch_shape
            cdbs = tensorshape_util.with_rank_at_least(
                components_distribution.batch_shape, 1)[:-1]
            if tensorshape_util.is_fully_defined(
                    mdbs) and tensorshape_util.is_fully_defined(cdbs):
                if tensorshape_util.rank(mdbs) != 0 and mdbs != cdbs:
                    raise ValueError(
                        "`mixture_distribution.batch_shape` (`{}`) is not "
                        "compatible with `components_distribution.batch_shape` "
                        "(`{}`)".format(tensorshape_util.as_list(mdbs),
                                        tensorshape_util.as_list(cdbs)))
            elif validate_args:
                mdbs = mixture_distribution.batch_shape_tensor()
                cdbs = components_distribution.batch_shape_tensor()[:-1]
                self._runtime_assertions += [
                    assert_util.assert_equal(
                        distribution_utils.pick_vector(
                            mixture_distribution.is_scalar_batch(), cdbs,
                            mdbs),
                        cdbs,
                        message=
                        ("`mixture_distribution.batch_shape` is not "
                         "compatible with `components_distribution.batch_shape`"
                         ))
                ]

            mixture_dist_param = (mixture_distribution.probs
                                  if mixture_distribution.logits is None else
                                  mixture_distribution.logits)
            km = tf.compat.dimension_value(
                tensorshape_util.with_rank_at_least(mixture_dist_param.shape,
                                                    1)[-1])
            kc = tf.compat.dimension_value(
                tensorshape_util.with_rank_at_least(
                    components_distribution.batch_shape, 1)[-1])
            if km is not None and kc is not None and km != kc:
                raise ValueError(
                    "`mixture_distribution components` ({}) does not "
                    "equal `components_distribution.batch_shape[-1]` "
                    "({})".format(km, kc))
            elif validate_args:
                km = tf.shape(mixture_dist_param)[-1]
                kc = components_distribution.batch_shape_tensor()[-1]
                self._runtime_assertions += [
                    assert_util.assert_equal(
                        km,
                        kc,
                        message=(
                            "`mixture_distribution components` does not equal "
                            "`components_distribution.batch_shape[-1:]`")),
                ]
            elif km is None:
                km = tf.shape(mixture_dist_param)[-1]

            self._num_components = km

            self._reparameterize = reparameterize
            if reparameterize:
                # Note: tfd.Independent passes through the reparameterization type hence
                # we do not need separate logic for Independent.
                if (self._components_distribution.reparameterization_type !=
                        reparameterization.FULLY_REPARAMETERIZED):
                    raise ValueError("Cannot reparameterize a mixture of "
                                     "non-reparameterized components.")
                reparameterization_type = reparameterization.FULLY_REPARAMETERIZED
            else:
                reparameterization_type = reparameterization.NOT_REPARAMETERIZED

            super(MixtureSameFamily, self).__init__(
                dtype=self._components_distribution.dtype,
                reparameterization_type=reparameterization_type,
                validate_args=validate_args,
                allow_nan_stats=allow_nan_stats,
                parameters=parameters,
                name=name)
示例#30
0
    def _distributional_transform(self, x):
        """Performs distributional transform of the mixture samples.

    Distributional transform removes the parameters from samples of a
    multivariate distribution by applying conditional CDFs:
      (F(x_1), F(x_2 | x1_), ..., F(x_d | x_1, ..., x_d-1))
    (the indexing is over the "flattened" event dimensions).
    The result is a sample of product of Uniform[0, 1] distributions.

    We assume that the components are factorized, so the conditional CDFs become
      F(x_i | x_1, ..., x_i-1) = sum_k w_i^k F_k (x_i),
    where w_i^k is the posterior mixture weight: for i > 0
      w_i^k = w_k prob_k(x_1, ..., x_i-1) / sum_k' w_k' prob_k'(x_1, ..., x_i-1)
    and w_0^k = w_k is the mixture probability of the k-th component.

    Arguments:
      x: Sample of mixture distribution

    Returns:
      Result of the distributional transform
    """

        if tensorshape_util.rank(x.shape) is None:
            # tf.math.softmax raises an error when applied to inputs of undefined
            # rank.
            raise ValueError(
                "Distributional transform does not support inputs of "
                "undefined rank.")

        # Obtain factorized components distribution and assert that it's
        # a scalar distribution.
        if isinstance(self._components_distribution, independent.Independent):
            univariate_components = self._components_distribution.distribution
        else:
            univariate_components = self._components_distribution

        with tf.control_dependencies([
                assert_util.assert_equal(
                    univariate_components.is_scalar_event(),
                    True,
                    message="`univariate_components` must have scalar event")
        ]):
            x_padded = self._pad_sample_dims(x)  # [S, B, 1, E]
            log_prob_x = univariate_components.log_prob(
                x_padded)  # [S, B, k, E]
            cdf_x = univariate_components.cdf(x_padded)  # [S, B, k, E]

            # log prob_k (x_1, ..., x_i-1)
            cumsum_log_prob_x = tf.reshape(
                tf.math.cumsum(
                    # [S*prod(B)*k, prod(E)]
                    tf.reshape(log_prob_x, [-1, self._event_size]),
                    exclusive=True,
                    axis=-1),
                tf.shape(log_prob_x))  # [S, B, k, E]

            logits_mix_prob = distribution_utils.pad_mixture_dimensions(
                self.mixture_distribution.logits_parameter(), self,
                self.mixture_distribution, self._event_ndims)  # [B, k, 1]

            # Logits of the posterior weights: log w_k + log prob_k (x_1, ..., x_i-1)
            log_posterior_weights_x = logits_mix_prob + cumsum_log_prob_x

            component_axis = tensorshape_util.rank(x.shape) - self._event_ndims
            posterior_weights_x = tf.math.softmax(log_posterior_weights_x,
                                                  axis=component_axis)
            return tf.reduce_sum(posterior_weights_x * cdf_x,
                                 axis=component_axis)