Beispiel #1
0
    def testProbSwapNumReplicaNoBatch(self, prob_swap, num_replica):
        fn = tfp.mcmc.default_swap_proposal_fn(prob_swap)
        num_results = 100
        seeds = samplers.split_seed(test_util.test_seed(), n=num_results)
        swaps = tf.stack(
            [fn(num_replica, seed=seeds[i]) for i in range(num_results)],
            axis=0)

        self.assertAllEqual((num_results, num_replica), swaps.shape)
        self.check_swaps_with_no_batch_shape(self.evaluate(swaps), prob_swap)
Beispiel #2
0
 def loop_body(should_continue, samples, prod, num_iters, seed):
     u_seed, next_seed = samplers.split_seed(seed)
     prod = prod * samplers.uniform(
         sample_shape, dtype=internal_dtype, seed=u_seed)
     accept = should_continue & (prod <= exp_neg_rate)
     samples = tf.where(accept, num_iters, samples)
     return [
         should_continue & (~accept), samples, prod, num_iters + 1,
         next_seed
     ]
Beispiel #3
0
 def _sample_n(self, n, seed=None):
   scale = tf.convert_to_tensor(self.scale)
   shape = ps.concat([[n], ps.shape(scale)], axis=0)
   shrinkage_seed, sample_seed = samplers.split_seed(seed,
                                                     salt='random_horseshoe')
   local_shrinkage = self._half_cauchy.sample(shape, seed=shrinkage_seed)
   shrinkage = scale * local_shrinkage
   sampled = samplers.normal(
       shape=shape, mean=0., stddev=1., dtype=scale.dtype, seed=sample_seed)
   return sampled * shrinkage
  def test_sampled_weights_follow_correct_distribution(self):
    seed = test_util.test_seed(sampler_type='stateless')
    design_seed, true_weights_seed, sampled_weights_seed = samplers.split_seed(
        seed, 3, 'test_sampled_weights_follow_correct_distribution')
    num_timesteps = 10
    num_features = 2
    batch_shape = [3, 1]
    design_matrix = samplers.normal(
        batch_shape + [num_timesteps, num_features], seed=design_seed)
    true_weights = samplers.normal(
        batch_shape + [num_features, 1], seed=true_weights_seed) * 10.0
    targets = tf.matmul(design_matrix, true_weights)
    is_missing = tf.convert_to_tensor([False, False, False, True, True,
                                       False, False, True, False, False],
                                      dtype=tf.bool)
    prior_scale = tf.convert_to_tensor(5.)
    likelihood_scale = tf.convert_to_tensor(0.1)

    # Analytically compute the true posterior distribution on weights.
    valid_design_matrix = tf.boolean_mask(design_matrix, ~is_missing, axis=-2)
    valid_targets = tf.boolean_mask(targets, ~is_missing, axis=-2)
    num_valid_observations = tf.shape(valid_design_matrix)[-2]
    weights_posterior_mean, weights_posterior_cov, _ = linear_gaussian_update(
        prior_mean=tf.zeros([num_features, 1]),
        prior_cov=tf.eye(num_features) * prior_scale**2,
        observation_matrix=tfl.LinearOperatorFullMatrix(valid_design_matrix),
        observation_noise=tfd.MultivariateNormalDiag(
            loc=tf.zeros([num_valid_observations]),
            scale_diag=likelihood_scale * tf.ones([num_valid_observations])),
        x_observed=valid_targets)

    # Check that the empirical moments of sampled weights match the true values.
    sampled_weights = parallel_for.pfor(
        lambda i: gibbs_sampler._resample_weights(  # pylint: disable=g-long-lambda
            design_matrix=design_matrix,
            target_residuals=targets[..., 0],
            observation_noise_scale=likelihood_scale,
            weights_prior_scale=prior_scale,
            is_missing=is_missing,
            seed=sampled_weights_seed),
        10000)
    sampled_weights_mean = tf.reduce_mean(sampled_weights, axis=0)
    centered_weights = sampled_weights - weights_posterior_mean[..., 0]
    sampled_weights_cov = tf.reduce_mean(centered_weights[..., :, tf.newaxis] *
                                         centered_weights[..., tf.newaxis, :],
                                         axis=0)

    (sampled_weights_mean_, weights_posterior_mean_,
     sampled_weights_cov_, weights_posterior_cov_) = self.evaluate((
         sampled_weights_mean, weights_posterior_mean[..., 0],
         sampled_weights_cov, weights_posterior_cov))
    self.assertAllClose(sampled_weights_mean_, weights_posterior_mean_,
                        atol=0.01, rtol=0.05)
    self.assertAllClose(sampled_weights_cov_, weights_posterior_cov_,
                        atol=0.01, rtol=0.05)
    def test_sampled_latents_have_correct_marginals(self, use_slope):
        seed = test_util.test_seed(sampler_type='stateless')
        residuals_seed, is_missing_seed, level_seed = samplers.split_seed(
            seed, 3, 'test_sampled_level_has_correct_marginals')

        num_timesteps = 10

        observed_residuals = samplers.normal([3, 1, num_timesteps],
                                             seed=residuals_seed)
        is_missing = samplers.uniform([3, 1, num_timesteps],
                                      seed=is_missing_seed) > 0.8
        level_scale = 1.5 * tf.ones([3, 1])
        observation_noise_scale = 0.2 * tf.ones([3, 1])

        if use_slope:
            initial_state_prior = tfd.MultivariateNormalDiag(
                loc=[-30., 2.], scale_diag=[1., 0.2])
            slope_scale = 0.5 * tf.ones([3, 1])
            ssm = tfp.sts.LocalLinearTrendStateSpaceModel(
                num_timesteps=num_timesteps,
                initial_state_prior=initial_state_prior,
                observation_noise_scale=observation_noise_scale,
                level_scale=level_scale,
                slope_scale=slope_scale)
        else:
            initial_state_prior = tfd.MultivariateNormalDiag(loc=[-30.],
                                                             scale_diag=[100.])
            slope_scale = None
            ssm = tfp.sts.LocalLevelStateSpaceModel(
                num_timesteps=num_timesteps,
                initial_state_prior=initial_state_prior,
                observation_noise_scale=observation_noise_scale,
                level_scale=level_scale)

        posterior_means, posterior_covs = ssm.posterior_marginals(
            observed_residuals[..., tf.newaxis], mask=is_missing)
        latents_samples = gibbs_sampler._resample_latents(
            observed_residuals=observed_residuals,
            level_scale=level_scale,
            slope_scale=slope_scale,
            observation_noise_scale=observation_noise_scale,
            initial_state_prior=initial_state_prior,
            is_missing=is_missing,
            sample_shape=10000,
            seed=level_seed)

        (posterior_means_, posterior_covs_, latents_means_,
         latents_covs_) = self.evaluate((posterior_means, posterior_covs,
                                         tf.reduce_mean(latents_samples,
                                                        axis=0),
                                         tfp.stats.covariance(latents_samples,
                                                              sample_axis=0,
                                                              event_axis=-1)))
        self.assertAllClose(latents_means_, posterior_means_, atol=0.1)
        self.assertAllClose(latents_covs_, posterior_covs_, atol=0.1)
 def run(seed, log_accept_ratio):
     kernel = tfp.mcmc.UncalibratedHamiltonianMonteCarlo(
         target_log_prob, step_size=1e-2, num_leapfrog_steps=2)
     kernel = FakeMHKernel(kernel, log_accept_ratio)
     kernel = tfp.mcmc.SimpleStepSizeAdaptation(
         kernel,
         10,
         reduce_fn=reduce_fn,
         experimental_reduce_chain_axis_names=self.axis_name)
     sharded_kernel = tfp.experimental.mcmc.Sharded(
         kernel, self.axis_name)
     init_seed, sample_seed = samplers.split_seed(seed)
     state_seeds = samplers.split_seed(init_seed)
     state = [
         samplers.normal(seed=state_seeds[0], shape=[]),
         samplers.normal(seed=state_seeds[1], shape=[4])
     ]
     kr = sharded_kernel.bootstrap_results(state)
     _, kr = sharded_kernel.one_step(state, kr, seed=sample_seed)
     return kr.new_step_size
