Esempio n. 1
0
 def _inverse(self, y):
     map_values = tf.convert_to_tensor(self.map_values)
     flat_y = tf.reshape(y, shape=[-1])
     # Search for the indices of map_values that are closest to flat_y.
     # Since map_values is strictly increasing, the closest is either the
     # first one that is strictly greater than flat_y, or the one before it.
     upper_candidates = tf.minimum(
         tf.size(map_values) - 1,
         tf.searchsorted(map_values, values=flat_y, side='right'))
     lower_candidates = tf.maximum(0, upper_candidates - 1)
     candidates = tf.stack([lower_candidates, upper_candidates], axis=-1)
     lower_cand_diff = tf.abs(flat_y - self._forward(lower_candidates))
     upper_cand_diff = tf.abs(flat_y - self._forward(upper_candidates))
     if self.validate_args:
         with tf.control_dependencies([
                 assert_util.assert_near(tf.minimum(lower_cand_diff,
                                                    upper_cand_diff),
                                         0,
                                         message='inverse value not found')
         ]):
             candidates = tf.identity(candidates)
     candidate_selector = tf.stack([
         tf.range(tf.size(flat_y), dtype=tf.int32),
         tf.argmin([lower_cand_diff, upper_cand_diff], output_type=tf.int32)
     ],
                                   axis=-1)
     return tf.reshape(tf.gather_nd(candidates, candidate_selector),
                       shape=y.shape)
Esempio n. 2
0
def _maybe_check_valid_map_values(map_values, validate_args):
    """Validate `map_values` if `validate_args`==True."""
    assertions = []

    message = 'Rank of map_values must be 1.'
    if tensorshape_util.rank(map_values.shape) is not None:
        if tensorshape_util.rank(map_values.shape) != 1:
            raise ValueError(message)
    elif validate_args:
        assertions.append(
            assert_util.assert_rank(map_values, 1, message=message))

    message = 'Size of map_values must be greater than 0.'
    if tensorshape_util.num_elements(map_values.shape) is not None:
        if tensorshape_util.num_elements(map_values.shape) == 0:
            raise ValueError(message)
    elif validate_args:
        assertions.append(
            assert_util.assert_greater(tf.size(map_values), 0,
                                       message=message))

    if validate_args:
        assertions.append(
            assert_util.assert_equal(
                tf.math.is_strictly_increasing(map_values),
                True,
                message='map_values is not strictly increasing.'))

    return assertions
Esempio n. 3
0
def _maybe_validate_perm(perm, validate_args, name=None):
    """Checks that `perm` is valid."""
    with tf.name_scope(name or 'maybe_validate_perm'):
        assertions = []
        if not dtype_util.is_integer(perm.dtype):
            raise TypeError('`perm` must be integer type')

        msg = '`perm` must be a vector.'
        if tensorshape_util.rank(perm.shape) is not None:
            if tensorshape_util.rank(perm.shape) != 1:
                raise ValueError(msg[:-1] + ', saw rank: {}.'.format(
                    tensorshape_util.rank(perm.shape)))
        elif validate_args:
            assertions += [assert_util.assert_rank(perm, 1, message=msg)]

        perm_ = tf.get_static_value(perm)
        msg = '`perm` must be a valid permutation vector.'
        if perm_ is not None:
            if not np.all(np.arange(np.size(perm_)) == np.sort(perm_)):
                raise ValueError(msg[:-1] + ', saw: {}.'.format(perm_))
        elif validate_args:
            assertions += [
                assert_util.assert_equal(tf.sort(perm),
                                         tf.range(tf.size(perm)),
                                         message=msg)
            ]

        return assertions
Esempio n. 4
0
def _size(input, out_type=tf.int32, name=None):  # pylint: disable=redefined-builtin
  if not hasattr(input, 'shape'):
    x = np.array(input)
    input = tf.convert_to_tensor(input) if x.dtype is np.object else x
  n = tensorshape_util.num_elements(tf.TensorShape(input.shape))
  if n is None:
    return tf.size(input, out_type=out_type, name=name)
  return np.array(n).astype(_numpy_dtype(out_type))
Esempio n. 5
0
def rank_from_shape(shape_tensor_fn, tensorshape=None):
  """Computes `rank` given a `Tensor`'s `shape`."""

  if tensorshape is None:
    shape_tensor = (shape_tensor_fn() if callable(shape_tensor_fn)
                    else shape_tensor_fn)
    if (hasattr(shape_tensor, 'shape') and
        hasattr(shape_tensor.shape, 'num_elements')):
      ndims_ = tensorshape_util.num_elements(shape_tensor.shape)
    else:
      ndims_ = len(shape_tensor)
    ndims_fn = lambda: tf.size(shape_tensor)
  else:
    ndims_ = tensorshape_util.rank(tensorshape)
    ndims_fn = lambda: tf.size(  # pylint: disable=g-long-lambda
        shape_tensor_fn() if callable(shape_tensor_fn) else shape_tensor_fn)
  return ndims_fn() if ndims_ is None else ndims_
Esempio n. 6
0
    def _parameter_control_dependencies(self, is_init):
        assertions = []

        message = 'Distributions must have the same `batch_shape`'

        if is_init:
            batch_shapes = tf.nest.flatten(self._cached_batch_shape)
            if all(tensorshape_util.is_fully_defined(b) for b in batch_shapes):
                if batch_shapes[1:] != batch_shapes[:-1]:
                    raise ValueError('{}; found: {}.'.format(
                        message, batch_shapes))

        if not self.validate_args:
            assert not assertions  # Should never happen.
            return []

        if self.validate_args:
            batch_shapes = self._cached_batch_shape
            if not all(
                    tensorshape_util.is_fully_defined(s)
                    for s in tf.nest.flatten(batch_shapes)):
                batch_shapes = tf.nest.map_structure(
                    lambda static_shape, shape_tensor:  # pylint: disable=g-long-lambda
                    (static_shape if tensorshape_util.is_fully_defined(
                        static_shape) else shape_tensor),
                    batch_shapes,
                    self._cached_batch_shape_tensor)
            batch_shapes = tf.nest.flatten(batch_shapes)
            assertions.extend(
                assert_util.assert_equal(  # pylint: disable=g-complex-comprehension
                    b1,
                    b2,
                    message='{}.'.format(message))
                for b1, b2 in zip(batch_shapes[1:], batch_shapes[:-1]))
            assertions.extend(
                assert_util.assert_equal(  # pylint: disable=g-complex-comprehension
                    tf.size(b1),
                    tf.size(b2),
                    message='{}.'.format(message))
                for b1, b2 in zip(batch_shapes[1:], batch_shapes[:-1]))

        return assertions
Esempio n. 7
0
    def _maybe_get_static_event_ndims(self):
        if tensorshape_util.rank(self.event_shape) is not None:
            return tensorshape_util.rank(self.event_shape)

        event_ndims = tf.size(self.event_shape_tensor())
        event_ndims_ = distribution_util.maybe_get_static_value(event_ndims)

        if event_ndims_ is not None:
            return event_ndims_

        return event_ndims
 def _sample_shape(self, x):
     """Computes graph and static `sample_shape`."""
     x_ndims = (tf.rank(x) if tensorshape_util.rank(x.shape) is None else
                tensorshape_util.rank(x.shape))
     event_ndims = (tf.size(self.event_shape_tensor())
                    if tensorshape_util.rank(self.event_shape) is None else
                    tensorshape_util.rank(self.event_shape))
     batch_ndims = (tf.size(self._batch_shape_unexpanded)
                    if tensorshape_util.rank(self.batch_shape) is None else
                    tensorshape_util.rank(self.batch_shape))
     sample_ndims = x_ndims - batch_ndims - event_ndims
     if isinstance(sample_ndims, int):
         static_sample_shape = x.shape[:sample_ndims]
     else:
         static_sample_shape = tf.TensorShape(None)
     if tensorshape_util.is_fully_defined(static_sample_shape):
         sample_shape = np.int32(static_sample_shape)
     else:
         sample_shape = tf.shape(x)[:sample_ndims]
     return sample_shape, static_sample_shape
Esempio n. 9
0
 def _forward(self, x):
     map_values = tf.convert_to_tensor(self.map_values)
     if self.validate_args:
         with tf.control_dependencies([
                 assert_util.assert_equal(
                     (0 <= x) & (x < tf.size(map_values)),
                     True,
                     message='indices out of bound')
         ]):
             x = tf.identity(x)
     # If we want batch dims in self.map_values, we can (after broadcasting),
     # use:
     # tf.gather(self.map_values, x, batch_dims=-1, axis=-1)
     return tf.gather(map_values, indices=x)
