Exemplo n.º 1
0
Arquivo: vi_utils.py Projeto: bkmi/sbi
def check_sample_shape_and_support(q: Distribution,
                                   prior: Distribution) -> None:
    """Checks the samples shape and support between variational distribution and the
    prior. Especially it checks if the shapes match and that the support between q and
    the prior matches (a property which holds for the true posterior in any case).

    Args:
        q: Variational distribution which is checked
        prior: Prior to check certain attributes which should be satisfied.

    """
    assert (q.event_shape == prior.event_shape
            ), "The event shape of q must match that of the prior"
    assert (q.batch_shape == prior.batch_shape
            ), "The batch sahpe of q must match that of the prior"

    sample_shape = torch.Size((1000, ))
    samples = q.sample(sample_shape)
    samples_prior = prior.sample(sample_shape).to(samples.device)
    try:
        _ = prior.support
        has_support = True
    except (NotImplementedError, AttributeError):
        has_support = False
    if has_support:
        assert all(prior.support.check(samples)  # type: ignore
                   ), "The support of q must match that of the prior"
    assert (
        samples.shape == samples_prior.shape
    ), "Something is wrong with sample shape and event_shape or batch_shape attributes."
    assert torch.isfinite(q.log_prob(
        samples_prior)).all(), "Invalid values in logprob on prior samples."
    assert torch.isfinite(prior.log_prob(
        samples)).all(), "Invalid values in logprob on q samples."
Exemplo n.º 2
0
def process_pytorch_prior(
        prior: Distribution) -> Tuple[Distribution, int, bool]:
    """Return PyTorch prior adapted to the requirements for sbi.

    Args:
        prior: PyTorch distribution prior provided by the user.

    Raises:
        ValueError: If prior is defined over an unwrapped scalar variable.

    Returns:
        prior: PyTorch distribution prior.
        theta_numel: Number of parameters - elements in a single sample from the prior.
        prior_returns_numpy: False.
    """

    # Reject unwrapped scalar priors.
    if prior.sample().ndim == 0:
        raise ValueError(
            "Detected scalar prior. Please make sure to pass a PyTorch prior with "
            "`batch_shape=torch.Size([1])` or `event_shape=torch.Size([1])`.")

    check_prior_batch_behavior(prior)
    check_prior_batch_dims(prior)

    if not prior.sample().dtype == float32:
        prior = PytorchReturnTypeWrapper(prior, return_type=float32)

    # This will fail for float64 priors.
    check_prior_return_type(prior)

    theta_numel = prior.sample().numel()

    return prior, theta_numel, False
Exemplo n.º 3
0
def plot_comparison(n_display: int, x_og: torch.Tensor, p_x: D.Distribution,
                    input_dims: Tuple[int]) -> Figure:
    fig = plt.figure(figsize=(6, 6))
    gs = fig.add_gridspec(4,
                          n_display,
                          width_ratios=[1] * n_display,
                          height_ratios=[1, 1, 1, 1])
    gs.update(wspace=0, hspace=0)

    x_hat = batch_reshape(p_x.sample(), input_dims).clip(0, 1)
    x_mu = batch_reshape(p_x.mean, input_dims).clip(0, 1)
    x_var = batch_reshape(p_x.variance, input_dims).clip(0, 1)

    for n in range(n_display):
        for k in range(4):
            ax = plt.subplot(gs[k, n])
            ax = disable_ticks(ax)
            # Original
            # imshow accepts (W, H, C)
            if k == 0:
                ax.imshow(x_og[n, :].permute(1, 2, 0), vmin=0, vmax=1)
            # Mean
            elif k == 1:
                ax.imshow(x_mu[n, :].permute(1, 2, 0), vmin=0, vmax=1)
            # Variance
            elif k == 2:
                ax.imshow(x_var[n, :].permute(1, 2, 0))
            # Sample
            elif k == 3:
                ax.imshow(x_hat[n, :].permute(1, 2, 0), vmin=0, vmax=1)

    return fig
