コード例 #1
0
    def sample(
        self,
        distr: Distribution,
        parents: [storch.Tensor],
        plates: [Plate],
        requires_grad: bool,
    ) -> (storch.StochasticTensor, Plate):
        # TODO: Currently very inefficient as it isn't batched
        # TODO: What if the expectation has a parent
        if not distr.has_enumerate_support:
            raise ValueError(
                "Can only calculate the expected value for distributions with enumerable support."
            )
        support: torch.Tensor = distr.enumerate_support(expand=True)
        support_non_expanded: torch.Tensor = distr.enumerate_support(
            expand=False)
        expect_size = support.shape[0]

        batch_len = len(plates)
        sizes = support.shape[batch_len + 1:len(support.shape) -
                              len(distr.event_shape)]
        amt_samples_used = expect_size
        cross_products = 1 if not sizes else None
        for dim in sizes:
            amt_samples_used = amt_samples_used**dim
            if not cross_products:
                cross_products = dim
            else:
                cross_products = cross_products**dim

        if amt_samples_used > self.budget:
            raise ValueError(
                "Computing the expectation on this distribution would exceed the computation budget."
            )

        enumerate_tensor = support.new_zeros([amt_samples_used] +
                                             list(support.shape[1:]))
        support_non_expanded = support_non_expanded.squeeze().unsqueeze(1)
        for i, t in enumerate(
                itertools.product(support_non_expanded,
                                  repeat=cross_products)):
            enumerate_tensor[i] = torch.cat(t, dim=0)

        enumerate_tensor = enumerate_tensor.detach()

        plate_size = enumerate_tensor.shape[0]

        plate = Plate(self.plate_name, plate_size, plates.copy())
        plates.insert(0, plate)

        s_tensor = storch.StochasticTensor(
            enumerate_tensor,
            parents,
            plates,
            self.plate_name,
            plate_size,
            distr,
            requires_grad,
        )
        return s_tensor, plate
コード例 #2
0
ファイル: ppo.py プロジェクト: yskim525/ppo-pytorch
    def forward(self, distribution_old: Distribution, value_old: Tensor,
                distribution: Distribution, value: Tensor, action: Tensor,
                reward: Tensor, advantage: Tensor):
        # Value loss
        value_old_clipped = value_old + (value - value_old).clamp(
            -self.v_clip_range, self.v_clip_range)
        v_old_loss_clipped = (reward - value_old_clipped).pow(2)
        v_loss = (reward - value).pow(2)
        value_loss = torch.min(v_old_loss_clipped, v_loss).mean()

        # Policy loss
        advantage = (advantage -
                     advantage.mean()) / (advantage.std(unbiased=False) + 1e-8)
        advantage.detach_()
        log_prob = distribution.log_prob(action)
        log_prob_old = distribution_old.log_prob(action)
        ratio = (log_prob - log_prob_old).exp().view(-1)

        surrogate = advantage * ratio
        surrogate_clipped = advantage * ratio.clamp(1 - self.clip_range,
                                                    1 + self.clip_range)
        policy_loss = torch.min(surrogate, surrogate_clipped).mean()

        # Entropy
        entropy = distribution.entropy().mean()

        # Total loss
        losses = policy_loss + self.c_entropy * entropy - self.c_value * value_loss
        total_loss = -losses
        self.reporter.scalar('ppo_loss/policy', -policy_loss.item())
        self.reporter.scalar('ppo_loss/entropy', -entropy.item())
        self.reporter.scalar('ppo_loss/value_loss', value_loss.item())
        self.reporter.scalar('ppo_loss/total', total_loss)
        return total_loss
