Пример #1
0
 def _log_prob(self, x, **kwargs):
     batch_ndims = prefer_static.rank_from_shape(
         self.distribution.batch_shape_tensor,
         self.distribution.batch_shape)
     extra_sample_ndims = prefer_static.rank_from_shape(self.sample_shape)
     event_ndims = prefer_static.rank_from_shape(
         self.distribution.event_shape_tensor,
         self.distribution.event_shape)
     ndims = prefer_static.rank(x)
     # (1) Expand x's dims.
     d = ndims - batch_ndims - extra_sample_ndims - event_ndims
     x = tf.reshape(x,
                    shape=tf.pad(
                        tf.shape(x),
                        paddings=[[prefer_static.maximum(0, -d), 0]],
                        constant_values=1))
     sample_ndims = prefer_static.maximum(0, d)
     # (2) Transpose x's dims.
     sample_dims = prefer_static.range(0, sample_ndims)
     batch_dims = prefer_static.range(sample_ndims,
                                      sample_ndims + batch_ndims)
     extra_sample_dims = prefer_static.range(
         sample_ndims + batch_ndims,
         sample_ndims + batch_ndims + extra_sample_ndims)
     event_dims = prefer_static.range(
         sample_ndims + batch_ndims + extra_sample_ndims, ndims)
     perm = prefer_static.concat(
         [sample_dims, extra_sample_dims, batch_dims, event_dims], axis=0)
     x = tf.transpose(a=x, perm=perm)
     # (3) Compute x's log_prob.
     lp = self.distribution.log_prob(x, **kwargs)
     # (4) Make the final reduction in x.
     axis = prefer_static.range(sample_ndims,
                                sample_ndims + extra_sample_ndims)
     return tf.reduce_sum(lp, axis=axis)
Пример #2
0
    def _assert_compatible_shape(self, index, sample_shape, samples):
        requested_shape, _ = self._expand_sample_shape_to_vector(
            tf.convert_to_tensor(sample_shape, dtype=tf.int32),
            name='requested_shape')
        actual_shape = prefer_static.shape(samples)
        actual_rank = prefer_static.rank_from_shape(actual_shape)
        requested_rank = prefer_static.rank_from_shape(requested_shape)

        # We test for two properties we expect of yielded distributions:
        # (1) The rank of the tensor of generated samples must be at least
        #     as large as the rank requested.
        # (2) The requested shape must be a prefix of the shape of the
        #     generated tensor of samples.
        # We attempt to perform test (1) statically first.
        # We don't need to do this explicitly for test (2) because
        # `assert_equal` evaluates statically if it can.
        static_actual_rank = tf.get_static_value(actual_rank)
        static_requested_rank = tf.get_static_value(requested_rank)

        assertion_message = ('Samples yielded by distribution #{} are not '
                             'consistent with `sample_shape` passed to '
                             '`JointDistributionCoroutine` '
                             'distribution.'.format(index))

        # TODO Remove this static check (b/138738650)
        if (static_actual_rank is not None
                and static_requested_rank is not None):
            # We're able to statically check the rank
            if static_actual_rank < static_requested_rank:
                raise ValueError(assertion_message)
            else:
                control_dependencies = []
        else:
            # We're not able to statically check the rank
            control_dependencies = [
                assert_util.assert_greater_equal(actual_rank,
                                                 requested_rank,
                                                 message=assertion_message)
            ]

        with tf.control_dependencies(control_dependencies):
            trimmed_actual_shape = actual_shape[:requested_rank]

        control_dependencies = [
            assert_util.assert_equal(requested_shape,
                                     trimmed_actual_shape,
                                     message=assertion_message)
        ]

        return control_dependencies
Пример #3
0
 def _reshape_part(part, dtype, event_shape):
     part = tf.cast(part, dtype)
     static_rank = tf.get_static_value(
         ps.rank_from_shape(event_shape))
     if static_rank == 1:
         return part
     new_shape = ps.concat([ps.shape(part)[:-1], event_shape],
                           axis=-1)
     return tf.reshape(part, ps.cast(new_shape, tf.int32))
Пример #4
0
 def _event_shape_tensor(self):
     with tf.control_dependencies(self._runtime_assertions):
         batch_shape = self.distribution.batch_shape_tensor()
         batch_ndims = prefer_static.rank_from_shape(
             batch_shape, self.distribution.batch_shape)
         return prefer_static.concat([
             batch_shape[batch_ndims - self.reinterpreted_batch_ndims:],
             self.distribution.event_shape_tensor(),
         ],
                                     axis=0)
Пример #5
0
 def _sample_n(self, n, seed, **kwargs):
     fake_sample_ndims = prefer_static.rank_from_shape(self.sample_shape)
     event_ndims = prefer_static.rank_from_shape(
         self.distribution.event_shape_tensor,
         self.distribution.event_shape)
     batch_ndims = prefer_static.rank_from_shape(
         self.distribution.batch_shape_tensor,
         self.distribution.batch_shape)
     perm = prefer_static.concat([
         [0],
         prefer_static.range(1 + fake_sample_ndims,
                             1 + fake_sample_ndims + batch_ndims),
         prefer_static.range(1, 1 + fake_sample_ndims),
         prefer_static.range(
             1 + fake_sample_ndims + batch_ndims,
             1 + fake_sample_ndims + batch_ndims + event_ndims),
     ],
                                 axis=0)
     x = self.distribution.sample(prefer_static.concat(
         [[n], self.sample_shape], axis=0),
                                  seed=seed,
                                  **kwargs)
     return tf.transpose(a=x, perm=perm)