Exemplo n.º 4
0
def get_normalization_uniform_prior(
    posterior: DirectPosterior,
    prior: Distribution,
    true_observation: Tensor,
) -> Tuple[Tensor, Tensor, Tensor]:
    """
    Return the unnormalized posterior likelihood, the normalized posterior likelihood,
    and the estimated acceptance probability.

    Args:
        posterior: estimated posterior
        prior: prior distribution
        true_observation: observation where we evaluate the posterior
    """

    # Test normalization.
    prior_sample = prior.sample()

    # Compute unnormalized density, i.e. just the output of the density estimator.
    posterior_likelihood_unnorm = torch.exp(
        posterior.log_prob(prior_sample, norm_posterior=False))
    # Compute the normalized density, scale up output of the density
    # estimator by the ratio of posterior samples within the prior bounds.
    posterior_likelihood_norm = torch.exp(
        posterior.log_prob(prior_sample, norm_posterior=True))

    # Estimate acceptance ratio through rejection sampling.
    acceptance_prob = posterior.leakage_correction(x=true_observation)

    return posterior_likelihood_unnorm, posterior_likelihood_norm, acceptance_prob
Exemplo n.º 5
0
def dist_sample(distribution: dist.Distribution, context: SamplingContext = None) -> torch.Tensor:
    """
    Sample n samples from a given distribution.

    Args:
        repetition_indices: Indices into the repetition axis.
        distribution (dists.Distribution): Base distribution to sample from.
        parent_indices (torch.Tensor): Tensor of indexes that point to specific representations of single features/scopes.
    """

    # Sample from the specified distribution
    if context.is_mpe:
        samples = _mode(distribution, context)
    else:
        samples = distribution.sample(sample_shape=(context.n,))

    assert (
        samples.shape[1] == 1
    ), "Something went wrong. First sample size dimension should be size 1 due to the distribution parameter dimensions. Please report this issue."
    samples.squeeze_(1)
    n, d, c, r = samples.shape

    # Filter each sample by its specific repetition
    tmp = torch.zeros(n, d, c, device=context.repetition_indices.device)
    for i in range(n):
        tmp[i, :, :] = samples[i, :, :, context.repetition_indices[i]]
    samples = tmp

    # If parent index into out_channels are given
    if context.parent_indices is not None:
        # Choose only specific samples for each feature/scope
        samples = torch.gather(samples, dim=2, index=context.parent_indices.unsqueeze(-1)).squeeze(-1)

    return samples
Exemplo n.º 6
0
 def _sample(self,
             distribution: dist.Distribution,
             sample_shape: Union[torch.Size, tuple] = torch.Size()):
     if self.training:
         return distribution.rsample(sample_shape=sample_shape)
     else:
         return distribution.sample(sample_shape=sample_shape)
Exemplo n.º 7
0
def run_pmmh(
    filter_: BaseFilter,
    state: FilterResult,
    prop_kernel: Distribution,
    prop_filt,
    y: torch.Tensor,
    size=torch.Size([]),
    **kwargs,
) -> Tuple[torch.Tensor, FilterResult, BaseFilter]:
    """
    Runs one iteration of a vectorized Particle Marginal Metropolis hastings.
    """

    rvs = prop_kernel.sample(size)
    prop_filt.ssm.parameters_from_array(rvs, constrained=False)

    new_res = prop_filt.longfilter(y, bar=False, **kwargs)

    diff_logl = new_res.loglikelihood - state.loglikelihood
    diff_prior = (prop_filt.ssm.eval_prior_log_prob(False) -
                  filter_.ssm.eval_prior_log_prob(False)).squeeze()

    if isinstance(prop_kernel, Independent) and size == torch.Size([]):
        diff_prop = 0.0
    else:
        diff_prop = prop_kernel.log_prob(
            filter_.ssm.parameters_to_array(
                constrained=False)) - prop_kernel.log_prob(rvs)

    log_acc_prob = diff_prop + diff_prior + diff_logl
    res: torch.Tensor = torch.empty_like(
        log_acc_prob).uniform_().log() < log_acc_prob

    return res, new_res, prop_filt
