def _eval(self, x, weights): kernel, bias = self.unpack_weights_fn(weights) # pylint: disable=not-callable y = x if kernel is not None: kernel_dist, _ = self.unpack_weights_fn( # pylint: disable=not-callable self.posterior.sample_distributions(value=weights)[0]) kernel_loc, kernel_scale = get_spherical_normal_loc_scale( kernel_dist) # batch_size = tf.shape(x)[0] # sign_input_shape = ([batch_size] + # [1] * self._rank + # [self._input_channels]) y *= random_rademacher(prefer_static.shape(y), dtype=y.dtype, seed=self._seed()) kernel_perturb = normal_lib.Normal(loc=0., scale=kernel_scale) y = self._apply_kernel_fn( # E.g., tf.matmul. y, kernel_perturb.sample(seed=self._seed())) y *= random_rademacher(prefer_static.shape(y), dtype=y.dtype, seed=self._seed()) y += self._apply_kernel_fn(x, kernel_loc) if bias is not None: y = y + bias if self.activation_fn is not None: y = self.activation_fn(y) # pylint: disable=not-callable return y
def proposal(seed): """Proposal for log-concave rejection sampler.""" (top_lobe_fractions_seed, exponential_samples_seed, top_selector_seed, random_rademacher_seed) = samplers.split_seed( seed, n=4, salt='log_concave_rejection_sampler_proposal') top_lobe_fractions = samplers.uniform(mode_shape, seed=top_lobe_fractions_seed, dtype=dtype) # V in ref [1]. top_offsets = top_lobe_fractions * top_width / mode_height exponential_samples = exponential_distribution.sample( mode_shape, seed=exponential_samples_seed) # E in ref [1]. exponential_height = ( exponential_distribution.prob(exponential_samples) * mode_height) exponential_offsets = (top_width + exponential_samples) / mode_height top_selector = samplers.uniform(mode_shape, seed=top_selector_seed, dtype=dtype) # U in ref [1]. on_top_mask = tf.less_equal(top_selector, top_fraction) unsigned_offsets = tf.where(on_top_mask, top_offsets, exponential_offsets) offsets = tf.round( random_rademacher( mode_shape, seed=random_rademacher_seed, dtype=dtype) * unsigned_offsets) potential_samples = mode + offsets envelope_height = tf.where(on_top_mask, mode_height, exponential_height) return potential_samples, envelope_height