Пример #6
0
 def _fn(self, **kwargs):
     """Implements summary statistic, eg, mean, stddev, mode."""
     x = getattr(self.distribution, attr)(**kwargs)
     shape = prefer_static.concat([
         self.distribution.batch_shape_tensor(),
         prefer_static.ones(prefer_static.rank_from_shape(
             self.sample_shape),
                            dtype=self.sample_shape.dtype),
         self.distribution.event_shape_tensor(),
     ],
                                  axis=0)
     x = tf.reshape(x, shape=shape)
     shape = prefer_static.concat([
         self.distribution.batch_shape_tensor(),
         self.sample_shape,
         self.distribution.event_shape_tensor(),
     ],
                                  axis=0)
     return tf.broadcast_to(x, shape)
Пример #7
0
 def _make_runtime_assertions(self, distribution, reinterpreted_batch_ndims,
                              validate_args):
     assertions = []
     static_reinterpreted_batch_ndims = tf.get_static_value(
         reinterpreted_batch_ndims)
     batch_ndims = tensorshape_util.rank(distribution.batch_shape)
     if batch_ndims is not None and static_reinterpreted_batch_ndims is not None:
         if static_reinterpreted_batch_ndims > batch_ndims:
             raise ValueError("reinterpreted_batch_ndims({}) cannot exceed "
                              "distribution.batch_ndims({})".format(
                                  static_reinterpreted_batch_ndims,
                                  batch_ndims))
     elif validate_args:
         assertions.append(
             assert_util.assert_less_equal(
                 reinterpreted_batch_ndims,
                 prefer_static.rank_from_shape(
                     distribution.batch_shape_tensor,
                     distribution.batch_shape),
                 message=("reinterpreted_batch_ndims cannot exceed "
                          "distribution.batch_ndims")))
     return assertions
def _is_scalar_from_shape_tensor(shape):
    """Returns `True` `Tensor` if `Tensor` shape implies a scalar."""
    return prefer_static.equal(prefer_static.rank_from_shape(shape), 0)
    def __init__(self,
                 distribution,
                 bijector,
                 batch_shape=None,
                 event_shape=None,
                 kwargs_split_fn=_default_kwargs_split_fn,
                 validate_args=False,
                 parameters=None,
                 name=None):
        """Construct a Transformed Distribution.

    Args:
      distribution: The base distribution instance to transform. Typically an
        instance of `Distribution`.
      bijector: The object responsible for calculating the transformation.
        Typically an instance of `Bijector`.
      batch_shape: `integer` vector `Tensor` which overrides `distribution`
        `batch_shape`; valid only if `distribution.is_scalar_batch()`.
      event_shape: `integer` vector `Tensor` which overrides `distribution`
        `event_shape`; valid only if `distribution.is_scalar_event()`.
      kwargs_split_fn: Python `callable` which takes a kwargs `dict` and returns
        a tuple of kwargs `dict`s for each of the `distribution` and `bijector`
        parameters respectively.
        Default value: `_default_kwargs_split_fn` (i.e.,
            `lambda kwargs: (kwargs.get('distribution_kwargs', {}),
                             kwargs.get('bijector_kwargs', {}))`)
      validate_args: Python `bool`, default `False`. When `True` distribution
        parameters are checked for validity despite possibly degrading runtime
        performance. When `False` invalid inputs may silently render incorrect
        outputs.
      parameters: Locals dict captured by subclass constructor, to be used for
        copy/slice re-instantiation operations.
      name: Python `str` name prefixed to Ops created by this class. Default:
        `bijector.name + distribution.name`.
    """
        parameters = dict(locals()) if parameters is None else parameters
        name = name or (("" if bijector is None else bijector.name) +
                        (distribution.name or ""))
        with tf.name_scope(name) as name:
            self._kwargs_split_fn = (_default_kwargs_split_fn
                                     if kwargs_split_fn is None else
                                     kwargs_split_fn)
            # For convenience we define some handy constants.
            self._zero = tf.constant(0, dtype=tf.int32, name="zero")
            self._empty = tf.constant([], dtype=tf.int32, name="empty")

            # We will keep track of a static and dynamic version of
            # self._is_{batch,event}_override. This way we can do more prior to graph
            # execution, including possibly raising Python exceptions.

            self._override_batch_shape = self._maybe_validate_shape_override(
                batch_shape, distribution.is_scalar_batch(), validate_args,
                "batch_shape")
            self._is_batch_override = prefer_static.logical_not(
                prefer_static.equal(
                    prefer_static.rank_from_shape(self._override_batch_shape),
                    self._zero))
            self._is_maybe_batch_override = bool(
                tf.get_static_value(self._override_batch_shape) is None
                or tf.get_static_value(self._override_batch_shape).size != 0)

            self._override_event_shape = self._maybe_validate_shape_override(
                event_shape, distribution.is_scalar_event(), validate_args,
                "event_shape")
            self._is_event_override = prefer_static.logical_not(
                prefer_static.equal(
                    prefer_static.rank_from_shape(self._override_event_shape),
                    self._zero))
            self._is_maybe_event_override = bool(
                tf.get_static_value(self._override_event_shape) is None
                or tf.get_static_value(self._override_event_shape).size != 0)

            # To convert a scalar distribution into a multivariate distribution we
            # will draw dims from the sample dims, which are otherwise iid. This is
            # easy to do except in the case that the base distribution has batch dims
            # and we're overriding event shape. When that case happens the event dims
            # will incorrectly be to the left of the batch dims. In this case we'll
            # cyclically permute left the new dims.
            self._needs_rotation = prefer_static.reduce_all([
                self._is_event_override,
                prefer_static.logical_not(self._is_batch_override),
                prefer_static.logical_not(distribution.is_scalar_batch())
            ])
            override_event_ndims = prefer_static.rank_from_shape(
                self._override_event_shape)
            self._rotate_ndims = _pick_scalar_condition(
                self._needs_rotation, override_event_ndims, 0)
            # We'll be reducing the head dims (if at all), i.e., this will be []
            # if we don't need to reduce.
            self._reduce_event_indices = prefer_static.range(
                self._rotate_ndims - override_event_ndims, self._rotate_ndims)

        self._distribution = distribution
        self._bijector = bijector
        super(TransformedDistribution, self).__init__(
            dtype=self._distribution.dtype,
            reparameterization_type=self._distribution.reparameterization_type,
            validate_args=validate_args,
            allow_nan_stats=self._distribution.allow_nan_stats,
            parameters=parameters,
            name=name)