Exemplo n.º 8
0
def process_pytorch_prior(
        prior: Distribution) -> Tuple[Distribution, int, bool]:
    """Return PyTorch prior adapted to the requirements for sbi.

    Args:
        prior: PyTorch distribution prior provided by the user.

    Raises:
        ValueError: If prior is defined over an unwrapped scalar variable.

    Returns:
        prior: PyTorch distribution prior.
        theta_numel: Number of parameters - elements in a single sample from the prior.
        prior_returns_numpy: False.
    """

    # Turn off validation of input arguments to allow `log_prob()` on samples outside
    # of the support.
    prior.set_default_validate_args(False)

    # Reject unwrapped scalar priors.
    # This will reject Uniform priors with dimension larger than 1.
    if prior.sample().ndim == 0:
        raise ValueError(
            "Detected scalar prior. Please make sure to pass a PyTorch prior with "
            "`batch_shape=torch.Size([1])` or `event_shape=torch.Size([1])`.")
    # Cast 1D Uniform to BoxUniform to avoid shape error in mdn log prob.
    elif isinstance(prior, Uniform) and prior.batch_shape.numel() == 1:
        prior = BoxUniform(low=prior.low, high=prior.high)
        warnings.warn(
            "Casting 1D Uniform prior to BoxUniform to match sbi batch requirements."
        )

    check_prior_batch_behavior(prior)
    check_prior_batch_dims(prior)

    if not prior.sample().dtype == float32:
        prior = PytorchReturnTypeWrapper(prior,
                                         return_type=float32,
                                         validate_args=False)

    # This will fail for float64 priors.
    check_prior_return_type(prior)

    theta_numel = prior.sample().numel()

    return prior, theta_numel, False
Exemplo n.º 9
0
 def mc_sample(
     self,
     distr: Distribution,
     parents: [storch.Tensor],
     plates: [Plate],
     amt_samples: int,
 ) -> torch.Tensor:
     return distr.sample((self.n_samples,))
Exemplo n.º 10
0
 def mc_sample(
     self,
     distr: Distribution,
     parents: [storch.Tensor],
     plates: [Plate],
     amt_samples: int,
 ) -> torch.Tensor:
     # TODO: Why does this ignore amt_samples?
     return distr.sample((self.n_samples,))
Exemplo n.º 11
0
def test_process_simulator(simulator: Callable, prior: Distribution):

    prior, theta_dim, prior_returns_numpy = process_prior(prior)
    simulator = process_simulator(simulator, prior, prior_returns_numpy)

    n_batch = 2
    x = simulator(prior.sample((n_batch, )))

    assert isinstance(x, Tensor), "Processed simulator must return Tensor."
    assert (
        x.shape[0] == n_batch
    ), "Processed simulator must return as many data points as parameters in batch."
Exemplo n.º 12
0
def check_transform(prior: Distribution, transform: TorchTransform) -> None:
    """Check validity of transformed and re-transformed samples."""

    theta = prior.sample(torch.Size((2, )))

    theta_unconstrained = transform.inv(theta)
    assert (
        theta_unconstrained.shape == theta.shape  # type: ignore
    ), """Mismatch between transformed and untransformed space. Note that you cannot
    use a transforms when using a MultipleIndependent prior with a Dirichlet prior."""

    assert torch.allclose(
        theta,
        transform(theta_unconstrained)  # type: ignore
    ), "Mismatch between true and re-transformed parameters."
Exemplo n.º 13
0
def process_pytorch_prior(
    prior: Distribution, ) -> Tuple[Distribution, int, bool]:
    """Return corrected prior after checking requirements for SBI.
    
    Args:
        prior: PyTorch distribution prior provided by the user.
    
    Raises:
        ValueError: If prior is defined over a scalar variable.
    
    Returns:
        prior: PyTorch distribution prior. 
        parameter_dim: event shape of the prior, number of parameters.
        prior_returns_numpy: False.
    """

    # reject scalar priors
    if prior.sample().ndim == 0:
        raise ValueError(
            "Detected scalar prior. Please make sure to pass a PyTorch prior with "
            "batch_shape=torch.Size([1]), or event_shape=torch.Size([1])")

    assert prior.batch_shape in (
        torch.Size([1]),
        torch.Size([]),
    ), f"""The prior must have batch shape torch.Size([]) or torch.Size([1]), but has
        {prior.batch_shape}.
        """

    check_prior_batch_behavior(prior)

    check_for_batch_reinterpretation_extra_d_uniform(prior)

    parameter_dim = prior.sample().numel()

    return prior, parameter_dim, False
Exemplo n.º 14
0
def naive_sampling(num_samples: int, potential_fn: Callable,
                   proposal: Distribution, **kwargs) -> Tensor:
    """Basic sampling method, which just samples from the proposal i.e. the variational
    posterior.

    Args:
        num_samples: Number of samples to draw.
        potential_fn: Potential function, this may be used to debias the proposal.
        proposal: Proposal distribution to propose samples.

    Returns:
        Tensor: Samples of shape (num_samples, event_shape)

    """
    return proposal.sample(torch.Size((num_samples, )))