Beispiel #7
0
    def _sample_n(self, n, seed=None):
        seeds = samplers.split_seed(seed,
                                    n=self.num_components + 1,
                                    salt='Mixture')
        try:
            seed_stream = SeedStream(seed, salt='Mixture')
        except TypeError as e:  # Can happen for Tensor seed.
            seed_stream = None
            seed_stream_err = e

        # This sampling approach is almost the same as the approach used by
        # `MixtureSameFamily`. The differences are due to having a list of
        # `Distribution` objects rather than a single object.
        samples = []
        cat_samples = self.cat.sample(n, seed=seeds[0])

        for c in range(self.num_components):
            try:
                samples.append(self.components[c].sample(n, seed=seeds[c + 1]))
                if seed_stream is not None:
                    seed_stream()
            except TypeError as e:
                if ('Expected int for argument' not in str(e)
                        and TENSOR_SEED_MSG_PREFIX not in str(e)):
                    raise
                if seed_stream is None:
                    raise seed_stream_err
                msg = (
                    'Falling back to stateful sampling for `components[{}]` {} of '
                    'type `{}`. Please update to use `tf.random.stateless_*` RNGs. '
                    'This fallback may be removed after 20-Aug-2020. ({})')
                warnings.warn(
                    msg.format(c, self.components[c].name,
                               type(self.components[c]), str(e)))
                samples.append(self.components[c].sample(n,
                                                         seed=seed_stream()))
        stack_axis = -1 - tensorshape_util.rank(self._static_event_shape)
        x = tf.stack(samples, axis=stack_axis)  # [n, B, k, E]
        # TODO(b/170730865): Is all this masking stuff really called for?
        npdt = dtype_util.as_numpy_dtype(x.dtype)
        mask = tf.one_hot(
            indices=cat_samples,  # [n, B]
            depth=self._num_components,  # == k
            on_value=npdt(1),
            off_value=npdt(0))  # [n, B, k]
        mask = distribution_util.pad_mixture_dimensions(
            mask, self, self._cat,
            tensorshape_util.rank(
                self._static_event_shape))  # [n, B, k, [1]*e]
        if x.dtype.is_floating:
            masked = tf.math.multiply_no_nan(x, mask)
        else:
            masked = x * mask
        return tf.reduce_sum(masked, axis=stack_axis)  # [n, B, E]