コード例 #3
0
ファイル: user_input_checks.py プロジェクト: www3cam/sbi
def process_pytorch_prior(
        prior: Distribution) -> Tuple[Distribution, int, bool]:
    """Return PyTorch prior adapted to the requirements for sbi.

    Args:
        prior: PyTorch distribution prior provided by the user.

    Raises:
        ValueError: If prior is defined over an unwrapped scalar variable.

    Returns:
        prior: PyTorch distribution prior.
        theta_numel: Number of parameters - elements in a single sample from the prior.
        prior_returns_numpy: False.
    """

    # Reject unwrapped scalar priors.
    if prior.sample().ndim == 0:
        raise ValueError(
            "Detected scalar prior. Please make sure to pass a PyTorch prior with "
            "`batch_shape=torch.Size([1])` or `event_shape=torch.Size([1])`.")

    check_prior_batch_behavior(prior)
    check_prior_batch_dims(prior)

    if not prior.sample().dtype == float32:
        prior = PytorchReturnTypeWrapper(prior, return_type=float32)

    # This will fail for float64 priors.
    check_prior_return_type(prior)

    theta_numel = prior.sample().numel()

    return prior, theta_numel, False
コード例 #4
0
def run_pmmh(
    filter_: BaseFilter,
    state: FilterResult,
    prop_kernel: Distribution,
    prop_filt,
    y: torch.Tensor,
    size=torch.Size([]),
    **kwargs,
) -> Tuple[torch.Tensor, FilterResult, BaseFilter]:
    """
    Runs one iteration of a vectorized Particle Marginal Metropolis hastings.
    """

    rvs = prop_kernel.sample(size)
    prop_filt.ssm.parameters_from_array(rvs, constrained=False)

    new_res = prop_filt.longfilter(y, bar=False, **kwargs)

    diff_logl = new_res.loglikelihood - state.loglikelihood
    diff_prior = (prop_filt.ssm.eval_prior_log_prob(False) -
                  filter_.ssm.eval_prior_log_prob(False)).squeeze()

    if isinstance(prop_kernel, Independent) and size == torch.Size([]):
        diff_prop = 0.0
    else:
        diff_prop = prop_kernel.log_prob(
            filter_.ssm.parameters_to_array(
                constrained=False)) - prop_kernel.log_prob(rvs)

    log_acc_prob = diff_prop + diff_prior + diff_logl
    res: torch.Tensor = torch.empty_like(
        log_acc_prob).uniform_().log() < log_acc_prob

    return res, new_res, prop_filt
コード例 #5
0
ファイル: vae.py プロジェクト: lizeyan/snippets
 def _sample(self,
             distribution: dist.Distribution,
             sample_shape: Union[torch.Size, tuple] = torch.Size()):
     if self.training:
         return distribution.rsample(sample_shape=sample_shape)
     else:
         return distribution.sample(sample_shape=sample_shape)
コード例 #6
0
ファイル: modified_elbo.py プロジェクト: zhangdabao96/Bagel
def m_elbo(observe_x, observe_z, normal, p_xz: dist.Distribution,
           q_zx: dist.Distribution, p_z: dist.Distribution):
    """
    :param observe_x: (batch_size, x_dims)
    :param observe_z: (sample_size, batch_size, z_dims) or (batch_size, z_dims,)
    :param normal: (batch_size, x_dims)
    :param p_xz: samples in shape (sample_size, batch_size, x_dims)
    :param q_zx: samples in shape (sample_size, batch_size, z_dims)
    :param p_z: samples in shape (z_dims, )
    :return:
    """
    observe_x = torch.unsqueeze(observe_x, 0)  # (1, batch_size, x_dims)
    normal = torch.unsqueeze(normal, 0)  # (1, batch_size, x_dims)
    log_p_xz = p_xz.log_prob(observe_x)  # (1, batch_size, x_dims)
    if observe_z.dim() == 2:
        torch.unsqueeze(observe_z, 0,
                        observe_z)  # (sample_size, batch_size, z_dims)
    # noinspection PyArgumentList
    log_q_zx = torch.sum(q_zx.log_prob(observe_z),
                         -1)  # (sample_size, batch_size)
    # noinspection PyArgumentList
    log_p_z = torch.sum(p_z.log_prob(observe_z),
                        -1)  # (sample_size, batch_size)
    # noinspection PyArgumentList
    radio = (torch.sum(normal, -1) / float(normal.size()[-1])
             )  # (1, batch_size)
    # noinspection PyArgumentList
    return -torch.mean(
        torch.sum(log_p_xz * normal, -1) + log_p_z * radio - log_q_zx)