Exemplo n.º 15
0
def psis_diagnostics(
        potential_function: Callable,
        q: Distribution,
        proposal: Optional[Distribution] = None,
        N: int = int(5e4),
) -> float:
    r"""This will evaluate the posteriors quality by investingating its importance
    weights. If q is a perfect posterior approximation then $q(\theta) \propto
    p(\theta, x_o)$ thus $\log w(\theta) = \log \frac{p(\theta, x_o)}{\log q(\theta)} =
    \log p(x_o)$ is constant. This function will fit a Generalized Paretto
    distribution to the tails of w. The shape parameter k serves as metric as detailed
    in [1]. In short it is related to the variance of a importance sampling estimate,
    especially for k > 1 the variance will be infinite.

    NOTE: In our experience this metric does distinguish "very bad" from "ok", but
    becomes less sensitive to distinguish "ok" from "good".

    Args:
        potential_function: Potential function of target.
        q: Variational distribution, should be proportional to the potential_function
        proposal: Proposal for samples. Typically this is q.
        N: Number of samples involved in the test.

    Returns:
        float: Quality metric

    Reference:
        [1] _Yes, but Did It Work?: Evaluating Variational Inference_, Yuling Yao, Aki
        Vehtari, Daniel Simpson, Andrew Gelman, 2018, https://arxiv.org/abs/1802.02538

    """
    M = int(min(N / 5, 3 * np.sqrt(N)))
    with torch.no_grad():
        if proposal is None:
            samples = q.sample(Size((N, )))
        else:
            samples = proposal.sample(Size((N, )))
        log_q = q.log_prob(samples)
        log_potential = potential_function(samples)
        logweights = log_potential - log_q
        logweights = logweights[torch.isfinite(logweights)]
        logweights_max = logweights.max()
        weights = torch.exp(logweights -
                            logweights_max)  # Thus will only affect scale
        vals, _ = weights.sort()
        largest_weigths = vals[-M:]
        k, _ = gpdfit(largest_weigths)
    return k
Exemplo n.º 16
0
def _mcmc_move(params: Iterable[Parameter], dist: Distribution, stacked: StackedObject, shape: int):
    """
    Performs an MCMC move to rejuvenate parameters.
    :param params: The parameters to use for defining the distribution
    :param dist: The distribution to use for sampling
    :param stacked: The mask to apply for parameters
    :param shape: The shape to sample
    :return: Samples from a multivariate normal distribution
    """

    rvs = dist.sample((shape,))

    for p, msk, ps in zip(params, stacked.mask, stacked.prev_shape):
        p.t_values = unflattify(rvs[:, msk], ps)

    return True
Exemplo n.º 17
0
def proportional_to_joint_diagnostics(
        potential_function: Callable,
        q: Distribution,
        proposal: Optional[Distribution] = None,
        N: int = int(5e4),
) -> float:
    r"""This will evaluate the posteriors quality by investingating its importance
    weights. If q is a perfect posterior approximation then $q(\theta) \propto
    p(\theta, x_o)$. Thus we should be able to fit a line to $(q(\theta),
    p(\theta, x_o))$, whereas the slope will be proportional to the normalizing
    constant. The quality of a linear fit is hence a direct metric for the quality of q.
    We use R2 statistic.

    NOTE: In our experience this metric does distinguish "good" from "ok", but
    becomes less sensitive to distinguish "very bad" from "ok".

    Args:
        potential_function: Potential function of target.
        q: Variational distribution, should be proportional to the potential_function
        proposal: Proposal for samples. Typically this is q.
        N: Number of samples involved in the test.

    Returns:
        float: Quality metric

    """

    with torch.no_grad():
        if proposal is None:
            samples = q.sample(Size((N, )))
        else:
            samples = proposal.sample(Size((N, )))
        log_q = q.log_prob(samples)
        log_potential = potential_function(samples)

        X = log_q.exp().unsqueeze(-1)
        Y = log_potential.exp().unsqueeze(-1)
        w = torch.linalg.solve(X.T @ X, X.T @ Y)  # Linear regression

        residuals = Y - w * X
        var_res = torch.sum(residuals**2)
        var_tot = torch.sum((Y - Y.mean())**2)
        r2 = 1 - var_res / var_tot  # R2 statistic to evaluate fit
    return r2.item()