Esempio n. 10
0
    def _entropy(self, **kwargs):
        if not self.bijector.is_constant_jacobian:
            raise NotImplementedError("entropy is not implemented")
        if not self.bijector._is_injective:  # pylint: disable=protected-access
            raise NotImplementedError("entropy is not implemented when "
                                      "bijector is not injective.")
        distribution_kwargs, bijector_kwargs = self._kwargs_split_fn(kwargs)
        # Suppose Y = g(X) where g is a diffeomorphism and X is a continuous rv. It
        # can be shown that:
        #   H[Y] = H[X] + E_X[(log o abs o det o J o g)(X)].
        # If is_constant_jacobian then:
        #   E_X[(log o abs o det o J o g)(X)] = (log o abs o det o J o g)(c)
        # where c can by anything.
        entropy = self.distribution.entropy(**distribution_kwargs)
        if self._is_maybe_event_override:
            # H[X] = sum_i H[X_i] if X_i are mutually independent.
            # This means that a reduce_sum is a simple rescaling.
            entropy = entropy * tf.cast(
                tf.reduce_prod(self._override_event_shape),
                dtype=dtype_util.base_dtype(entropy.dtype))
        if self._is_maybe_batch_override:
            new_shape = tf.concat([
                prefer_static.ones_like(self._override_batch_shape),
                self.distribution.batch_shape_tensor()
            ], 0)
            entropy = tf.reshape(entropy, new_shape)
            multiples = tf.concat([
                self._override_batch_shape,
                prefer_static.ones_like(self.distribution.batch_shape_tensor())
            ], 0)
            entropy = tf.tile(entropy, multiples)
        dummy = prefer_static.zeros(shape=tf.concat(
            [self.batch_shape_tensor(),
             self.event_shape_tensor()], 0),
                                    dtype=self.dtype)
        event_ndims = (tensorshape_util.rank(self.event_shape)
                       if tensorshape_util.rank(self.event_shape) is not None
                       else tf.size(self.event_shape_tensor()))
        ildj = self.bijector.inverse_log_det_jacobian(dummy,
                                                      event_ndims=event_ndims,
                                                      **bijector_kwargs)

        entropy = entropy - tf.cast(ildj, entropy.dtype)
        tensorshape_util.set_shape(entropy, self.batch_shape)
        return entropy
Esempio n. 11
0
 def _log_prob(self, x):
     x = tf.convert_to_tensor(x, name='x')
     right_indices = tf.minimum(
         tf.size(self.outcomes) - 1,
         tf.reshape(
             tf.searchsorted(self.outcomes,
                             values=tf.reshape(x, shape=[-1]),
                             side='right'),
             dist_util.prefer_static_shape(x)))
     use_right_indices = self._is_equal_or_close(
         x, tf.gather(self.outcomes, indices=right_indices))
     left_indices = tf.maximum(0, right_indices - 1)
     use_left_indices = self._is_equal_or_close(
         x, tf.gather(self.outcomes, indices=left_indices))
     log_probs = self._categorical.log_prob(
         tf.where(use_left_indices, left_indices, right_indices))
     return tf.where(tf.logical_not(use_left_indices | use_right_indices),
                     dtype_util.as_numpy_dtype(log_probs.dtype)(-np.inf),
                     log_probs)
Esempio n. 12
0
    def _forward_log_det_jacobian(self, x, **kwargs):
        x = tf.convert_to_tensor(x, name="x")

        fldj = tf.cast(0., dtype=dtype_util.base_dtype(x.dtype))

        if not self.bijectors:
            return fldj

        event_ndims = self._maybe_get_static_event_ndims(
            self.forward_min_event_ndims)

        if _use_static_shape(x, event_ndims):
            event_shape = x.shape[tensorshape_util.rank(x.shape) -
                                  event_ndims:]
        else:
            event_shape = tf.shape(x)[tf.rank(x) - event_ndims:]

        # TODO(b/129973548): Document and simplify.
        for b in reversed(self.bijectors):
            fldj = fldj + b.forward_log_det_jacobian(
                x, event_ndims=event_ndims, **kwargs.get(b.name, {}))
            if _use_static_shape(x, event_ndims):
                event_shape = b.forward_event_shape(event_shape)
                event_ndims = self._maybe_get_static_event_ndims(
                    tensorshape_util.rank(event_shape))
            else:
                event_shape = b.forward_event_shape_tensor(event_shape)
                event_shape_ = distribution_util.maybe_get_static_value(
                    event_shape)
                event_ndims = tf.size(event_shape)
                event_ndims_ = self._maybe_get_static_event_ndims(event_ndims)

                if event_ndims_ is not None and event_shape_ is not None:
                    event_ndims = event_ndims_
                    event_shape = event_shape_

            x = b.forward(x, **kwargs.get(b.name, {}))

        return fldj
Esempio n. 13
0
def _validate_block_sizes(block_sizes, bijectors, validate_args):
  """Helper to validate block sizes."""
  block_sizes_shape = block_sizes.shape
  if tensorshape_util.is_fully_defined(block_sizes_shape):
    if (tensorshape_util.rank(block_sizes_shape) != 1 or
        (tensorshape_util.num_elements(block_sizes_shape) != len(bijectors))):
      raise ValueError(
          '`block_sizes` must be `None`, or a vector of the same length as '
          '`bijectors`. Got a `Tensor` with shape {} and `bijectors` of '
          'length {}'.format(block_sizes_shape, len(bijectors)))
    return block_sizes
  elif validate_args:
    message = ('`block_sizes` must be `None`, or a vector of the same length '
               'as `bijectors`.')
    with tf.control_dependencies([
        assert_util.assert_equal(
            tf.size(block_sizes), len(bijectors), message=message),
        assert_util.assert_equal(tf.rank(block_sizes), 1)
    ]):
      return tf.identity(block_sizes)
  else:
    return block_sizes
Esempio n. 14
0
    def _inverse_log_det_jacobian(self, y, **kwargs):
        y = tf.convert_to_tensor(y, name="y")
        ildj = tf.cast(0., dtype=dtype_util.base_dtype(y.dtype))

        if not self.bijectors:
            return ildj

        event_ndims = self._maybe_get_static_event_ndims(
            self.inverse_min_event_ndims)

        if _use_static_shape(y, event_ndims):
            event_shape = y.shape[tensorshape_util.rank(y.shape) -
                                  event_ndims:]
        else:
            event_shape = tf.shape(y)[tf.rank(y) - event_ndims:]

        # TODO(b/129973548): Document and simplify.
        for b in self.bijectors:
            ildj = ildj + b.inverse_log_det_jacobian(
                y, event_ndims=event_ndims, **kwargs.get(b.name, {}))

            if _use_static_shape(y, event_ndims):
                event_shape = b.inverse_event_shape(event_shape)
                event_ndims = self._maybe_get_static_event_ndims(
                    tensorshape_util.rank(event_shape))
            else:
                event_shape = b.inverse_event_shape_tensor(event_shape)
                event_shape_ = distribution_util.maybe_get_static_value(
                    event_shape)
                event_ndims = tf.size(event_shape)
                event_ndims_ = self._maybe_get_static_event_ndims(event_ndims)

                if event_ndims_ is not None and event_shape_ is not None:
                    event_ndims = event_ndims_
                    event_shape = event_shape_

            y = b.inverse(y, **kwargs.get(b.name, {}))
        return ildj
Esempio n. 15
0
 def _inverse_event_shape_tensor(self, output_shape):
     perm = self._make_perm(tf.size(output_shape), tf.argsort(self.perm))
     return tf.gather(output_shape, perm)
Esempio n. 16
0
 def _forward_event_shape_tensor(self, input_shape):
     perm = self._make_perm(tf.size(input_shape), self.perm)
     return tf.gather(input_shape, perm)