コード例 #7
0
ファイル: vi_utils.py プロジェクト: bkmi/sbi
def check_sample_shape_and_support(q: Distribution,
                                   prior: Distribution) -> None:
    """Checks the samples shape and support between variational distribution and the
    prior. Especially it checks if the shapes match and that the support between q and
    the prior matches (a property which holds for the true posterior in any case).

    Args:
        q: Variational distribution which is checked
        prior: Prior to check certain attributes which should be satisfied.

    """
    assert (q.event_shape == prior.event_shape
            ), "The event shape of q must match that of the prior"
    assert (q.batch_shape == prior.batch_shape
            ), "The batch sahpe of q must match that of the prior"

    sample_shape = torch.Size((1000, ))
    samples = q.sample(sample_shape)
    samples_prior = prior.sample(sample_shape).to(samples.device)
    try:
        _ = prior.support
        has_support = True
    except (NotImplementedError, AttributeError):
        has_support = False
    if has_support:
        assert all(prior.support.check(samples)  # type: ignore
                   ), "The support of q must match that of the prior"
    assert (
        samples.shape == samples_prior.shape
    ), "Something is wrong with sample shape and event_shape or batch_shape attributes."
    assert torch.isfinite(q.log_prob(
        samples_prior)).all(), "Invalid values in logprob on prior samples."
    assert torch.isfinite(prior.log_prob(
        samples)).all(), "Invalid values in logprob on q samples."
コード例 #8
0
ファイル: utils.py プロジェクト: HaoWen6588/pyfilter
def _eval_kernel(params: Iterable[Parameter], dist: Distribution, n_params: Iterable[Parameter]):
    """
    Evaluates the kernel used for performing the MCMC move.
    :param params: The current parameters
    :param dist: The distribution to use for evaluating the prior
    :param n_params: The new parameters to evaluate against
    :return: The log difference in priors
    """

    p_vals = stacker(params, lambda u: u.t_values)
    n_p_vals = stacker(n_params, lambda u: u.t_values)

    return dist.log_prob(p_vals.concated) - dist.log_prob(n_p_vals.concated)
コード例 #9
0
    def compute_estimator(
        self, tensor: StochasticTensor, cost: CostTensor, plate_index: int,
        n: int, distribution: Distribution
    ) -> Tuple[storch.Tensor, storch.Tensor, storch.Tensor]:
        if self.rebar:
            log_prob = distribution.log_prob(tensor)

            @storch.deterministic
            def compute_REBAR_cvs(estimator: RELAX, tensor: StochasticTensor,
                                  cost: CostTensor):
                # TODO: Does it make more sense to implement REBAR by adding another plate dimension
                #  that controls the three different types of samples? (hard, relaxed, cond)?
                #  Then this can be implemented by simply reducing that plate, ie not requiring the weird zeros
                # Split the sampled tensor
                empty_slices = (slice(None), ) * plate_index
                _index1 = empty_slices + (slice(n), )
                _index2 = empty_slices + (slice(n, 2 * n), )
                _index3 = empty_slices + (slice(2 * n, 3 * n), )
                relaxed_sample, cond_sample = (
                    tensor[_index2],
                    tensor[_index3],
                )
                relaxed_cost, cond_cost = cost[_index2], cost[_index3]

                # Compute the control variates and log probabilities for the samples
                _c_phi_relaxed = estimator.c_phi(relaxed_sample) + relaxed_cost
                _c_phi_cond = estimator.c_phi(cond_sample) + cond_cost

                # Add zeros to ensure plates align
                c_phi_relaxed = torch.zeros_like(cost)
                c_phi_relaxed[_index1] = _c_phi_relaxed
                c_phi_cond = torch.zeros_like(cost)
                c_phi_cond[_index1] = _c_phi_cond
                return c_phi_relaxed, c_phi_cond

            return (log_prob, ) + compute_REBAR_cvs(self, tensor, cost)
        else:
            hard_sample = discretize(tensor, distribution)
            relaxed_sample = tensor
            cond_sample = storch.conditional_gumbel_rsample(
                hard_sample, distribution.probs,
                isinstance(distribution, torch.distributions.Bernoulli),
                self.temperature)
            # Input rsampled values into c_phi
            _c_phi_relaxed = self.c_phi(relaxed_sample)
            _c_phi_cond = self.c_phi(cond_sample)
            _log_prob = distribution.log_prob(hard_sample)
            return _log_prob, _c_phi_relaxed, _c_phi_cond
