def _sample_n(self, n, seed): seed = SeedStream(seed, salt='MixtureSameFamily') x = self.components_distribution.sample(n, seed=seed()) # [n, B, k, E] event_shape = None event_ndims = tensorshape_util.rank(self.event_shape) if event_ndims is None: event_shape = self.components_distribution.event_shape_tensor() event_ndims = prefer_static.rank_from_shape(event_shape) event_ndims_static = tf.get_static_value(event_ndims) num_components = None if event_ndims_static is not None: num_components = tf.compat.dimension_value( x.shape[-1 - event_ndims_static]) # We could also check if num_components can be computed statically from # self.mixture_distribution's logits or probs. if num_components is None: num_components = tf.shape(x)[-1 - event_ndims] # TODO(jvdillon): Consider using tf.gather (by way of index unrolling). npdt = dtype_util.as_numpy_dtype(x.dtype) mask = tf.one_hot( indices=self.mixture_distribution.sample( n, seed=seed()), # [n, B] or [n] depth=num_components, on_value=npdt(1), off_value=npdt(0)) # [n, B, k] or [n, k] # Pad `mask` to [n, B, k, [1]*e] or [n, [1]*b, k, [1]*e] . batch_ndims = prefer_static.rank(x) - event_ndims - 1 mask_batch_ndims = prefer_static.rank(mask) - 1 pad_ndims = batch_ndims - mask_batch_ndims mask_shape = prefer_static.shape(mask) mask = tf.reshape( mask, shape=prefer_static.concat([ mask_shape[:-1], prefer_static.ones([pad_ndims], dtype=tf.int32), mask_shape[-1:], prefer_static.ones([event_ndims], dtype=tf.int32), ], axis=0)) ret = tf.reduce_sum(x * mask, axis=-1 - event_ndims) # [n, B, E] if self._reparameterize: if event_shape is None: event_shape = self.components_distribution.event_shape_tensor() ret = self._reparameterize_sample(ret, event_shape=event_shape) return ret
def _finish_log_prob(self, lp, aux): (sample_ndims, extra_sample_ndims, batch_ndims) = aux # (1) Ensure lp is fully broadcast in the sample dims, i.e. ensure lp has # full sample shape in the sample axes, before we reduce. bcast_lp_shape = ps.broadcast_shape( ps.shape(lp), ps.concat([ps.ones([sample_ndims], tf.int32), ps.reshape(self.sample_shape, shape=[-1]), ps.ones([batch_ndims], tf.int32)], axis=0)) lp = tf.broadcast_to(lp, bcast_lp_shape) # (2) Make the final reduction. axis = ps.range(sample_ndims, sample_ndims + extra_sample_ndims) return self._sum_fn()(lp, axis=axis)
def _bcast_and_reduce_logdet(self, underlying_ldj): # Ensure ldj is fully broadcast in the sample dims, i.e. ensure ldj has # full sample shape in the sample axes, before we reduce. batch_ndims = ps.rank_from_shape(self.distribution.batch_shape_tensor, self.distribution.batch_shape) extra_sample_ndims = ps.rank_from_shape(self.sample_shape) sample_ndims = ps.rank(underlying_ldj) - extra_sample_ndims - batch_ndims bcast_ldj_shape = ps.broadcast_shape( ps.shape(underlying_ldj), ps.concat([ps.ones([sample_ndims], tf.int32), ps.ones([batch_ndims], tf.int32), ps.reshape(self.sample_shape, shape=[-1])], axis=0)) ldj = tf.broadcast_to(underlying_ldj, bcast_ldj_shape) return self._sum_fn(ldj, axis=-1 - ps.range(extra_sample_ndims))
def _sample_n(self, n, seed): components_seed, mix_seed = samplers.split_seed( seed, salt='MixtureSameFamily') mixture_distribution, components_distribution = ( self._get_distributions_with_broadcast_batch_shape()) x = components_distribution.sample( # [n, B, k, E] n, seed=components_seed) event_ndims = ps.rank_from_shape(self.event_shape_tensor, self.event_shape) # We could also check if num_components can be computed statically from # self.mixture_distribution's logits or probs. num_components = ps.dimension_size(x, idx=-1 - event_ndims) # TODO(jvdillon): Consider using tf.gather (by way of index unrolling). npdt = dtype_util.as_numpy_dtype(x.dtype) mix_sample = mixture_distribution.sample(n, seed=mix_seed) # [n, B] mask = tf.one_hot( indices=mix_sample, # [n, B] depth=num_components, on_value=npdt(1), off_value=npdt(0)) # [n, B, k] # Pad `mask` to [n, B, k, [1]*e]. batch_ndims = ps.rank(x) - event_ndims - 1 mask_batch_ndims = ps.rank(mask) - 1 pad_ndims = batch_ndims - mask_batch_ndims mask_shape = ps.shape(mask) target_shape = ps.concat([ mask_shape[:-1], ps.ones([pad_ndims], dtype=tf.int32), mask_shape[-1:], ps.ones([event_ndims], dtype=tf.int32), ], axis=0) mask = tf.reshape(mask, shape=target_shape) if dtype_util.is_floating(x.dtype) or dtype_util.is_complex(x.dtype): masked = tf.math.multiply_no_nan(x, mask) else: masked = x * mask ret = tf.reduce_sum(masked, axis=-1 - event_ndims) # [n, B, E] if self._reparameterize: ret = self._reparameterize_sample( ret, event_shape=components_distribution.event_shape_tensor()) return ret
def assertDistributionIsApproximatelyStandardNormal( self, dist, logprob_atol=1e-2, grad_atol=1e-2): """Verifies that dist's lps and gradients match those of Normal(0., 1.).""" event_ndims = ps.rank_from_shape(dist.event_shape_tensor, dist.event_shape) batch_ndims = ps.rank_from_shape(dist.batch_shape_tensor, dist.batch_shape) dist_shape = ps.concat( [dist.batch_shape_tensor(), dist.event_shape_tensor()], axis=0) reference_dist = tfd.Independent(tfd.Normal(loc=tf.zeros( dist_shape, dtype=dist.dtype), scale=1.), reinterpreted_batch_ndims=event_ndims) zs = tf.reshape( [-4., -2., 0., 2., 4.], ps.concat([[5], ps.ones([batch_ndims + event_ndims], dtype=np.int32)], axis=0)) zs = tf.broadcast_to(zs, ps.concat([[5], dist_shape], axis=0)) lp_dist, grad_dist = tfp.math.value_and_gradient(dist.log_prob, zs) lp_reference, grad_reference = tfp.math.value_and_gradient( reference_dist.log_prob, zs) self.assertAllClose(lp_reference, lp_dist, atol=logprob_atol) self.assertAllClose(grad_reference, grad_dist, atol=grad_atol)
def _broadcast_to_full_batch_shape_helper(data, event_ndims, batch_shape, sample_ndims=0): """Broadcasts `[sample, ?, event]` to `[sample, batch, event]`.""" if data is None: return None data_shape = ps.shape(data) data_rank = ps.rank_from_shape(data_shape) batch_ndims = ps.rank_from_shape(batch_shape) # Reshape the data to have full batch rank. For example, given # `batch_shape==[3, 2]`, this would reshape `data.shape==[S, 2, E]` to # `[S, 1, 2, E]`). # This reshaping is not necessary when `sample_ndims==0`, since with no sample # dimensions the batch shape itself is leftmost and can broadcast. For # example, we would not need to reshape `[2, E] -> [1, 2, E]`. if sample_ndims != 0: padding_ndims = batch_ndims - (data_rank - sample_ndims - event_ndims) padded_shape = ps.concat([data_shape[:sample_ndims], ps.ones([padding_ndims], dtype=np.int32), data_shape[sample_ndims:]], axis=0) data = tf.reshape(data, padded_shape) data_shape = padded_shape data_rank = ps.rank_from_shape(data_shape) # Broadcast the data to have full batch shape. For example, given # `batch_shape==[3, 2]`, this would broadcast `data.shape==[S, 1, 2, E]` to # `[S, 3, 2, E]`. new_shape = tf.concat([data_shape[:sample_ndims], batch_shape, data_shape[data_rank - event_ndims:]], axis=0) return tf.broadcast_to(data, new_shape)
def _design_matrix_for_one_seasonal_effect(num_steps, duration, period, dtype): current_period = np.int32(np.arange(num_steps) / duration) % period return np.transpose([ ps.where(current_period == p, # pylint: disable=g-complex-comprehension ps.ones([], dtype=dtype), ps.zeros([], dtype=dtype)) for p in range(period)])
def adjacent_swaps(num_replica, batch_shape=(), step_count=None, seed=None): """Make random shuffle using only one time swaps.""" del step_count # Unused for this function. with tf.name_scope(name or 'adjacent_swaps'): parity_seed, proposal_seed = samplers.split_seed(seed) # u selects parity. E.g., # u==False ==> [1, 0, 3, 2, 4] even parity swaps # u==True ==> [0, 2, 1, 4, 3] odd parity swaps # If there are only 2 replicas, then the "True" swaps are null # swaps...which would contradict the user provided `prob_swap`. # So special case num_replica==2, forcing u==False in this case. u_shape = ps.concat( (ps.ones(1, dtype=tf.int32), ps.cast(batch_shape, tf.int32)), axis=0) u = samplers.uniform(u_shape, seed=parity_seed) < 0.5 u = tf.where(num_replica > 2, u, False) x = bu.left_justified_expand_dims_to(ps.range(num_replica, dtype=tf.int64), rank=ps.size(u_shape)) y = tf.where(tf.equal(x % 2, tf.cast(u, dtype=tf.int64)), x + 1, x - 1) y = tf.clip_by_value(y, 0, num_replica - 1) # TODO(b/142689785): Consider using tf.cond and returning an empty list # then in REMC consider using a tf.cond for short-circuiting. return tf.where( samplers.uniform(batch_shape, seed=proposal_seed) < prob_swap, y, x)
def _variance(self): probs = self.mixture_distribution.probs_parameter() # [B, k] or [k] component_means = self.components_distribution.mean() # [B, k, E] component_vars = self.components_distribution.variance() # [B, k, E] event_ndims = self._event_ndims() # reshape probs to [B, k, [1]*e] or [k, [1]*e] probs = tf.reshape( probs, prefer_static.concat([ prefer_static.shape(probs), prefer_static.ones([event_ndims], dtype=tf.int32) ], axis=0)) # Law of total variance: Var(Y) = E[Var(Y|X)] + Var(E[Y|X]) mean_cond_var = tf.reduce_sum(probs * component_vars, axis=-1 - event_ndims) # [B, E] mean = tf.reduce_sum(probs * component_means, axis=-1 - event_ndims, keepdims=True) # [B, 1, E] var_cond_mean = tf.reduce_sum( probs * tf.math.squared_difference(component_means, mean), axis=-1 - event_ndims) # [B, E] return mean_cond_var + var_cond_mean
def _init_momentum(initial_transformed_position): """Initialize momentum so trace_fn can be concatenated.""" event_shape = ps.shape(initial_transformed_position)[-1] return dmma._make_momentum_distribution( # pylint: disable=protected-access running_variance_parts=[ps.ones(event_shape)], state_parts=tf.nest.flatten(initial_transformed_position), batch_ndims=1)
def _expand_dims_under_batch_dim(tensor, new_rank): """Adds size-1 dimensions below the first until `tensor` has `new_rank`.""" ones = prefer_static.ones([new_rank - prefer_static.rank(tensor)], dtype=tf.int32) shape = prefer_static.shape(tensor) new_shape = prefer_static.concat([shape[:1], ones, shape[1:]], axis=0) return tf.reshape(tensor, new_shape)
def _right_pad(x, final_rank): """Pads the shape of x to the right to be of rank final_rank. Expands the dims of `x` to the right such that its rank is equal to final_rank. For example, if `x` is of shape [1, 5, 7, 2] and `final_rank` is 7, we return padded_x, which is of shape [1, 5, 7, 2, 1, 1, 1]. Args: x: The tensor whose shape is to be padded. final_rank: Scalar int32 `Tensor` or Python `int`. The desired rank of x. Returns: padded_x: A tensor of rank final_rank. """ padded_shape = ps.concat( [ps.shape(x), ps.ones(final_rank - ps.rank(x), dtype=tf.int32)], axis=0) static_padded_shape = None if tensorshape_util.is_fully_defined(x.shape) and isinstance( final_rank, int): static_padded_shape = tensorshape_util.as_list(x.shape) extra_dims = final_rank - len(static_padded_shape) static_padded_shape.extend([1] * extra_dims) padded_x = tf.reshape(x, static_padded_shape or padded_shape) return padded_x
def _init_momentum(initial_transformed_position): """Initialize momentum so trace_fn can be concatenated.""" event_shape = ps.shape(initial_transformed_position)[-1] return preconditioning_utils.make_momentum_distribution( state_parts=tf.nest.flatten(initial_transformed_position), batch_ndims=1, running_variance_parts=[ps.ones(event_shape)])
def _transpose_and_reshape_result(self, x, sample_shape, event_shape=None): if event_shape is None: event_shape = self.event_shape_tensor() batch_shape = self.batch_shape_tensor() batch_rank = ps.rank_from_shape(batch_shape) underlying_batch_shape = self.distribution.batch_shape_tensor() underlying_batch_rank = ps.rank_from_shape(underlying_batch_shape) # Continuing the example from `_augment_sample_shape`, suppose we have: # - sample shape of `[n]`, # - underlying distribution batch shape of `[2, 1]`, # - final broadcast batch shape of `[4, 2, 3]`. # and have drawn an `x` of shape `[n, 12, 2, 1] + event_shape`, which we # ultimately want to have shape `[n, 4, 2, 3] + event_shape`. # First, we reshape to expand out the batch elements: # `shape_with_doubled_batch == [n] + [4, 1, 3] + [1, 2, 1] + event_shape`, # where `[1, 2, 1]` is the fully-expanded underlying batch shape, and # `[4, 1, 3]` is the shape of the elements being added by broadcasting. underlying_bcast_shp = ps.concat([ ps.ones([ps.maximum(batch_rank - underlying_batch_rank, 0)], dtype=underlying_batch_shape.dtype), underlying_batch_shape ], axis=0) is_dim_bcast = ps.not_equal(batch_shape, underlying_bcast_shp) x_with_doubled_batch = tf.reshape( x, ps.concat([ sample_shape, ps.where(is_dim_bcast, batch_shape, 1), underlying_bcast_shp, event_shape ], axis=0)) # Next, construct the permutation that interleaves the batch dimensions, # resulting in samples with shape # `[n] + [4, 1] + [1, 2] + [3, 1] + event_shape`. # Note that each interleaved pair of batch dimensions contains exactly one # dim of size `1` and one of size `>= 1`. sample_ndims = ps.rank_from_shape(sample_shape) x_with_interleaved_batch = tf.transpose( x_with_doubled_batch, perm=ps.concat([ ps.range(sample_ndims), sample_ndims + ps.reshape( ps.stack([ ps.range(batch_rank), ps.range(batch_rank) + batch_rank ], axis=-1), [-1]), sample_ndims + 2 * batch_rank + ps.range(ps.rank_from_shape(event_shape)) ], axis=0)) # Final reshape to remove the spurious `1` dimensions. return tf.reshape( x_with_interleaved_batch, ps.concat([sample_shape, batch_shape, event_shape], axis=0))
def _slice_single_param(param, param_event_ndims, slices, dist_batch_shape): """Slices a single parameter of a distribution. Args: param: A `Tensor`, the original parameter to slice. param_event_ndims: `int` event parameterization rank for this parameter. slices: A `tuple` of normalized slices. dist_batch_shape: The distribution's batch shape `Tensor`. Returns: new_param: A `Tensor`, batch-sliced according to slices. """ # Extend param shape with ones on the left to match dist_batch_shape. param_shape = ps.shape(param) insert_ones = ps.ones( [ps.size(dist_batch_shape) + param_event_ndims - ps.rank(param)], dtype=param_shape.dtype) new_param_shape = ps.concat([insert_ones, param_shape], axis=0) full_batch_param = tf.reshape(param, new_param_shape) param_slices = [] # We separately track the batch axis from the parameter axis because we want # them to align for positive indexing, and be offset by param_event_ndims for # negative indexing. param_dim_idx = 0 batch_dim_idx = 0 for slc in slices: if slc is tf.newaxis: param_slices.append(slc) continue if slc is Ellipsis: if batch_dim_idx < 0: raise ValueError('Found multiple `...` in slices {}'.format(slices)) param_slices.append(slc) # Switch over to negative indexing for the broadcast check. num_remaining_non_newaxis_slices = sum( [s is not tf.newaxis for s in slices[slices.index(Ellipsis) + 1:]]) batch_dim_idx = -num_remaining_non_newaxis_slices param_dim_idx = batch_dim_idx - param_event_ndims continue # Find the batch dimension sizes for both parameter and distribution. param_dim_size = new_param_shape[param_dim_idx] batch_dim_size = dist_batch_shape[batch_dim_idx] is_broadcast = batch_dim_size > param_dim_size # Slices are denoted by start:stop:step. if isinstance(slc, slice): start, stop, step = slc.start, slc.stop, slc.step if start is not None: start = ps.where(is_broadcast, 0, start) if stop is not None: stop = ps.where(is_broadcast, 1, stop) if step is not None: step = ps.where(is_broadcast, 1, step) param_slices.append(slice(start, stop, step)) else: # int, or int Tensor, e.g. d[d.batch_shape_tensor()[0] // 2] param_slices.append(ps.where(is_broadcast, 0, slc)) param_dim_idx += 1 batch_dim_idx += 1 param_slices.extend([ALL_SLICE] * param_event_ndims) return full_batch_param.__getitem__(tuple(param_slices))
def left_justified_expand_dims_to(x, rank, name=None): """Right pads `x` with `rank - rank(x)` ones.""" with tf.name_scope(name or 'left_justified_expand_dims_to'): expand_ndims = ps.maximum(rank - ps.rank(x), 0) expand_shape = ps.concat( [ps.shape(x), ps.ones(shape=[expand_ndims], dtype=tf.int32)], axis=0) return tf.reshape(x, expand_shape)
def expand_right_dims(x, broadcast=False): """Expand x so it can bcast w/ tensors of output shape.""" expanded_shape_left = ps.broadcast_shape( ps.shape(x)[:-1], ps.ones([ps.size(y_ref_shape_left)], dtype=tf.int32)) expanded_shape = ps.concat( (expanded_shape_left, ps.shape(x)[-1:], ps.ones([ps.size(y_ref_shape_right)], dtype=tf.int32)), axis=0) x_expanded = tf.reshape(x, expanded_shape) if broadcast: broadcast_shape_left = ps.broadcast_shape( ps.shape(x)[:-1], y_ref_shape_left) broadcast_shape = ps.concat( (broadcast_shape_left, ps.shape(x)[-1:], y_ref_shape_right), axis=0) x_expanded = _broadcast_with(x_expanded, broadcast_shape) return x_expanded
def _add_event_dims_to_mask(validity_mask, *, dist=None, event_ndims=None): validity_mask = tf.convert_to_tensor(validity_mask) if event_ndims is None: event_ndims = ps.rank_from_shape(dist.event_shape_tensor()) return tf.reshape( validity_mask, ps.concat( [ps.shape(validity_mask), ps.ones(event_ndims, dtype=tf.int32)], axis=0))
def _log_prob(self, x, **kwargs): batch_ndims = ps.rank_from_shape(self.distribution.batch_shape_tensor, self.distribution.batch_shape) extra_sample_ndims = ps.rank_from_shape(self.sample_shape) event_ndims = ps.rank_from_shape(self.distribution.event_shape_tensor, self.distribution.event_shape) ndims = ps.rank(x) # (1) Expand x's dims. d = ndims - batch_ndims - extra_sample_ndims - event_ndims x = tf.reshape(x, shape=ps.pad(ps.shape(x), paddings=[[ps.maximum(0, -d), 0]], constant_values=1)) ndims = ps.rank(x) sample_ndims = ps.maximum(0, d) # (2) Transpose x's dims. sample_dims = ps.range(0, sample_ndims) batch_dims = ps.range(sample_ndims, sample_ndims + batch_ndims) extra_sample_dims = ps.range( sample_ndims + batch_ndims, sample_ndims + batch_ndims + extra_sample_ndims) event_dims = ps.range(sample_ndims + batch_ndims + extra_sample_ndims, ndims) perm = ps.concat( [sample_dims, extra_sample_dims, batch_dims, event_dims], axis=0) x = tf.transpose(a=x, perm=perm) # (3) Compute x's log_prob. lp = self.distribution.log_prob(x, **kwargs) # (4) Ensure lp is fully broadcast in the sample dims, i.e. ensure lp has # full sample shape in the sample axes, before we reduce. bcast_lp_shape = ps.broadcast_shape( ps.shape(lp), ps.concat([ ps.ones([sample_ndims], tf.int32), ps.reshape(self.sample_shape, shape=[-1]), ps.ones([batch_ndims], tf.int32) ], axis=0)) lp = tf.broadcast_to(lp, bcast_lp_shape) # (5) Make the final reduction in x. axis = ps.range(sample_ndims, sample_ndims + extra_sample_ndims) return tf.reduce_sum(lp, axis=axis)
def left_justified_expand_dims_to(x, rank, name=None): """Right pads `x` with `rank - rank(x)` ones.""" with tf.name_scope(name or 'left_justified_expand_dims_to'): rank = tf.convert_to_tensor(rank, dtype=tf.int32) expand_ndims = prefer_static.maximum(rank - prefer_static.rank(x), 0) expand_shape = prefer_static.concat([ prefer_static.shape(x), prefer_static.ones(shape=[expand_ndims], dtype=tf.int32) ], axis=0) return prefer_static.reshape(x, expand_shape)
def _broadcast_transition_probs(self, sample_and_batch_shape) -> tf.Tensor: transition_probs_shape = ps.shape(self.transition_probs_tree.branch_lengths) transition_probs_batch_shape = transition_probs_shape[:-3] additional_dims = ( ps.shape(sample_and_batch_shape)[0] - ps.shape(transition_probs_batch_shape)[0] ) new_shape = ps.concat( [ps.ones(additional_dims, dtype=tf.int32), transition_probs_shape], axis=0 ) return tf.reshape(self.transition_probs_tree.branch_lengths, new_shape)
def _dummy_indices_like(indices): """Returns dummy indices ([0, 1, 2, ...]) with batch shape like `indices`.""" indices_shape = ps.shape(indices) num_particles = indices_shape[0] return tf.broadcast_to( ps.reshape( ps.range(num_particles), ps.concat([[num_particles], ps.ones([ps.rank_from_shape(indices_shape) - 1], dtype=np.int32)], axis=0)), indices_shape)
def do_padding(observed_time_series_tensor): current_sample_shape = ps.shape( observed_time_series_tensor)[:-(model_batch_ndims + event_ndims)] current_batch_and_event_shape = ps.shape( observed_time_series_tensor)[-(model_batch_ndims + event_ndims):] return tf.reshape(tensor=observed_time_series_tensor, shape=ps.concat([ current_sample_shape, ps.ones([chain_batch_ndims], dtype=tf.int32), current_batch_and_event_shape ], axis=0))
def even_odd_swaps(num_replica, batch_shape=(), step_count=None, seed=None): """Make deterministic even_odd one time swaps.""" if step_count is None: raise ValueError('`step_count` must be supplied. Found `None`.') del seed # Unused for this function. with tf.name_scope(name or 'even_odd_swaps'): # Period is 1 / frequency, and we want period = Inf if frequency = 0. # safe_swap_period is the correct swap period in case swap_frequency > 0. # If swap_frequency == 0, safe_swap_period is set to 1 (to avoid integer # div by zero below). We will hard-set this case to "null swap." swap_freq = tf.convert_to_tensor(swap_frequency, name='swap_frequency') safe_swap_period = tf.cast( tf.where(swap_freq > 0, tf.math.ceil(tf.math.reciprocal_no_nan(swap_freq)), 1), # Although period = 1 / frequency may have roundoff error, and result # in a period different than what the user intended, the # user will end up with a single integer period, and thus well defined # deterministic swaps. tf.int32, ) # u selects parity. E.g., # u==False ==> [1, 0, 3, 2, 4] even parity swaps # u==True ==> [0, 2, 1, 4, 3] odd parity swaps # If there are 2 replicas, then the "True" swaps are null # swaps...which would contradict the user provided `swap_frequency`. # So special case num_replica==2, forcing u==False in this case. u_shape = ps.concat( (ps.ones(1, dtype=tf.int32), ps.cast(batch_shape, tf.int32)), axis=0) u = tf.fill(u_shape, tf.cast((step_count // safe_swap_period) % 2, tf.bool)) u = tf.where(num_replica > 2, u, False) x = bu.left_justified_expand_dims_to(tf.range(num_replica, dtype=tf.int64), rank=ps.size(u_shape)) y = tf.where(tf.equal(x % 2, tf.cast(u, dtype=tf.int64)), x + 1, x - 1) y = tf.clip_by_value(y, 0, num_replica - 1) # TODO(b/142689785): Consider using tf.cond and returning an empty list # then in REMC consider using a tf.cond for short-circuiting. return tf.where( (tf.cast(step_count % safe_swap_period, tf.bool) | tf.math.equal(swap_freq, 0)), x, # Don't swap y, # Swap )
def _mean(self): probs = self.mixture_distribution.probs_parameter() # [B, k] or [k] component_means = self.components_distribution.mean() # [B, k, E] event_ndims = self._event_ndims() # reshape probs to [B, k, [1]*e] or [k, [1]*e] probs = tf.reshape(probs, ps.concat([ ps.shape(probs), ps.ones([event_ndims], dtype=tf.int32) ], axis=0)) return tf.reduce_sum(probs * component_means, axis=-1 - event_ndims) # [B, E]
def weighted_reduce_sum(x, axis=0): """Weighted sum over an axis of `x`.""" # Extend the weights to broadcast over any event dimensions of `x`. # This assumes that `weights` and `x` have the same sample and batch # dimensions, e.g., that they come from the same `sample_and_log_prob` call. event_ndims = ps.rank(x) - ps.rank(weights) aligned_weights = tf.reshape( weights, ps.concat( [ps.shape(weights), ps.ones([event_ndims], dtype=tf.int32)], axis=0)) return tf.reduce_sum(aligned_weights * tf.cast(x, weights.dtype), axis=axis)
def _fn(self): """Implements summary statistic, eg, mean, stddev, mode.""" x = getattr(self.distribution, attr)() shape = prefer_static.concat([ self.distribution.batch_shape_tensor(), prefer_static.ones(prefer_static.rank_from_shape(self.sample_shape), dtype=self.sample_shape.dtype), self.distribution.event_shape_tensor(), ], axis=0) x = tf.reshape(x, shape=shape) shape = prefer_static.concat([ self.distribution.batch_shape_tensor(), self.sample_shape, self.distribution.event_shape_tensor(), ], axis=0) return tf.broadcast_to(x, shape)
def _fn(self, **kwargs): """Implements summary statistic, eg, mean, stddev, mode.""" sample_shape = ps.reshape(self.sample_shape, shape=[-1]) x = getattr(self.distribution, attr)(**kwargs) shape = ps.concat([ self.distribution.batch_shape_tensor(), ps.ones(ps.rank_from_shape(sample_shape), dtype=sample_shape.dtype), self.distribution.event_shape_tensor(), ], axis=0) x = tf.reshape(x, shape=shape) shape = ps.concat([ self.distribution.batch_shape_tensor(), sample_shape, self.distribution.event_shape_tensor(), ], axis=0) return tf.broadcast_to(x, shape)
def test_dynamic_shape(self): x = tf.Variable(ps.ones([7, 3]), shape=[7, None]) self.evaluate(x.initializer) # Check that the shape is actually `None`. if not tf.executing_eagerly(): last_shape = x.shape[-1] if last_shape is not None: # This is a `tf.Dimension` in tf1. last_shape = last_shape.value self.assertIsNone(last_shape) dynamic_dist = tfd_e.MultivariateNormalPrecisionFactorLinearOperator( precision_factor=tf.linalg.LinearOperatorDiag(tf.ones_like(x))) static_dist = tfd_e.MultivariateNormalPrecisionFactorLinearOperator( precision_factor=tf.linalg.LinearOperatorDiag(tf.ones([7, 3]))) in_ = tf.zeros([7, 3]) self.assertAllClose(self.evaluate(dynamic_dist.log_prob(in_)), static_dist.log_prob(in_))
def _mean(self): mixture_distribution, components_distribution = ( self._get_distributions_with_broadcast_batch_shape()) probs = mixture_distribution.probs_parameter() # [B, k] component_means = components_distribution.mean() # [B, k, E] event_ndims = self._event_ndims() # reshape probs to [B, k, [1]*e] probs = tf.reshape( probs, ps.concat( [ps.shape(probs), ps.ones([event_ndims], dtype=tf.int32)], axis=0)) return tf.reduce_sum(probs * component_means, axis=-1 - event_ndims) # [B, E]