Esempio n. 17
0
def _slice_single_param(param, param_event_ndims, slices, dist_batch_shape):
    """Slices a single parameter of a distribution.

  Args:
    param: A `Tensor`, the original parameter to slice.
    param_event_ndims: `int` event parameterization rank for this parameter.
    slices: A `tuple` of normalized slices.
    dist_batch_shape: The distribution's batch shape `Tensor`.

  Returns:
    new_param: A `Tensor`, batch-sliced according to slices.
  """
    # Extend param shape with ones on the left to match dist_batch_shape.
    param_shape = tf.shape(input=param)
    insert_ones = tf.ones(
        [tf.size(input=dist_batch_shape) + param_event_ndims - tf.rank(param)],
        dtype=param_shape.dtype)
    new_param_shape = tf.concat([insert_ones, param_shape], axis=0)
    full_batch_param = tf.reshape(param, new_param_shape)
    param_slices = []
    # We separately track the batch axis from the parameter axis because we want
    # them to align for positive indexing, and be offset by param_event_ndims for
    # negative indexing.
    param_dim_idx = 0
    batch_dim_idx = 0
    for slc in slices:
        if slc is tf.newaxis:
            param_slices.append(slc)
            continue
        if slc is Ellipsis:
            if batch_dim_idx < 0:
                raise ValueError(
                    'Found multiple `...` in slices {}'.format(slices))
            param_slices.append(slc)
            # Switch over to negative indexing for the broadcast check.
            num_remaining_non_newaxis_slices = sum([
                s is not tf.newaxis
                for s in slices[slices.index(Ellipsis) + 1:]
            ])
            batch_dim_idx = -num_remaining_non_newaxis_slices
            param_dim_idx = batch_dim_idx - param_event_ndims
            continue
        # Find the batch dimension sizes for both parameter and distribution.
        param_dim_size = new_param_shape[param_dim_idx]
        batch_dim_size = dist_batch_shape[batch_dim_idx]
        is_broadcast = batch_dim_size > param_dim_size
        # Slices are denoted by start:stop:step.
        if isinstance(slc, slice):
            start, stop, step = slc.start, slc.stop, slc.step
            if start is not None:
                start = tf.where(is_broadcast, 0, start)
            if stop is not None:
                stop = tf.where(is_broadcast, 1, stop)
            if step is not None:
                step = tf.where(is_broadcast, 1, step)
            param_slices.append(slice(start, stop, step))
        else:  # int, or int Tensor, e.g. d[d.batch_shape_tensor()[0] // 2]
            param_slices.append(tf.where(is_broadcast, 0, slc))
        param_dim_idx += 1
        batch_dim_idx += 1
    param_slices.extend([ALL_SLICE] * param_event_ndims)
    return full_batch_param.__getitem__(param_slices)
Esempio n. 18
0
    def __init__(self,
                 cat,
                 components,
                 validate_args=False,
                 allow_nan_stats=True,
                 use_static_graph=False,
                 name="Mixture"):
        """Initialize a Mixture distribution.

    A `Mixture` is defined by a `Categorical` (`cat`, representing the
    mixture probabilities) and a list of `Distribution` objects
    all having matching dtype, batch shape, event shape, and continuity
    properties (the components).

    The `num_classes` of `cat` must be possible to infer at graph construction
    time and match `len(components)`.

    Args:
      cat: A `Categorical` distribution instance, representing the probabilities
          of `distributions`.
      components: A list or tuple of `Distribution` instances.
        Each instance must have the same type, be defined on the same domain,
        and have matching `event_shape` and `batch_shape`.
      validate_args: Python `bool`, default `False`. If `True`, raise a runtime
        error if batch or event ranks are inconsistent between cat and any of
        the distributions. This is only checked if the ranks cannot be
        determined statically at graph construction time.
      allow_nan_stats: Boolean, default `True`. If `False`, raise an
       exception if a statistic (e.g. mean/mode/etc...) is undefined for any
        batch member. If `True`, batch members with valid parameters leading to
        undefined statistics will return NaN for this statistic.
      use_static_graph: Calls to `sample` will not rely on dynamic tensor
        indexing, allowing for some static graph compilation optimizations, but
        at the expense of sampling all underlying distributions in the mixture.
        (Possibly useful when running on TPUs).
        Default value: `False` (i.e., use dynamic indexing).
      name: A name for this distribution (optional).

    Raises:
      TypeError: If cat is not a `Categorical`, or `components` is not
        a list or tuple, or the elements of `components` are not
        instances of `Distribution`, or do not have matching `dtype`.
      ValueError: If `components` is an empty list or tuple, or its
        elements do not have a statically known event rank.
        If `cat.num_classes` cannot be inferred at graph creation time,
        or the constant value of `cat.num_classes` is not equal to
        `len(components)`, or all `components` and `cat` do not have
        matching static batch shapes, or all components do not
        have matching static event shapes.
    """
        parameters = dict(locals())
        if not isinstance(cat, categorical.Categorical):
            raise TypeError(
                "cat must be a Categorical distribution, but saw: %s" % cat)
        if not components:
            raise ValueError("components must be a non-empty list or tuple")
        if not isinstance(components, (list, tuple)):
            raise TypeError("components must be a list or tuple, but saw: %s" %
                            components)
        if not all(
                isinstance(c, distribution.Distribution) for c in components):
            raise TypeError(
                "all entries in components must be Distribution instances"
                " but saw: %s" % components)

        dtype = components[0].dtype
        if not all(d.dtype == dtype for d in components):
            raise TypeError("All components must have the same dtype, but saw "
                            "dtypes: %s" % [(d.name, d.dtype)
                                            for d in components])
        static_event_shape = components[0].event_shape
        static_batch_shape = cat.batch_shape
        for di, d in enumerate(components):
            if not tensorshape_util.is_compatible_with(static_batch_shape,
                                                       d.batch_shape):
                raise ValueError(
                    "components[{}] batch shape must be compatible with cat "
                    "shape and other component batch shapes".format(di))
            static_event_shape = tensorshape_util.merge_with(
                static_event_shape, d.event_shape)
            static_batch_shape = tensorshape_util.merge_with(
                static_batch_shape, d.batch_shape)
        if tensorshape_util.rank(static_event_shape) is None:
            raise ValueError(
                "Expected to know rank(event_shape) from components, but "
                "none of the components provide a static number of ndims")

        # Ensure that all batch and event ndims are consistent.
        with tf.name_scope(name) as name:
            num_components = cat._num_categories()
            static_num_components = tf.get_static_value(num_components)
            if static_num_components is None:
                raise ValueError(
                    "Could not infer number of classes from cat and unable "
                    "to compare this value to the number of components passed in."
                )
            # Possibly convert from numpy 0-D array.
            static_num_components = int(static_num_components)
            if static_num_components != len(components):
                raise ValueError(
                    "cat.num_classes != len(components): %d vs. %d" %
                    (static_num_components, len(components)))

            cat_batch_shape = cat.batch_shape_tensor()
            cat_batch_rank = tf.size(cat_batch_shape)
            if validate_args:
                batch_shapes = [d.batch_shape_tensor() for d in components]
                batch_ranks = [tf.size(bs) for bs in batch_shapes]
                check_message = ("components[%d] batch shape must match cat "
                                 "batch shape")
                self._assertions = [
                    assert_util.assert_equal(cat_batch_rank,
                                             batch_ranks[di],
                                             message=check_message % di)
                    for di in range(len(components))
                ]
                self._assertions += [
                    assert_util.assert_equal(cat_batch_shape,
                                             batch_shapes[di],
                                             message=check_message % di)
                    for di in range(len(components))
                ]
            else:
                self._assertions = []

            self._cat = cat
            self._components = list(components)
            self._num_components = static_num_components
            self._static_event_shape = static_event_shape
            self._static_batch_shape = static_batch_shape

            self._use_static_graph = use_static_graph
            if use_static_graph and static_num_components is None:
                raise ValueError(
                    "Number of categories must be known statically when "
                    "`static_sample=True`.")

        super(Mixture, self).__init__(
            dtype=dtype,
            reparameterization_type=reparameterization.NOT_REPARAMETERIZED,
            validate_args=validate_args,
            allow_nan_stats=allow_nan_stats,
            parameters=parameters,
            name=name)