def _gen_gaussian_updating_example(x_dim, y_dim, seed):
  """An implementation of section 2.3.3 from [1].

  We initialize a joint distribution

  x ~ N(mu, Lambda^{-1})
  y ~ N(Ax, L^{-1})

  Then condition the model on an observation for y. We can test to confirm that
  Cov(p(x | y_obs)) is near to

  Sigma = (Lambda + A^T L A)^{-1}

  This test can actually check whether the posterior samples have the proper
  covariance, and whether the windowed tuning recovers 1 / diag(Sigma) as the
  diagonal scaling factor.

  References:
  [1] Bishop, Christopher M. Pattern Recognition and Machine Learning.
      Springer, 2006.

  Args:
    x_dim: int
    y_dim: int
    seed: For reproducibility
  Returns:
    (tfd.JointDistribution, tf.Tensor), representing the joint distribution
    above, and the posterior variance.
  """
  seeds = samplers.split_seed(seed, 5)
  x_mean = samplers.normal((x_dim,), seed=seeds[0])
  x_scale_diag = samplers.normal((x_dim,), seed=seeds[1])
  y_scale_diag = samplers.normal((y_dim,), seed=seeds[2])
  scale_mat = samplers.normal((y_dim, x_dim), seed=seeds[3])
  y_shift = samplers.normal((y_dim,), seed=seeds[4])

  @tfd.JointDistributionCoroutine
  def model():
    x = yield Root(tfd.MultivariateNormalDiag(
        x_mean, scale_diag=x_scale_diag, name='x'))
    yield tfd.MultivariateNormalDiag(
        tf.linalg.matvec(scale_mat, x) + y_shift,
        scale_diag=y_scale_diag,
        name='y')

  dists, _ = model.sample_distributions()
  precision_x = tf.linalg.inv(dists.x.covariance())
  precision_y = tf.linalg.inv(dists.y.covariance())
  true_cov = tf.linalg.inv(precision_x  +
                           tf.linalg.matmul(
                               tf.linalg.matmul(scale_mat, precision_y,
                                                transpose_a=True),
                               scale_mat))
  return model, tf.linalg.diag_part(true_cov)
    def test_constrained_affine_from_distributions(self,
                                                   dist_classes,
                                                   event_shape,
                                                   operators,
                                                   initial_loc,
                                                   implicit_batch_shape,
                                                   bijector,
                                                   dtype,
                                                   is_static,
                                                   is_stateless=JAX_MODE):
        if not tf.executing_eagerly() and not is_static:
            self.skipTest(
                'tfb.Reshape requires statically known shapes in graph'
                ' mode.')

        init_seed, grads_seed, shapes_seed, dtype_seed = samplers.split_seed(
            test_util.test_seed(sampler_type='stateless'), n=4)

        # pylint: disable=g-long-lambda
        initial_loc = tf.nest.map_structure(
            lambda s: self.maybe_static(np.array(s, dtype=dtype),
                                        is_static=is_static), initial_loc)
        distributions = nest.map_structure_up_to(
            dist_classes, lambda d, loc, s: tfd.Independent(
                d(loc=loc, scale=1.),
                reinterpreted_batch_ndims=ps.rank_from_shape(s)), dist_classes,
            initial_loc, event_shape)
        # pylint: enable=g-long-lambda

        surrogate_posterior = self._initialize_surrogate(
            'build_affine_surrogate_posterior_from_base_distribution',
            is_stateless=is_stateless,
            seed=init_seed,
            base_distribution=distributions,
            operators=operators,
            bijector=bijector,
            validate_args=True)

        event_shape = nest.map_structure(lambda d: d.event_shape_tensor(),
                                         distributions)
        if bijector is not None:
            event_shape = nest.map_structure(
                lambda b, s: s
                if b is None else b.forward_event_shape_tensor(s), bijector,
                event_shape)

        self._test_shapes(surrogate_posterior,
                          batch_shape=implicit_batch_shape,
                          event_shape=event_shape,
                          seed=shapes_seed)
        self._test_dtype(surrogate_posterior, dtype, dtype_seed)
        if not is_stateless:
            self._test_gradients(surrogate_posterior, seed=grads_seed)