コード例 #10
0
ファイル: vi_quality_control.py プロジェクト: bkmi/sbi
def psis_diagnostics(
        potential_function: Callable,
        q: Distribution,
        proposal: Optional[Distribution] = None,
        N: int = int(5e4),
) -> float:
    r"""This will evaluate the posteriors quality by investingating its importance
    weights. If q is a perfect posterior approximation then $q(\theta) \propto
    p(\theta, x_o)$ thus $\log w(\theta) = \log \frac{p(\theta, x_o)}{\log q(\theta)} =
    \log p(x_o)$ is constant. This function will fit a Generalized Paretto
    distribution to the tails of w. The shape parameter k serves as metric as detailed
    in [1]. In short it is related to the variance of a importance sampling estimate,
    especially for k > 1 the variance will be infinite.

    NOTE: In our experience this metric does distinguish "very bad" from "ok", but
    becomes less sensitive to distinguish "ok" from "good".

    Args:
        potential_function: Potential function of target.
        q: Variational distribution, should be proportional to the potential_function
        proposal: Proposal for samples. Typically this is q.
        N: Number of samples involved in the test.

    Returns:
        float: Quality metric

    Reference:
        [1] _Yes, but Did It Work?: Evaluating Variational Inference_, Yuling Yao, Aki
        Vehtari, Daniel Simpson, Andrew Gelman, 2018, https://arxiv.org/abs/1802.02538

    """
    M = int(min(N / 5, 3 * np.sqrt(N)))
    with torch.no_grad():
        if proposal is None:
            samples = q.sample(Size((N, )))
        else:
            samples = proposal.sample(Size((N, )))
        log_q = q.log_prob(samples)
        log_potential = potential_function(samples)
        logweights = log_potential - log_q
        logweights = logweights[torch.isfinite(logweights)]
        logweights_max = logweights.max()
        weights = torch.exp(logweights -
                            logweights_max)  # Thus will only affect scale
        vals, _ = weights.sort()
        largest_weigths = vals[-M:]
        k, _ = gpdfit(largest_weigths)
    return k
コード例 #11
0
ファイル: distributions.py プロジェクト: zero-cooper/SPFlow
def dist_sample(distribution: dist.Distribution, context: SamplingContext = None) -> torch.Tensor:
    """
    Sample n samples from a given distribution.

    Args:
        repetition_indices: Indices into the repetition axis.
        distribution (dists.Distribution): Base distribution to sample from.
        parent_indices (torch.Tensor): Tensor of indexes that point to specific representations of single features/scopes.
    """

    # Sample from the specified distribution
    if context.is_mpe:
        samples = _mode(distribution, context)
    else:
        samples = distribution.sample(sample_shape=(context.n,))

    assert (
        samples.shape[1] == 1
    ), "Something went wrong. First sample size dimension should be size 1 due to the distribution parameter dimensions. Please report this issue."
    samples.squeeze_(1)
    n, d, c, r = samples.shape

    # Filter each sample by its specific repetition
    tmp = torch.zeros(n, d, c, device=context.repetition_indices.device)
    for i in range(n):
        tmp[i, :, :] = samples[i, :, :, context.repetition_indices[i]]
    samples = tmp

    # If parent index into out_channels are given
    if context.parent_indices is not None:
        # Choose only specific samples for each feature/scope
        samples = torch.gather(samples, dim=2, index=context.parent_indices.unsqueeze(-1)).squeeze(-1)

    return samples