Esempio n. 19
0
    def _sample_n(self, n, seed=None):
        if self._use_static_graph:
            with tf.control_dependencies(self._assertions):
                # This sampling approach is almost the same as the approach used by
                # `MixtureSameFamily`. The differences are due to having a list of
                # `Distribution` objects rather than a single object, and maintaining
                # random seed management that is consistent with the non-static code
                # path.
                samples = []
                cat_samples = self.cat.sample(n, seed=seed)
                stream = SeedStream(seed, salt="Mixture")

                for c in range(self.num_components):
                    samples.append(self.components[c].sample(n, seed=stream()))
                stack_axis = -1 - tensorshape_util.rank(
                    self._static_event_shape)
                x = tf.stack(samples, axis=stack_axis)  # [n, B, k, E]
                npdt = dtype_util.as_numpy_dtype(x.dtype)
                mask = tf.one_hot(
                    indices=cat_samples,  # [n, B]
                    depth=self._num_components,  # == k
                    on_value=npdt(1),
                    off_value=npdt(0))  # [n, B, k]
                mask = distribution_util.pad_mixture_dimensions(
                    mask, self, self._cat,
                    tensorshape_util.rank(
                        self._static_event_shape))  # [n, B, k, [1]*e]
                return tf.reduce_sum(x * mask, axis=stack_axis)  # [n, B, E]

        with tf.control_dependencies(self._assertions):
            n = tf.convert_to_tensor(n, name="n")
            static_n = tf.get_static_value(n)
            n = int(static_n) if static_n is not None else n
            cat_samples = self.cat.sample(n, seed=seed)

            static_samples_shape = cat_samples.shape
            if tensorshape_util.is_fully_defined(static_samples_shape):
                samples_shape = tensorshape_util.as_list(static_samples_shape)
                samples_size = tensorshape_util.num_elements(
                    static_samples_shape)
            else:
                samples_shape = tf.shape(cat_samples)
                samples_size = tf.size(cat_samples)
            static_batch_shape = self.batch_shape
            if tensorshape_util.is_fully_defined(static_batch_shape):
                batch_shape = tensorshape_util.as_list(static_batch_shape)
                batch_size = tensorshape_util.num_elements(static_batch_shape)
            else:
                batch_shape = self.batch_shape_tensor()
                batch_size = tf.reduce_prod(batch_shape)
            static_event_shape = self.event_shape
            if tensorshape_util.is_fully_defined(static_event_shape):
                event_shape = np.array(
                    tensorshape_util.as_list(static_event_shape),
                    dtype=np.int32)
            else:
                event_shape = self.event_shape_tensor()

            # Get indices into the raw cat sampling tensor. We will
            # need these to stitch sample values back out after sampling
            # within the component partitions.
            samples_raw_indices = tf.reshape(tf.range(0, samples_size),
                                             samples_shape)

            # Partition the raw indices so that we can use
            # dynamic_stitch later to reconstruct the samples from the
            # known partitions.
            partitioned_samples_indices = tf.dynamic_partition(
                data=samples_raw_indices,
                partitions=cat_samples,
                num_partitions=self.num_components)

            # Copy the batch indices n times, as we will need to know
            # these to pull out the appropriate rows within the
            # component partitions.
            batch_raw_indices = tf.reshape(
                tf.tile(tf.range(0, batch_size), [n]), samples_shape)

            # Explanation of the dynamic partitioning below:
            #   batch indices are i.e., [0, 1, 0, 1, 0, 1]
            # Suppose partitions are:
            #     [1 1 0 0 1 1]
            # After partitioning, batch indices are cut as:
            #     [batch_indices[x] for x in 2, 3]
            #     [batch_indices[x] for x in 0, 1, 4, 5]
            # i.e.
            #     [1 1] and [0 0 0 0]
            # Now we sample n=2 from part 0 and n=4 from part 1.
            # For part 0 we want samples from batch entries 1, 1 (samples 0, 1),
            # and for part 1 we want samples from batch entries 0, 0, 0, 0
            #   (samples 0, 1, 2, 3).
            partitioned_batch_indices = tf.dynamic_partition(
                data=batch_raw_indices,
                partitions=cat_samples,
                num_partitions=self.num_components)
            samples_class = [None for _ in range(self.num_components)]

            stream = SeedStream(seed, salt="Mixture")

            for c in range(self.num_components):
                n_class = tf.size(partitioned_samples_indices[c])
                samples_class_c = self.components[c].sample(n_class,
                                                            seed=stream())

                # Pull out the correct batch entries from each index.
                # To do this, we may have to flatten the batch shape.

                # For sample s, batch element b of component c, we get the
                # partitioned batch indices from
                # partitioned_batch_indices[c]; and shift each element by
                # the sample index. The final lookup can be thought of as
                # a matrix gather along locations (s, b) in
                # samples_class_c where the n_class rows correspond to
                # samples within this component and the batch_size columns
                # correspond to batch elements within the component.
                #
                # Thus the lookup index is
                #   lookup[c, i] = batch_size * s[i] + b[c, i]
                # for i = 0 ... n_class[c] - 1.
                lookup_partitioned_batch_indices = (
                    batch_size * tf.range(n_class) +
                    partitioned_batch_indices[c])
                samples_class_c = tf.reshape(
                    samples_class_c,
                    tf.concat([[n_class * batch_size], event_shape], 0))
                samples_class_c = tf.gather(samples_class_c,
                                            lookup_partitioned_batch_indices,
                                            name="samples_class_c_gather")
                samples_class[c] = samples_class_c

            # Stitch back together the samples across the components.
            lhs_flat_ret = tf.dynamic_stitch(
                indices=partitioned_samples_indices, data=samples_class)
            # Reshape back to proper sample, batch, and event shape.
            ret = tf.reshape(
                lhs_flat_ret,
                tf.concat(
                    [samples_shape, self.event_shape_tensor()], 0))
            tensorshape_util.set_shape(
                ret,
                tensorshape_util.concatenate(static_samples_shape,
                                             self.event_shape))
            return ret
Esempio n. 20
0
    def _validate_sample_arg(self, x):
        """Helper which validates sample arg, e.g., input to `log_prob`."""
        with tf.name_scope('validate_sample_arg'):
            x_ndims = (tf.rank(x) if tensorshape_util.rank(x.shape) is None
                       else tensorshape_util.rank(x.shape))
            event_ndims = (tf.size(self.event_shape_tensor())
                           if tensorshape_util.rank(self.event_shape) is None
                           else tensorshape_util.rank(self.event_shape))
            batch_ndims = (tf.size(self._batch_shape_unexpanded)
                           if tensorshape_util.rank(self.batch_shape) is None
                           else tensorshape_util.rank(self.batch_shape))
            expected_batch_event_ndims = batch_ndims + event_ndims

            if (isinstance(x_ndims, int)
                    and isinstance(expected_batch_event_ndims, int)):
                if x_ndims < expected_batch_event_ndims:
                    raise NotImplementedError(
                        'Broadcasting is not supported; too few batch and event dims '
                        '(expected at least {}, saw {}).'.format(
                            expected_batch_event_ndims, x_ndims))
                ndims_assertion = []
            elif self.validate_args:
                ndims_assertion = [
                    assert_util.assert_greater_equal(
                        x_ndims,
                        expected_batch_event_ndims,
                        message=('Broadcasting is not supported; too few '
                                 'batch and event dims.'),
                        name='assert_batch_and_event_ndims_large_enough'),
                ]

            if (tensorshape_util.is_fully_defined(self.batch_shape)
                    and tensorshape_util.is_fully_defined(self.event_shape)):
                expected_batch_event_shape = np.int32(
                    tensorshape_util.concatenate(self.batch_shape,
                                                 self.event_shape))
            else:
                expected_batch_event_shape = tf.concat([
                    self.batch_shape_tensor(),
                    self.event_shape_tensor(),
                ],
                                                       axis=0)

            sample_ndims = x_ndims - expected_batch_event_ndims
            if isinstance(sample_ndims, int):
                sample_ndims = max(sample_ndims, 0)
            if (isinstance(sample_ndims, int) and
                    tensorshape_util.is_fully_defined(x.shape[sample_ndims:])):
                actual_batch_event_shape = np.int32(x.shape[sample_ndims:])
            else:
                sample_ndims = tf.maximum(sample_ndims, 0)
                actual_batch_event_shape = tf.shape(x)[sample_ndims:]

            if (isinstance(expected_batch_event_shape, np.ndarray)
                    and isinstance(actual_batch_event_shape, np.ndarray)):
                if any(expected_batch_event_shape != actual_batch_event_shape):
                    raise NotImplementedError(
                        'Broadcasting is not supported; '
                        'unexpected batch and event shape '
                        '(expected {}, saw {}).'.format(
                            expected_batch_event_shape,
                            actual_batch_event_shape))
                # We need to set the final runtime-assertions to `ndims_assertion` since
                # its possible this assertion was created. We could add a condition to
                # only do so if `self.validate_args == True`, however this is redundant
                # as `ndims_assertion` already encodes this information.
                runtime_assertions = ndims_assertion
            elif self.validate_args:
                # We need to make the `ndims_assertion` a control dep because otherwise
                # TF itself might raise an exception owing to this assertion being
                # ill-defined, ie, one cannot even compare different rank Tensors.
                with tf.control_dependencies(ndims_assertion):
                    shape_assertion = assert_util.assert_equal(
                        expected_batch_event_shape,
                        actual_batch_event_shape,
                        message=('Broadcasting is not supported; '
                                 'unexpected batch and event shape.'),
                        name='assert_batch_and_event_shape_same')
                runtime_assertions = [shape_assertion]
            else:
                runtime_assertions = []

            return runtime_assertions