Beispiel #10
0
 def loop_body(unconstrained_parameters_seed, cooling_fraction):
     unconstrained_parameters, seed = unconstrained_parameters_seed
     step_seed, seed = samplers.split_seed(seed)
     return (self.one_step(
         observations=observations,
         num_particles=num_particles,
         perturbation_scale=tf.nest.map_structure(
             lambda s: cooling_fraction * s,
             initial_perturbation_scale),
         initial_unconstrained_parameters=unconstrained_parameters,
         seed=step_seed,
         **kwargs), seed)
Beispiel #11
0
 def randomized_computation(seed):
   """Internal randomized computation."""
   proposal_seed, mask_seed = samplers.split_seed(
       seed, salt='batched_rejection_sampler')
   proposed_samples, proposed_values = proposal_fn(proposal_seed)
   good_samples_mask = tf.less_equal(
       proposed_values * samplers.uniform(
           prefer_static.shape(proposed_samples),
           seed=mask_seed,
           dtype=dtype),
       target_fn(proposed_samples))
   return proposed_samples, good_samples_mask
  def test_steps_are_reproducible(self):

    def propose_and_update_log_weights_fn(_, weighted_particles, seed=None):
      proposed_particles = tfd.Normal(
          loc=weighted_particles.particles, scale=1.).sample(seed=seed)
      return WeightedParticles(
          particles=proposed_particles,
          log_weights=weighted_particles.log_weights + tfd.Normal(
              loc=-2.6, scale=0.1).log_prob(proposed_particles))

    num_particles = 16
    initial_state = self.evaluate(
        WeightedParticles(
            particles=tf.random.normal([num_particles],
                                       seed=test_util.test_seed()),
            log_weights=tf.fill([num_particles],
                                -tf.math.log(float(num_particles)))))

    # Run a couple of steps.
    seeds = samplers.split_seed(
        test_util.test_seed(sampler_type='stateless'), n=2)
    kernel = SequentialMonteCarlo(
        propose_and_update_log_weights_fn=propose_and_update_log_weights_fn,
        resample_fn=tfp.experimental.mcmc.resample_systematic,
        resample_criterion_fn=tfp.experimental.mcmc.ess_below_threshold)
    state, results = kernel.one_step(
        state=initial_state,
        kernel_results=kernel.bootstrap_results(initial_state),
        seed=seeds[0])
    state, results = kernel.one_step(state=state, kernel_results=results,
                                     seed=seeds[1])
    state, results = self.evaluate(
        (tf.nest.map_structure(tf.convert_to_tensor, state),
         tf.nest.map_structure(tf.convert_to_tensor, results)))

    # Re-initialize and run the same steps with the same seed.
    kernel2 = SequentialMonteCarlo(
        propose_and_update_log_weights_fn=propose_and_update_log_weights_fn,
        resample_fn=tfp.experimental.mcmc.resample_systematic,
        resample_criterion_fn=tfp.experimental.mcmc.ess_below_threshold)
    state2, results2 = kernel2.one_step(
        state=initial_state,
        kernel_results=kernel2.bootstrap_results(initial_state),
        seed=seeds[0])
    state2, results2 = kernel2.one_step(state=state2, kernel_results=results2,
                                        seed=seeds[1])
    state2, results2 = self.evaluate(
        (tf.nest.map_structure(tf.convert_to_tensor, state2),
         tf.nest.map_structure(tf.convert_to_tensor, results2)))

    # Results should match.
    self.assertAllCloseNested(state, state2)
    self.assertAllCloseNested(results, results2)
