def _log_prob(self, x): batch_ndims = prefer_static.rank_from_shape( self.distribution.batch_shape_tensor, self.distribution.batch_shape) extra_sample_ndims = prefer_static.rank_from_shape(self.sample_shape) event_ndims = prefer_static.rank_from_shape( self.distribution.event_shape_tensor, self.distribution.event_shape) ndims = prefer_static.rank(x) # (1) Expand x's dims. d = ndims - batch_ndims - extra_sample_ndims - event_ndims x = tf.reshape(x, shape=tf.pad( tensor=tf.shape(input=x), paddings=[[prefer_static.maximum(0, -d), 0]], constant_values=1)) sample_ndims = prefer_static.maximum(0, d) # (2) Transpose x's dims. sample_dims = prefer_static.range(0, sample_ndims) batch_dims = prefer_static.range(sample_ndims, sample_ndims + batch_ndims) extra_sample_dims = prefer_static.range( sample_ndims + batch_ndims, sample_ndims + batch_ndims + extra_sample_ndims) event_dims = prefer_static.range( sample_ndims + batch_ndims + extra_sample_ndims, ndims) perm = prefer_static.concat( [sample_dims, extra_sample_dims, batch_dims, event_dims], axis=0) x = tf.transpose(a=x, perm=perm) # (3) Compute x's log_prob. lp = self.distribution.log_prob(x) # (4) Make the final reduction in x. axis = prefer_static.range(sample_ndims, sample_ndims + extra_sample_ndims) return tf.reduce_sum(input_tensor=lp, axis=axis)
def _prepare_for_underlying(self, x): batch_ndims = ps.rank_from_shape(self.distribution.batch_shape_tensor, self.distribution.batch_shape) extra_sample_ndims = ps.rank_from_shape(self.sample_shape) event_ndims = ps.rank_from_shape(self.distribution.event_shape_tensor, self.distribution.event_shape) ndims = ps.rank(x) # (1) Expand x's dims. d = ndims - batch_ndims - extra_sample_ndims - event_ndims x = tf.reshape(x, shape=ps.pad(ps.shape(x), paddings=[[ps.maximum(0, -d), 0]], constant_values=1)) ndims = ps.rank(x) sample_ndims = ps.maximum(0, d) # (2) Transpose x's dims. sample_dims = ps.range(0, sample_ndims) batch_dims = ps.range(sample_ndims, sample_ndims + batch_ndims) extra_sample_dims = ps.range( sample_ndims + batch_ndims, sample_ndims + batch_ndims + extra_sample_ndims) event_dims = ps.range(sample_ndims + batch_ndims + extra_sample_ndims, ndims) perm = ps.concat( [sample_dims, extra_sample_dims, batch_dims, event_dims], axis=0) x = tf.transpose(x, perm=perm) return x, (sample_ndims, extra_sample_ndims, batch_ndims)
def loop_body(i_, event_ind): i = i_ // strides j = i_ % strides i_ind = ps.range(i * fw, ps.maximum(i, fh) * fw, delta=strides * fw, dtype=dtype) j_ind = ps.range(j, ps.maximum(j, fw), delta=strides, dtype=dtype) nc = cartesian_add([i_ind, j_ind]) ind = ps.reverse(ps.reshape(nc, shape=[-1]), axis=[0]) k = ps.reshape(cartesian_add([ ps.range(ps.shape(nc)[0] * sub_fw, delta=sub_fw, dtype=dtype), ps.range(ps.shape(nc)[1], dtype=dtype) ]), shape=[-1]) last_j = strides - (fw - j - 1) % strides - 1 last_i = strides - (fh - i - 1) % strides - 1 kernel_ind = ps.stack( [k, ps.ones_like(k) * last_i * strides + last_j], axis=1) event_ind = ps.tensor_scatter_nd_update(event_ind, ind[..., tf.newaxis], kernel_ind) return i_ + 1, event_ind
def _split_and_reshape_event(self, x): event_tensors = self._distribution.event_shape_tensor() splits = [ ps.maximum(1, ps.reduce_prod(s)) for s in tf.nest.flatten(event_tensors) ] x = tf.nest.pack_sequence_as(event_tensors, tf.split(x, splits, axis=-1)) def _reshape_part(part, dtype, event_shape): part = tf.cast(part, dtype) static_rank = tf.get_static_value(ps.rank_from_shape(event_shape)) if static_rank == 1: return part new_shape = ps.concat([ps.shape(part)[:-1], event_shape], axis=-1) return tf.reshape(part, ps.cast(new_shape, tf.int32)) if all( tensorshape_util.is_fully_defined(s) for s in tf.nest.flatten(self._distribution.event_shape)): x = tf.nest.map_structure(_reshape_part, x, self._distribution.dtype, self._distribution.event_shape) else: x = tf.nest.map_structure(_reshape_part, x, self._distribution.dtype, self._distribution.event_shape_tensor()) return x
def log_joint_fn(*param_vals, **param_kwargs): """Generated log-density function.""" if param_kwargs: if param_vals: raise ValueError( 'log_joint_fn saw both positional args ({}) and named args ({}). ' 'This is not supported: you have to choose!'. format(param_vals, param_kwargs)) param_vals = [ param_kwargs[p.name] for p in self.parameters ] param_lp = parameter_prior.log_prob(*param_vals) # Build a linear Gaussian state space model and evaluate the marginal # log_prob on observations. lgssm = self.make_state_space_model( param_vals=param_vals, num_timesteps=num_timesteps) observation_lp = lgssm.log_prob(observed_time_series, mask=mask) # Sum over likelihoods from iid observations. Without this sum, # adding `param_lp + observation_lp` would broadcast the param priors # over the sample shape, which incorrectly multi-counts the param # priors. sample_ndims = ps.maximum( 0, ps.rank(observation_lp) - ps.rank(param_lp)) observation_lp = tf.reduce_sum(observation_lp, axis=ps.range(sample_ndims)) return param_lp + observation_lp
def _transpose_and_reshape_result(self, x, sample_shape, event_shape=None): if event_shape is None: event_shape = self.event_shape_tensor() batch_shape = self.batch_shape_tensor() batch_rank = ps.rank_from_shape(batch_shape) underlying_batch_shape = self.distribution.batch_shape_tensor() underlying_batch_rank = ps.rank_from_shape(underlying_batch_shape) # Continuing the example from `_augment_sample_shape`, suppose we have: # - sample shape of `[n]`, # - underlying distribution batch shape of `[2, 1]`, # - final broadcast batch shape of `[4, 2, 3]`. # and have drawn an `x` of shape `[n, 12, 2, 1] + event_shape`, which we # ultimately want to have shape `[n, 4, 2, 3] + event_shape`. # First, we reshape to expand out the batch elements: # `shape_with_doubled_batch == [n] + [4, 1, 3] + [1, 2, 1] + event_shape`, # where `[1, 2, 1]` is the fully-expanded underlying batch shape, and # `[4, 1, 3]` is the shape of the elements being added by broadcasting. underlying_bcast_shp = ps.concat([ ps.ones([ps.maximum(batch_rank - underlying_batch_rank, 0)], dtype=underlying_batch_shape.dtype), underlying_batch_shape ], axis=0) is_dim_bcast = ps.not_equal(batch_shape, underlying_bcast_shp) x_with_doubled_batch = tf.reshape( x, ps.concat([ sample_shape, ps.where(is_dim_bcast, batch_shape, 1), underlying_bcast_shp, event_shape ], axis=0)) # Next, construct the permutation that interleaves the batch dimensions, # resulting in samples with shape # `[n] + [4, 1] + [1, 2] + [3, 1] + event_shape`. # Note that each interleaved pair of batch dimensions contains exactly one # dim of size `1` and one of size `>= 1`. sample_ndims = ps.rank_from_shape(sample_shape) x_with_interleaved_batch = tf.transpose( x_with_doubled_batch, perm=ps.concat([ ps.range(sample_ndims), sample_ndims + ps.reshape( ps.stack([ ps.range(batch_rank), ps.range(batch_rank) + batch_rank ], axis=-1), [-1]), sample_ndims + 2 * batch_rank + ps.range(ps.rank_from_shape(event_shape)) ], axis=0)) # Final reshape to remove the spurious `1` dimensions. return tf.reshape( x_with_interleaved_batch, ps.concat([sample_shape, batch_shape, event_shape], axis=0))
def left_justified_expand_dims_to(x, rank, name=None): """Right pads `x` with `rank - rank(x)` ones.""" with tf.name_scope(name or 'left_justified_expand_dims_to'): expand_ndims = ps.maximum(rank - ps.rank(x), 0) expand_shape = ps.concat( [ps.shape(x), ps.ones(shape=[expand_ndims], dtype=tf.int32)], axis=0) return tf.reshape(x, expand_shape)
def left_justified_expand_dims_to(x, rank, name=None): """Right pads `x` with `rank - rank(x)` ones.""" with tf.name_scope(name or 'left_justified_expand_dims_to'): rank = tf.convert_to_tensor(rank, dtype=tf.int32) expand_ndims = prefer_static.maximum(rank - prefer_static.rank(x), 0) expand_shape = prefer_static.pad(prefer_static.shape(x), paddings=[[0, expand_ndims]], constant_values=1) return prefer_static.reshape(x, expand_shape)
def loop_body(i_, kernels_ind): i = i_ // sw j = i_ % sw i_ind = ps.range(i * fw, ps.maximum(i, fh) * fw, delta=sh * fw, dtype=dtype) j_ind = ps.range(j, ps.maximum(j, fw), delta=sw, dtype=dtype) last_j = sw - (fw - j - 1) % sw - 1 last_i = sh - (fh - i - 1) % sh - 1 pos = last_i * sw + last_j nc = cartesian_add([i_ind, j_ind]) kernels_ind = kernels_ind.write( pos, ps.reverse(ps.reverse(nc, [0]), [1])) return i_ + 1, kernels_ind
def _prepare_for_underlying(self, x): batch_ndims = ps.rank_from_shape(self.distribution.batch_shape_tensor, self.distribution.batch_shape) extra_sample_ndims = ps.rank_from_shape(self.sample_shape) event_ndims = ps.rank_from_shape(self.distribution.event_shape_tensor, self.distribution.event_shape) ndims = ps.rank(x) # (1) Expand x's dims. d = ndims - batch_ndims - extra_sample_ndims - event_ndims x = tf.reshape(x, shape=ps.pad(ps.shape(x), paddings=[[ps.maximum(0, -d), 0]], constant_values=1)) sample_ndims = ps.maximum(0, d) x = tf.transpose(x, perm=ps.invert_permutation( self._sampling_permutation(sample_ndims))) return x, (sample_ndims, extra_sample_ndims, batch_ndims)
def _rightmost_expand_to_rank(tensor, new_rank): """Expands `tensor`'s rank by `new_rank - tensor.rank` rightmost dims.""" return tf.reshape( tensor, shape=prefer_static.pad( prefer_static.shape(tensor), paddings=[[0, prefer_static.maximum( new_rank - prefer_static.rank(tensor), 0)]], constant_values=1))
def update_event_ndims(input_event_ndims, input_min_event_ndims, output_min_event_ndims): """Returns output_event_ndims and updates rolling_offset as needed.""" nonlocal rolling_offset ldj_reduce_ndims = bijector_lib.ldj_reduction_ndims( input_event_ndims, input_min_event_ndims) # Update rolling_offset when batch_ndims are negative. rolling_offset = ps.maximum(rolling_offset, -ldj_reduce_ndims) return nest.map_structure(lambda nd: ldj_reduce_ndims + nd, output_min_event_ndims)
def _sample_n(self, n, seed=None): batch_shape = self.batch_shape_tensor() batch_rank = ps.rank_from_shape(batch_shape) n_batch = ps.reduce_prod(batch_shape) underlying_batch_shape = self.distribution.batch_shape_tensor() underlying_batch_rank = ps.rank_from_shape(underlying_batch_shape) underlying_n_batch = ps.reduce_prod(underlying_batch_shape) # Left pad underlying shape with any necessary ones. underlying_bcast_shp = ps.concat([ ps.ones([ps.maximum(batch_rank - underlying_batch_rank, 0)], dtype=underlying_batch_shape.dtype), underlying_batch_shape ], axis=0) # Determine how many underlying samples to produce. n_bcast_samples = ps.maximum(0, n_batch // underlying_n_batch) samps = self.distribution.sample([n, n_bcast_samples], seed=seed) is_dim_bcast = ps.not_equal(batch_shape, underlying_bcast_shp) event_shape = self.event_shape_tensor() event_rank = ps.rank_from_shape(event_shape) shp = ps.concat([[n], ps.where(is_dim_bcast, batch_shape, 1), underlying_bcast_shp, event_shape], axis=0) # Reshape to expand n_bcast_samples and ones-padded underlying_bcast_shp. samps = tf.reshape(samps, shp) # Interleave broadcast and underlying axis indices for transpose. interleaved_batch_axes = ps.reshape( ps.stack([ps.range(batch_rank), ps.range(batch_rank) + batch_rank], axis=-1), [-1]) + 1 event_axes = ps.range(event_rank) + (1 + 2 * batch_rank) perm = ps.concat([[0], interleaved_batch_axes, event_axes], axis=0) samps = tf.transpose(samps, perm=perm) # Finally, reshape to the fully-broadcast batch shape. return tf.reshape(samps, ps.concat([[n], batch_shape, event_shape], axis=0))
def _log_prob(self, x, **kwargs): batch_ndims = ps.rank_from_shape(self.distribution.batch_shape_tensor, self.distribution.batch_shape) extra_sample_ndims = ps.rank_from_shape(self.sample_shape) event_ndims = ps.rank_from_shape(self.distribution.event_shape_tensor, self.distribution.event_shape) ndims = ps.rank(x) # (1) Expand x's dims. d = ndims - batch_ndims - extra_sample_ndims - event_ndims x = tf.reshape(x, shape=ps.pad(ps.shape(x), paddings=[[ps.maximum(0, -d), 0]], constant_values=1)) ndims = ps.rank(x) sample_ndims = ps.maximum(0, d) # (2) Transpose x's dims. sample_dims = ps.range(0, sample_ndims) batch_dims = ps.range(sample_ndims, sample_ndims + batch_ndims) extra_sample_dims = ps.range( sample_ndims + batch_ndims, sample_ndims + batch_ndims + extra_sample_ndims) event_dims = ps.range(sample_ndims + batch_ndims + extra_sample_ndims, ndims) perm = ps.concat( [sample_dims, extra_sample_dims, batch_dims, event_dims], axis=0) x = tf.transpose(a=x, perm=perm) # (3) Compute x's log_prob. lp = self.distribution.log_prob(x, **kwargs) # (4) Ensure lp is fully broadcast in the sample dims, i.e. ensure lp has # full sample shape in the sample axes, before we reduce. bcast_lp_shape = ps.broadcast_shape( ps.shape(lp), ps.concat([ ps.ones([sample_ndims], tf.int32), ps.reshape(self.sample_shape, shape=[-1]), ps.ones([batch_ndims], tf.int32) ], axis=0)) lp = tf.broadcast_to(lp, bcast_lp_shape) # (5) Make the final reduction in x. axis = ps.range(sample_ndims, sample_ndims + extra_sample_ndims) return tf.reduce_sum(lp, axis=axis)
def left_justified_expand_dims_to(x, rank, name=None): """Right pads `x` with `rank - rank(x)` ones.""" with tf.name_scope(name or 'left_justified_expand_dims_to'): rank = tf.convert_to_tensor(rank, dtype=tf.int32) expand_ndims = prefer_static.maximum(rank - prefer_static.rank(x), 0) expand_shape = prefer_static.concat([ prefer_static.shape(x), prefer_static.ones(shape=[expand_ndims], dtype=tf.int32) ], axis=0) return prefer_static.reshape(x, expand_shape)
def _get_transpose_conv_dilated_padding(filter_dim, stride, dilation, padding): """Zero-padding for inputs dilated by strides.""" tot_filter_dim = filter_dim + (filter_dim - 1) * (dilation - 1) if padding == 'VALID': tot_pad = tot_filter_dim + stride - 2 + ps.maximum( tot_filter_dim - stride, 0) elif padding == 'SAME': tot_pad = tot_filter_dim + stride - 2 return ps.cond(filter_dim >= stride, lambda: (tot_pad - tot_pad // 2 - stride + 1, tot_pad // 2), lambda: (filter_dim - stride, tot_pad - filter_dim + 1))
def augmented_fn(step, *args, **kwargs): with tf.name_scope('augment_with_observation_history'): observation_idx = step // num_transitions_per_observation observation_history_indices = ps.range( ps.maximum(0, observation_idx - history_size), observation_idx) return fn(step, *args, observation_history=tf.gather( observations, observation_history_indices), **kwargs)
def _get_reinterpreted_batch_ndims(self, distribution_batch_shape_tensor=None): if self._static_reinterpreted_batch_ndims is not None: return self._static_reinterpreted_batch_ndims if self._reinterpreted_batch_ndims is not None: return tf.convert_to_tensor(self._reinterpreted_batch_ndims) if distribution_batch_shape_tensor is None: distribution_batch_shape_tensor = self.distribution.batch_shape_tensor() return ps.cast( ps.maximum(0, ps.size(distribution_batch_shape_tensor) - 1), np.int32)
def _augment_sample_shape(self, sample_shape): # Suppose we have: # - sample shape of `[n]`, # - underlying distribution batch shape of `[2, 1]`, # - final broadcast batch shape of `[4, 2, 3]`. # Then we must draw `sample_shape + [12]` samples, where # `12 == n_batch // underlying_n_batch`. batch_shape = self.batch_shape_tensor() n_batch = ps.reduce_prod(batch_shape) underlying_batch_shape = self.distribution.batch_shape_tensor() underlying_n_batch = ps.reduce_prod(underlying_batch_shape) return ps.concat( [sample_shape, [ps.maximum(0, n_batch // underlying_n_batch)]], axis=0)
def _split_and_reshape_event(x, model): """Splits and reshapes a flat event `x` to match the structure of `model`.""" splits = [ ps.maximum(1, ps.reduce_prod(s)) for s in tf.nest.flatten(model.event_shape) ] x = tf.nest.pack_sequence_as(model.event_shape, tf.split(x, splits, axis=-1)) def _reshape_part(part, dtype, event_shape): part = tf.cast(part, dtype) new_shape = ps.concat([ps.shape(part)[:-1], event_shape], axis=-1) return tf.reshape(part, ps.cast(new_shape, tf.int32)) x = tf.nest.map_structure(_reshape_part, x, model.dtype, model.event_shape) return x
def _split_and_reshape_event(self, x): splits = [ ps.maximum(1, ps.reduce_prod(s)) for s in tf.nest.flatten(self._model.event_shape) ] x = tf.nest.pack_sequence_as(self._model.event_shape, tf.split(x, splits, axis=-1)) def _reshape_part(part, dtype, event_shape): part = tf.cast(part, dtype) new_shape = ps.concat([ps.shape(part)[:-1], event_shape], axis=-1) return tf.reshape(part, ps.cast(new_shape, tf.int32)) x = tf.nest.map_structure(_reshape_part, x, self._model.dtype, self._model.event_shape) return x
def _compute_fans_from_shape(shape, batch_ndims=0): """Extracts `fan_in, fan_out` from specified shape `Tensor`.""" # Ensure shape is a vector of length >=2. num_pad = prefer_static.maximum(0, 2 - prefer_static.size(shape)) shape = prefer_static.pad( shape, paddings=[[0, num_pad]], constant_values=1) ( batch_shape, # pylint: disable=unused-variable extra_shape, fan_in, fan_out, ) = prefer_static.split(shape, [batch_ndims, -1, 1, 1]) # The following logic is primarily intended for convolutional layers which # have spatial semantics in addition to input/output channels. receptive_field_size = prefer_static.reduce_prod(extra_shape) fan_in = fan_in[0] * receptive_field_size fan_out = fan_out[0] * receptive_field_size return fan_in, fan_out
def _log_prob(self, x, **kwargs): batch_ndims = prefer_static.rank_from_shape( self.distribution.batch_shape_tensor, self.distribution.batch_shape) extra_batch_ndims = prefer_static.rank_from_shape(self.batch_stack) event_ndims = prefer_static.rank_from_shape( self.distribution.event_shape_tensor, self.distribution.event_shape) ndims = prefer_static.rank(x) # (1) Expand x's dims. d = ndims - extra_batch_ndims - batch_ndims - event_ndims x = tf.reshape( x, shape=tf.pad(tf.shape(x), paddings=[[prefer_static.maximum(0, -d), 0]], constant_values=1), ) # (2) Compute x's log_prob. return self.distribution.log_prob(x, **kwargs)
def _log_prob(self, x): assertions = [] message = 'Input must have at least one dimension.' if tensorshape_util.rank(x.shape) is not None: if tensorshape_util.rank(x.shape) == 0: raise ValueError(message) elif self.validate_args: assertions.append( assert_util.assert_rank_at_least(x, 1, message=message)) with tf.control_dependencies(assertions): event_tensors = self._distribution.event_shape_tensor() splits = [ ps.maximum(1, ps.reduce_prod(s)) for s in tf.nest.flatten(event_tensors) ] x = tf.nest.pack_sequence_as(event_tensors, tf.split(x, splits, axis=-1)) def _reshape_part(part, dtype, event_shape): part = tf.cast(part, dtype) static_rank = tf.get_static_value( ps.rank_from_shape(event_shape)) if static_rank == 1: return part new_shape = ps.concat([ps.shape(part)[:-1], event_shape], axis=-1) return tf.reshape(part, ps.cast(new_shape, tf.int32)) if all( tensorshape_util.is_fully_defined(s) for s in tf.nest.flatten(self._distribution.event_shape)): x = tf.nest.map_structure(_reshape_part, x, self._distribution.dtype, self._distribution.event_shape) else: x = tf.nest.map_structure( _reshape_part, x, self._distribution.dtype, self._distribution.event_shape_tensor()) return self._distribution.log_prob(x)
def _calculate_new_shape(self): # Try to get the old shape statically if available. original_shape = self._distribution.batch_shape if not tensorshape_util.is_fully_defined(original_shape): original_shape = self._distribution.batch_shape_tensor() # This is not a check for falseness, it's a check for exactly that shape. if original_shape == (): # pylint: disable=g-explicit-bool-comparison # Force the size to be an integer, not a float, when the shape contains no # dtype information. original_size = 1 else: original_size = ps.reduce_prod(original_shape) original_size = ps.cast(original_size, tf.int32) # Compute the new shape, filling in the `-1` dimension if present. new_shape = self._batch_shape_unexpanded implicit_dim_mask = ps.equal(new_shape, -1) size_implicit_dim = (original_size // ps.maximum(1, -ps.reduce_prod(new_shape))) expanded_new_shape = ps.where( # Assumes exactly one `-1`. implicit_dim_mask, size_implicit_dim, new_shape) # Return the original size on the side because one caller would otherwise # have to recompute it. return expanded_new_shape, original_size
def _deconv_output_length(input_size, filter_size, padding, output_padding, stride, dilation): """Determines output length of a transposed convolution given input length. Args: input_size: `int`. filter_size: `int`. padding: one of `"SAME"`, `"VALID"`, `"FULL"`. output_padding: `int`, amount of padding along the output dimension. Can be set to `None` in which case the output length is inferred. stride: `int`. dilation: `int`. Returns: output_length: The output length (`int`). """ assert padding in {'SAME', 'VALID', 'FULL'} if input_size is None: return None # Get the dilated kernel size filter_size = filter_size + (filter_size - 1) * (dilation - 1) # Infer length if output padding is None, else compute the exact length if output_padding is None: if padding == 'VALID': return input_size * stride + ps.maximum(filter_size - stride, 0) elif padding == 'FULL': return input_size * stride - (stride + filter_size - 2) elif padding == 'SAME': return input_size * stride if padding == 'SAME': pad = filter_size // 2 elif padding == 'VALID': pad = 0 elif padding == 'FULL': pad = filter_size - 1 return (input_size - 1) * stride + filter_size - 2 * pad + output_padding
def one_step(self, state, kernel_results, seed=None): """Takes one Sequential Monte Carlo inference step. Args: state: instance of `tfp.experimental.mcmc.WeightedParticles` representing the current particles with (log) weights. The `log_weights` must be a float `Tensor` of shape `[num_particles, b1, ..., bN]`. The `particles` may be any structure of `Tensor`s, each of which must have shape `concat([log_weights.shape, event_shape])` for some `event_shape`, which may vary across components. kernel_results: instance of `tfp.experimental.mcmc.SequentialMonteCarloResults` representing results from a previous step. seed: Optional seed for reproducible sampling. Returns: state: instance of `tfp.experimental.mcmc.WeightedParticles` representing new particles with (log) weights. kernel_results: instance of `tfp.experimental.mcmc.SequentialMonteCarloResults`. """ with tf.name_scope(self.name): with tf.name_scope('one_step'): seed = samplers.sanitize_seed(seed) proposal_seed, resample_seed = samplers.split_seed(seed) state = WeightedParticles(*state) # Canonicalize. num_particles = ps.size0(state.log_weights) # Propose new particles and update weights for this step, unless it's # the initial step, in which case, use the user-provided initial # particles and weights. proposed_state = self.propose_and_update_log_weights_fn( # Propose state[t] from state[t - 1]. ps.maximum(0, kernel_results.steps - 1), state, seed=proposal_seed) is_initial_step = ps.equal(kernel_results.steps, 0) # TODO(davmre): this `where` assumes the state size didn't change. state = tf.nest.map_structure( lambda a, b: tf.where(is_initial_step, a, b), state, proposed_state) normalized_log_weights = tf.nn.log_softmax(state.log_weights, axis=0) # Every entry of `log_weights` differs from `normalized_log_weights` # by the same normalizing constant. We extract that constant by # examining an arbitrary entry. incremental_log_marginal_likelihood = ( state.log_weights[0] - normalized_log_weights[0]) do_resample = self.resample_criterion_fn(state) # Some batch elements may require resampling and others not, so # we first do the resampling for all elements, then select whether to # use the resampled values for each batch element according to # `do_resample`. If there were no batching, we might prefer to use # `tf.cond` to avoid the resampling computation on steps where it's not # needed---but we're ultimately interested in adaptive resampling # for statistical (not computational) purposes, so this isn't a # dealbreaker. resampled_particles, resample_indices = weighted_resampling.resample( state.particles, state.log_weights, self.resample_fn, seed=resample_seed) uniform_weights = tf.fill( ps.shape(state.log_weights), value=-tf.math.log( tf.cast(num_particles, state.log_weights.dtype))) (resampled_particles, resample_indices, log_weights) = tf.nest.map_structure( lambda r, p: ps.where(do_resample, r, p), (resampled_particles, resample_indices, uniform_weights), (state.particles, _dummy_indices_like(resample_indices), normalized_log_weights)) return ( WeightedParticles(particles=resampled_particles, log_weights=log_weights), SequentialMonteCarloResults( steps=kernel_results.steps + 1, parent_indices=resample_indices, incremental_log_marginal_likelihood=( incremental_log_marginal_likelihood), accumulated_log_marginal_likelihood=( kernel_results.accumulated_log_marginal_likelihood + incremental_log_marginal_likelihood), seed=seed))
def _log_prob(self, x): if self.input_output_cholesky: x_sqrt = x else: # Complexity: O(nbk**3) x_sqrt = tf.linalg.cholesky(x) df = tf.convert_to_tensor(self.df) batch_shape = self._batch_shape_tensor(df) event_shape = self._event_shape_tensor() dimension = self._dimension() x_ndims = ps.rank(x_sqrt) num_singleton_axes_to_prepend = ( ps.maximum(ps.size(batch_shape) + 2, x_ndims) - x_ndims) x_with_prepended_singletons_shape = ps.concat([ ps.ones([num_singleton_axes_to_prepend], dtype=tf.int32), ps.shape(x_sqrt) ], 0) x_sqrt = tf.reshape(x_sqrt, x_with_prepended_singletons_shape) ndims = ps.rank(x_sqrt) # sample_ndims = ndims - batch_ndims - event_ndims sample_ndims = ndims - ps.size(batch_shape) - 2 sample_shape = ps.shape(x_sqrt)[:sample_ndims] # We need to be able to pre-multiply each matrix by its corresponding # batch scale matrix. Since a Distribution Tensor supports multiple # samples per batch, this means we need to reshape the input matrix `x` # so that the first b dimensions are batch dimensions and the last two # are of shape [dimension, dimensions*number_of_samples]. Doing these # gymnastics allows us to do a batch_solve. # # After we're done with sqrt_solve (the batch operation) we need to undo # this reshaping so what we're left with is a Tensor partitionable by # sample, batch, event dimensions. # Complexity: O(nbk**2) since transpose must access every element. scale_sqrt_inv_x_sqrt = x_sqrt perm = ps.concat( [ps.range(sample_ndims, ndims), ps.range(0, sample_ndims)], 0) scale_sqrt_inv_x_sqrt = tf.transpose(a=scale_sqrt_inv_x_sqrt, perm=perm) last_dim_size = ( ps.cast(dimension, dtype=tf.int32) * ps.reduce_prod(x_with_prepended_singletons_shape[:sample_ndims])) shape = ps.concat([ x_with_prepended_singletons_shape[sample_ndims:-2], [ps.cast(dimension, dtype=tf.int32), last_dim_size] ], axis=0) scale_sqrt_inv_x_sqrt = tf.reshape(scale_sqrt_inv_x_sqrt, shape) # Complexity: O(nbM*k) where M is the complexity of the operator solving a # vector system. For LinearOperatorLowerTriangular, each solve is O(k**2) so # this step has complexity O(nbk^3). scale_sqrt_inv_x_sqrt = self._scale.solve(scale_sqrt_inv_x_sqrt) # Undo make batch-op ready. # Complexity: O(nbk**2) shape = ps.concat( [ps.shape(scale_sqrt_inv_x_sqrt)[:-2], event_shape, sample_shape], axis=0) scale_sqrt_inv_x_sqrt = tf.reshape(scale_sqrt_inv_x_sqrt, shape) perm = ps.concat([ ps.range(ndims - sample_ndims, ndims), ps.range(0, ndims - sample_ndims) ], 0) scale_sqrt_inv_x_sqrt = tf.transpose(a=scale_sqrt_inv_x_sqrt, perm=perm) # Write V = SS', X = LL'. Then: # tr[inv(V) X] = tr[inv(S)' inv(S) L L'] # = tr[inv(S) L L' inv(S)'] # = tr[(inv(S) L) (inv(S) L)'] # = sum_{ik} (inv(S) L)_{ik}**2 # The second equality follows from the cyclic permutation property. # Complexity: O(nbk**2) trace_scale_inv_x = tf.reduce_sum(tf.square(scale_sqrt_inv_x_sqrt), axis=[-2, -1]) # Complexity: O(nbk) half_log_det_x = tf.reduce_sum(tf.math.log( tf.linalg.diag_part(x_sqrt)), axis=[-1]) # Complexity: O(nbk**2) log_prob = ((df - dimension - 1.) * half_log_det_x - 0.5 * trace_scale_inv_x - self._log_normalization(df=df, scale=self._scale)) # Set shape hints. # Try to merge what we know from the input x with what we know from the # parameters of this distribution. if tensorshape_util.rank( x.shape) is not None and tensorshape_util.rank( self.batch_shape) is not None: tensorshape_util.set_shape( log_prob, tf.broadcast_static_shape(x.shape[:-2], self.batch_shape)) return log_prob
def infected(new_infections, new_recoveries): return tfd.Deterministic( prefer_static.maximum( 0., previous_state['infected'] + new_infections - new_recoveries))
def susceptible(new_infections): return tfd.Deterministic( prefer_static.maximum( 0., previous_state['susceptible'] - new_infections))