def __init__(self, model): """Constructs the adapter. Args: model: An Inference Gym model. Raises: TypeError: If `model` has more than one unique Tensor dtype. """ self._model = model dtypes = set( tf.nest.flatten(tf.nest.map_structure(tf.as_dtype, self._model.dtype))) if len(dtypes) > 1: raise TypeError('Model must have only one Tensor dtype, saw: {}'.format( self._model.dtype)) dtype = dtypes.pop() # TODO(siege): Make this work with multi-part default_event_bijector. def _make_reshaped_bijector(b, s): return tfb.Chain([ tfb.Reshape(event_shape_in=s, event_shape_out=[ps.reduce_prod(s)]), b, tfb.Reshape(event_shape_out=b.inverse_event_shape(s)), ]) reshaped_bijector = tf.nest.map_structure( _make_reshaped_bijector, self._model.default_event_space_bijector, self._model.event_shape) bijector = tfb.Blockwise( bijectors=tf.nest.flatten(reshaped_bijector), block_sizes=tf.nest.flatten( tf.nest.map_structure( lambda b, s: ps.reduce_prod(b.inverse_event_shape(s)), # pylint: disable=g-long-lambda self._model.default_event_space_bijector, self._model.event_shape))) event_sizes = tf.nest.map_structure( lambda b, s: ps.reduce_prod(b.inverse_event_shape(s)), self._model.default_event_space_bijector, self._model.event_shape) event_shape = tf.TensorShape([sum(tf.nest.flatten(event_sizes))]) sample_transformations = collections.OrderedDict() def make_flattened_transform(transform): # We yank this out to avoid capturing the loop variable. return transform._replace( fn=lambda x: transform(self._split_and_reshape_event(x))) for key, transform in self._model.sample_transformations.items(): sample_transformations[key] = make_flattened_transform(transform) super(VectorModel, self).__init__( default_event_space_bijector=bijector, event_shape=event_shape, dtype=dtype, name='vector_' + self._model.name, pretty_name=str(self._model), sample_transformations=sample_transformations, )
def make_conditional_linear_gaussian(y_event_shape, x, x_event_ndims, variables=None): """Build trainable distribution `p(y | x)` conditioned on an input Tensor `x`. The distribution is independent Gaussian with mean linearly transformed from `x`: `y ~ N(loc=matvec(matrix, x) + loc, scale_diag=scale)` Args: y_event_shape: int `Tensor` event shape. x: `Tensor` input to condition on. x_event_ndims: int number of dimensions in `x`'s `event_shape`. variables: Optional `LinearGaussianVariables` instance, or `None`. Default value: `None`. Returns: dist: Instance of `tfd.Distribution` representing the conditional distribution `p(y | x)`. variables: Instance of `LinearGaussianVariables` used to parameterize `dist`. If a `variables` arg was passed, it is returned unmodified; otherwise new variables are created. """ x_shape = ps.shape(x) x_ndims = ps.rank_from_shape(x_shape) y_event_ndims = ps.rank_from_shape(y_event_shape) batch_shape, x_event_shape = (x_shape[:x_ndims - x_event_ndims], x_shape[x_ndims - x_event_ndims:]) x_event_size = ps.reduce_prod(x_event_shape) y_event_size = ps.reduce_prod(y_event_shape) x_flat_shape = ps.concat([batch_shape, [x_event_size]], axis=0) y_flat_shape = ps.concat([batch_shape, [y_event_size]], axis=0) y_full_shape = ps.concat([batch_shape, y_event_shape], axis=0) if variables is None: variables = LinearGaussianVariables( matrix=tf.Variable(tf.random.normal(ps.concat( [batch_shape, [y_event_size, x_event_size]], axis=0), dtype=x.dtype), name='matrix'), loc=tf.Variable(tf.random.normal(y_flat_shape, dtype=x.dtype), name='loc'), scale=tfp_util.TransformedVariable(tf.ones(y_full_shape, dtype=x.dtype), bijector=tfb.Softplus(), name='scale')) flat_x = tf.reshape(x, x_flat_shape) dist = tfd.Normal(loc=tf.reshape( tf.linalg.matvec(variables.matrix, flat_x) + variables.loc, y_full_shape), scale=variables.scale) if y_event_ndims != 0: dist = tfd.Independent(dist, reinterpreted_batch_ndims=y_event_ndims) dist._also_track = variables # pylint: disable=protected-access return dist, variables
def _make_reshaped_bijector(b, s): return tfb.Chain([ tfb.Reshape(event_shape_in=s, event_shape_out=[ps.reduce_prod(s)]), b, tfb.Reshape( event_shape_in=[ps.reduce_prod(b.inverse_event_shape(s))], event_shape_out=b.inverse_event_shape(s)), ])
def pairwise_square_distance_matrix(x1, x2, feature_ndims): """Returns pairwise square distance between x1 and x2. Given `x1` and `x2`, Tensors with shape `[..., N, D1, ... Dk]` and `[..., M, D1, ... Dk]`, compute the pairwise distance matrix `a_ij` of shape `[..., N, M]`, where each entry `a_ij` is the square of the euclidean norm of `x1[..., i, ...] - x2[..., j, ...]`. The approach uses the fact that (where k = 1). ```none a_ij = sum_d (x1[i, d] - x2[j, d]) ** 2 = sum_d x1[i, d] ** 2 + x2[j, d] ** 2 - 2 * x1[i, d] * x2[j, d] ``` The latter term can be written as a matmul between `x1` and `x2`. This reduces the memory from the naive approach of computing the squared difference of `x1` and `x2` by a factor of `(prod_k D_k) ** 2`. This is at the cost of the computation being more numerically unstable. Args: x1: Floating point `Tensor` with shape `B1 + [N] + [D1, ..., Dk]`, where `B1` is a (possibly empty) batch shape. x2: Floating point `Tensor` with shape `B2 + [M] + [D1, ..., Dk]`, where `B2` is a (possibly empty) batch shape that broadcasts with `B1`. feature_ndims: The number of dimensions to consider for the euclidean norm. This is `k` from above. Returns: `Tensor` of shape `[..., N, M]` representing the pairwise square distance matrix. """ row_norm_x1 = sum_rightmost_ndims_preserving_shape( tf.square(x1), feature_ndims)[..., tf.newaxis] row_norm_x2 = sum_rightmost_ndims_preserving_shape( tf.square(x2), feature_ndims)[..., tf.newaxis, :] x1 = tf.reshape( x1, ps.concat([ ps.shape(x1)[:-feature_ndims], [ps.reduce_prod(ps.shape(x1)[-feature_ndims:])] ], axis=0)) x2 = tf.reshape( x2, ps.concat([ ps.shape(x2)[:-feature_ndims], [ps.reduce_prod(ps.shape(x2)[-feature_ndims:])] ], axis=0)) pairwise_sq = row_norm_x1 + row_norm_x2 - 2 * tf.linalg.matmul( x1, x2, transpose_b=True) pairwise_sq = tf.clip_by_value(pairwise_sq, 0., np.inf) return pairwise_sq
def _augment_sample_shape(self, sample_shape): # Suppose we have: # - sample shape of `[n]`, # - underlying distribution batch shape of `[2, 1]`, # - final broadcast batch shape of `[4, 2, 3]`. # Then we must draw `sample_shape + [12]` samples, where # `12 == n_batch // underlying_n_batch`. batch_shape = self.batch_shape_tensor() n_batch = ps.reduce_prod(batch_shape) underlying_batch_shape = self.distribution.batch_shape_tensor() underlying_n_batch = ps.reduce_prod(underlying_batch_shape) return ps.concat( [sample_shape, [ps.maximum(0, n_batch // underlying_n_batch)]], axis=0)
def _split_and_reshape_event(self, x): event_tensors = self._distribution.event_shape_tensor() splits = [ ps.maximum(1, ps.reduce_prod(s)) for s in tf.nest.flatten(event_tensors) ] x = tf.nest.pack_sequence_as(event_tensors, tf.split(x, splits, axis=-1)) def _reshape_part(part, dtype, event_shape): part = tf.cast(part, dtype) static_rank = tf.get_static_value(ps.rank_from_shape(event_shape)) if static_rank == 1: return part new_shape = ps.concat([ps.shape(part)[:-1], event_shape], axis=-1) return tf.reshape(part, ps.cast(new_shape, tf.int32)) if all( tensorshape_util.is_fully_defined(s) for s in tf.nest.flatten(self._distribution.event_shape)): x = tf.nest.map_structure(_reshape_part, x, self._distribution.dtype, self._distribution.event_shape) else: x = tf.nest.map_structure(_reshape_part, x, self._distribution.dtype, self._distribution.event_shape_tensor()) return x
def _parameter_control_dependencies(self, is_init): if not self.validate_args: # Avoid computing intermediates needed to construct the assertions. return [] assertions = [] if is_init != tensor_util.is_ref(self._batch_shape_unexpanded): implicit_dim_mask = ps.equal(self._batch_shape_unexpanded, -1) assertions.append( assert_util.assert_rank(self._batch_shape_unexpanded, 1, message='New shape must be a vector.')) assertions.append( assert_util.assert_less_equal( tf.math.count_nonzero(implicit_dim_mask, dtype=tf.int32), 1, message='At most one dimension can be unknown.')) assertions.append( assert_util.assert_non_negative( self._batch_shape_unexpanded + 1, message='Shape elements must be >=-1.')) # Check that the old and new shapes are the same size. expanded_new_shape, original_size = self._calculate_new_shape() new_size = ps.reduce_prod(expanded_new_shape) assertions.append( assert_util.assert_equal(new_size, tf.cast(original_size, new_size.dtype), message='Shape sizes do not match.')) return assertions
def _axis_size(x, axis=None): """Get number of elements of `x` in `axis`, as type `x.dtype`.""" if axis is None: return prefer_static.cast(prefer_static.size(x), x.dtype) return prefer_static.cast( prefer_static.reduce_prod( prefer_static.gather(prefer_static.shape(x), axis)), x.dtype)
def _sample_direction_part(state_part, part_seed): state_part_shape = ps.shape(state_part) batch_shape = state_part_shape[:batch_rank] dimension = ps.reduce_prod(state_part_shape[batch_rank:]) return ps.reshape( random_ops.spherical_uniform(shape=batch_shape, dimension=dimension, dtype=state_part.dtype, seed=part_seed), state_part_shape)
def iid_sample(sample_fn, sample_shape): """Lift a sampling function to one that draws multiple iid samples. Args: sample_fn: Python `callable` that returns a (possibly nested) structure of `Tensor`s. May optionally take a `seed` named arg: if so, any `int` seeds (for stateful samplers) are passed through directly, while any pair-of-`int` seeds (for stateless samplers) are split into independent seeds for each sample. sample_shape: `int` `Tensor` shape of iid samples to draw. Returns: iid_sample_fn: Python `callable` taking the same arguments as `sample_fn` and returning iid samples. Each returned `Tensor` will have shape `concat([sample_shape, shape_of_original_returned_tensor])`. """ sample_shape = distribution_util.expand_to_vector( ps.cast(sample_shape, np.int32), tensor_name='sample_shape') n = ps.cast(ps.reduce_prod(sample_shape), dtype=np.int32) def unflatten(x): unflattened_shape = ps.cast( ps.concat([sample_shape, ps.shape(x)[1:]], axis=0), dtype=np.int32) return tf.reshape(x, unflattened_shape) def iid_sample_fn(*args, **kwargs): """Draws iid samples from `fn`.""" with tf.name_scope('iid_sample_fn'): seed = kwargs.pop('seed', None) if samplers.is_stateful_seed(seed): kwargs = dict(kwargs, seed=SeedStream(seed, salt='iid_sample')()) def pfor_loop_body(_): with tf.name_scope('iid_sample_fn_stateful_body'): return sample_fn(*args, **kwargs) else: # If a stateless seed arg is passed, split it into `n` different # stateless seeds, so that we don't just get a bunch of copies of the # same sample. if not JAX_MODE: warnings.warn( 'Saw Tensor seed {}, implying stateless sampling. Autovectorized ' 'functions that use stateless sampling may be quite slow because ' 'the current implementation falls back to an explicit loop. This ' 'will be fixed in the future. For now, you will likely see ' 'better performance from stateful sampling, which you can invoke ' 'by passing a Python `int` seed.'.format(seed)) seed = samplers.split_seed(seed, n=n, salt='iid_sample_stateless') def pfor_loop_body(i): with tf.name_scope('iid_sample_fn_stateless_body'): return sample_fn(*args, seed=tf.gather(seed, i), **kwargs) draws = parallel_for.pfor(pfor_loop_body, n) return tf.nest.map_structure(unflatten, draws, expand_composites=True) return iid_sample_fn
def iid_sample(sample_fn, sample_shape): """Lift a sampling function to one that draws multiple iid samples. Args: sample_fn: Python `callable` that returns a (possibly nested) structure of `Tensor`s. May optionally take a `seed` named arg: if so, any `int` seeds (for stateful samplers) are passed through directly, while any pair-of-`int` seeds (for stateless samplers) are split into independent seeds for each sample. sample_shape: `int` `Tensor` shape of iid samples to draw. Returns: iid_sample_fn: Python `callable` taking the same arguments as `sample_fn` and returning iid samples. Each returned `Tensor` will have shape `concat([sample_shape, shape_of_original_returned_tensor])`. """ sample_shape = distribution_util.expand_to_vector( prefer_static.cast(sample_shape, np.int32), tensor_name='sample_shape') n = prefer_static.cast(prefer_static.reduce_prod(sample_shape), dtype=np.int32) def unflatten(x): unflattened_shape = prefer_static.cast(prefer_static.concat( [sample_shape, prefer_static.shape(x)[1:]], axis=0), dtype=np.int32) return tf.reshape(x, unflattened_shape) def iid_sample_fn(*args, **kwargs): """Draws iid samples from `fn`.""" pfor_loop_body = lambda _: sample_fn(*args, **kwargs) seed = kwargs.pop('seed', None) try: # Assume that `seed` is a valid stateful seed (Python `int`). kwargs = dict(kwargs, seed=SeedStream(seed, salt='iid_sample')()) pfor_loop_body = lambda _: sample_fn(*args, **kwargs) except TypeError as e: # If a stateless seed arg is passed, split it into `n` different stateless # seeds, so that we don't just get a bunch of copies of the same sample. if TENSOR_SEED_MSG_PREFIX not in str(e): raise warnings.warn( 'Saw non-`int` seed {}, implying stateless sampling. ' 'Autovectorized functions that use stateless sampling ' 'may be quite slow because the current implementation ' 'falls back to an explicit loop. This will be fixed in the ' 'future. For now, you will likely see better performance ' 'from stateful sampling, which you can invoke by passing a' 'traditional Python `int` seed.'.format(seed)) seed = samplers.split_seed(seed, n=n, salt='iid_sample_stateless') pfor_loop_body = ( lambda i: sample_fn(*args, seed=tf.gather(seed, i), **kwargs)) draws = parallel_for.pfor(pfor_loop_body, n) return tf.nest.map_structure(unflatten, draws, expand_composites=True) return iid_sample_fn
def _sample_n(self, n, seed=None): batch_shape = self.batch_shape_tensor() batch_rank = ps.rank_from_shape(batch_shape) n_batch = ps.reduce_prod(batch_shape) underlying_batch_shape = self.distribution.batch_shape_tensor() underlying_batch_rank = ps.rank_from_shape(underlying_batch_shape) underlying_n_batch = ps.reduce_prod(underlying_batch_shape) # Left pad underlying shape with any necessary ones. underlying_bcast_shp = ps.concat([ ps.ones([ps.maximum(batch_rank - underlying_batch_rank, 0)], dtype=underlying_batch_shape.dtype), underlying_batch_shape ], axis=0) # Determine how many underlying samples to produce. n_bcast_samples = ps.maximum(0, n_batch // underlying_n_batch) samps = self.distribution.sample([n, n_bcast_samples], seed=seed) is_dim_bcast = ps.not_equal(batch_shape, underlying_bcast_shp) event_shape = self.event_shape_tensor() event_rank = ps.rank_from_shape(event_shape) shp = ps.concat([[n], ps.where(is_dim_bcast, batch_shape, 1), underlying_bcast_shp, event_shape], axis=0) # Reshape to expand n_bcast_samples and ones-padded underlying_bcast_shp. samps = tf.reshape(samps, shp) # Interleave broadcast and underlying axis indices for transpose. interleaved_batch_axes = ps.reshape( ps.stack([ps.range(batch_rank), ps.range(batch_rank) + batch_rank], axis=-1), [-1]) + 1 event_axes = ps.range(event_rank) + (1 + 2 * batch_rank) perm = ps.concat([[0], interleaved_batch_axes, event_axes], axis=0) samps = tf.transpose(samps, perm=perm) # Finally, reshape to the fully-broadcast batch shape. return tf.reshape(samps, ps.concat([[n], batch_shape, event_shape], axis=0))
def vectorize_over_batch_dims( fn, elems, event_shape, batch_shape, vectorized_map=True, fn_output_signature=None ): flat_batch_shape = tf.expand_dims(ps.reduce_prod(batch_shape), 0) flat_structure = reshape_structure(elems, event_shape, flat_batch_shape) if vectorized_map: result = tf.vectorized_map(fn, flat_structure, fallback_to_while_loop=False) else: assert fn_output_signature is not None result = tf.map_fn(fn, flat_structure, fn_output_signature=fn_output_signature) new_event_shape = tf.nest.map_structure(lambda elem: tf.shape(elem)[1:], result) return reshape_structure(result, new_event_shape, batch_shape)
def update_running_variance(): diags = [ variance_part.variance() for variance_part in variance_parts ] new_state_parts = tf.nest.flatten(new_state) new_variance_parts = [] for variance_part, diag, state_part in zip( variance_parts, diags, new_state_parts): # Compute new variance for each variance part, accounting for partial # batching of the variance calculation across chains (ie, some, all, # or none of the chains may share the estimated mass matrix). # # For example, say # # state_part has shape [2, 3, 4] + [5, 6] (batch + event) # variance_part has shape [4] + [5, 6] # log_prob has shape [2, 3, 4] # # i.e., we have a batch of chains of shape [2, 3, 4], and 4 mass # matrices, each being shared across a [2, 3]-batch of chains. Note # this division is inferred from the shapes of the state part, the # log_prob, and the user-provided initial running variances. # # Until RunningVariance supports rank > 1 chunking, we need to flatten # the states that go into updating the variance estimates. In the # above example, `state_part` will be reshaped to `[6, 4, 5, 6]`, and # fed to `RunningVariance.update(state_part, axis=0)`, recording # 6 new observations in the running variance calculation. # `RunningVariance.variance()` will then be of shape `[4, 5, 6]`, and # the resulting momentum distribution will have batch shape of # `[2, 3, 4]` and event_shape of `[5, 6]`, matching the state_part. state_rank = ps.rank(state_part) variance_rank = ps.rank(diag) num_reduce_dims = state_rank - variance_rank state_part_shape = ps.shape(state_part) # This reshape adds a 1 when reduce_dims==0, and collapses all the # lead dimensions to a single one otherwise. reshaped_state = ps.reshape( state_part, ps.concat([[ ps.reduce_prod(state_part_shape[:num_reduce_dims]) ], state_part_shape[num_reduce_dims:]], axis=0)) # The `axis=0` here removes the leading dimension we got from the # reshape above, so the new_variance_parts have the correct shape # again. new_variance_parts.append( variance_part.update(reshaped_state, axis=0)) return new_variance_parts
def _calculate_new_shape(self): # Try to get the old shape statically if available. original_shape = self._distribution.batch_shape if not tensorshape_util.is_fully_defined(original_shape): original_shape = self._distribution.batch_shape_tensor() # This is not a check for falseness, it's a check for exactly that shape. if original_shape == (): # pylint: disable=g-explicit-bool-comparison # Force the size to be an integer, not a float, when the shape contains no # dtype information. original_size = 1 else: original_size = ps.reduce_prod(original_shape) original_size = ps.cast(original_size, tf.int32) # Compute the new shape, filling in the `-1` dimension if present. new_shape = self._batch_shape_unexpanded implicit_dim_mask = ps.equal(new_shape, -1) size_implicit_dim = (original_size // ps.maximum(1, -ps.reduce_prod(new_shape))) expanded_new_shape = ps.where( # Assumes exactly one `-1`. implicit_dim_mask, size_implicit_dim, new_shape) # Return the original size on the side because one caller would otherwise # have to recompute it. return expanded_new_shape, original_size
def resample(log_weights, current_state, particle_info, seed=None): """Resample particles based on importance weights.""" with tf.name_scope('resample_particles'): seed = SeedStream(seed, salt='resample_particles') resampling_indexes = tf.random.categorical( [log_weights], ps.reduce_prod(*ps.shape(log_weights)), seed=seed()) next_state = tf.nest.map_structure( lambda x: tf.reshape(tf.gather(x, resampling_indexes), ps.shape(x)), current_state) next_particle_info = tf.nest.map_structure( lambda x: tf.reshape(tf.gather(x, resampling_indexes), ps.shape(x)), particle_info) return next_state, next_particle_info
def _split_and_reshape_event(x, model): """Splits and reshapes a flat event `x` to match the structure of `model`.""" splits = [ ps.maximum(1, ps.reduce_prod(s)) for s in tf.nest.flatten(model.event_shape) ] x = tf.nest.pack_sequence_as(model.event_shape, tf.split(x, splits, axis=-1)) def _reshape_part(part, dtype, event_shape): part = tf.cast(part, dtype) new_shape = ps.concat([ps.shape(part)[:-1], event_shape], axis=-1) return tf.reshape(part, ps.cast(new_shape, tf.int32)) x = tf.nest.map_structure(_reshape_part, x, model.dtype, model.event_shape) return x
def _split_and_reshape_event(self, x): splits = [ ps.maximum(1, ps.reduce_prod(s)) for s in tf.nest.flatten(self._model.event_shape) ] x = tf.nest.pack_sequence_as(self._model.event_shape, tf.split(x, splits, axis=-1)) def _reshape_part(part, dtype, event_shape): part = tf.cast(part, dtype) new_shape = ps.concat([ps.shape(part)[:-1], event_shape], axis=-1) return tf.reshape(part, ps.cast(new_shape, tf.int32)) x = tf.nest.map_structure(_reshape_part, x, self._model.dtype, self._model.event_shape) return x
def make_momentum_distribution(state_parts, batch_shape, running_variance_parts=None): """Construct a momentum distribution from the running variance. This uses a running variance to construct a momentum distribution with the correct batch_shape and event_shape. Args: state_parts: List of `Tensor`. batch_shape: Batch shape. running_variance_parts: Optional, list of `Tensor` outputs of `tfp.experimental.stats.RunningVariance.variance()`. Defaults to ones with the same shape as state_parts. Returns: `tfd.Distribution` where `.sample` has the same structure as `state_parts`, and `.log_prob` of the sample will have the rank of `batch_ndims` """ if running_variance_parts is None: running_variance_parts = tf.nest.map_structure(tf.ones_like, state_parts) distributions = [] batch_ndims = ps.rank_from_shape(batch_shape) for variance_part, state_part in zip(running_variance_parts, state_parts): event_shape = state_part.shape[batch_ndims:] if not tensorshape_util.is_fully_defined(event_shape): event_shape = ps.shape(state_part, name='state_part_shp')[batch_ndims:] variance_tiled = tf.broadcast_to( variance_part, ps.concat([batch_shape, event_shape], axis=0)) nevt = ps.cast(ps.reduce_prod(event_shape), tf.int32) variance_flattened = tf.reshape( variance_tiled, ps.concat([batch_shape, [nevt]], axis=0)) distribution = _CompositeTransformedDistribution( bijector=_CompositeReshape(event_shape_out=event_shape, name='reshape_mvnpfl'), distribution=( _CompositeMultivariateNormalPrecisionFactorLinearOperator( precision_factor=_CompositeLinearOperatorDiag( tf.math.sqrt(variance_flattened)), precision=_CompositeLinearOperatorDiag(variance_flattened), name='momentum'))) distributions.append(distribution) return maybe_make_list_and_batch_broadcast( _CompositeJointDistributionSequential(distributions), batch_shape)
def _make_momentum_distribution(running_variance_parts, state_parts, batch_ndims): """Construct a momentum distribution from the running variance. This uses a running variance to construct a momentum distribution with the correct batch_shape and event_shape. Args: running_variance_parts: List of `Tensor`, outputs of `tfp.experimental.stats.RunningVariance.variance()`. state_parts: List of `Tensor`. batch_ndims: Scalar, for leading batch dimensions. Returns: `tfd.Distribution` where `.sample` has the same structure as `state_parts`, and `.log_prob` of the sample will have the rank of `batch_ndims` """ distributions = [] for variance_part, state_part in zip(running_variance_parts, state_parts): running_variance_rank = ps.rank(variance_part) state_rank = ps.rank(state_part) event_shape = ps.shape(state_part)[batch_ndims:] nevt = ps.reduce_prod(event_shape) # Pad dimensions and tile by multiplying by tf.ones to add a batch shape ones = tf.ones( ps.shape(state_part)[:-(state_rank - running_variance_rank)], dtype=variance_part.dtype) ones = bu.left_justified_expand_dims_like(ones, state_part) variance_tiled = ones * variance_part variance_flattened = tf.reshape( variance_tiled, ps.concat([ps.shape(variance_tiled)[:batch_ndims], [nevt]], axis=0)) distributions.append( _CompositeTransformedDistribution( bijector=_CompositeReshape(event_shape_out=event_shape, event_shape_in=[nevt]), distribution=( _CompositeMultivariateNormalPrecisionFactorLinearOperator( precision_factor=_CompositeLinearOperatorDiag( tf.math.sqrt(variance_flattened)), precision=_CompositeLinearOperatorDiag( variance_flattened))))) return _CompositeJointDistributionSequential(distributions)
def _compute_fans_from_shape(shape, batch_ndims=0): """Extracts `fan_in, fan_out` from specified shape `Tensor`.""" # Ensure shape is a vector of length >=2. num_pad = prefer_static.maximum(0, 2 - prefer_static.size(shape)) shape = prefer_static.pad( shape, paddings=[[0, num_pad]], constant_values=1) ( batch_shape, # pylint: disable=unused-variable extra_shape, fan_in, fan_out, ) = prefer_static.split(shape, [batch_ndims, -1, 1, 1]) # The following logic is primarily intended for convolutional layers which # have spatial semantics in addition to input/output channels. receptive_field_size = prefer_static.reduce_prod(extra_shape) fan_in = fan_in[0] * receptive_field_size fan_out = fan_out[0] * receptive_field_size return fan_in, fan_out
def _check_at_least_two_chains(accept_prob, reduce_chain_axis_names, validate_args, message): """Checks that the number of chains is at least 2.""" # Number of total chains is local batch size * distributed axis size local_axis_size = ps.size(accept_prob) distributed_axis_size = int( ps.reduce_prod([ distribute_lib.get_axis_size(a) for a in reduce_chain_axis_names ])) num_chains = local_axis_size * distributed_axis_size num_chains_ = tf.get_static_value(num_chains) if num_chains_ is not None: if num_chains_ < 2: raise ValueError('{} Got: {}'.format(message, num_chains_)) elif validate_args: with tf.control_dependencies( [assert_util.assert_greater_equal(num_chains, 2, message)]): accept_prob = tf.identity(accept_prob) return accept_prob
def preprocess_state(init_state): """Initial preprocessing at Stage 0.""" dimension = ps.reduce_sum([ ps.reduce_prod(ps.shape(x)[1:]) for x in init_state]) likelihood_log_prob = likelihood_log_prob_fn(*init_state) # Default to the optimal for normal distributed targets. # TODO(b/152412213): Revisit this default parameter. scale_start = ( tf.constant(2.38 ** 2, dtype=likelihood_log_prob.dtype) / tf.constant(dimension, dtype=likelihood_log_prob.dtype)) # TODO(b/152412213): Enable batch of batches style by using non-scalar # inverse_temperature inverse_temperature = tf.zeros([], dtype=likelihood_log_prob.dtype) scalings = ps.ones_like(likelihood_log_prob) * ps.minimum(scale_start, 1.) kernel = make_kernel_fn( _make_tempered_target_log_prob_fn( prior_log_prob_fn, likelihood_log_prob_fn, inverse_temperature), init_state, scalings, seed=seed_stream()) pkr = kernel.bootstrap_results(current_state) _, kernel_target_log_prob = gather_mh_like_result(pkr) particle_info = ParticleInfo( log_accept_prob=ps.zeros_like(likelihood_log_prob), log_scalings=tf.math.log(scalings), tempered_log_prob=kernel_target_log_prob, likelihood_log_prob=likelihood_log_prob, ) return SMCResults( num_steps=tf.convert_to_tensor( max_num_steps, dtype=tf.int32, name='num_steps'), inverse_temperature=inverse_temperature, log_marginal_likelihood=tf.constant( 0., dtype=likelihood_log_prob.dtype), particle_info=particle_info )
def sample(self, sample_shape=(), seed=None, name=None): with tf.name_scope(name or 'sample'): # Grab the required number of values from the provided tensors. sample_shape = dist_util.expand_to_vector(sample_shape) n = ps.cast(ps.reduce_prod(sample_shape), dtype=tf.int32) # Check that we're not trying to draw too many samples. assertions = [] will_overflow_ = tf.get_static_value(n > self.max_num_samples) if will_overflow_: raise ValueError( 'Trying to draw {} samples from a ' '`DeterministicEmpirical` instance for which only {} ' 'samples were provided.'.format( tf.get_static_value(n), tf.get_static_value(self.max_num_samples))) elif (will_overflow_ is None # Couldn't determine statically. and self.validate_args): assertions.append( tf.debugging.assert_less_equal( n, self.max_num_samples, message='Number of samples to draw ' 'from a `DeterministicEmpirical` instance must not exceed the ' 'number provided at construction.')) # Extract the appropriate number of sampled values. with tf.control_dependencies(assertions): sampled = tf.nest.map_structure(lambda x: x[:n, ...], self.values_with_sample_dim) # Reshape the values to the appropriate sample shape. return tf.nest.map_structure( lambda x: tf.reshape( x, # pylint: disable=g-long-lambda ps.concat([ ps.cast(sample_shape, tf.int32), ps.cast(ps.shape(x)[1:], tf.int32) ], axis=0)), sampled)
def _kl_sample(a, b, name='kl_sample'): """Batched KL divergence `KL(a || b)` for Sample distributions. We can leverage the fact that: ``` KL(Sample(a) || Sample(b)) = sum(KL(a || b)) ``` where the sum is over the `sample_shape` dims. Args: a: Instance of `Sample` distribution. b: Instance of `Sample` distribution. name: (optional) name to use for created ops. Default value: `"kl_sample"`'. Returns: kldiv: Batchwise `KL(a || b)`. Raises: ValueError: If the `sample_shape` of `a` and `b` don't match. """ assertions = [] a_ss = tf.get_static_value(a.sample_shape) b_ss = tf.get_static_value(b.sample_shape) msg = '`a.sample_shape` must be identical to `b.sample_shape`.' if a_ss is not None and b_ss is not None: if not np.array_equal(a_ss, b_ss): raise ValueError(msg) elif a.validate_args or b.validate_args: assertions.append( assert_util.assert_equal(a.sample_shape, b.sample_shape, message=msg)) with tf.control_dependencies(assertions): kl = kullback_leibler.kl_divergence(a.distribution, b.distribution, name=name) n = ps.reduce_prod(a.sample_shape) return tf.cast(x=n, dtype=kl.dtype) * kl
def _get_flat_unconstraining_bijector(jd_model): """Create a bijector from a joint distribution that flattens and unconstrains. The intention is (loosely) to go from a model joint distribution supported on U_1 x U_2 x ... U_n, with U_j a subset of R^{n_j} to a model supported on R^N, with N = sum(n_j). (This is "loose" in the sense of base measures: some distribution may be supported on an m-dimensional subset of R^n, and the default transform for that distribution may then have support on R^m. See [1] for details. Args: jd_model: subclass of `tfd.JointDistribution` A JointDistribution for a model. Returns: A `tfb.Bijector` where the `.forward` method flattens and unconstrains points. """ # TODO(b/180396233): This bijector is in general point-dependent. to_chain = [jd_model.experimental_default_event_space_bijector()] flat_bijector = restructure.pack_sequence_as(jd_model.event_shape_tensor()) to_chain.append(flat_bijector) unconstrained_shapes = flat_bijector.inverse_event_shape_tensor( jd_model.event_shape_tensor()) # this reshaping is required as as split can produce a tensor of shape [1] # when the distribution event shape is [] reshapers = [ reshape.Reshape(event_shape_out=x, event_shape_in=[-1]) for x in unconstrained_shapes ] to_chain.append(joint_map.JointMap(bijectors=reshapers)) size_splits = [ps.reduce_prod(x) for x in unconstrained_shapes] to_chain.append(split.Split(num_or_size_splits=size_splits)) return invert.Invert(chain.Chain(to_chain))
def _log_prob(self, x): assertions = [] message = 'Input must have at least one dimension.' if tensorshape_util.rank(x.shape) is not None: if tensorshape_util.rank(x.shape) == 0: raise ValueError(message) elif self.validate_args: assertions.append( assert_util.assert_rank_at_least(x, 1, message=message)) with tf.control_dependencies(assertions): event_tensors = self._distribution.event_shape_tensor() splits = [ ps.maximum(1, ps.reduce_prod(s)) for s in tf.nest.flatten(event_tensors) ] x = tf.nest.pack_sequence_as(event_tensors, tf.split(x, splits, axis=-1)) def _reshape_part(part, dtype, event_shape): part = tf.cast(part, dtype) static_rank = tf.get_static_value( ps.rank_from_shape(event_shape)) if static_rank == 1: return part new_shape = ps.concat([ps.shape(part)[:-1], event_shape], axis=-1) return tf.reshape(part, ps.cast(new_shape, tf.int32)) if all( tensorshape_util.is_fully_defined(s) for s in tf.nest.flatten(self._distribution.event_shape)): x = tf.nest.map_structure(_reshape_part, x, self._distribution.dtype, self._distribution.event_shape) else: x = tf.nest.map_structure( _reshape_part, x, self._distribution.dtype, self._distribution.event_shape_tensor()) return self._distribution.log_prob(x)
def _make_flatten_unflatten_fns(batch_shape): """Builds functions for flattening and unflattening batch dimensions.""" batch_shape = tuple(batch_shape) batch_rank = len(batch_shape) ndims = ps.cast(ps.reduce_prod(batch_shape), tf.int32) def flatten_fn(x): x_shape = tuple(x.shape) if x_shape[:batch_rank] != batch_shape: raise ValueError( 'Expected batch-shape=%s; received array of shape=%s' % (batch_shape, x_shape)) flat_shape = (ndims, ) + x_shape[batch_rank:] return tf.reshape(x, flat_shape) def unflatten_fn(x): x_shape = tuple(x.shape) if x_shape[0] != ndims: raise ValueError('Expected batch-size=%d; received shape=%s' % (ndims, x_shape)) return tf.reshape(x, batch_shape + x_shape[1:]) return flatten_fn, unflatten_fn
def _make_vector_event_space_bijector(model): """Creates a vector bijector that constrains like the structured model.""" # TODO(siege): Make this work with multi-part default_event_bijector. def _make_reshaped_bijector(b, s): return tfb.Chain([ tfb.Reshape(event_shape_in=s, event_shape_out=[ps.reduce_prod(s)]), b, tfb.Reshape( event_shape_in=[ps.reduce_prod(b.inverse_event_shape(s))], event_shape_out=b.inverse_event_shape(s)), ]) reshaped_bijector = tf.nest.map_structure(_make_reshaped_bijector, model.default_event_space_bijector, model.event_shape) return tfb.Blockwise( bijectors=tf.nest.flatten(reshaped_bijector), block_sizes=tf.nest.flatten( tf.nest.map_structure( lambda b, s: ps.reduce_prod(b.inverse_event_shape(s)), # pylint: disable=g-long-lambda model.default_event_space_bijector, model.event_shape)))
def _log_prob(self, x): if self.input_output_cholesky: x_sqrt = x else: # Complexity: O(nbk**3) x_sqrt = tf.linalg.cholesky(x) df = tf.convert_to_tensor(self.df) batch_shape = self._batch_shape_tensor(df) event_shape = self._event_shape_tensor() dimension = self._dimension() x_ndims = ps.rank(x_sqrt) num_singleton_axes_to_prepend = ( ps.maximum(ps.size(batch_shape) + 2, x_ndims) - x_ndims) x_with_prepended_singletons_shape = ps.concat([ ps.ones([num_singleton_axes_to_prepend], dtype=tf.int32), ps.shape(x_sqrt) ], 0) x_sqrt = tf.reshape(x_sqrt, x_with_prepended_singletons_shape) ndims = ps.rank(x_sqrt) # sample_ndims = ndims - batch_ndims - event_ndims sample_ndims = ndims - ps.size(batch_shape) - 2 sample_shape = ps.shape(x_sqrt)[:sample_ndims] # We need to be able to pre-multiply each matrix by its corresponding # batch scale matrix. Since a Distribution Tensor supports multiple # samples per batch, this means we need to reshape the input matrix `x` # so that the first b dimensions are batch dimensions and the last two # are of shape [dimension, dimensions*number_of_samples]. Doing these # gymnastics allows us to do a batch_solve. # # After we're done with sqrt_solve (the batch operation) we need to undo # this reshaping so what we're left with is a Tensor partitionable by # sample, batch, event dimensions. # Complexity: O(nbk**2) since transpose must access every element. scale_sqrt_inv_x_sqrt = x_sqrt perm = ps.concat( [ps.range(sample_ndims, ndims), ps.range(0, sample_ndims)], 0) scale_sqrt_inv_x_sqrt = tf.transpose(a=scale_sqrt_inv_x_sqrt, perm=perm) last_dim_size = ( ps.cast(dimension, dtype=tf.int32) * ps.reduce_prod(x_with_prepended_singletons_shape[:sample_ndims])) shape = ps.concat([ x_with_prepended_singletons_shape[sample_ndims:-2], [ps.cast(dimension, dtype=tf.int32), last_dim_size] ], axis=0) scale_sqrt_inv_x_sqrt = tf.reshape(scale_sqrt_inv_x_sqrt, shape) # Complexity: O(nbM*k) where M is the complexity of the operator solving a # vector system. For LinearOperatorLowerTriangular, each solve is O(k**2) so # this step has complexity O(nbk^3). scale_sqrt_inv_x_sqrt = self._scale.solve(scale_sqrt_inv_x_sqrt) # Undo make batch-op ready. # Complexity: O(nbk**2) shape = ps.concat( [ps.shape(scale_sqrt_inv_x_sqrt)[:-2], event_shape, sample_shape], axis=0) scale_sqrt_inv_x_sqrt = tf.reshape(scale_sqrt_inv_x_sqrt, shape) perm = ps.concat([ ps.range(ndims - sample_ndims, ndims), ps.range(0, ndims - sample_ndims) ], 0) scale_sqrt_inv_x_sqrt = tf.transpose(a=scale_sqrt_inv_x_sqrt, perm=perm) # Write V = SS', X = LL'. Then: # tr[inv(V) X] = tr[inv(S)' inv(S) L L'] # = tr[inv(S) L L' inv(S)'] # = tr[(inv(S) L) (inv(S) L)'] # = sum_{ik} (inv(S) L)_{ik}**2 # The second equality follows from the cyclic permutation property. # Complexity: O(nbk**2) trace_scale_inv_x = tf.reduce_sum(tf.square(scale_sqrt_inv_x_sqrt), axis=[-2, -1]) # Complexity: O(nbk) half_log_det_x = tf.reduce_sum(tf.math.log( tf.linalg.diag_part(x_sqrt)), axis=[-1]) # Complexity: O(nbk**2) log_prob = ((df - dimension - 1.) * half_log_det_x - 0.5 * trace_scale_inv_x - self._log_normalization(df=df, scale=self._scale)) # Set shape hints. # Try to merge what we know from the input x with what we know from the # parameters of this distribution. if tensorshape_util.rank( x.shape) is not None and tensorshape_util.rank( self.batch_shape) is not None: tensorshape_util.set_shape( log_prob, tf.broadcast_static_shape(x.shape[:-2], self.batch_shape)) return log_prob