Beispiel #13
0
def random_poisson_rejection_sampler(sample_shape,
                                     log_rate,
                                     internal_dtype=tf.float64,
                                     seed=None):
    """Samples from the Poisson distribution.

  The sampling algorithm is rejection sampling.

  Args:
    sample_shape: The output sample shape. Must broadcast with `log_rate`.
    log_rate: Floating point tensor, log rate.
    internal_dtype: dtype to use for internal computations.
    seed: (optional) The random seed.

  Returns:
    Samples from the poisson distribution.
  """
    output_dtype = dtype_util.common_dtype([log_rate], dtype_hint=tf.float32)
    good_params_mask = ~tf.math.is_nan(log_rate)

    seed_lo, seed_hi = samplers.split_seed(seed)

    # First, we sample the values for which rate >= 10.
    # When replacing NaN or < 10 values, use 100 for log rate, since that leads
    # to a high-likelihood of the rejection sampler accepting on the first pass.
    high_params_mask = good_params_mask & tf.math.greater_equal(
        log_rate, np.log(10.))
    cast_log_rate = tf.cast(log_rate, internal_dtype)
    safe_log_rate = tf.where(high_params_mask, cast_log_rate, 100.)
    high_rate_samples = _random_poisson_high_rate(
        sample_shape,
        log_rate=safe_log_rate,
        internal_dtype=internal_dtype,
        seed=seed_hi)
    high_rate_samples = tf.cast(high_rate_samples, output_dtype)

    # Next, we sample the values for which rate < 10. When replacing NaN or high
    # values, use a small number so that the sum-of-exponentials sampler
    # terminates on the first pass with high likelihood.
    low_params_mask = good_params_mask & ~high_params_mask
    safe_rate = tf.where(low_params_mask, tf.math.exp(cast_log_rate), 1e-5)
    low_rate_samples = _random_poisson_low_rate(sample_shape,
                                                rate=safe_rate,
                                                internal_dtype=internal_dtype,
                                                seed=seed_lo)
    low_rate_samples = tf.cast(low_rate_samples, output_dtype)

    samples = tf.where(
        good_params_mask,
        tf.where(high_params_mask, high_rate_samples, low_rate_samples),
        np.nan)

    return samples
    def sampler_loop_body(previous_sample, _):
        """Runs one sampler iteration, resampling all model variables."""

        (weights_seed, level_seed, observation_noise_scale_seed,
         level_scale_seed,
         loop_seed) = samplers.split_seed(previous_sample.seed,
                                          n=5,
                                          salt='sampler_loop_body')

        # We encourage a reasonable initialization by sampling the weights first,
        # so at the first step they are regressed directly against the observed
        # time series. If we instead sampled the level first it might 'explain away'
        # some observed variation that we would ultimately prefer to explain through
        # the regression weights, because the level can represent arbitrary
        # variation, while the weights are limited to representing variation in the
        # subspace given by the design matrix.
        weights = _resample_weights(
            design_matrix=design_matrix,
            target_residuals=(observed_time_series - previous_sample.level),
            observation_noise_scale=previous_sample.observation_noise_scale,
            weights_prior_scale=weights_param.prior.distribution.scale,
            is_missing=is_missing,
            seed=weights_seed)

        regression_residuals = observed_time_series - tf.linalg.matvec(
            design_matrix, weights)
        level = _resample_level(
            observed_residuals=regression_residuals,
            level_scale=previous_sample.level_scale,
            observation_noise_scale=previous_sample.observation_noise_scale,
            initial_state_prior=level_component.initial_state_prior,
            is_missing=is_missing,
            seed=level_seed)

        # Estimate level scale from the empirical changes in level.
        level_scale = _resample_scale(prior=level_scale_variance_prior,
                                      observed_residuals=level[..., 1:] -
                                      level[..., :-1],
                                      is_missing=None,
                                      seed=level_scale_seed)
        # Estimate noise scale from the residuals.
        observation_noise_scale = _resample_scale(
            prior=observation_noise_variance_prior,
            observed_residuals=regression_residuals - level,
            is_missing=is_missing,
            seed=observation_noise_scale_seed)

        return GibbsSamplerState(
            observation_noise_scale=observation_noise_scale,
            level_scale=level_scale,
            weights=weights,
            level=level,
            seed=loop_seed)
Beispiel #15
0
def _setup_mcmc(model, n_chains, seed, **pins):
    """Construct bijector and transforms needed for windowed MCMC.

  This pins the initial model, constructs a bijector that unconstrains and
  flattens each dimension and adds a leading batch shape of `n_chains`,
  initializes a point in the unconstrained space, and constructs a transformed
  log probability using the bijector.

  Note that we must manually construct this target log probability instead of
  using a transformed transition kernel because the TTK assumes the shape
  in is the same as the shape out.

  Args:
    model: `tfd.JointDistribution`
      The model to sample from.
    n_chains: int
      Number of chains (independent examples) to run.
    seed: A seed for reproducible sampling.
    **pins:
      Values passed to `model.experimental_pin`.


  Returns:
    target_log_prob_fn: Callable on the transformed space.
    initial_transformed_position: `tf.Tensor`, sampled from a uniform (-2, 2).
    bijector: `tfb.Bijector` instance, which unconstrains and flattens.
  """
    pinned_model = model.experimental_pin(**pins)
    bijector = _get_flat_unconstraining_bijector(pinned_model)
    initial_position = pinned_model.sample_unpinned(n_chains)
    initial_transformed_position = bijector.forward(initial_position)

    # Jitter init
    seeds = samplers.split_seed(seed, n=len(initial_transformed_position))
    unconstrained_unif_init_position = []
    for p, seed in zip(initial_transformed_position, seeds):
        unconstrained_unif_init_position.append(
            samplers.uniform(ps.shape(p),
                             minval=-2.,
                             maxval=2.,
                             seed=seed,
                             dtype=p.dtype))

    # pylint: disable=g-long-lambda
    def target_log_prob_fn(*args):
        return (
            pinned_model.unnormalized_log_prob(bijector.inverse(args)) +
            bijector.inverse_log_det_jacobian(
                args, event_ndims=[1 for _ in initial_transformed_position]))

    # pylint: enable=g-long-lambda
    return target_log_prob_fn, unconstrained_unif_init_position, bijector
