def __init__(self, values_with_sample_dim, batch_ndims=0, validate_args=False, name=None): """Initializes an empirical distribution with a list of samples. Args: values_with_sample_dim: nested structure of `Tensor`s, each of shape prefixed by `[num_samples, B1, ..., Bn]`, where `num_samples` as well as `B1, ..., Bn` are batch dimensions shared across all `Tensor`s. batch_ndims: optional scalar int `Tensor`, or structure matching `values_with_sample_dim` of scalar int `Tensor`s, specifying the number of batch dimensions. Used to determine the batch and event shapes of the distribution. Default value: `0`. validate_args: Python `bool` indicating whether to perform runtime checks that may have performance cost. Default value: `False`. name: Python `str` name for ops created by this distribution. """ parameters = dict(locals()) with tf.name_scope(name or 'DeterministicEmpirical') as name: # Ensure we don't break if the passed-in structures are externally # mutated. values_with_sample_dim = _copy_structure(values_with_sample_dim) batch_ndims = _copy_structure(batch_ndims) # Prevent tf.Module from wrapping passed-in values, because the # wrapper breaks JointDistributionNamed (and maybe other JDs). Instead, we # save a separate ref to the input that is used only by tf.Module # tracking. self._values_for_tracking = values_with_sample_dim self._values_with_sample_dim = self._no_dependency( values_with_sample_dim) if not tf.nest.is_nested(batch_ndims): batch_ndims = tf.nest.map_structure(lambda _: batch_ndims, values_with_sample_dim) self._batch_ndims = batch_ndims self._max_num_samples = prefer_static.reduce_min([ prefer_static.size0(x) for x in tf.nest.flatten(values_with_sample_dim) ]) super(DeterministicEmpirical, self).__init__(dtype=tf.nest.map_structure( lambda x: x.dtype, self.values_with_sample_dim), reparameterization_type=reparameterization. FULLY_REPARAMETERIZED, validate_args=validate_args, allow_nan_stats=True, name=name) self._parameters = self._no_dependency(parameters)
def _update_forward_min_event_ndims( bij, downstream_quantities, get_forward_min_event_ndims=lambda b: b.forward_min_event_ndims, get_inverse_min_event_ndims=lambda b: b.inverse_min_event_ndims, inverse_event_ndims_fn=lambda b, nd: b.inverse_event_ndims(nd)): """Step backwards through the graph to infer `forward_min_event_ndims`. Args: bij: local tfb.Bijector instance at the current graph node. downstream_quantities: Instance of `MinEventNdimsDownstreamQuantities` namedtuple, containing event_ndims that satisfy the bijector(s) downstream from `bij` in the graph. May be `None` if there are no such bijectors. get_forward_min_event_ndims: callable; may be overridden to swap forward/inverse direction. get_inverse_min_event_ndims: callable; may be overridden to swap forward/inverse direction. inverse_event_ndims_fn: callable; may be overridden to swap forward/inverse direction. Returns: downstream_quantities: Instance of `MinEventNdimsDownstreamQuantities` namedtuple containing event_ndims that satisfy `bij` and all downstream bijectors. """ if downstream_quantities is None: # This is a leaf bijector. return MinEventNdimsInferenceDownstreamQuantities( forward_min_event_ndims=get_forward_min_event_ndims(bij), parts_interact=bij._parts_interact) # pylint: disable=protected-access inverse_min_event_ndims = get_inverse_min_event_ndims(bij) downstream_min_event_ndims = nest_util.coerce_structure( inverse_min_event_ndims, downstream_quantities.forward_min_event_ndims) # Update the min_event_ndims that is a valid input to downstream bijectors # to also be a valid *output* of this bijector, or equivalently, a valid # input to `bij.inverse`. rank_mismatches = tf.nest.flatten( tf.nest.map_structure( lambda dim, min_dim: dim - min_dim, downstream_min_event_ndims, inverse_min_event_ndims)) if downstream_quantities.parts_interact: # If downstream bijectors involve interaction between parts, # then a valid input to the downstream bijectors must augment the # `downstream_min_event_ndims` by the # same rank for every part (otherwise we would induce event shape # broadcasting). Hopefully, this will also avoid event-shape broadcasting # at the current bijector---if not, the composition is invalid, and the call # to `bij.inverse_event_ndims(valid_inverse_min_event_ndims)` below will # raise an exception. maximum_rank_deficiency = -ps.reduce_min([0] + rank_mismatches) valid_inverse_min_event_ndims = tf.nest.map_structure( lambda ndims: maximum_rank_deficiency + ndims, downstream_min_event_ndims) else: if bij._parts_interact: # pylint: disable=protected-access # If this bijector does *not* operate independently on its parts, then a # valid input to `inverse` cannot require event shape broadcasting. That # is, each part must have the same 'excess rank' above the local # inverse_min_event_ndims; we ensure this by construction. maximum_excess_rank = ps.reduce_max([0] + rank_mismatches) valid_inverse_min_event_ndims = tf.nest.map_structure( lambda ndims: maximum_excess_rank + ndims, inverse_min_event_ndims) else: # If all parts are independent, can take the pointwise max event_ndims. valid_inverse_min_event_ndims = tf.nest.map_structure( ps.maximum, downstream_min_event_ndims, inverse_min_event_ndims) return MinEventNdimsInferenceDownstreamQuantities( # Pull the desired output ndims back through the bijector, to get # the ndims of a valid *input*. forward_min_event_ndims=inverse_event_ndims_fn( bij, valid_inverse_min_event_ndims), parts_interact=( downstream_quantities.parts_interact or bij._parts_interact)) # pylint: disable=protected-access
def sanitize_event_ndims(event_ndims): """Updates `rolling_offset` when event_ndims are negative.""" nonlocal rolling_offset max_missing_ndims = -ps.reduce_min(nest.flatten(event_ndims)) rolling_offset = ps.maximum(rolling_offset, max_missing_ndims) return event_ndims