コード例 #12
0
ファイル: model_helper.py プロジェクト: pierresegonne/SGGM
def density_gradient_descent(distribution: D.Distribution, x_0: torch.Tensor,
                             params: dict) -> torch.Tensor:
    N_steps, lr, threshold = params["N_steps"], params["lr"], params[
        "threshold"]

    x_hat = x_0.clone()
    x_hat.requires_grad = True

    print("   PIG gradient descent:", end=" ", flush=True)
    for n in range(N_steps):
        print(f"{n+1}", end=" ")
        with torch.no_grad():
            with torch.set_grad_enabled(True):
                log_prob = distribution.log_prob(x_hat).mean()
                density_grad = torch.autograd.grad(log_prob,
                                                   x_hat,
                                                   retain_graph=True)[0]
                normed_density_grad = normalise_grad(density_grad)
                normed_density_grad = torch.where(
                    torch.linalg.norm(density_grad, dim=1)[:, None] <
                    threshold,
                    normed_density_grad,
                    torch.zeros_like(normed_density_grad),
                )
                x_hat = x_hat - lr * normed_density_grad
    print("-> OK")
    x_hat = x_hat.detach()
    return x_hat
コード例 #13
0
ファイル: test_utils.py プロジェクト: jyotikab/sbi
def get_normalization_uniform_prior(
    posterior: DirectPosterior,
    prior: Distribution,
    true_observation: Tensor,
) -> Tuple[Tensor, Tensor, Tensor]:
    """
    Return the unnormalized posterior likelihood, the normalized posterior likelihood,
    and the estimated acceptance probability.

    Args:
        posterior: estimated posterior
        prior: prior distribution
        true_observation: observation where we evaluate the posterior
    """

    # Test normalization.
    prior_sample = prior.sample()

    # Compute unnormalized density, i.e. just the output of the density estimator.
    posterior_likelihood_unnorm = torch.exp(
        posterior.log_prob(prior_sample, norm_posterior=False))
    # Compute the normalized density, scale up output of the density
    # estimator by the ratio of posterior samples within the prior bounds.
    posterior_likelihood_norm = torch.exp(
        posterior.log_prob(prior_sample, norm_posterior=True))

    # Estimate acceptance ratio through rejection sampling.
    acceptance_prob = posterior.leakage_correction(x=true_observation)

    return posterior_likelihood_unnorm, posterior_likelihood_norm, acceptance_prob
コード例 #14
0
    def decode(
        self,
        distribution: Distribution,
        joint_log_probs: Optional[storch.Tensor],
        parents: [storch.Tensor],
        orig_distr_plates: [storch.Plate],
    ) -> (storch.Tensor, storch.Tensor, storch.Tensor):

        is_conditional_sample = False

        for plate in orig_distr_plates:
            if plate.name == self.plate_name:
                is_conditional_sample = True

        if is_conditional_sample:
            sample = self.mc_sample(distribution, parents, orig_distr_plates,
                                    1).squeeze(0)
        else:
            sample = self.mc_sample(distribution, parents, orig_distr_plates,
                                    self.k)

        s_log_probs = distribution.log_prob(sample)
        if joint_log_probs:
            joint_log_probs += s_log_probs
        else:
            joint_log_probs = s_log_probs

        if self.finished_samples:
            # Make sure we do not change the log probabilities for samples that were already finished
            joint_log_probs = (joint_log_probs * self.finished_samples +
                               joint_log_probs * (1 - self.finished_samples))
            sample[self.finished_samples] = self.eos

        return sample, joint_log_probs, None
コード例 #15
0
def loss(distr: Distribution, actions: Tensor, critic_value: Tensor,
         c_entropy: float) -> Tensor:
    """Computes A2C actor loss, i.e. -C(s, a) log pi(a | s) - c_entropy H(pi(.|s)).

    :args distr: distribution on actions, accepting actions of the same size as actions
    :args actions: actions performed
    :args critic_value: advantage corresponding to the actions performed
    :args c_entropy: entropy loss weighting

    :return: loss
    """
    logp_action = distr.log_prob(actions)
    entropy = distr.entropy()

    loss_critic = (-logp_action * critic_value.detach()).mean()
    return loss_critic - c_entropy * entropy.mean()