Beispiel #16
0
    def body(values, good_values_mask, num_iters, seed):
      """Batched Las Vegas Algorithm body."""

      trial_seed, new_seed = samplers.split_seed(seed)
      new_values, new_good_values_mask = batched_las_vegas_trial_fn(trial_seed)

      values = tf.nest.map_structure(
          lambda new, old: tf.where(new_good_values_mask, new, old),
          new_values, values)

      good_values_mask = tf.logical_or(good_values_mask, new_good_values_mask)

      return values, good_values_mask, num_iters + 1, new_seed
Beispiel #17
0
 def _sample_n(self, n, seed=None):
   seed1, seed2 = samplers.split_seed(seed, salt='beta')
   concentration1 = tf.convert_to_tensor(self.concentration1)
   concentration0 = tf.convert_to_tensor(self.concentration0)
   shape = self._batch_shape_tensor(concentration1, concentration0)
   expanded_concentration1 = tf.broadcast_to(concentration1, shape)
   expanded_concentration0 = tf.broadcast_to(concentration0, shape)
   gamma1_sample = gamma_lib.random_gamma(
       shape=[n], concentration=expanded_concentration1, seed=seed1)
   gamma2_sample = gamma_lib.random_gamma(
       shape=[n], concentration=expanded_concentration0, seed=seed2)
   beta_sample = gamma1_sample / (gamma1_sample + gamma2_sample)
   return beta_sample
Beispiel #18
0
            def resample_one_feature(step, seed, sampler_state):
                seed, next_seed = samplers.split_seed(seed, n=2)
                idx = tf.gather(feature_permutation, step)

                # Maybe flip this weight's sparsity indicator.
                proposed_sampler_state = self._flip_feature(sampler_state,
                                                            idx=idx)
                should_flip = bernoulli.Bernoulli(
                    logits=(proposed_sampler_state.unnormalized_log_prob -
                            sampler_state.unnormalized_log_prob),
                    dtype=tf.bool).sample(seed=seed)
                return step + 1, next_seed, mcmc_util.choose(
                    should_flip, proposed_sampler_state, sampler_state)
Beispiel #19
0
 def _fn(self, state_parts: List[tf.Tensor],
         seed: Optional[int]) -> List[tf.Tensor]:
     with tf.name_scope(self._name or "categorical_uniform_fn"):
         part_seeds = samplers.split_seed(seed,
                                          n=len(state_parts),
                                          salt="CategoricalUniformFn")
         deltas = tf.nest.map_structure(
             lambda x, s: tfd.Categorical(logits=tf.ones(self.classes)).
             sample(seed=s, sample_shape=tf.shape(x)),
             state_parts,
             part_seeds,
         )
         return deltas
Beispiel #20
0
def _asvi_surrogate_for_markov_chain(dist,
                                     base_distribution_surrogate_fn,
                                     sample_shape=None,
                                     variables=None,
                                     seed=None):
    """Builds a structured surrogate posterior for a Markov chain."""
    prior_seed, transition_seed = samplers.split_seed(seed, 2)
    if variables is None:
        prior_variables, transition_variables = None, None
    else:
        prior_variables, transition_variables = variables

    surrogate_prior, prior_variables = _asvi_surrogate_for_distribution(
        dist.initial_state_prior,
        base_distribution_surrogate_fn=base_distribution_surrogate_fn,
        variables=prior_variables,
        seed=prior_seed)

    if transition_variables is None:
        # Construct variables for all chain steps in a single call. These will have
        # an initial dimension of size `num_steps - 1`, which we can gather from
        # as the chain runs.
        all_steps = tf.range(dist.num_steps - 1)
        batch_state = dist.initial_state_prior.sample(dist.num_steps - 1)
        _, transition_variables = _asvi_surrogate_for_distribution(
            dist.transition_fn(all_steps, batch_state),
            base_distribution_surrogate_fn=base_distribution_surrogate_fn,
            variables=None,
            sample_shape=sample_shape,
            seed=transition_seed)

    def surrogate_transition_fn(step, state):
        surrogate_new_dist, _ = _asvi_surrogate_for_distribution(
            dist.transition_fn(step, state),
            base_distribution_surrogate_fn=base_distribution_surrogate_fn,
            variables=tf.nest.map_structure(
                # Gather parameters for this specific step of the chain.
                lambda v: tf.gather(v, step, axis=0),
                transition_variables),
            sample_shape=sample_shape,
            seed=transition_seed)
        return surrogate_new_dist

    chain_surrogate = markov_chain.MarkovChain(
        initial_state_prior=surrogate_prior,
        transition_fn=surrogate_transition_fn,
        num_steps=dist.num_steps,
        validate_args=dist.validate_args,
        name=_get_name(dist))

    return chain_surrogate, [prior_variables, transition_variables]