Esempio n. 21
0
    def __init__(self, permutation, axis=-1, validate_args=False, name=None):
        """Creates the `Permute` bijector.

    Args:
      permutation: An `int`-like vector-shaped `Tensor` representing the
        permutation to apply to the `axis` dimension of the transformed
        `Tensor`.
      axis: Scalar `int` `Tensor` representing the dimension over which to
        `tf.gather`. `axis` must be relative to the end (reading left to right)
        thus must be negative.
        Default value: `-1` (i.e., right-most).
      validate_args: Python `bool` indicating whether arguments should be
        checked for correctness.
      name: Python `str`, name given to ops managed by this object.

    Raises:
      TypeError: if `not dtype_util.is_integer(permutation.dtype)`.
      ValueError: if `permutation` does not contain exactly one of each of
        `{0, 1, ..., d}`.
      NotImplementedError: if `axis` is not known prior to graph execution.
      NotImplementedError: if `axis` is not negative.
    """
        with tf.name_scope(name or "permute") as name:
            axis = tf.convert_to_tensor(axis, name="axis")
            if not dtype_util.is_integer(axis.dtype):
                raise TypeError("axis.dtype ({}) should be `int`-like.".format(
                    dtype_util.name(axis.dtype)))
            permutation = tf.convert_to_tensor(permutation, name="permutation")
            if not dtype_util.is_integer(permutation.dtype):
                raise TypeError(
                    "permutation.dtype ({}) should be `int`-like.".format(
                        dtype_util.name(permutation.dtype)))
            p = tf.get_static_value(permutation)
            if p is not None:
                if set(p) != set(np.arange(p.size)):
                    raise ValueError(
                        "Permutation over `d` must contain exactly one of "
                        "each of `{0, 1, ..., d}`.")
            elif validate_args:
                p, _ = tf.math.top_k(-permutation,
                                     k=tf.shape(permutation)[-1],
                                     sorted=True)
                permutation = distribution_util.with_dependencies([
                    assert_util.assert_equal(
                        -p,
                        tf.range(tf.size(p)),
                        message=(
                            "Permutation over `d` must contain exactly one of "
                            "each of `{0, 1, ..., d}`.")),
                ], permutation)
            axis_ = tf.get_static_value(axis)
            if axis_ is None:
                raise NotImplementedError(
                    "`axis` must be known prior to graph "
                    "execution.")
            elif axis_ >= 0:
                raise NotImplementedError(
                    "`axis` must be relative the rightmost "
                    "dimension, i.e., negative.")
            else:
                forward_min_event_ndims = int(np.abs(axis_))
            self._permutation = permutation
            self._axis = axis
            super(Permute, self).__init__(
                forward_min_event_ndims=forward_min_event_ndims,
                is_constant_jacobian=True,
                validate_args=validate_args,
                name=name)
    def _sample_n(self, n, seed=None):
        with tf.control_dependencies(self._runtime_assertions):
            strm = SeedStream(seed, salt="HiddenMarkovModel")

            num_states = self._num_states

            batch_shape = self.batch_shape_tensor()
            batch_size = tf.reduce_prod(batch_shape)

            # The batch sizes of the underlying initial distributions and
            # transition distributions might not match the batch size of
            # the HMM distribution.
            # As a result we need to ask for more samples from the
            # underlying distributions and then reshape the results into
            # the correct batch size for the HMM.
            init_repeat = (
                tf.reduce_prod(self.batch_shape_tensor()) // tf.reduce_prod(
                    self._initial_distribution.batch_shape_tensor()))
            init_state = self._initial_distribution.sample(n * init_repeat,
                                                           seed=strm())
            init_state = tf.reshape(init_state, [n, batch_size])
            # init_state :: n batch_size

            transition_repeat = (
                tf.reduce_prod(self.batch_shape_tensor()) // tf.reduce_prod(
                    self._transition_distribution.batch_shape_tensor()[:-1]))

            def generate_step(state, _):
                """Take a single step in Markov chain."""

                gen = self._transition_distribution.sample(n *
                                                           transition_repeat,
                                                           seed=strm())
                # gen :: (n * transition_repeat) transition_batch

                new_states = tf.reshape(gen, [n, batch_size, num_states])

                # new_states :: n batch_size num_states

                old_states_one_hot = tf.one_hot(state,
                                                num_states,
                                                dtype=tf.int32)

                # old_states :: n batch_size num_states

                return tf.reduce_sum(old_states_one_hot * new_states, axis=-1)

            def _scan_multiple_steps():
                """Take multiple steps with tf.scan."""
                dummy_index = tf.zeros(self._num_steps - 1, dtype=tf.float32)
                if seed is not None:
                    # Force parallel_iterations to 1 to ensure reproducibility
                    # b/139210489
                    hidden_states = tf.scan(generate_step,
                                            dummy_index,
                                            initializer=init_state,
                                            parallel_iterations=1)
                else:
                    # Invoke default parallel_iterations behavior
                    hidden_states = tf.scan(generate_step,
                                            dummy_index,
                                            initializer=init_state)

                # TODO(b/115618503): add/use prepend_initializer to tf.scan
                return tf.concat([[init_state], hidden_states], axis=0)

            hidden_states = prefer_static.cond(
                self._num_steps > 1, _scan_multiple_steps,
                lambda: init_state[tf.newaxis, ...])

            hidden_one_hot = tf.one_hot(
                hidden_states,
                num_states,
                dtype=self._observation_distribution.dtype)
            # hidden_one_hot :: num_steps n batch_size num_states

            # The observation distribution batch size might not match
            # the required batch size so as with the initial and
            # transition distributions we generate more samples and
            # reshape.
            observation_repeat = (batch_size // tf.reduce_prod(
                self._observation_distribution.batch_shape_tensor()[:-1]))

            possible_observations = self._observation_distribution.sample(
                [self._num_steps, observation_repeat * n], seed=strm())

            inner_shape = self._observation_distribution.event_shape

            # possible_observations :: num_steps (observation_repeat * n)
            #                          observation_batch[:-1] num_states inner_shape

            possible_observations = tf.reshape(
                possible_observations,
                tf.concat([[self._num_steps, n], batch_shape, [num_states],
                           inner_shape],
                          axis=0))

            # possible_observations :: steps n batch_size num_states inner_shape

            hidden_one_hot = tf.reshape(
                hidden_one_hot,
                tf.concat([[self._num_steps, n], batch_shape, [num_states],
                           tf.ones_like(inner_shape)],
                          axis=0))

            # hidden_one_hot :: steps n batch_size num_states "inner_shape"

            observations = tf.reduce_sum(hidden_one_hot *
                                         possible_observations,
                                         axis=-1 - tf.size(inner_shape))

            # observations :: steps n batch_size inner_shape

            observations = distribution_util.move_dimension(
                observations, 0, 1 + tf.size(batch_shape))

            # returned :: n batch_shape steps inner_shape

            return observations
    def __init__(self,
                 initial_distribution,
                 transition_distribution,
                 observation_distribution,
                 num_steps,
                 validate_args=False,
                 allow_nan_stats=True,
                 name="HiddenMarkovModel"):
        """Initialize hidden Markov model.

    Args:
      initial_distribution: A `Categorical`-like instance.
        Determines probability of first hidden state in Markov chain.
        The number of categories must match the number of categories of
        `transition_distribution` as well as both the rightmost batch
        dimension of `transition_distribution` and the rightmost batch
        dimension of `observation_distribution`.
      transition_distribution: A `Categorical`-like instance.
        The rightmost batch dimension indexes the probability distribution
        of each hidden state conditioned on the previous hidden state.
      observation_distribution: A `tfp.distributions.Distribution`-like
        instance.  The rightmost batch dimension indexes the distribution
        of each observation conditioned on the corresponding hidden state.
      num_steps: The number of steps taken in Markov chain. A python `int`.
      validate_args: Python `bool`, default `False`. When `True` distribution
        parameters are checked for validity despite possibly degrading runtime
        performance. When `False` invalid inputs may silently render incorrect
        outputs.
        Default value: `False`.
      allow_nan_stats: Python `bool`, default `True`. When `True`, statistics
        (e.g., mean, mode, variance) use the value "`NaN`" to indicate the
        result is undefined. When `False`, an exception is raised if one or
        more of the statistic's batch members are undefined.
        Default value: `True`.
      name: Python `str` name prefixed to Ops created by this class.
        Default value: "HiddenMarkovModel".

    Raises:
      ValueError: if `num_steps` is not at least 1.
      ValueError: if `initial_distribution` does not have scalar `event_shape`.
      ValueError: if `transition_distribution` does not have scalar
        `event_shape.`
      ValueError: if `transition_distribution` and `observation_distribution`
        are fully defined but don't have matching rightmost dimension.
    """

        parameters = dict(locals())

        # pylint: disable=protected-access
        with tf.name_scope(name) as name:
            self._runtime_assertions = []  # pylint: enable=protected-access

            num_steps = tf.convert_to_tensor(value=num_steps, name="num_steps")
            if validate_args:
                self._runtime_assertions += [
                    assert_util.assert_equal(
                        tf.rank(num_steps),
                        0,
                        message="`num_steps` must be a scalar")
                ]
                self._runtime_assertions += [
                    assert_util.assert_greater_equal(
                        num_steps,
                        1,
                        message="`num_steps` must be at least 1.")
                ]

            self._initial_distribution = initial_distribution
            self._observation_distribution = observation_distribution
            self._transition_distribution = transition_distribution

            if (initial_distribution.event_shape is not None
                    and tensorshape_util.rank(
                        initial_distribution.event_shape) != 0):
                raise ValueError(
                    "`initial_distribution` must have scalar `event_dim`s")
            elif validate_args:
                self._runtime_assertions += [
                    assert_util.assert_equal(
                        tf.shape(initial_distribution.event_shape_tensor())[0],
                        0,
                        message="`initial_distribution` must have scalar"
                        "`event_dim`s")
                ]

            if (transition_distribution.event_shape is not None
                    and tensorshape_util.rank(
                        transition_distribution.event_shape) != 0):
                raise ValueError(
                    "`transition_distribution` must have scalar `event_dim`s")
            elif validate_args:
                self._runtime_assertions += [
                    assert_util.assert_equal(
                        tf.shape(
                            transition_distribution.event_shape_tensor())[0],
                        0,
                        message="`transition_distribution` must have scalar"
                        "`event_dim`s")
                ]

            if (transition_distribution.batch_shape is not None
                    and tensorshape_util.rank(
                        transition_distribution.batch_shape) == 0):
                raise ValueError(
                    "`transition_distribution` can't have scalar batches")
            elif validate_args:
                self._runtime_assertions += [
                    assert_util.assert_greater(
                        tf.size(transition_distribution.batch_shape_tensor()),
                        0,
                        message="`transition_distribution` can't have scalar "
                        "batches")
                ]

            if (observation_distribution.batch_shape is not None
                    and tensorshape_util.rank(
                        observation_distribution.batch_shape) == 0):
                raise ValueError(
                    "`observation_distribution` can't have scalar batches")
            elif validate_args:
                self._runtime_assertions += [
                    assert_util.assert_greater(
                        tf.size(observation_distribution.batch_shape_tensor()),
                        0,
                        message="`observation_distribution` can't have scalar "
                        "batches")
                ]

            # Infer number of hidden states and check consistency
            # between transitions and observations
            with tf.control_dependencies(self._runtime_assertions):
                self._num_states = (
                    (transition_distribution.batch_shape
                     and transition_distribution.batch_shape[-1])
                    or transition_distribution.batch_shape_tensor()[-1])

                observation_states = (
                    (observation_distribution.batch_shape
                     and observation_distribution.batch_shape[-1])
                    or observation_distribution.batch_shape_tensor()[-1])

            if (tf.is_tensor(self._num_states)
                    or tf.is_tensor(observation_states)):
                if validate_args:
                    self._runtime_assertions += [
                        assert_util.assert_equal(
                            self._num_states,
                            observation_states,
                            message="`transition_distribution` and "
                            "`observation_distribution` must agree on "
                            "last dimension of batch size")
                    ]
            elif self._num_states != observation_states:
                raise ValueError("`transition_distribution` and "
                                 "`observation_distribution` must agree on "
                                 "last dimension of batch size")

            self._log_init = _extract_log_probs(self._num_states,
                                                initial_distribution)
            self._log_trans = _extract_log_probs(self._num_states,
                                                 transition_distribution)

            self._num_steps = num_steps
            self._num_states = tf.shape(self._log_init)[-1]

            self._underlying_event_rank = tf.size(
                self._observation_distribution.event_shape_tensor())

            num_steps_ = tf.get_static_value(num_steps)
            if num_steps_ is not None:
                self.static_event_shape = tf.TensorShape([
                    num_steps_
                ]).concatenate(self._observation_distribution.event_shape)
            else:
                self.static_event_shape = None

            with tf.control_dependencies(self._runtime_assertions):
                self.static_batch_shape = tf.broadcast_static_shape(
                    self._initial_distribution.batch_shape,
                    tf.broadcast_static_shape(
                        self._transition_distribution.batch_shape[:-1],
                        self._observation_distribution.batch_shape[:-1]))

            # pylint: disable=protected-access
            super(HiddenMarkovModel, self).__init__(
                dtype=self._observation_distribution.dtype,
                reparameterization_type=reparameterization.NOT_REPARAMETERIZED,
                validate_args=validate_args,
                allow_nan_stats=allow_nan_stats,
                parameters=parameters,
                name=name)
            # pylint: enable=protected-access

            self._parameters = parameters