コード例 #16
0
ファイル: main.py プロジェクト: pierresegonne/SGGM
def plot_comparison(n_display: int, x_og: torch.Tensor, p_x: D.Distribution,
                    input_dims: Tuple[int]) -> Figure:
    fig = plt.figure(figsize=(6, 6))
    gs = fig.add_gridspec(4,
                          n_display,
                          width_ratios=[1] * n_display,
                          height_ratios=[1, 1, 1, 1])
    gs.update(wspace=0, hspace=0)

    x_hat = batch_reshape(p_x.sample(), input_dims).clip(0, 1)
    x_mu = batch_reshape(p_x.mean, input_dims).clip(0, 1)
    x_var = batch_reshape(p_x.variance, input_dims).clip(0, 1)

    for n in range(n_display):
        for k in range(4):
            ax = plt.subplot(gs[k, n])
            ax = disable_ticks(ax)
            # Original
            # imshow accepts (W, H, C)
            if k == 0:
                ax.imshow(x_og[n, :].permute(1, 2, 0), vmin=0, vmax=1)
            # Mean
            elif k == 1:
                ax.imshow(x_mu[n, :].permute(1, 2, 0), vmin=0, vmax=1)
            # Variance
            elif k == 2:
                ax.imshow(x_var[n, :].permute(1, 2, 0))
            # Sample
            elif k == 3:
                ax.imshow(x_hat[n, :].permute(1, 2, 0), vmin=0, vmax=1)

    return fig
コード例 #17
0
ファイル: user_input_checks.py プロジェクト: jyotikab/sbi
def process_pytorch_prior(
        prior: Distribution) -> Tuple[Distribution, int, bool]:
    """Return PyTorch prior adapted to the requirements for sbi.

    Args:
        prior: PyTorch distribution prior provided by the user.

    Raises:
        ValueError: If prior is defined over an unwrapped scalar variable.

    Returns:
        prior: PyTorch distribution prior.
        theta_numel: Number of parameters - elements in a single sample from the prior.
        prior_returns_numpy: False.
    """

    # Turn off validation of input arguments to allow `log_prob()` on samples outside
    # of the support.
    prior.set_default_validate_args(False)

    # Reject unwrapped scalar priors.
    # This will reject Uniform priors with dimension larger than 1.
    if prior.sample().ndim == 0:
        raise ValueError(
            "Detected scalar prior. Please make sure to pass a PyTorch prior with "
            "`batch_shape=torch.Size([1])` or `event_shape=torch.Size([1])`.")
    # Cast 1D Uniform to BoxUniform to avoid shape error in mdn log prob.
    elif isinstance(prior, Uniform) and prior.batch_shape.numel() == 1:
        prior = BoxUniform(low=prior.low, high=prior.high)
        warnings.warn(
            "Casting 1D Uniform prior to BoxUniform to match sbi batch requirements."
        )

    check_prior_batch_behavior(prior)
    check_prior_batch_dims(prior)

    if not prior.sample().dtype == float32:
        prior = PytorchReturnTypeWrapper(prior,
                                         return_type=float32,
                                         validate_args=False)

    # This will fail for float64 priors.
    check_prior_return_type(prior)

    theta_numel = prior.sample().numel()

    return prior, theta_numel, False
コード例 #18
0
ファイル: generative.py プロジェクト: vitaka/AEVNMT.pt
    def log_prob(self, likelihood: Distribution, x):
        """
        Return log-probability of observation.

        likelihood: as returned by forward
        x: [B, Tx] observed token ids
        """
        return (likelihood.log_prob(x) * (x != self.pad_idx).float()).sum(-1)
コード例 #19
0
ファイル: affine.py プロジェクト: HaoWen6588/pyfilter
def _define_transdist(loc: torch.Tensor, scale: torch.Tensor,
                      inc_dist: Distribution, ndim: int):
    loc, scale = torch.broadcast_tensors(loc, scale)

    shape = loc.shape[:-ndim] if ndim > 0 else loc.shape

    return TransformedDistribution(inc_dist.expand(shape),
                                   AffineTransform(loc, scale, event_dim=ndim))