Beispiel #21
0
  def _random_regression_task(self, num_outputs, num_features, batch_shape=(),
                              weights=None, observation_noise_scale=0.1,
                              seed=None):
    design_seed, weights_seed, noise_seed = samplers.split_seed(seed, n=3)
    batch_shape = list(batch_shape)

    design_matrix = samplers.uniform(batch_shape + [num_outputs, num_features],
                                     seed=design_seed)
    if weights is None:
      weights = samplers.normal(batch_shape + [num_features], seed=weights_seed)
    targets = (tf.linalg.matvec(design_matrix, weights) +
               observation_noise_scale * samplers.normal(
                   batch_shape + [num_outputs], seed=noise_seed))
    return design_matrix, weights, targets
Beispiel #22
0
    def _sample_and_log_prob(self, sample_shape, seed, **kwargs):
        all_seeds = samplers.split_seed(seed,
                                        len(self._distributions),
                                        salt='BatchConcat')
        samples = []
        log_probs = []
        for d, s in zip(self._distributions, all_seeds):
            x, lp = d.experimental_sample_and_log_prob(sample_shape, s)
            samples.append(self._broadcast(x, sample_shape))
            log_probs.append(self._broadcast(lp, sample_shape))

        sample_shape_size = ps.rank_from_shape(sample_shape)
        return (tf.concat(samples, axis=self._axis + sample_shape_size),
                tf.concat(log_probs, axis=self._axis + sample_shape_size))
Beispiel #23
0
 def randomized_computation(seed):
     """Internal randomized computation."""
     proposal_seed, mask_seed = samplers.split_seed(
         seed, salt='batched_rejection_sampler')
     proposed_samples, proposed_values = proposal_fn(proposal_seed)
     # The comparison needs to be strictly less to avoid spurious acceptances
     # when the uniform samples exactly 0 (or when the product underflows to
     # 0).
     good_samples_mask = tf.less(
         proposed_values *
         samplers.uniform(prefer_static.shape(proposed_samples),
                          seed=mask_seed,
                          dtype=dtype), target_fn(proposed_samples))
     return proposed_samples, good_samples_mask
Beispiel #24
0
 def _sample_n(self, n, seed=None):
   seed = samplers.sanitize_seed(seed)
   seed1, seed2 = samplers.split_seed(seed, salt='Skellam')
   log_rate1 = self._log_rate1_parameter_no_checks()
   log_rate2 = self._log_rate2_parameter_no_checks()
   batch_shape = self._batch_shape_tensor(
       log_rate1=log_rate1, log_rate2=log_rate2)
   log_rate1 = ps.broadcast_to(log_rate1, batch_shape)
   log_rate2 = ps.broadcast_to(log_rate2, batch_shape)
   sample1 = poisson_lib.random_poisson(
       [n], log_rates=log_rate1, seed=seed1)[0]
   sample2 = poisson_lib.random_poisson(
       [n], log_rates=log_rate2, seed=seed2)[0]
   return sample1 - sample2
Beispiel #25
0
    def body(values, good_values_mask, num_iters, seed):
      """Batched Las Vegas Algorithm body."""

      trial_seed, new_seed = samplers.split_seed(seed)
      new_values, new_good_values_mask = batched_las_vegas_trial_fn(trial_seed)

      def pick(new, old):
        return bu.where_left_justified_mask(new_good_values_mask, new, old)

      values = tf.nest.map_structure(pick, new_values, values)

      good_values_mask = good_values_mask | new_good_values_mask

      return values, good_values_mask, num_iters + 1, new_seed
Beispiel #26
0
 def _sample_n(self, n, seed=None):
   seed1, seed2 = samplers.split_seed(seed, salt='beta')
   concentration1 = tf.convert_to_tensor(self.concentration1)
   concentration0 = tf.convert_to_tensor(self.concentration0)
   shape = self._batch_shape_tensor(concentration1, concentration0)
   expanded_concentration1 = tf.broadcast_to(concentration1, shape)
   expanded_concentration0 = tf.broadcast_to(concentration0, shape)
   log_gamma1 = gamma_lib.random_gamma(
       shape=[n], concentration=expanded_concentration1, seed=seed1,
       log_space=True)
   log_gamma2 = gamma_lib.random_gamma(
       shape=[n], concentration=expanded_concentration0, seed=seed2,
       log_space=True)
   return tf.math.sigmoid(log_gamma1 - log_gamma2)
