def _maybe_validate_shape_override(self, override_shape, base_is_scalar, validate_args, name): """Helper to __init__ which ensures override batch/event_shape are valid.""" if override_shape is None: override_shape = [] override_shape = tf.convert_to_tensor(value=override_shape, dtype=tf.int32, name=name) if not dtype_util.is_integer(override_shape.dtype): raise TypeError("shape override must be an integer") override_is_scalar = _is_scalar_from_shape_tensor(override_shape) if tf.get_static_value(override_is_scalar): return self._empty dynamic_assertions = [] if tensorshape_util.rank(override_shape.shape) is not None: if tensorshape_util.rank(override_shape.shape) != 1: raise ValueError("shape override must be a vector") elif validate_args: dynamic_assertions += [ assert_util.assert_rank( override_shape, 1, message="shape override must be a vector") ] if tf.get_static_value(override_shape) is not None: if any(s < 0 for s in tf.get_static_value(override_shape)): raise ValueError( "shape override must have non-negative elements") elif validate_args: dynamic_assertions += [ assert_util.assert_non_negative( override_shape, message="shape override must have non-negative elements") ] is_both_nonscalar = prefer_static.logical_and( prefer_static.logical_not(base_is_scalar), prefer_static.logical_not(override_is_scalar)) if tf.get_static_value(is_both_nonscalar) is not None: if tf.get_static_value(is_both_nonscalar): raise ValueError("base distribution not scalar") elif validate_args: dynamic_assertions += [ assert_util.assert_equal( is_both_nonscalar, False, message="base distribution not scalar") ] if not dynamic_assertions: return override_shape return distribution_util.with_dependencies(dynamic_assertions, override_shape)
def _needs_rotation(self, override_event_shape, override_batch_shape, base_is_scalar_batch): # To convert a scalar distribution into a multivariate distribution we # will permute dims from the sample dims, which are otherwise iid. This is # easy to do except in the case that the base distribution has nonscalar # batch and we're overriding only event shape. Under these conditions, this # function returns `True`, indicating that event dims will incorrectly be to # the left of the batch dims and we'll need to cyclically permute left the # new dims (in `_maybe_rotate_dims`). If these conditions do not hold, this # function returns `False` and no rotation is needed. return prefer_static.reduce_all([ self._has_nonzero_rank(override_event_shape), prefer_static.logical_not( self._has_nonzero_rank(override_batch_shape)), prefer_static.logical_not(base_is_scalar_batch) ])
def _has_nonzero_rank(self, override_shape): return prefer_static.logical_not( prefer_static.equal(prefer_static.rank_from_shape(override_shape), self._zero))
def __init__(self, distribution, bijector, batch_shape=None, event_shape=None, kwargs_split_fn=_default_kwargs_split_fn, validate_args=False, parameters=None, name=None): """Construct a Transformed Distribution. Args: distribution: The base distribution instance to transform. Typically an instance of `Distribution`. bijector: The object responsible for calculating the transformation. Typically an instance of `Bijector`. batch_shape: `integer` vector `Tensor` which overrides `distribution` `batch_shape`; valid only if `distribution.is_scalar_batch()`. event_shape: `integer` vector `Tensor` which overrides `distribution` `event_shape`; valid only if `distribution.is_scalar_event()`. kwargs_split_fn: Python `callable` which takes a kwargs `dict` and returns a tuple of kwargs `dict`s for each of the `distribution` and `bijector` parameters respectively. Default value: `_default_kwargs_split_fn` (i.e., `lambda kwargs: (kwargs.get('distribution_kwargs', {}), kwargs.get('bijector_kwargs', {}))`) validate_args: Python `bool`, default `False`. When `True` distribution parameters are checked for validity despite possibly degrading runtime performance. When `False` invalid inputs may silently render incorrect outputs. parameters: Locals dict captured by subclass constructor, to be used for copy/slice re-instantiation operations. name: Python `str` name prefixed to Ops created by this class. Default: `bijector.name + distribution.name`. """ parameters = dict(locals()) if parameters is None else parameters name = name or (("" if bijector is None else bijector.name) + (distribution.name or "")) with tf.name_scope(name) as name: self._kwargs_split_fn = (_default_kwargs_split_fn if kwargs_split_fn is None else kwargs_split_fn) # For convenience we define some handy constants. self._zero = tf.constant(0, dtype=tf.int32, name="zero") self._empty = tf.constant([], dtype=tf.int32, name="empty") # We will keep track of a static and dynamic version of # self._is_{batch,event}_override. This way we can do more prior to graph # execution, including possibly raising Python exceptions. self._override_batch_shape = self._maybe_validate_shape_override( batch_shape, distribution.is_scalar_batch(), validate_args, "batch_shape") self._is_batch_override = prefer_static.logical_not( prefer_static.equal( prefer_static.rank_from_shape(self._override_batch_shape), self._zero)) self._is_maybe_batch_override = bool( tf.get_static_value(self._override_batch_shape) is None or tf.get_static_value(self._override_batch_shape).size != 0) self._override_event_shape = self._maybe_validate_shape_override( event_shape, distribution.is_scalar_event(), validate_args, "event_shape") self._is_event_override = prefer_static.logical_not( prefer_static.equal( prefer_static.rank_from_shape(self._override_event_shape), self._zero)) self._is_maybe_event_override = bool( tf.get_static_value(self._override_event_shape) is None or tf.get_static_value(self._override_event_shape).size != 0) # To convert a scalar distribution into a multivariate distribution we # will draw dims from the sample dims, which are otherwise iid. This is # easy to do except in the case that the base distribution has batch dims # and we're overriding event shape. When that case happens the event dims # will incorrectly be to the left of the batch dims. In this case we'll # cyclically permute left the new dims. self._needs_rotation = prefer_static.reduce_all([ self._is_event_override, prefer_static.logical_not(self._is_batch_override), prefer_static.logical_not(distribution.is_scalar_batch()) ]) override_event_ndims = prefer_static.rank_from_shape( self._override_event_shape) self._rotate_ndims = _pick_scalar_condition( self._needs_rotation, override_event_ndims, 0) # We'll be reducing the head dims (if at all), i.e., this will be [] # if we don't need to reduce. self._reduce_event_indices = tf.range( self._rotate_ndims - override_event_ndims, self._rotate_ndims) self._distribution = distribution self._bijector = bijector super(TransformedDistribution, self).__init__( dtype=self._distribution.dtype, reparameterization_type=self._distribution.reparameterization_type, validate_args=validate_args, allow_nan_stats=self._distribution.allow_nan_stats, parameters=parameters, # We let TransformedDistribution access _graph_parents since this class # is more like a baseclass than derived. graph_parents=( distribution._graph_parents + # pylint: disable=protected-access bijector.graph_parents), name=name)
def trace_scan(loop_fn, initial_state, elems, trace_fn, trace_criterion_fn=None, static_trace_allocation_size=None, parallel_iterations=10, name=None): """A simplified version of `tf.scan` that has configurable tracing. This function repeatedly calls `loop_fn(state, elem)`, where `state` is the `initial_state` during the first iteration, and the return value of `loop_fn` for every iteration thereafter. `elem` is a slice of `elements` along the first dimension, accessed in order. Additionally, it calls `trace_fn` on the return value of `loop_fn`. The `Tensor`s in return values of `trace_fn` are stacked and returned from this function, such that the first dimension of those `Tensor`s matches the size of `elems`. Args: loop_fn: A callable that takes in a `Tensor` or a nested collection of `Tensor`s with the same structure as `initial_state`, a slice of `elems` and returns the same structure as `initial_state`. initial_state: A `Tensor` or a nested collection of `Tensor`s passed to `loop_fn` in the first iteration. elems: A `Tensor` that is split along the first dimension and each element of which is passed to `loop_fn`. trace_fn: A callable that takes in the return value of `loop_fn` and returns a `Tensor` or a nested collection of `Tensor`s. trace_criterion_fn: Optional callable that takes in the return value of `loop_fn` and returns a boolean `Tensor` indicating whether to trace it. If `None`, all steps are traced. Default value: `None`. static_trace_allocation_size: Optional Python `int` size of trace to allocate statically. This should be an upper bound on the number of steps traced and is used only when the length cannot be statically inferred (for example, if a `trace_criterion_fn` is specified). It is primarily intended for contexts where static shapes are required, such as in XLA-compiled code. Default value: `None`. parallel_iterations: Passed to the internal `tf.while_loop`. name: Name scope used in this function. Default: 'trace_scan'. Returns: final_state: The final return value of `loop_fn`. trace: The same structure as the return value of `trace_fn`, but with each `Tensor` being a stack of the corresponding `Tensors` in the return value of `trace_fn` for each slice of `elems`. """ with tf.name_scope(name or 'trace_scan'), tf1.variable_scope( tf1.get_variable_scope()) as vs: if vs.caching_device is None and not tf.executing_eagerly(): vs.set_caching_device(lambda op: op.device) initial_state = tf.nest.map_structure( lambda x: tf.convert_to_tensor(x, name='initial_state'), initial_state) elems = tf.convert_to_tensor(elems, name='elems') length = prefer_static.size0(elems) # This is an TensorArray in part because of XLA, which had trouble with # non-statically known indices. I.e. elems[i] errored, but # elems_array.read(i) worked. elems_array = tf.TensorArray( elems.dtype, size=length, element_shape=elems.shape[1:]) elems_array = elems_array.unstack(elems) # Initialize trace arrays. dynamic_size, initial_size = True, 0 if trace_criterion_fn is None: dynamic_size = prefer_static.logical_not(prefer_static.is_numpy(length)) initial_size = length elif static_trace_allocation_size: dynamic_size, initial_size = False, static_trace_allocation_size trace_arrays = tf.nest.map_structure( lambda x: tf.TensorArray(x.dtype, # pylint: disable=g-long-lambda size=initial_size, dynamic_size=dynamic_size, element_shape=x.shape), trace_fn(initial_state)) # Helper for writing a (structured) state to (structured) arrays. def trace_one_step(num_steps_traced, trace_arrays, state): return tf.nest.map_structure( lambda ta, x: ta.write(num_steps_traced, x), trace_arrays, trace_fn(state)) def _body(i, state, num_steps_traced, trace_arrays): elem = elems_array.read(i) state = loop_fn(state, elem) trace_arrays, num_steps_traced = prefer_static.cond( trace_criterion_fn(state) if trace_criterion_fn else True, lambda: (trace_one_step(num_steps_traced, trace_arrays, state), # pylint: disable=g-long-lambda num_steps_traced + 1), lambda: (trace_arrays, num_steps_traced)) return i + 1, state, num_steps_traced, trace_arrays _, final_state, _, trace_arrays = tf.while_loop( cond=lambda i, *_: i < length, body=_body, loop_vars=(0, initial_state, 0, trace_arrays), parallel_iterations=parallel_iterations) stacked_trace = tf.nest.map_structure(lambda x: x.stack(), trace_arrays) # Restore the static length if we know it. static_length = tf.TensorShape(None if dynamic_size else initial_size) def _merge_static_length(x): tensorshape_util.set_shape(x, static_length.concatenate(x.shape[1:])) return x stacked_trace = tf.nest.map_structure(_merge_static_length, stacked_trace) return final_state, stacked_trace