コード例 #20
0
ファイル: method.py プロジェクト: dudugang/storchastic
 def mc_sample(
     self,
     distr: Distribution,
     parents: [storch.Tensor],
     plates: [Plate],
     amt_samples: int,
 ) -> torch.Tensor:
     return distr.sample((self.n_samples,))
コード例 #21
0
ファイル: generative.py プロジェクト: vitaka/AEVNMT.pt
    def log_prob(self, likelihood: Distribution, y):
        """
        Return log-probability of observation.

        likelihood: as returned by forward
        y: [B, Ty] observed token ids
        """
        return (likelihood.log_prob(y) *
                (y != self.tgt_embedder.padding_idx).float()).sum(-1)
コード例 #22
0
 def _compute_entropy(dist: td.Distribution):
     if isinstance(dist, td.TransformedDistribution):
         # TransformedDistribution is used by NormalProjectionNetwork with
         # scale_distribution=True, in which case we estimate with sampling.
         entropy, entropy_for_gradient = estimated_entropy(dist)
     else:
         entropy = dist.entropy()
         entropy_for_gradient = entropy
     return entropy, entropy_for_gradient
コード例 #23
0
 def mc_sample(
     self,
     distr: Distribution,
     parents: [storch.Tensor],
     plates: [Plate],
     amt_samples: int,
 ) -> torch.Tensor:
     # TODO: Why does this ignore amt_samples?
     return distr.sample((self.n_samples,))
コード例 #24
0
ファイル: vi_quality_control.py プロジェクト: bkmi/sbi
def proportional_to_joint_diagnostics(
        potential_function: Callable,
        q: Distribution,
        proposal: Optional[Distribution] = None,
        N: int = int(5e4),
) -> float:
    r"""This will evaluate the posteriors quality by investingating its importance
    weights. If q is a perfect posterior approximation then $q(\theta) \propto
    p(\theta, x_o)$. Thus we should be able to fit a line to $(q(\theta),
    p(\theta, x_o))$, whereas the slope will be proportional to the normalizing
    constant. The quality of a linear fit is hence a direct metric for the quality of q.
    We use R2 statistic.

    NOTE: In our experience this metric does distinguish "good" from "ok", but
    becomes less sensitive to distinguish "very bad" from "ok".

    Args:
        potential_function: Potential function of target.
        q: Variational distribution, should be proportional to the potential_function
        proposal: Proposal for samples. Typically this is q.
        N: Number of samples involved in the test.

    Returns:
        float: Quality metric

    """

    with torch.no_grad():
        if proposal is None:
            samples = q.sample(Size((N, )))
        else:
            samples = proposal.sample(Size((N, )))
        log_q = q.log_prob(samples)
        log_potential = potential_function(samples)

        X = log_q.exp().unsqueeze(-1)
        Y = log_potential.exp().unsqueeze(-1)
        w = torch.linalg.solve(X.T @ X, X.T @ Y)  # Linear regression

        residuals = Y - w * X
        var_res = torch.sum(residuals**2)
        var_tot = torch.sum((Y - Y.mean())**2)
        r2 = 1 - var_res / var_tot  # R2 statistic to evaluate fit
    return r2.item()
コード例 #25
0
def loss(distr: Distribution,
         actions: Tensor,
         critic_value: Tensor,
         c_entropy: float,
         eps_clamp: float,
         c_kl: float,
         old_logp: Tensor,
         old_distr: Optional[Distribution] = None) -> None:
    """Computes PPO actor loss. See
    https://spinningup.openai.com/en/latest/algorithms/ppo.html
    for a detailled explanation

    :args distr: current distribution of actions,
        accepting actions of the same size as actions
    :args actions: actions performed
    :args critic_value: advantage corresponding to the actions performed
    :args c_entropy: entropy loss weighting
    :args eps_clamp: clamping parameter
    :args c_kl: kl penalty coefficient
    :args old_logp: log probabilities of the old distribution of actions
    :args old_distr: old ditribution of actions

    :return: loss
    """
    logp_action = distr.log_prob(actions)
    logr = (logp_action - old_logp)

    r_clipped = torch.where(critic_value.detach() > 0,
                            torch.clamp(logr, max=np.log(1 + eps_clamp)),
                            torch.clamp(logr,
                                        min=np.log(1 - eps_clamp))).exp()

    loss = -r_clipped * critic_value.detach()
    if c_entropy != 0.:
        loss -= c_entropy * distr.entropy()

    if c_kl != 0.:
        if old_distr is None:
            raise ValueError(
                "Optional argument old_distr is required if c_kl > 0")
        loss += c_kl * kl_divergence(old_distr, distr)

    return loss.mean()