Beispiel #27
0
    def _sample_n(self, n, seed):
        # only for MixtureSameFamilySampleFix
        import warnings
        from tensorflow_probability.python.distributions import independent
        from tensorflow_probability.python.internal import dtype_util
        from tensorflow_probability.python.internal import prefer_static
        from tensorflow_probability.python.internal import samplers
        from tensorflow_probability.python.internal import tensorshape_util
        from tensorflow_probability.python.util.seed_stream import SeedStream
        from tensorflow_probability.python.util.seed_stream import (
            TENSOR_SEED_MSG_PREFIX, )

        components_seed, mix_seed = samplers.split_seed(
            seed, salt="MixtureSameFamily")
        try:
            seed_stream = SeedStream(seed, salt="MixtureSameFamily")
        except TypeError as e:  # Can happen for Tensor seeds.
            seed_stream = None
            seed_stream_err = e
        try:
            mix_sample = self.mixture_distribution.sample(
                n, seed=mix_seed)  # [n, B] or [n]
        except TypeError as e:
            if "Expected int for argument" not in str(
                    e) and TENSOR_SEED_MSG_PREFIX not in str(e):
                raise
            if seed_stream is None:
                raise seed_stream_err
            msg = (
                "Falling back to stateful sampling for `mixture_distribution` "
                "{} of type `{}`. Please update to use `tf.random.stateless_*` "
                "RNGs. This fallback may be removed after 20-Aug-2020. ({})")
            warnings.warn(
                msg.format(
                    self.mixture_distribution.name,
                    type(self.mixture_distribution),
                    str(e),
                ))
            mix_sample = self.mixture_distribution.sample(
                n, seed=seed_stream())  # [n, B] or [n]
        _seed = int(components_seed[0].numpy())
        ret = tf.stack(
            [
                self.components_distribution[i_component.numpy()].sample(
                    seed=_seed + i) for i, i_component in enumerate(mix_sample)
            ],
            axis=0,
        )
        return ret
Beispiel #28
0
        def randomized_computation(seed):
            """Internal randomized computation."""
            proposal_seed, mask_seed = samplers.split_seed(
                seed, salt='batched_rejection_sampler')

            proposed_samples, proposed_values = proposal_fn(proposal_seed)

            # The comparison needs to be strictly less to avoid spurious acceptances
            # when the uniform samples exactly 0 (or when the product underflows to
            # 0).
            target_values = target_fn(proposed_samples)
            good_samples_mask = tf.less(
                proposed_values *
                samplers.uniform(prefer_static.shape(proposed_samples),
                                 seed=mask_seed,
                                 dtype=dtype), target_values)

            # If either the `proposed_value` or the corresponding `target_value` is
            # `nan`, force that `proposed_sample` to `nan` and accept.  Why?
            #
            # - A `nan` would never be accepted, because tf.less must return False
            #   when either argument is `nan`.
            #
            # - If `nan` happens every time (e.g., due to `nan` in the parameters of
            #   the distribution we are trying to sample from), then we should clearly
            #   return `nan` after going around the rejection loop only once, rather
            #   than looping forever.
            #
            # - If `nan` happens only some of the time, it would silently skew the
            #   distribution on results to always reject, because some of those `nan`
            #   values may have stood for proposals that would have been accepted if
            #   we had computed more accurately.  Instead we forward the `nan`
            #   upstream, so the client can fix their proposal or evaluation
            #   functions.
            #
            # - We force the `proposed_sample` to `nan` because not doing so would
            #   hide a `nan` that occurred in only the `proposed_value` or
            #   `target_value`, silently skewing the distribution on results.
            #
            # - Corner case: if the `proposed_sample` is `nan` but both the
            #   corresponding `proposed_value` and `proposed_target` are for some
            #   reason not `nan`, we trust the user and proceed normally.
            nans = tf.math.is_nan(proposed_values) | tf.math.is_nan(
                target_values)
            proposed_samples = tf.where(
                nans, tf.cast(np.nan, proposed_samples.dtype),
                proposed_samples)
            good_samples_mask |= nans
            return proposed_samples, good_samples_mask
Beispiel #29
0
 def _sample_n(self, n, seed=None):
   normal_seed, exp_seed = samplers.split_seed(seed, salt='emg_sample')
   # need to make sure component distributions are broadcast appropriately
   # for correct generation of samples
   loc = tf.convert_to_tensor(self.loc)
   rate = tf.convert_to_tensor(self.rate)
   scale = tf.convert_to_tensor(self.scale)
   batch_shape = self._batch_shape_tensor(loc, scale, rate)
   loc_broadcast = tf.broadcast_to(loc, batch_shape)
   rate_broadcast = tf.broadcast_to(rate, batch_shape)
   normal_dist = normal_lib.Normal(loc=loc_broadcast, scale=scale)
   exp_dist = exponential_lib.Exponential(rate_broadcast)
   x = normal_dist.sample(n, normal_seed)
   y = exp_dist.sample(n, exp_seed)
   return x + y
 def _sample_n(self, n, seed=None):
   # Here we use the fact that if:
   # lam ~ Gamma(concentration=total_count, rate=(1-probs)/probs)
   # then X ~ Poisson(lam) is Negative Binomially distributed.
   logits = self._logits_parameter_no_checks()
   gamma_seed, poisson_seed = samplers.split_seed(
       seed, salt='NegativeBinomial')
   rate = samplers.gamma(
       shape=[n],
       alpha=self.total_count,
       beta=tf.math.exp(-logits),
       dtype=self.dtype,
       seed=gamma_seed)
   return samplers.poisson(
       shape=[], lam=rate, dtype=self.dtype, seed=poisson_seed)