def _log_prob(self, x, **kwargs): batch_ndims = prefer_static.rank_from_shape( self.distribution.batch_shape_tensor, self.distribution.batch_shape) extra_sample_ndims = prefer_static.rank_from_shape(self.sample_shape) event_ndims = prefer_static.rank_from_shape( self.distribution.event_shape_tensor, self.distribution.event_shape) ndims = prefer_static.rank(x) # (1) Expand x's dims. d = ndims - batch_ndims - extra_sample_ndims - event_ndims x = tf.reshape(x, shape=tf.pad( tf.shape(x), paddings=[[prefer_static.maximum(0, -d), 0]], constant_values=1)) sample_ndims = prefer_static.maximum(0, d) # (2) Transpose x's dims. sample_dims = prefer_static.range(0, sample_ndims) batch_dims = prefer_static.range(sample_ndims, sample_ndims + batch_ndims) extra_sample_dims = prefer_static.range( sample_ndims + batch_ndims, sample_ndims + batch_ndims + extra_sample_ndims) event_dims = prefer_static.range( sample_ndims + batch_ndims + extra_sample_ndims, ndims) perm = prefer_static.concat( [sample_dims, extra_sample_dims, batch_dims, event_dims], axis=0) x = tf.transpose(a=x, perm=perm) # (3) Compute x's log_prob. lp = self.distribution.log_prob(x, **kwargs) # (4) Make the final reduction in x. axis = prefer_static.range(sample_ndims, sample_ndims + extra_sample_ndims) return tf.reduce_sum(lp, axis=axis)
def _assert_compatible_shape(self, index, sample_shape, samples): requested_shape, _ = self._expand_sample_shape_to_vector( tf.convert_to_tensor(sample_shape, dtype=tf.int32), name='requested_shape') actual_shape = prefer_static.shape(samples) actual_rank = prefer_static.rank_from_shape(actual_shape) requested_rank = prefer_static.rank_from_shape(requested_shape) # We test for two properties we expect of yielded distributions: # (1) The rank of the tensor of generated samples must be at least # as large as the rank requested. # (2) The requested shape must be a prefix of the shape of the # generated tensor of samples. # We attempt to perform test (1) statically first. # We don't need to do this explicitly for test (2) because # `assert_equal` evaluates statically if it can. static_actual_rank = tf.get_static_value(actual_rank) static_requested_rank = tf.get_static_value(requested_rank) assertion_message = ('Samples yielded by distribution #{} are not ' 'consistent with `sample_shape` passed to ' '`JointDistributionCoroutine` ' 'distribution.'.format(index)) # TODO Remove this static check (b/138738650) if (static_actual_rank is not None and static_requested_rank is not None): # We're able to statically check the rank if static_actual_rank < static_requested_rank: raise ValueError(assertion_message) else: control_dependencies = [] else: # We're not able to statically check the rank control_dependencies = [ assert_util.assert_greater_equal( actual_rank, requested_rank, message=assertion_message) ] with tf.control_dependencies(control_dependencies): trimmed_actual_shape = actual_shape[:requested_rank] control_dependencies = [ assert_util.assert_equal( requested_shape, trimmed_actual_shape, message=assertion_message) ] return control_dependencies
def _reshape_part(part, dtype, event_shape): part = tf.cast(part, dtype) static_rank = tf.get_static_value( ps.rank_from_shape(event_shape)) if static_rank == 1: return part new_shape = ps.concat([ps.shape(part)[:-1], event_shape], axis=-1) return tf.reshape(part, ps.cast(new_shape, tf.int32))
def _event_shape_tensor(self): with tf.control_dependencies(self._runtime_assertions): batch_shape = self.distribution.batch_shape_tensor() batch_ndims = prefer_static.rank_from_shape( batch_shape, self.distribution.batch_shape) return prefer_static.concat([ batch_shape[batch_ndims - self.reinterpreted_batch_ndims:], self.distribution.event_shape_tensor(), ], axis=0)
def _sample_n(self, n, seed, **kwargs): fake_sample_ndims = prefer_static.rank_from_shape(self.sample_shape) event_ndims = prefer_static.rank_from_shape( self.distribution.event_shape_tensor, self.distribution.event_shape) batch_ndims = prefer_static.rank_from_shape( self.distribution.batch_shape_tensor, self.distribution.batch_shape) perm = prefer_static.concat([ [0], prefer_static.range(1 + fake_sample_ndims, 1 + fake_sample_ndims + batch_ndims), prefer_static.range(1, 1 + fake_sample_ndims), prefer_static.range( 1 + fake_sample_ndims + batch_ndims, 1 + fake_sample_ndims + batch_ndims + event_ndims), ], axis=0) x = self.distribution.sample(prefer_static.concat( [[n], self.sample_shape], axis=0), seed=seed, **kwargs) return tf.transpose(a=x, perm=perm)
def _fn(self, **kwargs): """Implements summary statistic, eg, mean, stddev, mode.""" x = getattr(self.distribution, attr)(**kwargs) shape = prefer_static.concat([ self.distribution.batch_shape_tensor(), prefer_static.ones(prefer_static.rank_from_shape( self.sample_shape), dtype=self.sample_shape.dtype), self.distribution.event_shape_tensor(), ], axis=0) x = tf.reshape(x, shape=shape) shape = prefer_static.concat([ self.distribution.batch_shape_tensor(), self.sample_shape, self.distribution.event_shape_tensor(), ], axis=0) return tf.broadcast_to(x, shape)
def _make_runtime_assertions(self, distribution, reinterpreted_batch_ndims, validate_args): assertions = [] static_reinterpreted_batch_ndims = tf.get_static_value( reinterpreted_batch_ndims) batch_ndims = tensorshape_util.rank(distribution.batch_shape) if batch_ndims is not None and static_reinterpreted_batch_ndims is not None: if static_reinterpreted_batch_ndims > batch_ndims: raise ValueError("reinterpreted_batch_ndims({}) cannot exceed " "distribution.batch_ndims({})".format( static_reinterpreted_batch_ndims, batch_ndims)) elif validate_args: assertions.append( assert_util.assert_less_equal( reinterpreted_batch_ndims, prefer_static.rank_from_shape( distribution.batch_shape_tensor, distribution.batch_shape), message=("reinterpreted_batch_ndims cannot exceed " "distribution.batch_ndims"))) return assertions
def _is_scalar_from_shape_tensor(shape): """Returns `True` `Tensor` if `Tensor` shape implies a scalar.""" return prefer_static.equal(prefer_static.rank_from_shape(shape), 0)
def __init__(self, distribution, bijector, batch_shape=None, event_shape=None, kwargs_split_fn=_default_kwargs_split_fn, validate_args=False, parameters=None, name=None): """Construct a Transformed Distribution. Args: distribution: The base distribution instance to transform. Typically an instance of `Distribution`. bijector: The object responsible for calculating the transformation. Typically an instance of `Bijector`. batch_shape: `integer` vector `Tensor` which overrides `distribution` `batch_shape`; valid only if `distribution.is_scalar_batch()`. event_shape: `integer` vector `Tensor` which overrides `distribution` `event_shape`; valid only if `distribution.is_scalar_event()`. kwargs_split_fn: Python `callable` which takes a kwargs `dict` and returns a tuple of kwargs `dict`s for each of the `distribution` and `bijector` parameters respectively. Default value: `_default_kwargs_split_fn` (i.e., `lambda kwargs: (kwargs.get('distribution_kwargs', {}), kwargs.get('bijector_kwargs', {}))`) validate_args: Python `bool`, default `False`. When `True` distribution parameters are checked for validity despite possibly degrading runtime performance. When `False` invalid inputs may silently render incorrect outputs. parameters: Locals dict captured by subclass constructor, to be used for copy/slice re-instantiation operations. name: Python `str` name prefixed to Ops created by this class. Default: `bijector.name + distribution.name`. """ parameters = dict(locals()) if parameters is None else parameters name = name or (("" if bijector is None else bijector.name) + (distribution.name or "")) with tf.name_scope(name) as name: self._kwargs_split_fn = (_default_kwargs_split_fn if kwargs_split_fn is None else kwargs_split_fn) # For convenience we define some handy constants. self._zero = tf.constant(0, dtype=tf.int32, name="zero") self._empty = tf.constant([], dtype=tf.int32, name="empty") # We will keep track of a static and dynamic version of # self._is_{batch,event}_override. This way we can do more prior to graph # execution, including possibly raising Python exceptions. self._override_batch_shape = self._maybe_validate_shape_override( batch_shape, distribution.is_scalar_batch(), validate_args, "batch_shape") self._is_batch_override = prefer_static.logical_not( prefer_static.equal( prefer_static.rank_from_shape(self._override_batch_shape), self._zero)) self._is_maybe_batch_override = bool( tf.get_static_value(self._override_batch_shape) is None or tf.get_static_value(self._override_batch_shape).size != 0) self._override_event_shape = self._maybe_validate_shape_override( event_shape, distribution.is_scalar_event(), validate_args, "event_shape") self._is_event_override = prefer_static.logical_not( prefer_static.equal( prefer_static.rank_from_shape(self._override_event_shape), self._zero)) self._is_maybe_event_override = bool( tf.get_static_value(self._override_event_shape) is None or tf.get_static_value(self._override_event_shape).size != 0) # To convert a scalar distribution into a multivariate distribution we # will draw dims from the sample dims, which are otherwise iid. This is # easy to do except in the case that the base distribution has batch dims # and we're overriding event shape. When that case happens the event dims # will incorrectly be to the left of the batch dims. In this case we'll # cyclically permute left the new dims. self._needs_rotation = prefer_static.reduce_all([ self._is_event_override, prefer_static.logical_not(self._is_batch_override), prefer_static.logical_not(distribution.is_scalar_batch()) ]) override_event_ndims = prefer_static.rank_from_shape( self._override_event_shape) self._rotate_ndims = _pick_scalar_condition( self._needs_rotation, override_event_ndims, 0) # We'll be reducing the head dims (if at all), i.e., this will be [] # if we don't need to reduce. self._reduce_event_indices = prefer_static.range( self._rotate_ndims - override_event_ndims, self._rotate_ndims) self._distribution = distribution self._bijector = bijector super(TransformedDistribution, self).__init__( dtype=self._distribution.dtype, reparameterization_type=self._distribution.reparameterization_type, validate_args=validate_args, allow_nan_stats=self._distribution.allow_nan_stats, parameters=parameters, name=name)
def _kl_independent(a, b, name="kl_independent"): """Batched KL divergence `KL(a || b)` for Independent distributions. We can leverage the fact that ``` KL(Independent(a) || Independent(b)) = sum(KL(a || b)) ``` where the sum is over the `reinterpreted_batch_ndims`. Args: a: Instance of `Independent`. b: Instance of `Independent`. name: (optional) name to use for created ops. Default "kl_independent". Returns: Batchwise `KL(a || b)`. Raises: ValueError: If the event space for `a` and `b`, or their underlying distributions don't match. """ p = a.distribution q = b.distribution # The KL between any two (non)-batched distributions is a scalar. # Given that the KL between two factored distributions is the sum, i.e. # KL(p1(x)p2(y) || q1(x)q2(y)) = KL(p1 || q1) + KL(q1 || q2), we compute # KL(p || q) and do a `reduce_sum` on the reinterpreted batch dimensions. if (tensorshape_util.is_fully_defined(a.event_shape) and tensorshape_util.is_fully_defined(b.event_shape)): if a.event_shape == b.event_shape: if p.event_shape == q.event_shape: num_reduce_dims = (tensorshape_util.rank(a.event_shape) - tensorshape_util.rank(p.event_shape)) reduce_dims = [-i - 1 for i in range(0, num_reduce_dims)] return tf.reduce_sum(kullback_leibler.kl_divergence(p, q, name=name), axis=reduce_dims) else: raise NotImplementedError( "KL between Independents with different " "event shapes not supported.") else: raise ValueError("Event shapes do not match.") else: with tf.control_dependencies([ assert_util.assert_equal(a.event_shape_tensor(), b.event_shape_tensor()), assert_util.assert_equal(p.event_shape_tensor(), q.event_shape_tensor()) ]): num_reduce_dims = (prefer_static.rank_from_shape( a.event_shape_tensor, a.event_shape) - prefer_static.rank_from_shape( p.event_shape_tensor, a.event_shape)) reduce_dims = prefer_static.range(-num_reduce_dims - 1, -1, 1) return tf.reduce_sum(kullback_leibler.kl_divergence(p, q, name=name), axis=reduce_dims)
def _get_default_reinterpreted_batch_ndims(self, distribution): """Computes the default value for reinterpreted_batch_ndim __init__ arg.""" ndims = prefer_static.rank_from_shape(distribution.batch_shape_tensor, distribution.batch_shape) return prefer_static.maximum(0, ndims - 1)
def _batch_shape_tensor(self): with tf.control_dependencies(self._runtime_assertions): batch_shape = self.distribution.batch_shape_tensor() batch_ndims = prefer_static.rank_from_shape( batch_shape, self.distribution.batch_shape) return batch_shape[:batch_ndims - self.reinterpreted_batch_ndims]