コード例 #26
0
ファイル: test_simulator_utils.py プロジェクト: boyali/sbi
def test_process_simulator(simulator: Callable, prior: Distribution):

    prior, theta_dim, prior_returns_numpy = process_prior(prior)
    simulator = process_simulator(simulator, prior, prior_returns_numpy)

    n_batch = 2
    x = simulator(prior.sample((n_batch, )))

    assert isinstance(x, Tensor), "Processed simulator must return Tensor."
    assert (
        x.shape[0] == n_batch
    ), "Processed simulator must return as many data points as parameters in batch."
コード例 #27
0
 def reparam_sample(
     self,
     distr: Distribution,
     parents: [storch.Tensor],
     plates: [Plate],
     amt_samples: int,
 ):
     if not distr.has_rsample:
         raise ValueError(
             "The input distribution has not implemented rsample. If you use a discrete "
             "distribution, make sure to use eg GumbelSoftmax.")
     return distr.rsample((amt_samples, ))
コード例 #28
0
def importance_resampling(
    num_samples: int,
    potential_fn: Callable,
    proposal: Distribution,
    K: int = 32,
    num_samples_batch: int = 10000,
    **kwargs,
) -> Tensor:
    """Perform sequential importance resampling (SIR).

    Args:
        num_samples: Number of samples to draw.
        potential_fn: Potential function, this may be used to debias the proposal.
        proposal: Proposal distribution to propose samples.
        K: Number of proposed samples form which only one is selected based on its
            importance weight.
        num_samples_batch: Number of samples processed in parallel. For large K you may
            want to reduce this, depending on your memory capabilities.

    Returns:
        Tensor: Samples of shape (num_samples, event_shape)

    """
    final_samples = []
    num_samples_batch = min(num_samples, num_samples_batch)
    iters = int(num_samples / num_samples_batch)
    for _ in range(iters):
        batch_size = min(num_samples_batch, num_samples - len(final_samples))
        with torch.no_grad():
            thetas = proposal.sample(torch.Size((batch_size * K, )))
            logp = potential_fn(thetas)
            logq = proposal.log_prob(thetas)
            weights = (logp - logq).reshape(batch_size,
                                            K).softmax(-1).cumsum(-1)
            u = torch.rand(batch_size, 1, device=thetas.device)
            mask = torch.cumsum(weights >= u, -1) == 1
            samples = thetas.reshape(batch_size, K, -1)[mask]
            final_samples.append(samples)
    return torch.vstack(final_samples)
コード例 #29
0
ファイル: sem.py プロジェクト: Lakshitha0912/rfi1
    def _support_check_wrapper(dist: Distribution):

        old_log_prob = dist.log_prob

        def new_log_prob_method(value):
            result = old_log_prob(value)

            # Out of support values
            result[torch.isnan(result)] = -float("Inf")
            result[~dist.support.check(value)] = -float("Inf")

            return result

        dist.log_prob = new_log_prob_method
コード例 #30
0
def check_transform(prior: Distribution, transform: TorchTransform) -> None:
    """Check validity of transformed and re-transformed samples."""

    theta = prior.sample(torch.Size((2, )))

    theta_unconstrained = transform.inv(theta)
    assert (
        theta_unconstrained.shape == theta.shape  # type: ignore
    ), """Mismatch between transformed and untransformed space. Note that you cannot
    use a transforms when using a MultipleIndependent prior with a Dirichlet prior."""

    assert torch.allclose(
        theta,
        transform(theta_unconstrained)  # type: ignore
    ), "Mismatch between true and re-transformed parameters."