def moments_of_masked_time_series(time_series_tensor, broadcast_mask): """Compute mean and variance, accounting for a mask. Args: time_series_tensor: float `Tensor` time series of shape `concat([batch_shape, [num_timesteps]])`. broadcast_mask: bool `Tensor` of the same shape as `time_series`. Returns: mean: float `Tensor` of shape `batch_shape`. variance: float `Tensor` of shape `batch_shape`. """ num_unmasked_entries = ps.cast( ps.reduce_sum(ps.cast(~broadcast_mask, np.int32), axis=-1), time_series_tensor.dtype) # Manually compute mean and variance, excluding masked entries. mean = (tf.reduce_sum(tf.where( broadcast_mask, tf.zeros([], dtype=time_series_tensor.dtype), time_series_tensor), axis=-1) / num_unmasked_entries) variance = (tf.reduce_sum(tf.where( broadcast_mask, tf.zeros([], dtype=time_series_tensor.dtype), (time_series_tensor - mean[..., tf.newaxis])**2), axis=-1) / num_unmasked_entries) return mean, variance
def _axis_size(x, axis=None): """Get number of elements of `x` in `axis`, as type `x.dtype`.""" if axis is None: return prefer_static.cast(prefer_static.size(x), x.dtype) return prefer_static.cast( prefer_static.reduce_prod( prefer_static.gather(prefer_static.shape(x), axis)), x.dtype)
def iid_sample(sample_fn, sample_shape): """Lift a sampling function to one that draws multiple iid samples. Args: sample_fn: Python `callable` that returns a (possibly nested) structure of `Tensor`s. May optionally take a `seed` named arg: if so, any `int` seeds (for stateful samplers) are passed through directly, while any pair-of-`int` seeds (for stateless samplers) are split into independent seeds for each sample. sample_shape: `int` `Tensor` shape of iid samples to draw. Returns: iid_sample_fn: Python `callable` taking the same arguments as `sample_fn` and returning iid samples. Each returned `Tensor` will have shape `concat([sample_shape, shape_of_original_returned_tensor])`. """ sample_shape = distribution_util.expand_to_vector( ps.cast(sample_shape, np.int32), tensor_name='sample_shape') n = ps.cast(ps.reduce_prod(sample_shape), dtype=np.int32) def unflatten(x): unflattened_shape = ps.cast( ps.concat([sample_shape, ps.shape(x)[1:]], axis=0), dtype=np.int32) return tf.reshape(x, unflattened_shape) def iid_sample_fn(*args, **kwargs): """Draws iid samples from `fn`.""" with tf.name_scope('iid_sample_fn'): seed = kwargs.pop('seed', None) if samplers.is_stateful_seed(seed): kwargs = dict(kwargs, seed=SeedStream(seed, salt='iid_sample')()) def pfor_loop_body(_): with tf.name_scope('iid_sample_fn_stateful_body'): return sample_fn(*args, **kwargs) else: # If a stateless seed arg is passed, split it into `n` different # stateless seeds, so that we don't just get a bunch of copies of the # same sample. if not JAX_MODE: warnings.warn( 'Saw Tensor seed {}, implying stateless sampling. Autovectorized ' 'functions that use stateless sampling may be quite slow because ' 'the current implementation falls back to an explicit loop. This ' 'will be fixed in the future. For now, you will likely see ' 'better performance from stateful sampling, which you can invoke ' 'by passing a Python `int` seed.'.format(seed)) seed = samplers.split_seed(seed, n=n, salt='iid_sample_stateless') def pfor_loop_body(i): with tf.name_scope('iid_sample_fn_stateless_body'): return sample_fn(*args, seed=tf.gather(seed, i), **kwargs) draws = parallel_for.pfor(pfor_loop_body, n) return tf.nest.map_structure(unflatten, draws, expand_composites=True) return iid_sample_fn
def iid_sample(sample_fn, sample_shape): """Lift a sampling function to one that draws multiple iid samples. Args: sample_fn: Python `callable` that returns a (possibly nested) structure of `Tensor`s. May optionally take a `seed` named arg: if so, any `int` seeds (for stateful samplers) are passed through directly, while any pair-of-`int` seeds (for stateless samplers) are split into independent seeds for each sample. sample_shape: `int` `Tensor` shape of iid samples to draw. Returns: iid_sample_fn: Python `callable` taking the same arguments as `sample_fn` and returning iid samples. Each returned `Tensor` will have shape `concat([sample_shape, shape_of_original_returned_tensor])`. """ sample_shape = distribution_util.expand_to_vector( prefer_static.cast(sample_shape, np.int32), tensor_name='sample_shape') n = prefer_static.cast(prefer_static.reduce_prod(sample_shape), dtype=np.int32) def unflatten(x): unflattened_shape = prefer_static.cast(prefer_static.concat( [sample_shape, prefer_static.shape(x)[1:]], axis=0), dtype=np.int32) return tf.reshape(x, unflattened_shape) def iid_sample_fn(*args, **kwargs): """Draws iid samples from `fn`.""" pfor_loop_body = lambda _: sample_fn(*args, **kwargs) seed = kwargs.pop('seed', None) try: # Assume that `seed` is a valid stateful seed (Python `int`). kwargs = dict(kwargs, seed=SeedStream(seed, salt='iid_sample')()) pfor_loop_body = lambda _: sample_fn(*args, **kwargs) except TypeError as e: # If a stateless seed arg is passed, split it into `n` different stateless # seeds, so that we don't just get a bunch of copies of the same sample. if TENSOR_SEED_MSG_PREFIX not in str(e): raise warnings.warn( 'Saw non-`int` seed {}, implying stateless sampling. ' 'Autovectorized functions that use stateless sampling ' 'may be quite slow because the current implementation ' 'falls back to an explicit loop. This will be fixed in the ' 'future. For now, you will likely see better performance ' 'from stateful sampling, which you can invoke by passing a' 'traditional Python `int` seed.'.format(seed)) seed = samplers.split_seed(seed, n=n, salt='iid_sample_stateless') pfor_loop_body = ( lambda i: sample_fn(*args, seed=tf.gather(seed, i), **kwargs)) draws = parallel_for.pfor(pfor_loop_body, n) return tf.nest.map_structure(unflatten, draws, expand_composites=True) return iid_sample_fn
def _summarize_fans(fan_in, fan_out, mode, dtype): """Combines `fan_in`, `fan_out` per specified `mode`.""" fan_in = prefer_static.cast(fan_in, dtype) fan_out = prefer_static.cast(fan_out, dtype) mode = str(mode).lower() if mode == 'fan_in': return fan_in elif mode == 'fan_out': return fan_out elif mode == 'fan_avg': return (fan_in + fan_out) / 2. raise ValueError('Unrecognized mode: "{}".'.format(mode))
def _calculate_batch_shape(self): """Computes fully defined batch shape for the new distribution.""" all_batch_shapes = [d.batch_shape.as_list() if tensorshape_util.is_fully_defined(d.batch_shape) else d.batch_shape_tensor() for d in self.distributions] original_shape = ps.stack(all_batch_shapes, axis=0) index_mask = ps.cast( ps.one_hot(self._axis, ps.shape(original_shape)[1]), dtype=tf.bool) new_concat_dim = ps.cast( ps.reduce_sum(original_shape, axis=0)[self._axis], dtype=tf.int32) return ps.where(index_mask, new_concat_dim, ps.reduce_max(original_shape, axis=0))
def _validate_elem_length(max_num_levels, elems_flat): """Checks that elems all have the same length, and returns that length.""" assertions = [] elem_length = prefer_static.shape(elems_flat[0])[0] # The default size limit will overflow a 32-bit int, so make sure we're # using 64-bit. size_limit = 2**(prefer_static.cast(max_num_levels, np.int64) + 1) enough_levels = prefer_static.less( prefer_static.cast(elem_length, np.int64), size_limit) enough_levels_ = tf.get_static_value(enough_levels) if enough_levels_ is None: assertions.append( tf.debugging.assert_equal( enough_levels, True, message='Input `Tensor`s must have first axis dimension less than' ' `2**(max_num_levels + 1)`' ' (saw: {} which is not less than 2**{} == {})'.format( elem_length, max_num_levels, size_limit))) elif not enough_levels_: raise ValueError( 'Input `Tensor`s must have first axis dimension less than' ' `2**(max_num_levels + 1)`' ' (saw: {} which is not less than 2**{} == {})'.format( elem_length, max_num_levels, size_limit)) is_consistent = prefer_static.reduce_all([ prefer_static.equal( prefer_static.shape(elem)[0], elem_length) for elem in elems_flat[1:]]) is_consistent_ = tf.get_static_value(is_consistent) if is_consistent_ is None: assertions.append( tf.debugging.assert_equal( is_consistent, True, message='Input `Tensor`s must have the same first dimension.' ' (saw: {})'.format([elem.shape for elem in elems_flat]))) elif not is_consistent_: raise ValueError( 'Input `Tensor`s must have the same first dimension.' ' (saw: {})'.format([elem.shape for elem in elems_flat])) return elem_length, assertions
def expand_dims(x, axis, name=None): """Like `tf.expand_dims` but accepts a vector of axes to expand.""" with tf.name_scope(name or 'expand_dims'): x = tf.convert_to_tensor(x, name='x') axis = tf.convert_to_tensor(axis, dtype_hint=tf.int32, name='axis') nx = prefer_static.rank(x) na = prefer_static.size(axis) is_neg_axis = axis < 0 k = prefer_static.reduce_sum( prefer_static.cast(is_neg_axis, axis.dtype)) axis = prefer_static.where(is_neg_axis, axis + nx, axis) axis = prefer_static.sort(axis) axis_neg, axis_pos = prefer_static.split(axis, [k, -1]) idx = prefer_static.argsort(prefer_static.concat([ axis_pos, prefer_static.range(nx), axis_neg, ], axis=0), stable=True) shape = prefer_static.pad(prefer_static.shape(x), paddings=[[na - k, k]], constant_values=1) shape = prefer_static.gather(shape, idx) return tf.reshape(x, shape)
def _reshape_part(part): part = tf.cast(part, dtype) new_shape = ps.concat( [batch_shape, [-1]], axis=-1, ) return tf.reshape(part, ps.cast(new_shape, tf.int32))
def _update_loop_variables(step, current_step_results, accumulated_traced_results, trace_fn, step_indices_to_trace, num_steps_traced): """Update the loop state to reflect a step of filtering.""" # Write particles, indices, and likelihoods to their respective arrays. trace_this_step = True if step_indices_to_trace is not None: trace_this_step = ps.equal( step_indices_to_trace[ps.minimum( num_steps_traced, ps.cast(ps.size0(step_indices_to_trace) - 1, dtype=np.int32))], step) num_steps_traced, accumulated_traced_results = ps.cond( trace_this_step, lambda: ( num_steps_traced + 1, # pylint: disable=g-long-lambda tf.nest.map_structure(lambda x, y: x.write(num_steps_traced, y), accumulated_traced_results, trace_fn(current_step_results))), lambda: (num_steps_traced, accumulated_traced_results)) return ParticleFilterLoopVariables( step=step + 1, previous_step_results=current_step_results, accumulated_traced_results=accumulated_traced_results, num_steps_traced=num_steps_traced)
def _reshape_part(part, dtype, event_shape): part = tf.cast(part, dtype) static_rank = tf.get_static_value(ps.rank_from_shape(event_shape)) if static_rank == 1: return part new_shape = ps.concat([ps.shape(part)[:-1], event_shape], axis=-1) return tf.reshape(part, ps.cast(new_shape, tf.int32))
def _im2row_index(input_shape, block_shape, slice_step=(1, 1), data_format='NHWC', padding='VALID', dtype=tf.int64, name=None): """Computes indexes into a flattened image for building `im2col`.""" with tf.name_scope(name or 'im2row_index'): # 1) Process input arguments. batch_shape, s3, s2, s1 = prefer_static.split( prefer_static.cast(input_shape, tf.int32), num_or_size_splits=[-1, 1, 1, 1]) fh, fw = _split_pair(block_shape) sh, sw = _split_pair(slice_step) data_format = _validate_data_format(data_format) padding = _validate_padding(padding) # 2) Assemble all block start positions as indexes into the flattened image. if data_format == 'NHWC': h, w, c = s3[0], s2[0], s1[0] # start_idx.shape = [fh, fw, c] start_idx = _cartesian_add([ prefer_static.range(c * w * fh, delta=c * w, dtype=dtype), prefer_static.range(c * fw, delta=c, dtype=dtype), prefer_static.range(c, delta=1, dtype=dtype), ]) elif data_format == 'NCHW': c, h, w = s3[0], s2[0], s1[0] # start_idx.shape = [c, fh, fw] start_idx = _cartesian_add([ prefer_static.range(w * h * c, delta=w * h, dtype=dtype), prefer_static.range(w * fh, delta=w, dtype=dtype), prefer_static.range(fw, delta=1, dtype=dtype), ]) else: assert False # Can't be here. # 3) Assemble all block offsets (into flattened image). if padding == 'VALID': eh = h - fh + 1 # extent height ew = w - fw + 1 # extent width # offset_idx.shape = [eh // sh, ew // sw] offset_idx = _cartesian_add([ prefer_static.range(w * eh, delta=w * sh, dtype=dtype), prefer_static.range(ew, delta=sw, dtype=dtype), ]) if data_format == 'NHWC': offset_idx *= c oh = eh // sh # out height ow = ew // sw # out width else: assert False # Can't be here. # 4) Combine block start/offset pairs. # shape = [(eh // sh) * (ew // sw), fh * fw * c] idx = _cartesian_add([offset_idx, start_idx]) new_shape = [oh, ow, fh * fw * c] new_shape = prefer_static.concat([batch_shape, new_shape], axis=0) return idx, new_shape
def adjacent_swaps(num_replica, batch_shape=(), step_count=None, seed=None): """Make random shuffle using only one time swaps.""" del step_count # Unused for this function. with tf.name_scope(name or 'adjacent_swaps'): parity_seed, proposal_seed = samplers.split_seed(seed) # u selects parity. E.g., # u==False ==> [1, 0, 3, 2, 4] even parity swaps # u==True ==> [0, 2, 1, 4, 3] odd parity swaps # If there are only 2 replicas, then the "True" swaps are null # swaps...which would contradict the user provided `prob_swap`. # So special case num_replica==2, forcing u==False in this case. u_shape = ps.concat( (ps.ones(1, dtype=tf.int32), ps.cast(batch_shape, tf.int32)), axis=0) u = samplers.uniform(u_shape, seed=parity_seed) < 0.5 u = tf.where(num_replica > 2, u, False) x = bu.left_justified_expand_dims_to(ps.range(num_replica, dtype=tf.int64), rank=ps.size(u_shape)) y = tf.where(tf.equal(x % 2, tf.cast(u, dtype=tf.int64)), x + 1, x - 1) y = tf.clip_by_value(y, 0, num_replica - 1) # TODO(b/142689785): Consider using tf.cond and returning an empty list # then in REMC consider using a tf.cond for short-circuiting. return tf.where( samplers.uniform(batch_shape, seed=proposal_seed) < prob_swap, y, x)
def _multi_gamma_sequence(self, a, p, name='multi_gamma_sequence'): """Creates sequence used in multivariate (di)gamma; shape = shape(a)+[p].""" with tf.name_scope(name): # Linspace only takes scalars, so we'll add in the offset afterwards. seq = ps.linspace(tf.constant(0., dtype=self.dtype), 0.5 - 0.5 * p, ps.cast(p, tf.int32)) return seq + a[..., tf.newaxis]
def expand_dims_(x): """Implementation of `expand_dims`.""" with tf.name_scope(name or 'expand_dims'): x = tf.convert_to_tensor(x, name='x') new_axis = tf.convert_to_tensor(axis, dtype_hint=tf.int32, name='axis') nx = prefer_static.rank(x) na = prefer_static.size(new_axis) is_neg_axis = new_axis < 0 k = prefer_static.reduce_sum( prefer_static.cast(is_neg_axis, new_axis.dtype)) new_axis = prefer_static.where(is_neg_axis, new_axis + nx, new_axis) new_axis = prefer_static.sort(new_axis) axis_neg, axis_pos = prefer_static.split(new_axis, [k, -1]) idx = prefer_static.argsort(prefer_static.concat([ axis_pos, prefer_static.range(nx), axis_neg, ], axis=0), stable=True) shape = prefer_static.pad(prefer_static.shape(x), paddings=[[na - k, k]], constant_values=1) shape = prefer_static.gather(shape, idx) return tf.reshape(x, shape)
def _apply_with_distance(self, x1, x2, pairwise_square_distance, example_ndims=0): exponent = -2. * pairwise_square_distance locs = util.pad_shape_with_ones(self.locs, ndims=example_ndims, start=-(self.feature_ndims + 1)) cos_coeffs = tf.math.cos(2 * np.pi * (x1 - x2) * locs) feature_ndims = ps.cast(self.feature_ndims, ps.rank(cos_coeffs).dtype) reduction_axes = ps.range( ps.rank(cos_coeffs) - feature_ndims, ps.rank(cos_coeffs)) coeff_sign = tf.math.reduce_prod(tf.math.sign(cos_coeffs), axis=reduction_axes) log_cos_coeffs = tf.math.reduce_sum(tf.math.log( tf.math.abs(cos_coeffs)), axis=reduction_axes) logits = util.pad_shape_with_ones(self.logits, ndims=example_ndims, start=-1) log_result, sign = tfp_math.reduce_weighted_logsumexp( exponent + log_cos_coeffs + logits, coeff_sign, return_sign=True, axis=-(example_ndims + 1)) return sign * tf.math.exp(log_result)
def _joint_sample_n(self, n, seed=None): """Draw a joint sample from the prior over latents and observations. This sampler is specific to LocalLevel models and is faster than the generic LinearGaussianStateSpaceModel implementation. Args: n: `int` `Tensor` number of samples to draw. seed: PRNG seed; see `tfp.random.sanitize_seed` for details. Returns: latents: `float` `Tensor` of shape `concat([[n], self.batch_shape, [self.num_timesteps, self.latent_size]], axis=0)` representing samples of latent trajectories. observations: `float` `Tensor` of shape `concat([[n], self.batch_shape, [self.num_timesteps, self.observation_size]], axis=0)` representing samples of observed series generated from the sampled `latents`. """ with tf.name_scope('joint_sample_n'): (initial_level_seed, level_jumps_seed, prior_observation_seed) = samplers.split_seed( seed, n=3, salt='LocalLevelStateSpaceModel_joint_sample_n') if self.batch_shape.is_fully_defined(): batch_shape = self.batch_shape else: batch_shape = self.batch_shape_tensor() sample_and_batch_shape = ps.cast( ps.concat([[n], batch_shape], axis=0), tf.int32) # Sample the initial timestep from the prior. Since we want # this sample to have full batch shape (not just the batch shape # of the self.initial_state_prior object which might in general be # smaller), we augment the sample shape to include whatever # extra batch dimensions are required. initial_level = self.initial_state_prior.sample( linear_gaussian_ssm._augment_sample_shape( # pylint: disable=protected-access self.initial_state_prior, sample_and_batch_shape, self.validate_args), seed=initial_level_seed) # Sample the latent random walk and observed noise, more efficiently than # the generic loop in `LinearGaussianStateSpaceModel`. level_jumps = self.level_scale[..., tf.newaxis] * samplers.normal( ps.concat([sample_and_batch_shape, [self.num_timesteps - 1]], axis=0), dtype=self.dtype, seed=level_jumps_seed) prior_level_sample = tf.cumsum(tf.concat( [initial_level, level_jumps], axis=-1), axis=-1) prior_observation_sample = prior_level_sample + ( # Sample noise. self.observation_noise_scale[..., tf.newaxis] * samplers.normal(ps.shape(prior_level_sample), dtype=self.dtype, seed=prior_observation_seed)) return (prior_level_sample[..., tf.newaxis], prior_observation_sample[..., tf.newaxis])
def __init__( self, input_size, output_size, # Weights init_kernel_fn=None, # tfp.experimental.nn.initializers.glorot_uniform() init_bias_fn=None, # tf.initializers.zeros() make_kernel_bias_fn=nn_util_lib.make_kernel_bias, dtype=tf.float32, batch_shape=(), # Misc activation_fn=None, name=None): """Constructs layer. Args: input_size: ... output_size: ... init_kernel_fn: ... Default value: `None` (i.e., `tfp.experimental.nn.initializers.glorot_uniform()`). init_bias_fn: ... Default value: `None` (i.e., `tf.initializers.zeros()`). make_kernel_bias_fn: ... Default value: `tfp.experimental.nn.util.make_kernel_bias`. dtype: ... Default value: `tf.float32`. batch_shape: ... Default value: `()`. activation_fn: ... Default value: `None`. name: ... Default value: `None` (i.e., `'Affine'`). """ batch_shape = tf.constant( [], dtype=tf.int32) if batch_shape is None else prefer_static.cast( prefer_static.reshape(batch_shape, shape=[-1]), tf.int32) batch_ndims = prefer_static.size(batch_shape) kernel_shape = prefer_static.concat([ batch_shape, [input_size, output_size]], axis=0) bias_shape = prefer_static.concat([batch_shape, [output_size]], axis=0) apply_kernel_fn = lambda x, k: tf.matmul( x[..., tf.newaxis, :], k)[..., 0, :] # pylint-disable=long-lambda kernel, bias = make_kernel_bias_fn( kernel_shape, bias_shape, init_kernel_fn, init_bias_fn, batch_ndims, batch_ndims, dtype) self._make_kernel_bias_fn = make_kernel_bias_fn # For tracking. super(Affine, self).__init__( kernel=kernel, bias=bias, apply_kernel_fn=apply_kernel_fn, activation_fn=activation_fn, dtype=dtype, name=name)
def _log_prob(self, value): """Log probability of multivariate normal. Costs a log_abs_determinant, matvec, and a reduce_sum over a squared (batch of) vector(s) Args: value: Floating point `Tensor`. Returns: Floating point `Tensor` with batch shape. """ dim = self.precision_factor.domain_dimension_tensor() return (ps.cast(-0.5 * np.log(2 * np.pi), self.dtype) * ps.cast(dim, self.dtype) + # Notice the sign on the LinearOperator.log_abs_determinant is # positive, since it is precision_factor not scale. self._precision_factor.log_abs_determinant() + self._log_prob_unnormalized(value))
def _forward_log_det_jacobian(self, x): # This code is similar to tf.math.log_softmax but different because we have # an implicit zero column to handle. I.e., instead of: # reduce_sum(logits - reduce_sum(exp(logits), dim)) # we must do: # log_normalization = 1 + reduce_sum(exp(logits)) # -log_normalization + reduce_sum(logits - log_normalization) np1 = prefer_static.cast(1 + prefer_static.shape(x)[-1], dtype=x.dtype) return (0.5 * prefer_static.log(np1) + tf.reduce_sum(x, axis=-1) - np1 * tf.math.softplus(tf.reduce_logsumexp(x, axis=-1)))
def _reshape_part(part, event_shape): part = tf.cast(part, self.dtype) new_shape = ps.concat( [ ps.shape(part)[:ps.size(ps.shape(part)) - ps.size(event_shape)], [-1] ], axis=-1, ) return tf.reshape(part, ps.cast(new_shape, tf.int32))
def _forward(self, x): ndims = ps.rank(x) indices = ps.reshape(ps.add(self.axis, ndims), shape=[-1, 1]) return tf.pad( x, paddings=ps.tensor_scatter_nd_update( ps.zeros([ndims, 2], dtype=tf.int32), indices, self.paddings), mode=self.mode, constant_values=ps.cast(self.constant_values, dtype=x.dtype))
def __init__(self, samples, event_ndims=0, validate_args=False, allow_nan_stats=True, name='Empirical'): """Initialize `Empirical` distributions. Args: samples: Numeric `Tensor` of shape [B1, ..., Bk, S, E1, ..., En]`, `k, n >= 0`. Samples or batches of samples on which the distribution is based. The first `k` dimensions index into a batch of independent distributions. Length of `S` dimension determines number of samples in each multiset. The last `n` dimension represents samples for each distribution. n is specified by argument event_ndims. event_ndims: Python `int32`, default `0`. number of dimensions for each event. When `0` this distribution has scalar samples. When `1` this distribution has vector-like samples. validate_args: Python `bool`, default `False`. When `True` distribution parameters are checked for validity despite possibly degrading runtime performance. When `False` invalid inputs may silently render incorrect outputs. allow_nan_stats: Python `bool`, default `True`. When `True`, statistics (e.g., mean, mode, variance) use the value `NaN` to indicate the result is undefined. When `False`, an exception is raised if one or more of the statistic's batch members are undefined. name: Python `str` name prefixed to Ops created by this class. Raises: ValueError: if the rank of `samples` is statically known and less than event_ndims + 1. """ parameters = dict(locals()) with tf.name_scope(name): self._samples = tensor_util.convert_nonref_to_tensor(samples) dtype = dtype_util.common_dtype([self._samples], dtype_hint=self._samples.dtype) self._event_ndims = event_ndims # Note: this tf.rank call affects the graph, but is ok in `__init__` # because we don't expect shapes (or ranks) to be runtime-variable, nor # ever need to differentiate with respect to them. samples_rank = prefer_static.rank(self._samples) self._samples_axis = prefer_static.cast( samples_rank - self._event_ndims - 1, tf.int32) super(Empirical, self).__init__(dtype=dtype, reparameterization_type=reparameterization. FULLY_REPARAMETERIZED, validate_args=validate_args, allow_nan_stats=allow_nan_stats, parameters=parameters, name=name)
def _initialize(shape, dtype, batch_ndims, scale, mode, distribution, seed=None): """Samples a random `Tensor` per specified args.""" if not dtype_util.is_floating(dtype): raise TypeError('Argument `dtype` must be float type (saw: "{}").'.format( dtype)) shape = prefer_static.reshape(shape, shape=[-1]) # Ensure shape is vector. fan_in, fan_out = _compute_fans_from_shape(shape, batch_ndims) fans = _summarize_fans(fan_in, fan_out, mode, dtype) scale = prefer_static.cast(scale, dtype) return _sample_distribution(shape, scale / fans, distribution, seed, dtype)
def sample(self, sample_shape=(), seed=None, name=None): with tf.name_scope(name or 'sample'): # Grab the required number of values from the provided tensors. sample_shape = dist_util.expand_to_vector(sample_shape) n = ps.cast(ps.reduce_prod(sample_shape), dtype=tf.int32) # Check that we're not trying to draw too many samples. assertions = [] will_overflow_ = tf.get_static_value(n > self.max_num_samples) if will_overflow_: raise ValueError( 'Trying to draw {} samples from a ' '`DeterministicEmpirical` instance for which only {} ' 'samples were provided.'.format( tf.get_static_value(n), tf.get_static_value(self.max_num_samples))) elif (will_overflow_ is None # Couldn't determine statically. and self.validate_args): assertions.append( tf.debugging.assert_less_equal( n, self.max_num_samples, message='Number of samples to draw ' 'from a `DeterministicEmpirical` instance must not exceed the ' 'number provided at construction.')) # Extract the appropriate number of sampled values. with tf.control_dependencies(assertions): sampled = tf.nest.map_structure(lambda x: x[:n, ...], self.values_with_sample_dim) # Reshape the values to the appropriate sample shape. return tf.nest.map_structure( lambda x: tf.reshape( x, # pylint: disable=g-long-lambda ps.concat([ ps.cast(sample_shape, tf.int32), ps.cast(ps.shape(x)[1:], tf.int32) ], axis=0)), sampled)
def _get_reinterpreted_batch_ndims(self, distribution_batch_shape_tensor=None): if self._static_reinterpreted_batch_ndims is not None: return self._static_reinterpreted_batch_ndims if self._reinterpreted_batch_ndims is not None: return tf.convert_to_tensor(self._reinterpreted_batch_ndims) if distribution_batch_shape_tensor is None: distribution_batch_shape_tensor = self.distribution.batch_shape_tensor() return ps.cast( ps.maximum(0, ps.size(distribution_batch_shape_tensor) - 1), np.int32)
def even_odd_swaps(num_replica, batch_shape=(), step_count=None, seed=None): """Make deterministic even_odd one time swaps.""" if step_count is None: raise ValueError('`step_count` must be supplied. Found `None`.') del seed # Unused for this function. with tf.name_scope(name or 'even_odd_swaps'): # Period is 1 / frequency, and we want period = Inf if frequency = 0. # safe_swap_period is the correct swap period in case swap_frequency > 0. # If swap_frequency == 0, safe_swap_period is set to 1 (to avoid integer # div by zero below). We will hard-set this case to "null swap." swap_freq = tf.convert_to_tensor(swap_frequency, name='swap_frequency') safe_swap_period = tf.cast( tf.where(swap_freq > 0, tf.math.ceil(tf.math.reciprocal_no_nan(swap_freq)), 1), # Although period = 1 / frequency may have roundoff error, and result # in a period different than what the user intended, the # user will end up with a single integer period, and thus well defined # deterministic swaps. tf.int32, ) # u selects parity. E.g., # u==False ==> [1, 0, 3, 2, 4] even parity swaps # u==True ==> [0, 2, 1, 4, 3] odd parity swaps # If there are 2 replicas, then the "True" swaps are null # swaps...which would contradict the user provided `swap_frequency`. # So special case num_replica==2, forcing u==False in this case. u_shape = ps.concat( (ps.ones(1, dtype=tf.int32), ps.cast(batch_shape, tf.int32)), axis=0) u = tf.fill(u_shape, tf.cast((step_count // safe_swap_period) % 2, tf.bool)) u = tf.where(num_replica > 2, u, False) x = bu.left_justified_expand_dims_to(tf.range(num_replica, dtype=tf.int64), rank=ps.size(u_shape)) y = tf.where(tf.equal(x % 2, tf.cast(u, dtype=tf.int64)), x + 1, x - 1) y = tf.clip_by_value(y, 0, num_replica - 1) # TODO(b/142689785): Consider using tf.cond and returning an empty list # then in REMC consider using a tf.cond for short-circuiting. return tf.where( (tf.cast(step_count % safe_swap_period, tf.bool) | tf.math.equal(swap_freq, 0)), x, # Don't swap y, # Swap )
def _make_input_and_kernel( make_input, input_batch_shape, input_shape, kernel_batch_shape, filter_shape, channels_out, dtype): total_input_shape = ps.concat([input_batch_shape, input_shape], axis=0) total_kernel_shape = ps.concat( [kernel_batch_shape, [filter_shape[0] * filter_shape[1] * input_shape[-1], channels_out]], axis=0) # Use integers for numerical stability. sample_fn = lambda s: make_input(tf.cast( # pylint: disable=g-long-lambda tf.random.uniform( ps.cast(s, tf.int32), minval=-10, maxval=10, dtype=tf.int32), dtype=dtype)) return sample_fn(total_input_shape), sample_fn(total_kernel_shape)
def _scatter_nd_batch(indices, updates, shape, batch_dims=0): """A partial implementation of `scatter_nd` supporting `batch_dims`.""" # `tf.scatter_nd` does not support a `batch_dims` argument. # Instead we use the gradient of `tf.gather_nd`. # From a purely mathematical perspective this works because # (if `tf.scatter_nd` supported `batch_dims`) # `gather_nd` and `scatter_nd` (with matching `indices`) are # adjoint linear operators and # the gradient w.r.t `x` of `dot(y, A(x))` is `adjoint(A)(y)`. # # Another perspective: back propagating through a "neural" network # containing a gather operation carries derivatives backwards through the # network, accumulating the derivatives in the locations that # were gathered from, ie. they are scattered. # If the network multiplies each gathered element by # some quantity, then the backwardly propagating derivatives are scaled # by this quantity before being scattered. # Combining this with the fact that`GradientTape.gradient` # starts back-propagation with derivatives equal to `1`, this allows us # to use the multipliers to determine the quantities scattered. # # However, derivatives are only supported for floating point types # so we 'tunnel' our types through the `float64` type. # So the implmentation is "partial" in the sense that it supports # data that can be losslessly converted to `tf.float64` and back. dtype = updates.dtype internal_dtype = tf.float64 multipliers = ps.cast(updates, internal_dtype) with tf.GradientTape() as tape: zeros = tf.zeros(shape, dtype=internal_dtype) tape.watch(zeros) weighted_gathered = multipliers * tf.gather_nd( zeros, indices, batch_dims=batch_dims) grad = tape.gradient(weighted_gathered, zeros) return ps.cast(grad, dtype=dtype)
def make_rwmh_kernel_fn(target_log_prob_fn, init_state, scalings): """Generate a Random Walk MH kernel.""" with tf.name_scope('make_rwmh_kernel_fn'): state_std = [ tf.math.reduce_std(x, axis=0, keepdims=True) for x in init_state ] step_size = [ s * ps.cast( # pylint: disable=g-complex-comprehension bu.left_justified_expand_dims_like(scalings, s), s.dtype) for s in state_std ] return random_walk_metropolis.RandomWalkMetropolis( target_log_prob_fn, new_state_fn=random_walk_metropolis.random_walk_normal_fn( scale=step_size))