Esempio n. 24
0
    def _log_prob(self, x):
        if self.input_output_cholesky:
            x_sqrt = x
        else:
            # Complexity: O(nbk**3)
            x_sqrt = tf.linalg.cholesky(x)

        batch_shape = self.batch_shape_tensor()
        event_shape = self.event_shape_tensor()
        x_ndims = tf.rank(x_sqrt)
        num_singleton_axes_to_prepend = (
            tf.maximum(tf.size(batch_shape) + 2, x_ndims) - x_ndims)
        x_with_prepended_singletons_shape = tf.concat([
            tf.ones([num_singleton_axes_to_prepend], dtype=tf.int32),
            tf.shape(x_sqrt)
        ], 0)
        x_sqrt = tf.reshape(x_sqrt, x_with_prepended_singletons_shape)
        ndims = tf.rank(x_sqrt)
        # sample_ndims = ndims - batch_ndims - event_ndims
        sample_ndims = ndims - tf.size(batch_shape) - 2
        sample_shape = tf.shape(x_sqrt)[:sample_ndims]

        # We need to be able to pre-multiply each matrix by its corresponding
        # batch scale matrix. Since a Distribution Tensor supports multiple
        # samples per batch, this means we need to reshape the input matrix `x`
        # so that the first b dimensions are batch dimensions and the last two
        # are of shape [dimension, dimensions*number_of_samples]. Doing these
        # gymnastics allows us to do a batch_solve.
        #
        # After we're done with sqrt_solve (the batch operation) we need to undo
        # this reshaping so what we're left with is a Tensor partitionable by
        # sample, batch, event dimensions.

        # Complexity: O(nbk**2) since transpose must access every element.
        scale_sqrt_inv_x_sqrt = x_sqrt
        perm = tf.concat(
            [tf.range(sample_ndims, ndims),
             tf.range(0, sample_ndims)], 0)
        scale_sqrt_inv_x_sqrt = tf.transpose(a=scale_sqrt_inv_x_sqrt,
                                             perm=perm)
        last_dim_size = (
            tf.cast(self.dimension, dtype=tf.int32) *
            tf.reduce_prod(x_with_prepended_singletons_shape[:sample_ndims]))
        shape = tf.concat([
            x_with_prepended_singletons_shape[sample_ndims:-2],
            [tf.cast(self.dimension, dtype=tf.int32), last_dim_size]
        ],
                          axis=0)
        scale_sqrt_inv_x_sqrt = tf.reshape(scale_sqrt_inv_x_sqrt, shape)

        # Complexity: O(nbM*k) where M is the complexity of the operator solving a
        # vector system. For LinearOperatorLowerTriangular, each solve is O(k**2) so
        # this step has complexity O(nbk^3).
        scale_sqrt_inv_x_sqrt = self.scale_operator.solve(
            scale_sqrt_inv_x_sqrt)

        # Undo make batch-op ready.
        # Complexity: O(nbk**2)
        shape = tf.concat(
            [tf.shape(scale_sqrt_inv_x_sqrt)[:-2], event_shape, sample_shape],
            axis=0)
        scale_sqrt_inv_x_sqrt = tf.reshape(scale_sqrt_inv_x_sqrt, shape)
        perm = tf.concat([
            tf.range(ndims - sample_ndims, ndims),
            tf.range(0, ndims - sample_ndims)
        ], 0)
        scale_sqrt_inv_x_sqrt = tf.transpose(a=scale_sqrt_inv_x_sqrt,
                                             perm=perm)

        # Write V = SS', X = LL'. Then:
        # tr[inv(V) X] = tr[inv(S)' inv(S) L L']
        #              = tr[inv(S) L L' inv(S)']
        #              = tr[(inv(S) L) (inv(S) L)']
        #              = sum_{ik} (inv(S) L)_{ik}**2
        # The second equality follows from the cyclic permutation property.
        # Complexity: O(nbk**2)
        trace_scale_inv_x = tf.reduce_sum(tf.square(scale_sqrt_inv_x_sqrt),
                                          axis=[-2, -1])

        # Complexity: O(nbk)
        half_log_det_x = tf.reduce_sum(tf.math.log(
            tf.linalg.diag_part(x_sqrt)),
                                       axis=[-1])

        # Complexity: O(nbk**2)
        log_prob = ((self.df - self.dimension - 1.) * half_log_det_x -
                    0.5 * trace_scale_inv_x - self.log_normalization())

        # Set shape hints.
        # Try to merge what we know from the input x with what we know from the
        # parameters of this distribution.
        if tensorshape_util.rank(
                x.shape) is not None and tensorshape_util.rank(
                    self.batch_shape) is not None:
            tensorshape_util.set_shape(
                log_prob,
                tf.broadcast_static_shape(x.shape[:-2], self.batch_shape))

        return log_prob