Exemplo n.º 18
0
def importance_resampling(
    num_samples: int,
    potential_fn: Callable,
    proposal: Distribution,
    K: int = 32,
    num_samples_batch: int = 10000,
    **kwargs,
) -> Tensor:
    """Perform sequential importance resampling (SIR).

    Args:
        num_samples: Number of samples to draw.
        potential_fn: Potential function, this may be used to debias the proposal.
        proposal: Proposal distribution to propose samples.
        K: Number of proposed samples form which only one is selected based on its
            importance weight.
        num_samples_batch: Number of samples processed in parallel. For large K you may
            want to reduce this, depending on your memory capabilities.

    Returns:
        Tensor: Samples of shape (num_samples, event_shape)

    """
    final_samples = []
    num_samples_batch = min(num_samples, num_samples_batch)
    iters = int(num_samples / num_samples_batch)
    for _ in range(iters):
        batch_size = min(num_samples_batch, num_samples - len(final_samples))
        with torch.no_grad():
            thetas = proposal.sample(torch.Size((batch_size * K, )))
            logp = potential_fn(thetas)
            logq = proposal.log_prob(thetas)
            weights = (logp - logq).reshape(batch_size,
                                            K).softmax(-1).cumsum(-1)
            u = torch.rand(batch_size, 1, device=thetas.device)
            mask = torch.cumsum(weights >= u, -1) == 1
            samples = thetas.reshape(batch_size, K, -1)[mask]
            final_samples.append(samples)
    return torch.vstack(final_samples)
