def _fn(x): """MADE parameterized via `masked_autoregressive_default_template`.""" # TODO(b/67594795): Better support of dynamic shape. input_depth = tf.compat.dimension_value( tensorshape_util.with_rank_at_least(x.shape, 1)[-1]) if input_depth is None: raise NotImplementedError( 'Rightmost dimension must be known prior to graph execution.' ) input_shape = (np.int32(tensorshape_util.as_list(x.shape)) if tensorshape_util.is_fully_defined(x.shape) else tf.shape(x)) if tensorshape_util.rank(x.shape) == 1: x = x[tf.newaxis, ...] for i, units in enumerate(hidden_layers): x = masked_dense( inputs=x, units=units, num_blocks=input_depth, exclusive=True if i == 0 else False, activation=activation, *args, # pylint: disable=keyword-arg-before-vararg **kwargs) x = masked_dense( inputs=x, units=(1 if shift_only else 2) * input_depth, num_blocks=input_depth, activation=None, *args, # pylint: disable=keyword-arg-before-vararg **kwargs) if shift_only: x = tf.reshape(x, shape=input_shape) return x, None x = tf.reshape(x, shape=tf.concat([input_shape, [2]], axis=0)) shift, log_scale = tf.unstack(x, num=2, axis=-1) which_clip = (tf.clip_by_value if log_scale_clip_gradient else clip_by_value_preserve_gradient) log_scale = which_clip(log_scale, log_scale_min_clip, log_scale_max_clip) return shift, log_scale
def _parameter_control_dependencies(self, is_init): if tensorshape_util.is_fully_defined(self.distribution.batch_shape): if self.to_shape is not None: static_to_shape = tf.get_static_value(self.to_shape) if static_to_shape is not None: bcast_shp = tf.broadcast_static_shape( tf.TensorShape(static_to_shape), self.distribution.batch_shape) if bcast_shp != static_to_shape: raise ValueError(f'Argument `to_shape` ({static_to_shape}) ' 'is incompatible with underlying distribution ' f'batch shape ({self.distribution.batch_shape}).') else: static_with_shape = tf.get_static_value(self.with_shape) if static_with_shape is not None: tf.broadcast_static_shape( # Ensure compatible. tf.TensorShape(static_with_shape), self.distribution.batch_shape) underlying = self.distribution._parameter_control_dependencies(is_init) # pylint: disable=protected-access if not self.validate_args: return underlying checks = [] if self.to_shape is not None: if tensor_util.is_ref(self.to_shape) != is_init: checks += [assert_util.assert_equal( self.to_shape, ps.broadcast_shape(self.distribution.batch_shape_tensor(), self.to_shape), message='Argument `to_shape` is incompatible with underlying ' 'distribution batch shape.')] else: if tensor_util.is_ref(self.with_shape) != is_init: checks += [tf.broadcast_dynamic_shape( self.distribution.batch_shape_tensor(), self.with_shape)] return tuple(checks) + tuple(underlying)
def _flatten_and_concat_event(self, x): def _reshape_part(part, event_shape): part = tf.cast(part, self.dtype) static_rank = tf.get_static_value(ps.rank_from_shape(event_shape)) if static_rank == 1: return part new_shape = ps.concat([ ps.shape(part)[:ps.size(ps.shape(part)) - ps.size(event_shape)], [-1] ], axis=-1) return tf.reshape(part, ps.cast(new_shape, tf.int32)) if all( tensorshape_util.is_fully_defined(s) for s in tf.nest.flatten(self._distribution.event_shape)): x = tf.nest.map_structure(_reshape_part, x, self._distribution.event_shape) else: x = tf.nest.map_structure(_reshape_part, x, self._distribution.event_shape_tensor()) return tf.concat(tf.nest.flatten(x), axis=-1)
def _broadcast_cat_event_and_params(event, params, base_dtype): """Broadcasts the event or distribution parameters.""" if dtype_util.is_integer(event.dtype): pass elif dtype_util.is_floating(event.dtype): # When `validate_args=True` we've already ensured int/float casting # is closed. event = tf.cast(event, dtype=tf.int32) else: raise TypeError("`value` should have integer `dtype` or " "`self.dtype` ({})".format(base_dtype)) shape_known_statically = (tensorshape_util.rank(params.shape) is not None and params.shape[:-1].is_fully_defined() and tensorshape_util.is_fully_defined(event.shape)) if not shape_known_statically or params.shape[:-1] != event.shape: params *= tf.ones_like(event[..., tf.newaxis], dtype=params.dtype) params_shape = tf.shape(input=params)[:-1] event *= tf.ones(params_shape, dtype=event.dtype) if tensorshape_util.rank(params.shape) is not None: tensorshape_util.set_shape(event, params.shape[:-1]) return event, params
def _slice_params_to_dict(dist, params_event_ndims, slices): """Computes the override dictionary of sliced parameters. Args: dist: The tfd.Distribution being batch-sliced. params_event_ndims: Per-event parameter ranks, a `str->int` `dict`. slices: Slices as received by __getitem__. Returns: overrides: `str->Tensor` `dict` of batch-sliced parameter overrides. """ override_dict = {} for param_name, param_event_ndims in params_event_ndims.items(): # Verify that either None or a legit value is in the parameters dict. if param_name not in dist.parameters: raise ValueError('Distribution {} is missing advertised ' 'parameter {}'.format(dist, param_name)) param = dist.parameters[param_name] if param is None: # some distributions have multiple possible parameterizations; this # param was not provided continue dtype = None if hasattr(dist, param_name): attr = getattr(dist, param_name) dtype = getattr(attr, 'dtype', None) if dtype is None: dtype = dist.dtype warnings.warn('Unable to find property getter for parameter Tensor {} ' 'on {}, falling back to Distribution.dtype {}'.format( param_name, dist, dtype)) param = tf.convert_to_tensor(value=param, dtype=dtype) dist_batch_shape = dist.batch_shape if not tensorshape_util.is_fully_defined(dist_batch_shape): dist_batch_shape = dist.batch_shape_tensor() override_dict[param_name] = _slice_single_param(param, param_event_ndims, slices, dist_batch_shape) return override_dict
def _validate_block_sizes(block_sizes, bijectors, validate_args): """Helper to validate block sizes.""" block_sizes_shape = block_sizes.shape if tensorshape_util.is_fully_defined(block_sizes_shape): if (tensorshape_util.rank(block_sizes_shape) != 1 or (tensorshape_util.num_elements(block_sizes_shape) != len(bijectors))): raise ValueError( '`block_sizes` must be `None`, or a vector of the same length as ' '`bijectors`. Got a `Tensor` with shape {} and `bijectors` of ' 'length {}'.format(block_sizes_shape, len(bijectors))) return block_sizes elif validate_args: message = ('`block_sizes` must be `None`, or a vector of the same length ' 'as `bijectors`.') with tf.control_dependencies([ assert_util.assert_equal( tf.size(block_sizes), len(bijectors), message=message), assert_util.assert_equal(tf.rank(block_sizes), 1) ]): return tf.identity(block_sizes) else: return block_sizes
def _split_and_reshape_event(self, x): assertions = [] message = 'Input must have at least one dimension.' if tensorshape_util.rank(x.shape) is not None: if tensorshape_util.rank(x.shape) == 0: raise ValueError(message) elif self.validate_args: assertions.append( assert_util.assert_rank_at_least(x, 1, message=message)) with tf.control_dependencies(assertions): event_tensors = self._distribution.event_shape_tensor() splits = [ ps.maximum(1, ps.reduce_prod(s)) for s in tf.nest.flatten(event_tensors) ] x = tf.nest.pack_sequence_as(event_tensors, tf.split(x, splits, axis=-1)) def _reshape_part(part, dtype, event_shape): part = tf.cast(part, dtype) static_rank = tf.get_static_value( ps.rank_from_shape(event_shape)) if static_rank == 1: return part new_shape = ps.concat([ps.shape(part)[:-1], event_shape], axis=-1) return tf.reshape(part, ps.cast(new_shape, tf.int32)) if all( tensorshape_util.is_fully_defined(s) for s in tf.nest.flatten(self._distribution.event_shape)): x = tf.nest.map_structure(_reshape_part, x, self._distribution.dtype, self._distribution.event_shape) else: x = tf.nest.map_structure( _reshape_part, x, self._distribution.dtype, self._distribution.event_shape_tensor()) return x
def _sample_control_dependencies(self, x): assertions = [] if tensorshape_util.is_fully_defined(x.shape[-2:]): if not (tensorshape_util.dims(x.shape)[-2] == tensorshape_util.dims(x.shape)[-1] == self.dimension): raise ValueError( 'Input dimension mismatch: expected [..., {}, {}], got {}'.format( self.dimension, self.dimension, tensorshape_util.dims(x.shape))) elif self.validate_args: msg = 'Input dimension mismatch: expected [..., {}, {}], got {}'.format( self.dimension, self.dimension, tf.shape(x)) assertions.append(assert_util.assert_equal( tf.shape(x)[-2], self.dimension, message=msg)) assertions.append(assert_util.assert_equal( tf.shape(x)[-1], self.dimension, message=msg)) if self.validate_args: assertions.append(assert_util.assert_near( x, tf.linalg.band_part(x, -1, 0), message='Cholesky factors must be lower triangular.')) return assertions
def _has_valid_dimensions(self, x): if tensorshape_util.is_fully_defined(x.shape[-2:]): if (tensorshape_util.dims(x.shape)[-2] == tensorshape_util.dims( x.shape)[-1] == self.dimension): return [] else: raise ValueError( 'Input dimension mismatch: expected [..., {}, {}], got {}'. format(self.dimension, self.dimension, tensorshape_util.dims(x.shape))) elif self.validate_args: msg = 'Input dimension mismatch: expected [..., {}, {}], got {}'.format( self.dimension, self.dimension, tf.shape(x)) return [ assert_util.assert_equal(tf.shape(x)[-2], self.dimension, message=msg), assert_util.assert_equal(tf.shape(x)[-1], self.dimension, message=msg) ] return []
def _validate_dimension(self, x): x = tf.convert_to_tensor(x, name='x') if tensorshape_util.is_fully_defined(x.shape[-2:]): if (tensorshape_util.dims(x.shape)[-2] == tensorshape_util.dims(x.shape)[-1] == self.dimension): pass else: raise ValueError( 'Input dimension mismatch: expected [..., {}, {}], got {}'.format( self.dimension, self.dimension, tensorshape_util.dims(x.shape))) elif self.validate_args: msg = 'Input dimension mismatch: expected [..., {}, {}], got {}'.format( self.dimension, self.dimension, tf.shape(x)) with tf.control_dependencies([ assert_util.assert_equal( tf.shape(x)[-2], self.dimension, message=msg), assert_util.assert_equal( tf.shape(x)[-1], self.dimension, message=msg) ]): x = tf.identity(x) return x
def _sample_control_dependencies(self, x): assertions = [] if tensorshape_util.is_fully_defined(x.shape[-2:]): if not (tensorshape_util.dims(x.shape)[-2] == tensorshape_util.dims(x.shape)[-1] == self.dimension): raise ValueError( 'Input dimension mismatch: expected [..., {}, {}], got {}'.format( self.dimension, self.dimension, tensorshape_util.dims(x.shape))) elif self.validate_args: msg = 'Input dimension mismatch: expected [..., {}, {}], got {}'.format( self.dimension, self.dimension, tf.shape(x)) assertions.append(assert_util.assert_equal( tf.shape(x)[-2], self.dimension, message=msg)) assertions.append(assert_util.assert_equal( tf.shape(x)[-1], self.dimension, message=msg)) if self.validate_args and not self.input_output_cholesky: assertions.append(assert_util.assert_less_equal( dtype_util.as_numpy_dtype(x.dtype)(-1), x, message='Correlations must be >= -1.', summarize=30)) assertions.append(assert_util.assert_less_equal( x, dtype_util.as_numpy_dtype(x.dtype)(1), message='Correlations must be <= 1.', summarize=30)) assertions.append(assert_util.assert_near( tf.linalg.diag_part(x), dtype_util.as_numpy_dtype(x.dtype)(1), message='Self-correlations must be = 1.', summarize=30)) assertions.append(assert_util.assert_near( x, tf.linalg.matrix_transpose(x), message='Correlation matrices must be symmetric.', summarize=30)) return assertions
def _sub_diag(nonmatrix): """Get the first sub-diagonal of a shape [N, N, ...] 'non matrix'.""" with tf.name_scope('sub_matrix'): # TODO(b/143702351) Once array_ops.matrix_diag_part_v3 is ready and exposed, # replace the call to matrix_diag_part_v2 below with tf.linalg.matrix_diag. # We can also stop special casing for matrix_dim < 2 at that point. # Until then, OpError raised for 1x1 matricies without static shape. # In fact, non-static shape breaks matrix_diag_part_v2, so we must raise # this message now. # See http://b/138403336 for the TF issue tracker. if not tensorshape_util.is_fully_defined(nonmatrix.shape[:2]): raise ValueError( '`inverse_temperatures did not have statically defined shape, ' 'which breaks tracking of is_swap_{proposed,accepted}. ' 'Please provide an inverse_temperatures with statically known shape.' ) # The sub-matrix of a 1x1 matrix is not defined (throws exception), so in # this special case return an empty matrix. # TODO(b/143702351) Remove this special case handling once # matrix_diag_part_v3 is ready. matrix_dim = ps.size0(nonmatrix) if matrix_dim is not None and matrix_dim < 2: # Shape is [..., 0], so returned tensor is empty, thus contains no # values...and therefore the fact that we use 'ones' doesn't matter. shape = ps.pad(ps.shape(nonmatrix)[2:], paddings=[[0, 1]], constant_values=0) matrix_sub_diag = tf.cast(tf.ones(shape), nonmatrix.dtype) else: # Get first sub-diagonal. `padding_value` is not used (since matrix is # square), but is required for the API since this is raw gen_array_ops. matrix_sub_diag = tf.raw_ops.MatrixDiagPartV2( input=distribution_util.rotate_transpose(nonmatrix, shift=-2), k=ps.convert_to_shape_tensor(-1, dtype=tf.int32), padding_value=tf.cast(0.0, dtype=nonmatrix.dtype)) return distribution_util.rotate_transpose(matrix_sub_diag, shift=1)
def get_broadcast_shape(*tensors): """Get broadcast shape as a Python list of integers (preferred) or `Tensor`. Args: *tensors: One or more `Tensor` objects (already converted!). Returns: broadcast shape: Python list (if shapes determined statically), otherwise an `int32` `Tensor`. """ # Try static. s_shape = tensors[0].shape for t in tensors[1:]: s_shape = tf.broadcast_static_shape(s_shape, t.shape) if tensorshape_util.is_fully_defined(s_shape): return tensorshape_util.as_list(s_shape) # Fallback on dynamic. d_shape = tf.shape(tensors[0]) for t in tensors[1:]: d_shape = tf.broadcast_dynamic_shape(d_shape, tf.shape(t)) return d_shape
def _move_dims_to_flat_end(x, axis, x_ndims, right_end=True): """Move dims corresponding to `axis` in `x` to the end, then flatten. Args: x: `Tensor` with shape `[B0,B1,...,Bb]`. axis: Python list of indices into dimensions of `x`. x_ndims: Python integer holding number of dimensions in `x`. right_end: Python bool. Whether to move dims to the right end (else left). Returns: `Tensor` with value from `x` and dims in `axis` moved to end into one single dimension. """ if not axis: return x # Suppose x.shape = [a, b, c, d] # Suppose axis = [1, 3] # other_dims = [0, 2] in example above. other_dims = sorted(set(range(x_ndims)).difference(axis)) # x_permed.shape = [a, c, b, d] perm = other_dims + list(axis) if right_end else list(axis) + other_dims x_permed = tf.transpose(a=x, perm=perm) if tensorshape_util.is_fully_defined(x.shape): x_shape = tensorshape_util.as_list(x.shape) # other_shape = [a, c], end_shape = [b * d] other_shape = [x_shape[i] for i in other_dims] end_shape = [np.prod([x_shape[i] for i in axis])] full_shape = (other_shape + end_shape if right_end else end_shape + other_shape) else: other_shape = ps.gather(ps.shape(x), ps.cast(other_dims, tf.int64)) full_shape = ps.concat( [other_shape, [-1]] if right_end else [[-1], other_shape], axis=0) return tf.reshape(x_permed, shape=full_shape)
def _calculate_new_shape(self): # Try to get the old shape statically if available. original_shape = self._distribution.batch_shape if not tensorshape_util.is_fully_defined(original_shape): original_shape = self._distribution.batch_shape_tensor() # This is not a check for falseness, it's a check for exactly that shape. if original_shape == (): # pylint: disable=g-explicit-bool-comparison # Force the size to be an integer, not a float, when the shape contains no # dtype information. original_size = 1 else: original_size = ps.reduce_prod(original_shape) original_size = ps.cast(original_size, tf.int32) # Compute the new shape, filling in the `-1` dimension if present. new_shape = self._batch_shape_unexpanded implicit_dim_mask = ps.equal(new_shape, -1) size_implicit_dim = (original_size // ps.maximum(1, -ps.reduce_prod(new_shape))) expanded_new_shape = ps.where( # Assumes exactly one `-1`. implicit_dim_mask, size_implicit_dim, new_shape) # Return the original size on the side because one caller would otherwise # have to recompute it. return expanded_new_shape, original_size
def event_shape_tensor(self, name='event_shape_tensor'): """Shape of a single sample from a single batch as a 1-D int32 `Tensor`. Args: name: name to give to the op Returns: event_shape: `Tensor`. """ with self._name_and_control_scope(name): if all([tensorshape_util.is_fully_defined(s) for s in nest.flatten(self.event_shape)]): event_shape = nest.map_structure_up_to( self.dtype, tensorshape_util.as_list, self.event_shape, check_types=False) else: event_shape = self._event_shape_tensor() return nest.map_structure_up_to( self.dtype, lambda s: tf.identity( # pylint: disable=g-long-lambda tf.convert_to_tensor(s, dtype=tf.int32), name='event_shape'), event_shape, check_types=False)
def _sample_shape(self, x): """Computes graph and static `sample_shape`.""" x_ndims = ( tf.rank(x) if tensorshape_util.rank(x.shape) is None else tensorshape_util.rank(x.shape)) event_ndims = ( tf.size(self.event_shape_tensor()) if tensorshape_util.rank(self.event_shape) is None else tensorshape_util.rank(self.event_shape)) batch_ndims = ( tf.size(self._batch_shape_unexpanded) if tensorshape_util.rank(self.batch_shape) is None else tensorshape_util.rank(self.batch_shape)) sample_ndims = x_ndims - batch_ndims - event_ndims if isinstance(sample_ndims, int): static_sample_shape = x.shape[:sample_ndims] else: static_sample_shape = tf.TensorShape(None) if tensorshape_util.is_fully_defined(static_sample_shape): sample_shape = np.int32(static_sample_shape) else: sample_shape = tf.shape(x)[:sample_ndims] return sample_shape, static_sample_shape
def param_static_shapes(cls, sample_shape): """param_shapes with static (i.e. `TensorShape`) shapes. This is a class method that describes what key/value arguments are required to instantiate the given `Distribution` so that a particular shape is returned for that instance's call to `sample()`. Assumes that the sample's shape is known statically. Subclasses should override class method `_param_shapes` to return constant-valued tensors when constant values are fed. Args: sample_shape: `TensorShape` or python list/tuple. Desired shape of a call to `sample()`. Returns: `dict` of parameter name to `TensorShape`. Raises: ValueError: if `sample_shape` is a `TensorShape` and is not fully defined. """ if isinstance(sample_shape, tf.TensorShape): if not tensorshape_util.is_fully_defined(sample_shape): raise ValueError('TensorShape sample_shape must be fully defined') sample_shape = tensorshape_util.as_list(sample_shape) params = cls.param_shapes(sample_shape) static_params = {} for name, shape in params.items(): static_shape = tf.get_static_value(shape) if static_shape is None: raise ValueError( 'sample_shape must be a fully-defined TensorShape or list/tuple') static_params[name] = tf.TensorShape(static_shape) return static_params
def _replace_event_shape_in_tensorshape( input_tensorshape, event_shape_in, event_shape_out): """Replaces the event shape dims of a `TensorShape`. Args: input_tensorshape: a `TensorShape` instance in which to attempt replacing event shape. event_shape_in: `Tensor` shape representing the event shape expected to be present in (rightmost dims of) `tensorshape_in`. Must be compatible with the rightmost dims of `tensorshape_in`. event_shape_out: `Tensor` shape representing the new event shape, i.e., the replacement of `event_shape_in`, Returns: output_tensorshape: `TensorShape` with the rightmost `event_shape_in` replaced by `event_shape_out`. Might be partially defined, i.e., `TensorShape(None)`. is_validated: Python `bool` indicating static validation happened. Raises: ValueError: if we can determine the event shape portion of `tensorshape_in` as well as `event_shape_in` both statically, and they are not compatible. "Compatible" here means that they are identical on any dims that are not -1 in `event_shape_in`. """ event_shape_in_ndims = tensorshape_util.num_elements(event_shape_in.shape) if tensorshape_util.rank( input_tensorshape) is None or event_shape_in_ndims is None: return tf.TensorShape(None), False # Not is_validated. input_non_event_ndims = tensorshape_util.rank( input_tensorshape) - event_shape_in_ndims if input_non_event_ndims < 0: raise ValueError( 'Input has fewer ndims ({}) than event shape ndims ({}).'.format( tensorshape_util.rank(input_tensorshape), event_shape_in_ndims)) input_non_event_tensorshape = input_tensorshape[:input_non_event_ndims] input_event_tensorshape = input_tensorshape[input_non_event_ndims:] # Check that `input_event_shape_` and `event_shape_in` are compatible in the # sense that they have equal entries in any position that isn't a `-1` in # `event_shape_in`. Note that our validations at construction time ensure # there is at most one such entry in `event_shape_in`. event_shape_in_ = tf.get_static_value(event_shape_in) is_validated = ( tensorshape_util.is_fully_defined(input_event_tensorshape) and event_shape_in_ is not None) if is_validated: input_event_shape_ = np.int32(input_event_tensorshape) mask = event_shape_in_ >= 0 explicit_input_event_shape_ = input_event_shape_[mask] explicit_event_shape_in_ = event_shape_in_[mask] if not all(explicit_input_event_shape_ == explicit_event_shape_in_): raise ValueError( 'Input `event_shape` does not match `event_shape_in`. ' '({} vs {}).'.format(input_event_shape_, event_shape_in_)) event_tensorshape_out = tensorshape_util.constant_value_as_shape( event_shape_out) if tensorshape_util.rank(event_tensorshape_out) is None: output_tensorshape = tf.TensorShape(None) else: output_tensorshape = tensorshape_util.concatenate( input_non_event_tensorshape, event_tensorshape_out) return output_tensorshape, is_validated
def _replace_event_shape_in_shape_tensor( input_shape, event_shape_in, event_shape_out, validate_args): """Replaces the rightmost dims in a `Tensor` representing a shape. Args: input_shape: a rank-1 `Tensor` of integers event_shape_in: the event shape expected to be present in rightmost dims of `shape_in`. event_shape_out: the event shape with which to replace `event_shape_in` in the rightmost dims of `input_shape`. validate_args: Python `bool` indicating whether arguments should be checked for correctness. Returns: output_shape: A rank-1 integer `Tensor` with the same contents as `input_shape` except for the event dims, which are replaced with `event_shape_out`. """ output_tensorshape, is_validated = _replace_event_shape_in_tensorshape( tensorshape_util.constant_value_as_shape(input_shape), event_shape_in, event_shape_out) # TODO(b/124240153): Remove map(tf.identity, deps) once tf.function # correctly supports control_dependencies. validation_dependencies = ( map(tf.identity, (event_shape_in, event_shape_out)) if validate_args else ()) if (tensorshape_util.is_fully_defined(output_tensorshape) and (is_validated or not validate_args)): with tf.control_dependencies(validation_dependencies): output_shape = tf.convert_to_tensor( output_tensorshape, name='output_shape', dtype_hint=tf.int32) return output_shape, output_tensorshape with tf.control_dependencies(validation_dependencies): event_shape_in_ndims = ( tf.size(event_shape_in) if tensorshape_util.num_elements(event_shape_in.shape) is None else tensorshape_util.num_elements(event_shape_in.shape)) input_non_event_shape, input_event_shape = tf.split( input_shape, num_or_size_splits=[-1, event_shape_in_ndims]) additional_assertions = [] if is_validated: pass elif validate_args: # Check that `input_event_shape` and `event_shape_in` are compatible in the # sense that they have equal entries in any position that isn't a `-1` in # `event_shape_in`. Note that our validations at construction time ensure # there is at most one such entry in `event_shape_in`. mask = event_shape_in >= 0 explicit_input_event_shape = tf.boolean_mask(input_event_shape, mask=mask) explicit_event_shape_in = tf.boolean_mask(event_shape_in, mask=mask) additional_assertions.append( assert_util.assert_equal( explicit_input_event_shape, explicit_event_shape_in, message='Input `event_shape` does not match `event_shape_in`.')) # We don't explicitly additionally verify # `tf.size(input_shape) > tf.size(event_shape_in)` since `tf.split` # already makes this assertion. with tf.control_dependencies(additional_assertions): output_shape = tf.concat([input_non_event_shape, event_shape_out], axis=0, name='output_shape') return output_shape, output_tensorshape
def __init__(self, image_shape: tuple, conditional_shape: tuple = None, num_resnet: int = 5, num_hierarchies: int = 3, num_filters: int = 160, num_logistic_mix: int = 10, receptive_field_dims: tuple = (3, 3), dropout_p: float = 0.5, resnet_activation: str = 'concat_elu', l2_weight: float = 0., use_weight_norm: bool = True, use_data_init: bool = True, high: int = 255, low: int = 0, dtype=tf.float32, name: str = 'PixelCNN') -> None: """ Construct Pixel CNN++ distribution. Parameters ---------- image_shape 3D `TensorShape` or tuple for the `[height, width, channels]` dimensions of the image. conditional_shape `TensorShape` or tuple for the shape of the conditional input, or `None` if there is no conditional input. num_resnet The number of layers (shown in Figure 2 of [2]) within each highest-level block of Figure 2 of [1]. num_hierarchies The number of highest-level blocks (separated by expansions/contractions of dimensions in Figure 2 of [1].) num_filters The number of convolutional filters. num_logistic_mix Number of components in the logistic mixture distribution. receptive_field_dims Height and width in pixels of the receptive field of the convolutional layers above and to the left of a given pixel. The width (second element of the tuple) should be odd. Figure 1 (middle) of [2] shows a receptive field of (3, 5) (the row containing the current pixel is included in the height). The default of (3, 3) was used to produce the results in [1]. dropout_p The dropout probability. Should be between 0 and 1. resnet_activation The type of activation to use in the resnet blocks. May be 'concat_elu', 'elu', or 'relu'. use_weight_norm If `True` then use weight normalization (works only in Eager mode). use_data_init If `True` then use data-dependent initialization (has no effect if `use_weight_norm` is `False`). high The maximum value of the input data (255 for an 8-bit image). low The minimum value of the input data. dtype Data type of the `Distribution`. name The name of the `Distribution`. """ parameters = dict(locals()) with tf.name_scope(name) as name: super(PixelCNN, self).__init__( dtype=dtype, reparameterization_type=reparameterization.NOT_REPARAMETERIZED, validate_args=False, allow_nan_stats=True, parameters=parameters, name=name) if not tensorshape_util.is_fully_defined(image_shape): raise ValueError('`image_shape` must be fully defined.') if conditional_shape is not None and not tensorshape_util.is_fully_defined( conditional_shape): raise ValueError('`conditional_shape` must be fully defined`') if tensorshape_util.rank(image_shape) != 3: raise ValueError( '`image_shape` must have length 3, representing [height, width, channels] dimensions.' ) self._high = tf.cast(high, self.dtype) self._low = tf.cast(low, self.dtype) self._num_logistic_mix = num_logistic_mix self.network = _PixelCNNNetwork( dropout_p=dropout_p, num_resnet=num_resnet, num_hierarchies=num_hierarchies, num_filters=num_filters, num_logistic_mix=num_logistic_mix, receptive_field_dims=receptive_field_dims, resnet_activation=resnet_activation, l2_weight=l2_weight, use_weight_norm=use_weight_norm, use_data_init=use_data_init, dtype=dtype) image_input_shape = tensorshape_util.concatenate([None], image_shape) if conditional_shape is None: input_shape = image_input_shape else: conditional_input_shape = tensorshape_util.concatenate( [None], conditional_shape) input_shape = [image_input_shape, conditional_input_shape] self.image_shape = image_shape self.conditional_shape = conditional_shape self.network.build(input_shape)
def auto_correlation(x, axis=-1, max_lags=None, center=True, normalize=True, name='auto_correlation'): """Auto correlation along one axis. Given a `1-D` wide sense stationary (WSS) sequence `X`, the auto correlation `RXX` may be defined as (with `E` expectation and `Conj` complex conjugate) ``` RXX[m] := E{ W[m] Conj(W[0]) } = E{ W[0] Conj(W[-m]) }, W[n] := (X[n] - MU) / S, MU := E{ X[0] }, S**2 := E{ (X[0] - MU) Conj(X[0] - MU) }. ``` This function takes the viewpoint that `x` is (along one axis) a finite sub-sequence of a realization of (WSS) `X`, and then uses `x` to produce an estimate of `RXX[m]` as follows: After extending `x` from length `L` to `inf` by zero padding, the auto correlation estimate `rxx[m]` is computed for `m = 0, 1, ..., max_lags` as ``` rxx[m] := (L - m)**-1 sum_n w[n + m] Conj(w[n]), w[n] := (x[n] - mu) / s, mu := L**-1 sum_n x[n], s**2 := L**-1 sum_n (x[n] - mu) Conj(x[n] - mu) ``` The error in this estimate is proportional to `1 / sqrt(len(x) - m)`, so users often set `max_lags` small enough so that the entire output is meaningful. Note that since `mu` is an imperfect estimate of `E{ X[0] }`, and we divide by `len(x) - m` rather than `len(x) - m - 1`, our estimate of auto correlation contains a slight bias, which goes to zero as `len(x) - m --> infinity`. Args: x: `float32` or `complex64` `Tensor`. axis: Python `int`. The axis number along which to compute correlation. Other dimensions index different batch members. max_lags: Positive `int` tensor. The maximum value of `m` to consider (in equation above). If `max_lags >= x.shape[axis]`, we effectively re-set `max_lags` to `x.shape[axis] - 1`. center: Python `bool`. If `False`, do not subtract the mean estimate `mu` from `x[n]` when forming `w[n]`. normalize: Python `bool`. If `False`, do not divide by the variance estimate `s**2` when forming `w[n]`. name: `String` name to prepend to created ops. Returns: `rxx`: `Tensor` of same `dtype` as `x`. `rxx.shape[i] = x.shape[i]` for `i != axis`, and `rxx.shape[axis] = max_lags + 1`. Raises: TypeError: If `x` is not a supported type. """ # Implementation details: # Extend length N / 2 1-D array x to length N by zero padding onto the end. # Then, set # F[x]_k := sum_n x_n exp{-i 2 pi k n / N }. # It is not hard to see that # F[x]_k Conj(F[x]_k) = F[R]_k, where # R_m := sum_n x_n Conj(x_{(n - m) mod N}). # One can also check that R_m / (N / 2 - m) is an unbiased estimate of RXX[m]. # Since F[x] is the DFT of x, this leads us to a zero-padding and FFT/IFFT # based version of estimating RXX. # Note that this is a special case of the Wiener-Khinchin Theorem. with tf.name_scope(name): x = tf.convert_to_tensor(x, name='x') # Rotate dimensions of x in order to put axis at the rightmost dim. # FFT op requires this. rank = ps.rank(x) if axis < 0: axis = rank + axis shift = rank - 1 - axis # Suppose x.shape[axis] = T, so there are T 'time' steps. # ==> x_rotated.shape = B + [T], # where B is x_rotated's batch shape. x_rotated = distribution_util.rotate_transpose(x, shift) if center: x_rotated = x_rotated - tf.reduce_mean( x_rotated, axis=-1, keepdims=True) # x_len = N / 2 from above explanation. The length of x along axis. # Get a value for x_len that works in all cases. x_len = ps.shape(x_rotated)[-1] # TODO(langmore) Investigate whether this zero padding helps or hurts. At # the moment is necessary so that all FFT implementations work. # Zero pad to the next power of 2 greater than 2 * x_len, which equals # 2**(ceil(Log_2(2 * x_len))). Note: Log_2(X) = Log_e(X) / Log_e(2). x_len_float64 = ps.cast(x_len, np.float64) target_length = ps.pow(np.float64(2.), ps.ceil(ps.log(x_len_float64 * 2) / np.log(2.))) pad_length = ps.cast(target_length - x_len_float64, np.int32) # We should have: # x_rotated_pad.shape = x_rotated.shape[:-1] + [T + pad_length] # = B + [T + pad_length] x_rotated_pad = distribution_util.pad(x_rotated, axis=-1, back=True, count=pad_length) dtype = x.dtype if not dtype_util.is_complex(dtype): if not dtype_util.is_floating(dtype): raise TypeError( 'Argument x must have either float or complex dtype' ' found: {}'.format(dtype)) x_rotated_pad = tf.complex( x_rotated_pad, dtype_util.as_numpy_dtype(dtype_util.real_dtype(dtype))(0.)) # Autocorrelation is IFFT of power-spectral density (up to some scaling). fft_x_rotated_pad = tf.signal.fft(x_rotated_pad) spectral_density = fft_x_rotated_pad * tf.math.conj(fft_x_rotated_pad) # shifted_product is R[m] from above detailed explanation. # It is the inner product sum_n X[n] * Conj(X[n - m]). shifted_product = tf.signal.ifft(spectral_density) # Cast back to real-valued if x was real to begin with. shifted_product = tf.cast(shifted_product, dtype) # Figure out if we can deduce the final static shape, and set max_lags. # Use x_rotated as a reference, because it has the time dimension in the far # right, and was created before we performed all sorts of crazy shape # manipulations. know_static_shape = True if not tensorshape_util.is_fully_defined(x_rotated.shape): know_static_shape = False if max_lags is None: max_lags = x_len - 1 else: max_lags = tf.convert_to_tensor(max_lags, name='max_lags') max_lags_ = tf.get_static_value(max_lags) if max_lags_ is None or not know_static_shape: know_static_shape = False max_lags = tf.minimum(x_len - 1, max_lags) else: max_lags = min(x_len - 1, max_lags_) # Chop off the padding. # We allow users to provide a huge max_lags, but cut it off here. # shifted_product_chopped.shape = x_rotated.shape[:-1] + [max_lags] shifted_product_chopped = shifted_product[..., :max_lags + 1] # If possible, set shape. if know_static_shape: chopped_shape = tensorshape_util.as_list(x_rotated.shape) chopped_shape[-1] = min(x_len, max_lags + 1) tensorshape_util.set_shape(shifted_product_chopped, chopped_shape) # Recall R[m] is a sum of N / 2 - m nonzero terms x[n] Conj(x[n - m]). The # other terms were zeros arising only due to zero padding. # `denominator = (N / 2 - m)` (defined below) is the proper term to # divide by to make this an unbiased estimate of the expectation # E[X[n] Conj(X[n - m])]. x_len = ps.cast(x_len, dtype_util.real_dtype(dtype)) max_lags = ps.cast(max_lags, dtype_util.real_dtype(dtype)) denominator = x_len - ps.range(0., max_lags + 1.) denominator = ps.cast(denominator, dtype) shifted_product_rotated = shifted_product_chopped / denominator if normalize: shifted_product_rotated /= shifted_product_rotated[..., :1] # Transpose dimensions back to those of x. return distribution_util.rotate_transpose(shifted_product_rotated, -shift)
def fill_triangular(x, upper=False, name=None): """Creates a (batch of) triangular matrix from a vector of inputs. Created matrix can be lower- or upper-triangular. (It is more efficient to create the matrix as upper or lower, rather than transpose.) Triangular matrix elements are filled in a clockwise spiral. See example, below. If `x.shape` is `[b1, b2, ..., bB, d]` then the output shape is `[b1, b2, ..., bB, n, n]` where `n` is such that `d = n(n+1)/2`, i.e., `n = int(np.sqrt(0.25 + 2. * m) - 0.5)`. Example: ```python fill_triangular([1, 2, 3, 4, 5, 6]) # ==> [[4, 0, 0], # [6, 5, 0], # [3, 2, 1]] fill_triangular([1, 2, 3, 4, 5, 6], upper=True) # ==> [[1, 2, 3], # [0, 5, 6], # [0, 0, 4]] ``` The key trick is to create an upper triangular matrix by concatenating `x` and a tail of itself, then reshaping. Suppose that we are filling the upper triangle of an `n`-by-`n` matrix `M` from a vector `x`. The matrix `M` contains n**2 entries total. The vector `x` contains `n * (n+1) / 2` entries. For concreteness, we'll consider `n = 5` (so `x` has `15` entries and `M` has `25`). We'll concatenate `x` and `x` with the first (`n = 5`) elements removed and reversed: ```python x = np.arange(15) + 1 xc = np.concatenate([x, x[5:][::-1]]) # ==> array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 15, 14, 13, # 12, 11, 10, 9, 8, 7, 6]) # (We add one to the arange result to disambiguate the zeros below the # diagonal of our upper-triangular matrix from the first entry in `x`.) # Now, when reshapedlay this out as a matrix: y = np.reshape(xc, [5, 5]) # ==> array([[ 1, 2, 3, 4, 5], # [ 6, 7, 8, 9, 10], # [11, 12, 13, 14, 15], # [15, 14, 13, 12, 11], # [10, 9, 8, 7, 6]]) # Finally, zero the elements below the diagonal: y = np.triu(y, k=0) # ==> array([[ 1, 2, 3, 4, 5], # [ 0, 7, 8, 9, 10], # [ 0, 0, 13, 14, 15], # [ 0, 0, 0, 12, 11], # [ 0, 0, 0, 0, 6]]) ``` From this example we see that the resuting matrix is upper-triangular, and contains all the entries of x, as desired. The rest is details: - If `n` is even, `x` doesn't exactly fill an even number of rows (it fills `n / 2` rows and half of an additional row), but the whole scheme still works. - If we want a lower triangular matrix instead of an upper triangular, we remove the first `n` elements from `x` rather than from the reversed `x`. For additional comparisons, a pure numpy version of this function can be found in `distribution_util_test.py`, function `_fill_triangular`. Args: x: `Tensor` representing lower (or upper) triangular elements. upper: Python `bool` representing whether output matrix should be upper triangular (`True`) or lower triangular (`False`, default). name: Python `str`. The name to give this op. Returns: tril: `Tensor` with lower (or upper) triangular elements filled from `x`. Raises: ValueError: if `x` cannot be mapped to a triangular matrix. """ with tf.name_scope(name or 'fill_triangular'): x = tf.convert_to_tensor(x, name='x') m = tf.compat.dimension_value( tensorshape_util.with_rank_at_least(x.shape, 1)[-1]) if m is not None: # Formula derived by solving for n: m = n(n+1)/2. m = np.int32(m) n = np.sqrt(0.25 + 2. * m) - 0.5 if n != np.floor(n): raise ValueError( 'Input right-most shape ({}) does not ' 'correspond to a triangular matrix.'.format(m)) n = np.int32(n) static_final_shape = tensorshape_util.concatenate( x.shape[:-1], [n, n]) else: m = tf.shape(x)[-1] # For derivation, see above. Casting automatically lops off the 0.5, so we # omit it. We don't validate n is an integer because this has # graph-execution cost; an error will be thrown from the reshape, below. n = tf.cast(tf.sqrt(0.25 + tf.cast(2 * m, dtype=tf.float32)), dtype=tf.int32) static_final_shape = tensorshape_util.concatenate( tensorshape_util.with_rank_at_least(x.shape, 1)[:-1], [None, None]) # Try it out in numpy: # n = 3 # x = np.arange(n * (n + 1) / 2) # m = x.shape[0] # n = np.int32(np.sqrt(.25 + 2 * m) - .5) # x_tail = x[(m - (n**2 - m)):] # np.concatenate([x_tail, x[::-1]], 0).reshape(n, n) # lower # # ==> array([[3, 4, 5], # [5, 4, 3], # [2, 1, 0]]) # np.concatenate([x, x_tail[::-1]], 0).reshape(n, n) # upper # # ==> array([[0, 1, 2], # [3, 4, 5], # [5, 4, 3]]) # # Note that we can't simply do `x[..., -(n**2 - m):]` because this doesn't # correctly handle `m == n == 1`. Hence, we do nonnegative indexing. # Furthermore observe that: # m - (n**2 - m) # = n**2 / 2 + n / 2 - (n**2 - n**2 / 2 + n / 2) # = 2 (n**2 / 2 + n / 2) - n**2 # = n**2 + n - n**2 # = n ndims = prefer_static.rank(x) if upper: x_list = [x, tf.reverse(x[..., n:], axis=[ndims - 1])] else: x_list = [x[..., n:], tf.reverse(x, axis=[ndims - 1])] new_shape = (tensorshape_util.as_list(static_final_shape) if tensorshape_util.is_fully_defined(static_final_shape) else tf.concat([tf.shape(x)[:-1], [n, n]], axis=0)) x = tf.reshape(tf.concat(x_list, axis=-1), new_shape) x = tf.linalg.band_part(x, num_lower=(0 if upper else -1), num_upper=(-1 if upper else 0)) tensorshape_util.set_shape(x, static_final_shape) return x
def _sample_n(self, n, seed=None): stream = seed_stream.SeedStream(seed, salt="VectorDiffeomixture") x = self.distribution.sample(sample_shape=concat_vectors( [n], self.batch_shape_tensor(), self.event_shape_tensor()), seed=stream()) # shape: [n, B, e] x = [aff.forward(x) for aff in self.endpoint_affine] # Get ids as a [n, batch_size]-shaped matrix, unless batch_shape=[] then get # ids as a [n]-shaped vector. batch_size = tensorshape_util.num_elements(self.batch_shape) if batch_size is None: batch_size = tf.reduce_prod(input_tensor=self.batch_shape_tensor()) mix_batch_size = tensorshape_util.num_elements( self.mixture_distribution.batch_shape) if mix_batch_size is None: mix_batch_size = tf.reduce_prod( input_tensor=self.mixture_distribution.batch_shape_tensor()) ids = self.mixture_distribution.sample(sample_shape=concat_vectors( [n], distribution_util.pick_vector(self.is_scalar_batch(), np.int32([]), [batch_size // mix_batch_size])), seed=stream()) # We need to flatten batch dims in case mixture_distribution has its own # batch dims. ids = tf.reshape(ids, shape=concat_vectors([n], distribution_util.pick_vector( self.is_scalar_batch(), np.int32([]), np.int32([-1])))) # Stride `components * quadrature_size` for `batch_size` number of times. stride = tensorshape_util.num_elements( tensorshape_util.with_rank_at_least(self.grid.shape, 2)[-2:]) if stride is None: stride = tf.reduce_prod(input_tensor=tf.shape( input=self.grid)[-2:]) offset = tf.range(start=0, limit=batch_size * stride, delta=stride, dtype=ids.dtype) weight = tf.gather(tf.reshape(self.grid, shape=[-1]), ids + offset) # At this point, weight flattened all batch dims into one. # We also need to append a singleton to broadcast with event dims. if tensorshape_util.is_fully_defined(self.batch_shape): new_shape = [-1] + tensorshape_util.as_list(self.batch_shape) + [1] else: new_shape = tf.concat(([-1], self.batch_shape_tensor(), [1]), axis=0) weight = tf.reshape(weight, shape=new_shape) if len(x) != 2: # We actually should have already triggered this exception. However as a # policy we're putting this exception wherever we exploit the bimixture # assumption. raise NotImplementedError( "Currently only bimixtures are supported; " "len(scale)={} is not 2.".format(len(x))) # Alternatively: # x = weight * x[0] + (1. - weight) * x[1] x = weight * (x[0] - x[1]) + x[1] return x
def minimize(loss_fn, num_steps, optimizer, trainable_variables=None, trace_fn=_trace_loss, name='minimize'): """Minimize a loss function using a provided optimizer. Args: loss_fn: Python callable with signature `loss = loss_fn()`, where `loss` is a `Tensor` loss to be minimized. num_steps: Python `int` number of steps to run the optimizer. optimizer: Optimizer instance to use. This may be a TF1-style `tf.train.Optimizer`, TF2-style `tf.optimizers.Optimizer`, or any Python object that implements `optimizer.apply_gradients(grads_and_vars)`. trainable_variables: list of `tf.Variable` instances to optimize with respect to. If `None`, defaults to the set of all variables accessed during the execution of `loss_fn()`. Default value: `None`. trace_fn: Python callable with signature `state = trace_fn( loss, grads, variables)`, where `state` may be a `Tensor` or nested structure of `Tensor`s. The state values are accumulated (by `tf.scan`) and returned. The default `trace_fn` simply returns the loss, but in general can depend on the gradients and variables (if `trainable_variables` is not `None` then `variables==trainable_variables`; otherwise it is the list of all variables accessed during execution of `loss_fn()`), as well as any other quantities captured in the closure of `trace_fn`, for example, statistics of a variational distribution. Default value: `lambda loss, grads, variables: loss`. name: Python `str` name prefixed to ops created by this function. Default value: 'minimize'. Returns: trace: `Tensor` or nested structure of `Tensor`s, according to the return type of `trace_fn`. Each `Tensor` has an added leading dimension of size `num_steps`, packing the trajectory of the result over the course of the optimization. ### Examples To minimize the scalar function `(x - 5)**2`: ```python x = tf.Variable(0.) loss_fn = lambda: (x - 5.)**2 losses = tfp.math.minimize(loss_fn, num_steps=100, optimizer=tf.optimizers.Adam(learning_rate=0.1)) # In TF2/eager mode, the optimization runs immediately. print("optimized value is {} with loss {}".format(x, losses[-1])) ``` In graph mode (e.g., inside of `tf.function` wrapping), retrieving any Tensor that depends on the minimization op will trigger the optimization: ```python with tf.control_dependencies([losses]): optimized_x = tf.identity(x) # Use a dummy op to attach the dependency. ``` In some cases, we may want to track additional context inside the optimization. We can do this by defining a custom `trace_fn`. Note that the `trace_fn` is passed the loss and gradients, but it may also report the values of trainable variables or other derived quantities by capturing them in its closure. For example, we can capture `x` and track its value over the optimization: ```python # `x` is the tf.Variable instance defined above. trace_fn = lambda loss, grads, variables: {'loss': loss, 'x': x} trace = tfp.vi.minimize(loss_fn, num_steps=100, optimizer=tf.optimizers.Adam(0.1), trace_fn=trace_fn) print(trace['loss'].shape, # => [100] trace['x'].shape) # => [100] ``` """ @tf.function(autograph=False) def train_loop_body(old_result, step): # pylint: disable=unused-argument """Run a single optimization step.""" with tf.GradientTape( watch_accessed_variables=trainable_variables is None) as tape: for v in trainable_variables or []: tape.watch(v) loss = loss_fn() watched_variables = tape.watched_variables() grads = tape.gradient(loss, watched_variables) train_op = optimizer.apply_gradients(zip(grads, watched_variables)) with tf.control_dependencies([train_op]): state = trace_fn(tf.identity(loss), [tf.identity(g) for g in grads], [tf.identity(v) for v in watched_variables]) return state with tf.name_scope(name) as name: # Compute the shape of the trace without executing the graph, if possible. concrete_loop_body = train_loop_body.get_concrete_function( tf.TensorSpec([]), tf.TensorSpec([])) # Inputs ignored. if all([ tensorshape_util.is_fully_defined(shape) for shape in tf.nest.flatten(concrete_loop_body.output_shapes) ]): state_initializer = tf.nest.map_structure( lambda shape, dtype: tf.zeros(shape, dtype=dtype), concrete_loop_body.output_shapes, concrete_loop_body.output_dtypes) initial_trace_step = None else: state_initializer = concrete_loop_body( tf.convert_to_tensor(0.), tf.convert_to_tensor(0.)) # Inputs ignored. num_steps = num_steps - 1 initial_trace_step = state_initializer # TODO(b/136103064): Rewrite as explicit `while_loop` to support custom # convergence criteria and Tensor-valued `num_steps`, and avoid # re-tracing the train loop body. trace = tf.scan(train_loop_body, elems=np.arange(num_steps), initializer=state_initializer) if initial_trace_step is not None: trace = tf.nest.map_structure( lambda a, b: tf.concat([a[tf.newaxis, ...], b], axis=0), initial_trace_step, trace) return trace
def pad_batch_dimension_for_multiple_chains(observed_time_series, model, chain_batch_shape): """"Expand the observed time series with extra batch dimension(s).""" # Running with multiple chains introduces an extra batch dimension. In # general we also need to pad the observed time series with a matching batch # dimension. # # For example, suppose our model has batch shape [3, 4] and # the observed time series has shape `concat([[5], [3, 4], [100])`, # corresponding to `sample_shape`, `batch_shape`, and `num_timesteps` # respectively. The model will produce distributions with batch shape # `concat([chain_batch_shape, [3, 4]])`, so we pad `observed_time_series` to # have matching shape `[5, 1, 3, 4, 100]`, where the added `1` dimension # between the sample and batch shapes will broadcast to `chain_batch_shape`. [ # Extract mask and guarantee `event_ndims=2`. observed_time_series, is_missing ] = canonicalize_observed_time_series_with_mask(observed_time_series) event_ndims = 2 # event_shape = [num_timesteps, observation_size=1] model_batch_ndims = (tensorshape_util.rank(model.batch_shape) if tensorshape_util.rank(model.batch_shape) is not None else tf.shape(model.batch_shape_tensor())[0]) # Compute ndims from chain_batch_shape. chain_batch_shape = tf.convert_to_tensor(value=chain_batch_shape, name='chain_batch_shape', dtype=tf.int32) if not tensorshape_util.is_fully_defined(chain_batch_shape.shape): raise ValueError( 'Batch shape must have static rank. (given: {})'.format( chain_batch_shape)) if tensorshape_util.rank(chain_batch_shape.shape) == 0: # expand int `k` to `[k]`. chain_batch_shape = chain_batch_shape[tf.newaxis] chain_batch_ndims = tf.compat.dimension_value(chain_batch_shape.shape[0]) def do_padding(observed_time_series_tensor): current_sample_shape = ps.shape( observed_time_series_tensor)[:-(model_batch_ndims + event_ndims)] current_batch_and_event_shape = ps.shape( observed_time_series_tensor)[-(model_batch_ndims + event_ndims):] return tf.reshape(tensor=observed_time_series_tensor, shape=ps.concat([ current_sample_shape, ps.ones([chain_batch_ndims], dtype=tf.int32), current_batch_and_event_shape ], axis=0)) # Padding is only needed if the observed time series has sample shape. observed_time_series = ps.cond( ps.rank(observed_time_series) > model_batch_ndims + event_ndims, lambda: do_padding(observed_time_series), lambda: observed_time_series) if is_missing is not None: is_missing = ps.cond( ps.rank(is_missing) > model_batch_ndims + event_ndims, lambda: do_padding(is_missing), lambda: is_missing) return missing_values_util.MaskedTimeSeries(observed_time_series, is_missing=is_missing)
def _sample_n(self, n, seed=None): if self._use_static_graph: # This sampling approach is almost the same as the approach used by # `MixtureSameFamily`. The differences are due to having a list of # `Distribution` objects rather than a single object, and maintaining # random seed management that is consistent with the non-static code # path. samples = [] cat_samples = self.cat.sample(n, seed=seed) stream = SeedStream(seed, salt='Mixture') for c in range(self.num_components): samples.append(self.components[c].sample(n, seed=stream())) stack_axis = -1 - tensorshape_util.rank(self._static_event_shape) x = tf.stack(samples, axis=stack_axis) # [n, B, k, E] npdt = dtype_util.as_numpy_dtype(x.dtype) mask = tf.one_hot( indices=cat_samples, # [n, B] depth=self._num_components, # == k on_value=npdt(1), off_value=npdt(0)) # [n, B, k] mask = distribution_util.pad_mixture_dimensions( mask, self, self._cat, tensorshape_util.rank( self._static_event_shape)) # [n, B, k, [1]*e] return tf.reduce_sum(x * mask, axis=stack_axis) # [n, B, E] n = tf.convert_to_tensor(n, name='n') static_n = tf.get_static_value(n) n = int(static_n) if static_n is not None else n cat_samples = self.cat.sample(n, seed=seed) static_samples_shape = cat_samples.shape if tensorshape_util.is_fully_defined(static_samples_shape): samples_shape = tensorshape_util.as_list(static_samples_shape) samples_size = tensorshape_util.num_elements(static_samples_shape) else: samples_shape = tf.shape(cat_samples) samples_size = tf.size(cat_samples) static_batch_shape = self.batch_shape if tensorshape_util.is_fully_defined(static_batch_shape): batch_shape = tensorshape_util.as_list(static_batch_shape) batch_size = tensorshape_util.num_elements(static_batch_shape) else: batch_shape = tf.shape(cat_samples)[1:] batch_size = tf.reduce_prod(batch_shape) static_event_shape = self.event_shape if tensorshape_util.is_fully_defined(static_event_shape): event_shape = np.array( tensorshape_util.as_list(static_event_shape), dtype=np.int32) else: event_shape = None # Get indices into the raw cat sampling tensor. We will # need these to stitch sample values back out after sampling # within the component partitions. samples_raw_indices = tf.reshape(tf.range(0, samples_size), samples_shape) # Partition the raw indices so that we can use # dynamic_stitch later to reconstruct the samples from the # known partitions. partitioned_samples_indices = tf.dynamic_partition( data=samples_raw_indices, partitions=cat_samples, num_partitions=self.num_components) # Copy the batch indices n times, as we will need to know # these to pull out the appropriate rows within the # component partitions. batch_raw_indices = tf.reshape(tf.tile(tf.range(0, batch_size), [n]), samples_shape) # Explanation of the dynamic partitioning below: # batch indices are i.e., [0, 1, 0, 1, 0, 1] # Suppose partitions are: # [1 1 0 0 1 1] # After partitioning, batch indices are cut as: # [batch_indices[x] for x in 2, 3] # [batch_indices[x] for x in 0, 1, 4, 5] # i.e. # [1 1] and [0 0 0 0] # Now we sample n=2 from part 0 and n=4 from part 1. # For part 0 we want samples from batch entries 1, 1 (samples 0, 1), # and for part 1 we want samples from batch entries 0, 0, 0, 0 # (samples 0, 1, 2, 3). partitioned_batch_indices = tf.dynamic_partition( data=batch_raw_indices, partitions=cat_samples, num_partitions=self.num_components) samples_class = [None for _ in range(self.num_components)] stream = SeedStream(seed, salt='Mixture') for c in range(self.num_components): n_class = tf.size(partitioned_samples_indices[c]) samples_class_c = self.components[c].sample(n_class, seed=stream()) if event_shape is None: batch_ndims = prefer_static.rank_from_shape(batch_shape) event_shape = tf.shape(samples_class_c)[1 + batch_ndims:] # Pull out the correct batch entries from each index. # To do this, we may have to flatten the batch shape. # For sample s, batch element b of component c, we get the # partitioned batch indices from # partitioned_batch_indices[c]; and shift each element by # the sample index. The final lookup can be thought of as # a matrix gather along locations (s, b) in # samples_class_c where the n_class rows correspond to # samples within this component and the batch_size columns # correspond to batch elements within the component. # # Thus the lookup index is # lookup[c, i] = batch_size * s[i] + b[c, i] # for i = 0 ... n_class[c] - 1. lookup_partitioned_batch_indices = ( batch_size * tf.range(n_class) + partitioned_batch_indices[c]) samples_class_c = tf.reshape( samples_class_c, tf.concat([[n_class * batch_size], event_shape], 0)) samples_class_c = tf.gather(samples_class_c, lookup_partitioned_batch_indices, name='samples_class_c_gather') samples_class[c] = samples_class_c # Stitch back together the samples across the components. lhs_flat_ret = tf.dynamic_stitch(indices=partitioned_samples_indices, data=samples_class) # Reshape back to proper sample, batch, and event shape. ret = tf.reshape(lhs_flat_ret, tf.concat([samples_shape, event_shape], 0)) tensorshape_util.set_shape( ret, tensorshape_util.concatenate(static_samples_shape, self.event_shape)) return ret
def make_momentum_distribution(state_parts, batch_shape, running_variance_parts=None, shard_axis_names=None): """Construct a momentum distribution from the running variance. This uses a running variance to construct a momentum distribution with the correct batch_shape and event_shape. Args: state_parts: List of `Tensor`. batch_shape: Batch shape. running_variance_parts: Optional, list of `Tensor` outputs of `tfp.experimental.stats.RunningVariance.variance()`. Defaults to ones with the same shape as state_parts. shard_axis_names: A structure of string names indicating how members of the state are sharded. Returns: `tfd.Distribution` where `.sample` has the same structure as `state_parts`, and `.log_prob` of the sample will have the rank of `batch_ndims` """ if running_variance_parts is None: running_variance_parts = tf.nest.map_structure(tf.ones_like, state_parts) distributions = [] batch_ndims = ps.rank_from_shape(batch_shape) use_sharded_jd = True if shard_axis_names is None: use_sharded_jd = False shard_axis_names = [None] * len(state_parts) for variance_part, state_part, shard_axes in zip(running_variance_parts, state_parts, shard_axis_names): event_shape = state_part.shape[batch_ndims:] if not tensorshape_util.is_fully_defined(event_shape): event_shape = ps.shape(state_part, name='state_part_shp')[batch_ndims:] variance_tiled = tf.broadcast_to( variance_part, ps.concat([batch_shape, event_shape], axis=0)) nevt = ps.cast(ps.reduce_prod(event_shape), tf.int32) variance_flattened = tf.reshape( variance_tiled, ps.concat([batch_shape, [nevt]], axis=0)) distribution = _CompositeTransformedDistribution( bijector=reshape.Reshape(event_shape_out=event_shape, name='reshape_mvnpfl'), distribution=( _CompositeMultivariateNormalPrecisionFactorLinearOperator( precision_factor=tf.linalg.LinearOperatorDiag( tf.math.sqrt(variance_flattened)), precision=tf.linalg.LinearOperatorDiag(variance_flattened), name='momentum'))) if shard_axes: distribution = sharded.Sharded(distribution, shard_axis_name=shard_axes) distributions.append(distribution) if use_sharded_jd: jd = _CompositeShardedJointDistributionSequential(distributions) else: jd = _CompositeJointDistributionSequential(distributions) return maybe_make_list_and_batch_broadcast(jd, batch_shape)
def testBijector(self, bijector_name, data): tfp_hps.guitar_skip_if_matches('Tanh', bijector_name, 'b/144163991') bijector, event_dim = self._draw_bijector(bijector_name, data) # Forward mapping: Check differentiation through forward mapping with # respect to the input and parameter variables. Also check that any # variables are not referenced overmuch. xs = self._draw_domain_tensor(bijector, data, event_dim) wrt_vars = [xs] + [ v for v in bijector.trainable_variables if v.dtype.is_floating ] with tf.GradientTape() as tape: with tfp_hps.assert_no_excessive_var_usage( 'method `forward` of {}'.format(bijector)): tape.watch(wrt_vars) # TODO(b/73073515): Fix graph mode gradients with bijector caching. ys = bijector.forward(xs + 0) grads = tape.gradient(ys, wrt_vars) assert_no_none_grad(bijector, 'forward', wrt_vars, grads) # For scalar bijectors, verify correctness of the _is_increasing method. # TODO(b/148459057): Except, don't verify Softfloor on Guitar because # of numerical problem. def exception(bijector): if not tfp_hps.running_under_guitar(): return False if isinstance(bijector, tfb.Softfloor): return True if is_invert(bijector): return exception(bijector.bijector) return False if (bijector.forward_min_event_ndims == 0 and bijector.inverse_min_event_ndims == 0 and not exception(bijector)): dydx = grads[0] hp.note('dydx: {}'.format(dydx)) isfinite = tf.math.is_finite(dydx) incr_or_slope_eq0 = bijector._internal_is_increasing() | tf.equal( dydx, 0) # pylint: disable=protected-access self.assertAllEqual( isfinite & incr_or_slope_eq0, isfinite & (dydx >= 0) | tf.zeros_like(incr_or_slope_eq0)) # FLDJ: Check differentiation through forward log det jacobian with # respect to the input and parameter variables. Also check that any # variables are not referenced overmuch. event_ndims = data.draw( hps.integers(min_value=bijector.forward_min_event_ndims, max_value=xs.shape.ndims)) with tf.GradientTape() as tape: max_permitted = _ldj_tensor_conversions_allowed(bijector, is_forward=True) with tfp_hps.assert_no_excessive_var_usage( 'method `forward_log_det_jacobian` of {}'.format(bijector), max_permissible=max_permitted): tape.watch(wrt_vars) # TODO(b/73073515): Fix graph mode gradients with bijector caching. ldj = bijector.forward_log_det_jacobian( xs + 0, event_ndims=event_ndims) grads = tape.gradient(ldj, wrt_vars) assert_no_none_grad(bijector, 'forward_log_det_jacobian', wrt_vars, grads) # Inverse mapping: Check differentiation through inverse mapping with # respect to the codomain "input" and parameter variables. Also check that # any variables are not referenced overmuch. ys = self._draw_codomain_tensor(bijector, data, event_dim) wrt_vars = [ys] + [ v for v in bijector.trainable_variables if v.dtype.is_floating ] with tf.GradientTape() as tape: with tfp_hps.assert_no_excessive_var_usage( 'method `inverse` of {}'.format(bijector)): tape.watch(wrt_vars) # TODO(b/73073515): Fix graph mode gradients with bijector caching. xs = bijector.inverse(ys + 0) grads = tape.gradient(xs, wrt_vars) assert_no_none_grad(bijector, 'inverse', wrt_vars, grads) # ILDJ: Check differentiation through inverse log det jacobian with respect # to the codomain "input" and parameter variables. Also check that any # variables are not referenced overmuch. event_ndims = data.draw( hps.integers(min_value=bijector.inverse_min_event_ndims, max_value=ys.shape.ndims)) with tf.GradientTape() as tape: max_permitted = _ldj_tensor_conversions_allowed(bijector, is_forward=False) with tfp_hps.assert_no_excessive_var_usage( 'method `inverse_log_det_jacobian` of {}'.format(bijector), max_permissible=max_permitted): tape.watch(wrt_vars) # TODO(b/73073515): Fix graph mode gradients with bijector caching. ldj = bijector.inverse_log_det_jacobian( ys + 0, event_ndims=event_ndims) grads = tape.gradient(ldj, wrt_vars) assert_no_none_grad(bijector, 'inverse_log_det_jacobian', wrt_vars, grads) # Verify that `_is_permutation` implies constant zero Jacobian. if bijector._is_permutation: self.assertTrue(bijector._is_constant_jacobian) self.assertAllEqual(ldj, 0.) # Verify correctness of batch shape. xs_batch_shapes = tf.nest.map_structure( lambda x, nd: ps.shape(x)[:ps.rank(x) - nd], xs, bijector.inverse_event_ndims(event_ndims)) empirical_batch_shape = functools.reduce( ps.broadcast_shape, nest.flatten_up_to(bijector.forward_min_event_ndims, xs_batch_shapes)) batch_shape = bijector.experimental_batch_shape( y_event_ndims=event_ndims) if tensorshape_util.is_fully_defined(batch_shape): self.assertAllEqual(empirical_batch_shape, batch_shape) self.assertAllEqual( empirical_batch_shape, bijector.experimental_batch_shape_tensor( y_event_ndims=event_ndims)) # Check that the outputs of forward_dtype and inverse_dtype match the dtypes # of the outputs of forward and inverse. self.assertAllEqualNested(ys.dtype, bijector.forward_dtype(xs.dtype)) self.assertAllEqualNested(xs.dtype, bijector.inverse_dtype(ys.dtype))
def _setup_mcmc(model, n_chains, *, init_position=None, seed=None, **pins): """Construct bijector and transforms needed for windowed MCMC. This pins the initial model, constructs a bijector that unconstrains and flattens each dimension and adds a leading batch shape of `n_chains`, initializes a point in the unconstrained space, and constructs a transformed log probability using the bijector. Note that we must manually construct this target log probability instead of using a transformed transition kernel because the TTK assumes the shape in is the same as the shape out. Args: model: `tfd.JointDistribution` The model to sample from. n_chains: list of ints Number of chains (independent examples) to run. init_position: Optional Structure of tensors at which to initialize sampling. Should have the same shape and structure as `model.experimental_pin(**pins).sample_unpinned(n_chains)`. seed: PRNG seed; see `tfp.random.sanitize_seed` for details. **pins: Values passed to `model.experimental_pin`. Returns: target_log_prob_fn: Callable on the transformed space. initial_transformed_position: `tf.Tensor`, sampled from a uniform (-2, 2). bijector: `tfb.Bijector` instance, which unconstrains and flattens. step_broadcast_fn: Callable to broadcast step size over latent structure. batch_shape: Batch shape of the model. shard_axis_names: Shard axis names for the model """ pinned_model = model.experimental_pin(**pins) if pins else model bijector, step_bijector = _get_flat_unconstraining_bijector(pinned_model) if init_position is None: raw_init_dist = initialization.init_near_unconstrained_zero( pinned_model) init_position = initialization.retry_init( raw_init_dist.sample, target_fn=pinned_model.unnormalized_log_prob, sample_shape=n_chains, seed=seed) initial_transformed_position = tf.nest.map_structure( tf.identity, bijector.forward(init_position)) batch_shape = pinned_model.batch_shape if tf.nest.is_nested(batch_shape): batch_shape = functools.reduce(tf.broadcast_static_shape, tf.nest.flatten(batch_shape)) if not tensorshape_util.is_fully_defined(batch_shape): batch_shape = pinned_model.batch_shape_tensor() if tf.nest.is_nested(batch_shape): batch_shape = functools.reduce(tf.broadcast_dynamic_shape, tf.nest.flatten(batch_shape)) # This tf.function is not redundant with the ones on _fast_window # and _slow_window because the various kernels (like HMC) may invoke # `target_log_prob_fn` multiple times within one window. @tf.function(autograph=False) def target_log_prob_fn(*args): lp = pinned_model.unnormalized_log_prob(bijector.inverse(args)) ldj = bijector.inverse_log_det_jacobian( args, event_ndims=[1 for _ in initial_transformed_position]) return lp + ldj def step_broadcast(step_size): # Only apply the bijector to nested step sizes or non-scalar batches. if tf.nest.is_nested(step_size): return step_bijector( nest_util.broadcast_structure( pinned_model.event_shape_tensor(), step_size)) else: return step_size shard_axis_names = pinned_model.experimental_shard_axis_names if any(tf.nest.flatten(shard_axis_names)): shard_axis_names = nest.flatten_up_to( initial_transformed_position, list(pinned_model._model_flatten(shard_axis_names))) # pylint: disable=protected-access else: # No active shard axis names shard_axis_names = None return (target_log_prob_fn, initial_transformed_position, bijector, step_broadcast, ps.convert_to_shape_tensor(batch_shape, name='batch_shape'), shard_axis_names)