Пример #10
0
def _kl_independent(a, b, name="kl_independent"):
    """Batched KL divergence `KL(a || b)` for Independent distributions.

  We can leverage the fact that
  ```
  KL(Independent(a) || Independent(b)) = sum(KL(a || b))
  ```
  where the sum is over the `reinterpreted_batch_ndims`.

  Args:
    a: Instance of `Independent`.
    b: Instance of `Independent`.
    name: (optional) name to use for created ops. Default "kl_independent".

  Returns:
    Batchwise `KL(a || b)`.

  Raises:
    ValueError: If the event space for `a` and `b`, or their underlying
      distributions don't match.
  """
    p = a.distribution
    q = b.distribution

    # The KL between any two (non)-batched distributions is a scalar.
    # Given that the KL between two factored distributions is the sum, i.e.
    # KL(p1(x)p2(y) || q1(x)q2(y)) = KL(p1 || q1) + KL(q1 || q2), we compute
    # KL(p || q) and do a `reduce_sum` on the reinterpreted batch dimensions.
    if (tensorshape_util.is_fully_defined(a.event_shape)
            and tensorshape_util.is_fully_defined(b.event_shape)):
        if a.event_shape == b.event_shape:
            if p.event_shape == q.event_shape:
                num_reduce_dims = (tensorshape_util.rank(a.event_shape) -
                                   tensorshape_util.rank(p.event_shape))
                reduce_dims = [-i - 1 for i in range(0, num_reduce_dims)]

                return tf.reduce_sum(kullback_leibler.kl_divergence(p,
                                                                    q,
                                                                    name=name),
                                     axis=reduce_dims)
            else:
                raise NotImplementedError(
                    "KL between Independents with different "
                    "event shapes not supported.")
        else:
            raise ValueError("Event shapes do not match.")
    else:
        with tf.control_dependencies([
                assert_util.assert_equal(a.event_shape_tensor(),
                                         b.event_shape_tensor()),
                assert_util.assert_equal(p.event_shape_tensor(),
                                         q.event_shape_tensor())
        ]):
            num_reduce_dims = (prefer_static.rank_from_shape(
                a.event_shape_tensor, a.event_shape) -
                               prefer_static.rank_from_shape(
                                   p.event_shape_tensor, a.event_shape))
            reduce_dims = prefer_static.range(-num_reduce_dims - 1, -1, 1)
            return tf.reduce_sum(kullback_leibler.kl_divergence(p,
                                                                q,
                                                                name=name),
                                 axis=reduce_dims)
Пример #11
0
 def _get_default_reinterpreted_batch_ndims(self, distribution):
     """Computes the default value for reinterpreted_batch_ndim __init__ arg."""
     ndims = prefer_static.rank_from_shape(distribution.batch_shape_tensor,
                                           distribution.batch_shape)
     return prefer_static.maximum(0, ndims - 1)
Пример #12
0
 def _batch_shape_tensor(self):
     with tf.control_dependencies(self._runtime_assertions):
         batch_shape = self.distribution.batch_shape_tensor()
         batch_ndims = prefer_static.rank_from_shape(
             batch_shape, self.distribution.batch_shape)
         return batch_shape[:batch_ndims - self.reinterpreted_batch_ndims]