Exemplo n.º 19
0
def check_sbi_inputs(simulator: Callable, prior: Distribution) -> None:
    """Assert requirements for simulator, prior and observation for usage in sbi.

    Args:
        simulator: simulator function
        prior: prior (Distribution like)
        x_shape: Shape of single simulation output $x$.
    """
    num_prior_samples = 1
    theta = prior.sample((num_prior_samples, ))
    theta_batch_shape, *_ = theta.shape
    simulation = simulator(theta)
    sim_batch_shape, *sim_event_shape = simulation.shape

    assert isinstance(theta, Tensor), "Parameters theta must be a `Tensor`."
    assert isinstance(simulation,
                      Tensor), "Simulator output must be a `Tensor`."

    assert (theta_batch_shape == num_prior_samples
            ), f"""Theta batch shape {theta_batch_shape} must match
        num_samples={num_prior_samples}."""
    assert (sim_batch_shape == num_prior_samples
            ), f"""Simulation batch shape {sim_batch_shape} must match
Exemplo n.º 20
0
def mcmc_transform(
    prior: Distribution,
    num_prior_samples_for_zscoring: int = 1000,
    enable_transform: bool = True,
    device: str = "cpu",
    **kwargs,
) -> TorchTransform:
    """
    Builds a transform that is applied to parameters during MCMC.

    The resulting transform is defined such that the forward mapping maps from
    constrained to unconstrained space.

    It does two things:
    1) When the prior support is bounded, it transforms the parameters into unbounded
        space.
    2) It z-scores the parameters such that MCMC is performed in a z-scored space.

    Args:
        prior: The prior distribution.
        num_prior_samples_for_zscoring: The number of samples drawn from the prior
            to infer the `mean` and `stddev` of the prior used for z-scoring. Unused if
            the prior has bounded support or when the prior has `mean` and `stddev`
            attributes.
        enable_transform: Whether or not to use a transformation during MCMC.

    Returns: A transformation that transforms whose `forward()` maps from unconstrained
        (or z-scored) to constrained (or non-z-scored) space.
    """

    if enable_transform:
        # Some distributions have a support argument but it raises a
        # NotImplementedError. We catch this case here.
        try:
            _ = prior.support
            has_support = True
        except (NotImplementedError, AttributeError):
            # NotImplementedError -> Distribution that inherits from torch dist but
            # does not implement support.
            # AttributeError -> Custom distribution that has no support attribute.
            warnings.warn(
                """The passed prior has no support property, transform will be
                constructed from mean and std. If the passed prior is supposed to be
                bounded consider implementing the prior.support property.""")
            has_support = False

        # If the distribution has a `support`, check if the support is bounded.
        # If it is not bounded, we want to z-score the space. This is not done
        # by `biject_to()`, so we have to deal with this case separately.
        if has_support:
            if hasattr(prior.support, "base_constraint"):
                constraint = prior.support.base_constraint  # type: ignore
            else:
                constraint = prior.support
            if isinstance(constraint, constraints._Real):
                support_is_bounded = False
            else:
                support_is_bounded = True
        else:
            support_is_bounded = False

        # Prior with bounded support, e.g., uniform priors.
        if has_support and support_is_bounded:
            transform = biject_to(prior.support)
        # For all other cases build affine transform with mean and std.
        else:
            if hasattr(prior, "mean") and hasattr(prior, "stddev"):
                prior_mean = prior.mean.to(device)
                prior_std = prior.stddev.to(device)
            else:
                theta = prior.sample(
                    torch.Size((num_prior_samples_for_zscoring, )))
                prior_mean = theta.mean(dim=0).to(device)
                prior_std = theta.std(dim=0).to(device)

            transform = torch_tf.AffineTransform(loc=prior_mean,
                                                 scale=prior_std)
    else:
        transform = torch_tf.identity_transform

    # Pytorch `transforms` do not sum the determinant over the parameters. However, if
    # the `transform` explicitly is an `IndependentTransform`, it does. Since our
    # `BoxUniform` is a `Independent` distribution, it will also automatically get a
    # `IndependentTransform` wrapper in `biject_to`. Our solution here is to wrap all
    # transforms as `IndependentTransform`.
    if not isinstance(transform, torch_tf.IndependentTransform):
        transform = torch_tf.IndependentTransform(transform,
                                                  reinterpreted_batch_ndims=1)

    check_transform(prior, transform)  # type: ignore

    return transform.inv  # type: ignore
Exemplo n.º 21
0
def run_pmmh(
        context: ParameterContext,
        state: FilterAlgorithmState,
        proposal: BaseProposal,
        proposal_kernel: Distribution,
        proposal_filter: BaseFilter,
        proposal_context: ParameterContext,
        y: torch.Tensor,
        size=torch.Size([]),
        mutate_kernel=False,
) -> torch.BoolTensor:
    r"""
    Runs one iteration of the PMMH update step in which we sample a candidate :math:`\theta^*` from the proposal
    kernel, run a filter for the considered dataset with :math:`\theta^*`, and accept based on the acceptance
    probability given by the article.

    Args:
        context: the parameter context of the main algorithm.
        state: the latest algorithm state.
        proposal: the proposal to use when generating the candidate sample :math:`\theta^*`.
        proposal_kernel: the kernel from which to draw the candidate sample :math:`\theta^*`. To clarify, ``proposal``
            corresponds to the ``BaseProposal`` class that was used when generating ``prop_kernel``.
        proposal_filter: the proposal filter to use.
        proposal_context: the parameter context of the proposal filter.
        y: see ``pyfilter.inference.base.BaseAlgorithm``.
        size: optional parameter specifying the number of num_samples to draw from ``proposal_kernel``. Should be empty
            if we draw from an independent kernel.
        mutate_kernel: optional parameter specifying whether to update ``proposal_kernel`` with the newly accepted
            candidate sample.

    Returns:
        Returns the candidate sample(s) that were accepted.
    """

    constrained = False

    rvs = proposal_kernel.sample(size)
    proposal_context.unstack_parameters(rvs, constrained=constrained)

    new_res = proposal_filter.batch_filter(y, bar=False)

    diff_logl = new_res.loglikelihood - state.filter_state.loglikelihood
    diff_prior = proposal_context.eval_priors(
        constrained=constrained) - context.eval_priors(constrained=constrained)

    new_prop_kernel = proposal.build(proposal_context,
                                     state.replicate(new_res), proposal_filter,
                                     y)
    params_as_tensor = context.stack_parameters(constrained=constrained)

    diff_prop = new_prop_kernel.log_prob(
        params_as_tensor) - proposal_kernel.log_prob(rvs)

    log_acc_prob = diff_prop + diff_prior.squeeze(-1) + diff_logl
    accepted: torch.BoolTensor = torch.empty_like(
        log_acc_prob).uniform_().log() < log_acc_prob

    state.filter_state.exchange(new_res, accepted)
    context.exchange(proposal_context, accepted)

    if mutate_kernel:
        proposal.exchange(proposal_kernel, new_prop_kernel, accepted)

    return accepted
 def sample(self, actions: Distribution) -> Tensor:
     return actions.sample().clone().long()