Esempio n. 25
0
    def __init__(self,
                 perm=None,
                 rightmost_transposed_ndims=None,
                 validate_args=False,
                 name='transpose'):
        """Instantiates the `Transpose` bijector.

    Args:
      perm: Positive `int32` vector-shaped `Tensor` representing permutation of
        rightmost dims (for forward transformation).  Note that the `0`th index
        represents the first of the rightmost dims and the largest value must be
        `rightmost_transposed_ndims - 1` and corresponds to `tf.rank(x) - 1`.
        Only one of `perm` and `rightmost_transposed_ndims` can (and must) be
        specified.
        Default value:
        `tf.range(start=rightmost_transposed_ndims, limit=-1, delta=-1)`.
      rightmost_transposed_ndims: Positive `int32` scalar-shaped `Tensor`
        representing the number of rightmost dimensions to permute.
        Only one of `perm` and `rightmost_transposed_ndims` can (and must) be
        specified.
        Default value: `tf.size(perm)`.
      validate_args: Python `bool` indicating whether arguments should be
        checked for correctness.
      name: Python `str` name given to ops managed by this object.

    Raises:
      ValueError: if both or neither `perm` and `rightmost_transposed_ndims` are
        specified.
      NotImplementedError: if `rightmost_transposed_ndims` is not known prior to
        graph execution.
    """
        with tf.name_scope(name) as name:
            if (rightmost_transposed_ndims is None) == (perm is None):
                raise ValueError('Must specify exactly one of '
                                 '`rightmost_transposed_ndims` and `perm`.')
            if rightmost_transposed_ndims is not None:
                rightmost_transposed_ndims = tf.convert_to_tensor(
                    rightmost_transposed_ndims,
                    dtype_hint=np.int32,
                    name='rightmost_transposed_ndims')
                rightmost_transposed_ndims_ = tf.get_static_value(
                    rightmost_transposed_ndims)
                assertions = _maybe_validate_rightmost_transposed_ndims(
                    rightmost_transposed_ndims, validate_args)
                if assertions:
                    with tf.control_dependencies(assertions):
                        rightmost_transposed_ndims = tf.identity(
                            rightmost_transposed_ndims)
                perm_start = (distribution_util.prefer_static_value(
                    rightmost_transposed_ndims) - 1)
                perm = tf.range(start=perm_start,
                                limit=-1,
                                delta=-1,
                                name='perm')
            else:  # perm is not None:
                perm = tf.convert_to_tensor(perm,
                                            dtype_hint=np.int32,
                                            name='perm')
                rightmost_transposed_ndims = tf.size(
                    perm, name='rightmost_transposed_ndims')
                rightmost_transposed_ndims_ = tf.get_static_value(
                    rightmost_transposed_ndims)
                assertions = _maybe_validate_perm(perm, validate_args)
                if assertions:
                    with tf.control_dependencies(assertions):
                        perm = tf.identity(perm)

            # TODO(b/110828604): If bijector base class ever supports dynamic
            # `min_event_ndims`, then this class already works dynamically and the
            # following five lines can be removed.
            if rightmost_transposed_ndims_ is None:
                raise NotImplementedError(
                    '`rightmost_transposed_ndims` must be '
                    'known prior to graph execution.')
            else:
                rightmost_transposed_ndims_ = int(rightmost_transposed_ndims_)

            self._perm = perm
            self._rightmost_transposed_ndims = rightmost_transposed_ndims
            super(Transpose, self).__init__(
                forward_min_event_ndims=rightmost_transposed_ndims_,
                is_constant_jacobian=True,
                validate_args=validate_args,
                name=name)
Esempio n. 26
0
def _replace_event_shape_in_shape_tensor(input_shape, event_shape_in,
                                         event_shape_out, validate_args):
    """Replaces the rightmost dims in a `Tensor` representing a shape.

  Args:
    input_shape: a rank-1 `Tensor` of integers
    event_shape_in: the event shape expected to be present in rightmost dims
      of `shape_in`.
    event_shape_out: the event shape with which to replace `event_shape_in` in
      the rightmost dims of `input_shape`.
    validate_args: Python `bool` indicating whether arguments should
      be checked for correctness.

  Returns:
    output_shape: A rank-1 integer `Tensor` with the same contents as
      `input_shape` except for the event dims, which are replaced with
      `event_shape_out`.
  """
    output_tensorshape, is_validated = _replace_event_shape_in_tensorshape(
        tensorshape_util.constant_value_as_shape(input_shape), event_shape_in,
        event_shape_out)

    # TODO(b/124240153): Remove map(tf.identity, deps) once tf.function
    # correctly supports control_dependencies.
    validation_dependencies = (map(tf.identity,
                                   (event_shape_in,
                                    event_shape_out)) if validate_args else ())

    if (tensorshape_util.is_fully_defined(output_tensorshape)
            and (is_validated or not validate_args)):
        with tf.control_dependencies(validation_dependencies):
            output_shape = tf.convert_to_tensor(
                tensorshape_util.as_list(output_tensorshape),
                name='output_shape',
                dtype_hint=tf.int32)
        return output_shape, output_tensorshape

    with tf.control_dependencies(validation_dependencies):
        event_shape_in_ndims = (
            tf.size(event_shape_in)
            if tensorshape_util.num_elements(event_shape_in.shape) is None else
            tensorshape_util.num_elements(event_shape_in.shape))
        input_non_event_shape, input_event_shape = tf.split(
            input_shape, num_or_size_splits=[-1, event_shape_in_ndims])

    additional_assertions = []
    if is_validated:
        pass
    elif validate_args:
        # Check that `input_event_shape` and `event_shape_in` are compatible in the
        # sense that they have equal entries in any position that isn't a `-1` in
        # `event_shape_in`. Note that our validations at construction time ensure
        # there is at most one such entry in `event_shape_in`.
        mask = event_shape_in >= 0
        explicit_input_event_shape = tf.boolean_mask(input_event_shape,
                                                     mask=mask)
        explicit_event_shape_in = tf.boolean_mask(event_shape_in, mask=mask)
        additional_assertions.append(
            assert_util.assert_equal(
                explicit_input_event_shape,
                explicit_event_shape_in,
                message='Input `event_shape` does not match `event_shape_in`.')
        )
        # We don't explicitly additionally verify
        # `tf.size(input_shape) > tf.size(event_shape_in)` since `tf.split`
        # already makes this assertion.

    with tf.control_dependencies(additional_assertions):
        output_shape = tf.concat([input_non_event_shape, event_shape_out],
                                 axis=0,
                                 name='output_shape')

    return output_shape, output_tensorshape
Esempio n. 27
0
    def _parameter_control_dependencies(self, is_init):
        assertions = []

        # For `logits` and `probs`, we only want to have an assertion on what the
        # user actually passed. For now, we access the underlying categorical's
        # _logits and _probs directly. After the 2019-10-01 deprecation, it would
        # also work to use .logits() and .probs().
        logits = self._categorical._logits
        probs = self._categorical._probs
        outcomes = self._outcomes
        validate_args = self._validate_args

        # Build all shape and dtype checks during the `is_init` call.
        if is_init:

            def validate_equal_last_dim(tensor_a, tensor_b, message):
                event_size_a = tf.compat.dimension_value(tensor_a.shape[-1])
                event_size_b = tf.compat.dimension_value(tensor_b.shape[-1])
                if event_size_a is not None and event_size_b is not None:
                    if event_size_a != event_size_b:
                        raise ValueError(message)
                elif validate_args:
                    return assert_util.assert_equal(tf.shape(tensor_a)[-1],
                                                    tf.shape(tensor_b)[-1],
                                                    message=message)

            message = 'Size of outcomes must be greater than 0.'
            if tensorshape_util.num_elements(outcomes.shape) is not None:
                if tensorshape_util.num_elements(outcomes.shape) == 0:
                    raise ValueError(message)
            elif validate_args:
                assertions.append(
                    tf.assert_greater(tf.size(outcomes), 0, message=message))

            if logits is not None:
                maybe_assert = validate_equal_last_dim(
                    outcomes,
                    # pylint: disable=protected-access
                    self._categorical._logits,
                    # pylint: enable=protected-access
                    message=
                    'Last dimension of outcomes and logits must be equal size.'
                )
                if maybe_assert:
                    assertions.append(maybe_assert)

            if probs is not None:
                maybe_assert = validate_equal_last_dim(
                    outcomes,
                    probs,
                    message=
                    'Last dimension of outcomes and probs must be equal size.')
                if maybe_assert:
                    assertions.append(maybe_assert)

            message = 'Rank of outcomes must be 1.'
            ndims = tensorshape_util.rank(outcomes.shape)
            if ndims is not None:
                if ndims != 1:
                    raise ValueError(message)
            elif validate_args:
                assertions.append(
                    assert_util.assert_rank(outcomes, 1, message=message))

        if not validate_args:
            assert not assertions  # Should never happen.
            return []

        if is_init != tensor_util.is_ref(outcomes):
            assertions.append(
                assert_util.assert_equal(
                    tf.math.is_strictly_increasing(outcomes),
                    True,
                    message='outcomes is not strictly increasing.'))

        return assertions
    def __init__(self,
                 mixture_distribution,
                 components_distribution,
                 reparameterize=False,
                 validate_args=False,
                 allow_nan_stats=True,
                 name="MixtureSameFamily"):
        """Construct a `MixtureSameFamily` distribution.

    Args:
      mixture_distribution: `tfp.distributions.Categorical`-like instance.
        Manages the probability of selecting components. The number of
        categories must match the rightmost batch dimension of the
        `components_distribution`. Must have either scalar `batch_shape` or
        `batch_shape` matching `components_distribution.batch_shape[:-1]`.
      components_distribution: `tfp.distributions.Distribution`-like instance.
        Right-most batch dimension indexes components.
      reparameterize: Python `bool`, default `False`. Whether to reparameterize
        samples of the distribution using implicit reparameterization gradients
        [(Figurnov et al., 2018)][1]. The gradients for the mixture logits are
        equivalent to the ones described by [(Graves, 2016)][2]. The gradients
        for the components parameters are also computed using implicit
        reparameterization (as opposed to ancestral sampling), meaning that
        all components are updated every step.
        Only works when:
          (1) components_distribution is fully reparameterized;
          (2) components_distribution is either a scalar distribution or
          fully factorized (tfd.Independent applied to a scalar distribution);
          (3) batch shape has a known rank.
        Experimental, may be slow and produce infs/NaNs.
      validate_args: Python `bool`, default `False`. When `True` distribution
        parameters are checked for validity despite possibly degrading runtime
        performance. When `False` invalid inputs may silently render incorrect
        outputs.
      allow_nan_stats: Python `bool`, default `True`. When `True`, statistics
        (e.g., mean, mode, variance) use the value "`NaN`" to indicate the
        result is undefined. When `False`, an exception is raised if one or
        more of the statistic's batch members are undefined.
      name: Python `str` name prefixed to Ops created by this class.

    Raises:
      ValueError: `if not dtype_util.is_integer(mixture_distribution.dtype)`.
      ValueError: if mixture_distribution does not have scalar `event_shape`.
      ValueError: if `mixture_distribution.batch_shape` and
        `components_distribution.batch_shape[:-1]` are both fully defined and
        the former is neither scalar nor equal to the latter.
      ValueError: if `mixture_distribution` categories does not equal
        `components_distribution` rightmost batch shape.

    #### References

    [1]: Michael Figurnov, Shakir Mohamed and Andriy Mnih. Implicit
         reparameterization gradients. In _Neural Information Processing
         Systems_, 2018. https://arxiv.org/abs/1805.08498

    [2]: Alex Graves. Stochastic Backpropagation through Mixture Density
         Distributions. _arXiv_, 2016. https://arxiv.org/abs/1607.05690
    """
        parameters = dict(locals())
        with tf.name_scope(name) as name:
            self._mixture_distribution = mixture_distribution
            self._components_distribution = components_distribution
            self._runtime_assertions = []

            s = components_distribution.event_shape_tensor()
            self._event_ndims = tf.compat.dimension_value(s.shape[0])
            if self._event_ndims is None:
                self._event_ndims = tf.size(s)
            self._event_size = tf.reduce_prod(s)

            if not dtype_util.is_integer(mixture_distribution.dtype):
                raise ValueError(
                    "`mixture_distribution.dtype` ({}) is not over integers".
                    format(dtype_util.name(mixture_distribution.dtype)))

            if (tensorshape_util.rank(mixture_distribution.event_shape)
                    is not None and tensorshape_util.rank(
                        mixture_distribution.event_shape) != 0):
                raise ValueError(
                    "`mixture_distribution` must have scalar `event_dim`s")
            elif validate_args:
                self._runtime_assertions += [
                    assert_util.assert_equal(
                        tf.size(mixture_distribution.event_shape_tensor()),
                        0,
                        message=
                        "`mixture_distribution` must have scalar `event_dim`s"
                    ),
                ]

            mdbs = mixture_distribution.batch_shape
            cdbs = tensorshape_util.with_rank_at_least(
                components_distribution.batch_shape, 1)[:-1]
            if tensorshape_util.is_fully_defined(
                    mdbs) and tensorshape_util.is_fully_defined(cdbs):
                if tensorshape_util.rank(mdbs) != 0 and mdbs != cdbs:
                    raise ValueError(
                        "`mixture_distribution.batch_shape` (`{}`) is not "
                        "compatible with `components_distribution.batch_shape` "
                        "(`{}`)".format(tensorshape_util.as_list(mdbs),
                                        tensorshape_util.as_list(cdbs)))
            elif validate_args:
                mdbs = mixture_distribution.batch_shape_tensor()
                cdbs = components_distribution.batch_shape_tensor()[:-1]
                self._runtime_assertions += [
                    assert_util.assert_equal(
                        distribution_utils.pick_vector(
                            mixture_distribution.is_scalar_batch(), cdbs,
                            mdbs),
                        cdbs,
                        message=
                        ("`mixture_distribution.batch_shape` is not "
                         "compatible with `components_distribution.batch_shape`"
                         ))
                ]

            mixture_dist_param = (mixture_distribution.probs
                                  if mixture_distribution.logits is None else
                                  mixture_distribution.logits)
            km = tf.compat.dimension_value(
                tensorshape_util.with_rank_at_least(mixture_dist_param.shape,
                                                    1)[-1])
            kc = tf.compat.dimension_value(
                tensorshape_util.with_rank_at_least(
                    components_distribution.batch_shape, 1)[-1])
            if km is not None and kc is not None and km != kc:
                raise ValueError(
                    "`mixture_distribution components` ({}) does not "
                    "equal `components_distribution.batch_shape[-1]` "
                    "({})".format(km, kc))
            elif validate_args:
                km = tf.shape(mixture_dist_param)[-1]
                kc = components_distribution.batch_shape_tensor()[-1]
                self._runtime_assertions += [
                    assert_util.assert_equal(
                        km,
                        kc,
                        message=(
                            "`mixture_distribution components` does not equal "
                            "`components_distribution.batch_shape[-1:]`")),
                ]
            elif km is None:
                km = tf.shape(mixture_dist_param)[-1]

            self._num_components = km

            self._reparameterize = reparameterize
            if reparameterize:
                # Note: tfd.Independent passes through the reparameterization type hence
                # we do not need separate logic for Independent.
                if (self._components_distribution.reparameterization_type !=
                        reparameterization.FULLY_REPARAMETERIZED):
                    raise ValueError("Cannot reparameterize a mixture of "
                                     "non-reparameterized components.")
                reparameterization_type = reparameterization.FULLY_REPARAMETERIZED
            else:
                reparameterization_type = reparameterization.NOT_REPARAMETERIZED

            super(MixtureSameFamily, self).__init__(
                dtype=self._components_distribution.dtype,
                reparameterization_type=reparameterization_type,
                validate_args=validate_args,
                allow_nan_stats=allow_nan_stats,
                parameters=parameters,
                name=name)