Пример #1
0
def run_pmmh(
    filter_: BaseFilter,
    state: FilterResult,
    prop_kernel: Distribution,
    prop_filt,
    y: torch.Tensor,
    size=torch.Size([]),
    **kwargs,
) -> Tuple[torch.Tensor, FilterResult, BaseFilter]:
    """
    Runs one iteration of a vectorized Particle Marginal Metropolis hastings.
    """

    rvs = prop_kernel.sample(size)
    prop_filt.ssm.parameters_from_array(rvs, constrained=False)

    new_res = prop_filt.longfilter(y, bar=False, **kwargs)

    diff_logl = new_res.loglikelihood - state.loglikelihood
    diff_prior = (prop_filt.ssm.eval_prior_log_prob(False) -
                  filter_.ssm.eval_prior_log_prob(False)).squeeze()

    if isinstance(prop_kernel, Independent) and size == torch.Size([]):
        diff_prop = 0.0
    else:
        diff_prop = prop_kernel.log_prob(
            filter_.ssm.parameters_to_array(
                constrained=False)) - prop_kernel.log_prob(rvs)

    log_acc_prob = diff_prop + diff_prior + diff_logl
    res: torch.Tensor = torch.empty_like(
        log_acc_prob).uniform_().log() < log_acc_prob

    return res, new_res, prop_filt
Пример #2
0
    def forward(self, distribution_old: Distribution, value_old: Tensor,
                distribution: Distribution, value: Tensor, action: Tensor,
                reward: Tensor, advantage: Tensor):
        # Value loss
        value_old_clipped = value_old + (value - value_old).clamp(
            -self.v_clip_range, self.v_clip_range)
        v_old_loss_clipped = (reward - value_old_clipped).pow(2)
        v_loss = (reward - value).pow(2)
        value_loss = torch.min(v_old_loss_clipped, v_loss).mean()

        # Policy loss
        advantage = (advantage -
                     advantage.mean()) / (advantage.std(unbiased=False) + 1e-8)
        advantage.detach_()
        log_prob = distribution.log_prob(action)
        log_prob_old = distribution_old.log_prob(action)
        ratio = (log_prob - log_prob_old).exp().view(-1)

        surrogate = advantage * ratio
        surrogate_clipped = advantage * ratio.clamp(1 - self.clip_range,
                                                    1 + self.clip_range)
        policy_loss = torch.min(surrogate, surrogate_clipped).mean()

        # Entropy
        entropy = distribution.entropy().mean()

        # Total loss
        losses = policy_loss + self.c_entropy * entropy - self.c_value * value_loss
        total_loss = -losses
        self.reporter.scalar('ppo_loss/policy', -policy_loss.item())
        self.reporter.scalar('ppo_loss/entropy', -entropy.item())
        self.reporter.scalar('ppo_loss/value_loss', value_loss.item())
        self.reporter.scalar('ppo_loss/total', total_loss)
        return total_loss
Пример #3
0
def m_elbo(observe_x, observe_z, normal, p_xz: dist.Distribution,
           q_zx: dist.Distribution, p_z: dist.Distribution):
    """
    :param observe_x: (batch_size, x_dims)
    :param observe_z: (sample_size, batch_size, z_dims) or (batch_size, z_dims,)
    :param normal: (batch_size, x_dims)
    :param p_xz: samples in shape (sample_size, batch_size, x_dims)
    :param q_zx: samples in shape (sample_size, batch_size, z_dims)
    :param p_z: samples in shape (z_dims, )
    :return:
    """
    observe_x = torch.unsqueeze(observe_x, 0)  # (1, batch_size, x_dims)
    normal = torch.unsqueeze(normal, 0)  # (1, batch_size, x_dims)
    log_p_xz = p_xz.log_prob(observe_x)  # (1, batch_size, x_dims)
    if observe_z.dim() == 2:
        torch.unsqueeze(observe_z, 0,
                        observe_z)  # (sample_size, batch_size, z_dims)
    # noinspection PyArgumentList
    log_q_zx = torch.sum(q_zx.log_prob(observe_z),
                         -1)  # (sample_size, batch_size)
    # noinspection PyArgumentList
    log_p_z = torch.sum(p_z.log_prob(observe_z),
                        -1)  # (sample_size, batch_size)
    # noinspection PyArgumentList
    radio = (torch.sum(normal, -1) / float(normal.size()[-1])
             )  # (1, batch_size)
    # noinspection PyArgumentList
    return -torch.mean(
        torch.sum(log_p_xz * normal, -1) + log_p_z * radio - log_q_zx)
Пример #4
0
def check_sample_shape_and_support(q: Distribution,
                                   prior: Distribution) -> None:
    """Checks the samples shape and support between variational distribution and the
    prior. Especially it checks if the shapes match and that the support between q and
    the prior matches (a property which holds for the true posterior in any case).

    Args:
        q: Variational distribution which is checked
        prior: Prior to check certain attributes which should be satisfied.

    """
    assert (q.event_shape == prior.event_shape
            ), "The event shape of q must match that of the prior"
    assert (q.batch_shape == prior.batch_shape
            ), "The batch sahpe of q must match that of the prior"

    sample_shape = torch.Size((1000, ))
    samples = q.sample(sample_shape)
    samples_prior = prior.sample(sample_shape).to(samples.device)
    try:
        _ = prior.support
        has_support = True
    except (NotImplementedError, AttributeError):
        has_support = False
    if has_support:
        assert all(prior.support.check(samples)  # type: ignore
                   ), "The support of q must match that of the prior"
    assert (
        samples.shape == samples_prior.shape
    ), "Something is wrong with sample shape and event_shape or batch_shape attributes."
    assert torch.isfinite(q.log_prob(
        samples_prior)).all(), "Invalid values in logprob on prior samples."
    assert torch.isfinite(prior.log_prob(
        samples)).all(), "Invalid values in logprob on q samples."
Пример #5
0
def _eval_kernel(params: Iterable[Parameter], dist: Distribution, n_params: Iterable[Parameter]):
    """
    Evaluates the kernel used for performing the MCMC move.
    :param params: The current parameters
    :param dist: The distribution to use for evaluating the prior
    :param n_params: The new parameters to evaluate against
    :return: The log difference in priors
    """

    p_vals = stacker(params, lambda u: u.t_values)
    n_p_vals = stacker(n_params, lambda u: u.t_values)

    return dist.log_prob(p_vals.concated) - dist.log_prob(n_p_vals.concated)
Пример #6
0
    def compute_estimator(
        self, tensor: StochasticTensor, cost: CostTensor, plate_index: int,
        n: int, distribution: Distribution
    ) -> Tuple[storch.Tensor, storch.Tensor, storch.Tensor]:
        if self.rebar:
            log_prob = distribution.log_prob(tensor)

            @storch.deterministic
            def compute_REBAR_cvs(estimator: RELAX, tensor: StochasticTensor,
                                  cost: CostTensor):
                # TODO: Does it make more sense to implement REBAR by adding another plate dimension
                #  that controls the three different types of samples? (hard, relaxed, cond)?
                #  Then this can be implemented by simply reducing that plate, ie not requiring the weird zeros
                # Split the sampled tensor
                empty_slices = (slice(None), ) * plate_index
                _index1 = empty_slices + (slice(n), )
                _index2 = empty_slices + (slice(n, 2 * n), )
                _index3 = empty_slices + (slice(2 * n, 3 * n), )
                relaxed_sample, cond_sample = (
                    tensor[_index2],
                    tensor[_index3],
                )
                relaxed_cost, cond_cost = cost[_index2], cost[_index3]

                # Compute the control variates and log probabilities for the samples
                _c_phi_relaxed = estimator.c_phi(relaxed_sample) + relaxed_cost
                _c_phi_cond = estimator.c_phi(cond_sample) + cond_cost

                # Add zeros to ensure plates align
                c_phi_relaxed = torch.zeros_like(cost)
                c_phi_relaxed[_index1] = _c_phi_relaxed
                c_phi_cond = torch.zeros_like(cost)
                c_phi_cond[_index1] = _c_phi_cond
                return c_phi_relaxed, c_phi_cond

            return (log_prob, ) + compute_REBAR_cvs(self, tensor, cost)
        else:
            hard_sample = discretize(tensor, distribution)
            relaxed_sample = tensor
            cond_sample = storch.conditional_gumbel_rsample(
                hard_sample, distribution.probs,
                isinstance(distribution, torch.distributions.Bernoulli),
                self.temperature)
            # Input rsampled values into c_phi
            _c_phi_relaxed = self.c_phi(relaxed_sample)
            _c_phi_cond = self.c_phi(cond_sample)
            _log_prob = distribution.log_prob(hard_sample)
            return _log_prob, _c_phi_relaxed, _c_phi_cond
Пример #7
0
def density_gradient_descent(distribution: D.Distribution, x_0: torch.Tensor,
                             params: dict) -> torch.Tensor:
    N_steps, lr, threshold = params["N_steps"], params["lr"], params[
        "threshold"]

    x_hat = x_0.clone()
    x_hat.requires_grad = True

    print("   PIG gradient descent:", end=" ", flush=True)
    for n in range(N_steps):
        print(f"{n+1}", end=" ")
        with torch.no_grad():
            with torch.set_grad_enabled(True):
                log_prob = distribution.log_prob(x_hat).mean()
                density_grad = torch.autograd.grad(log_prob,
                                                   x_hat,
                                                   retain_graph=True)[0]
                normed_density_grad = normalise_grad(density_grad)
                normed_density_grad = torch.where(
                    torch.linalg.norm(density_grad, dim=1)[:, None] <
                    threshold,
                    normed_density_grad,
                    torch.zeros_like(normed_density_grad),
                )
                x_hat = x_hat - lr * normed_density_grad
    print("-> OK")
    x_hat = x_hat.detach()
    return x_hat
Пример #8
0
    def decode(
        self,
        distribution: Distribution,
        joint_log_probs: Optional[storch.Tensor],
        parents: [storch.Tensor],
        orig_distr_plates: [storch.Plate],
    ) -> (storch.Tensor, storch.Tensor, storch.Tensor):

        is_conditional_sample = False

        for plate in orig_distr_plates:
            if plate.name == self.plate_name:
                is_conditional_sample = True

        if is_conditional_sample:
            sample = self.mc_sample(distribution, parents, orig_distr_plates,
                                    1).squeeze(0)
        else:
            sample = self.mc_sample(distribution, parents, orig_distr_plates,
                                    self.k)

        s_log_probs = distribution.log_prob(sample)
        if joint_log_probs:
            joint_log_probs += s_log_probs
        else:
            joint_log_probs = s_log_probs

        if self.finished_samples:
            # Make sure we do not change the log probabilities for samples that were already finished
            joint_log_probs = (joint_log_probs * self.finished_samples +
                               joint_log_probs * (1 - self.finished_samples))
            sample[self.finished_samples] = self.eos

        return sample, joint_log_probs, None
Пример #9
0
    def log_prob(self, likelihood: Distribution, x):
        """
        Return log-probability of observation.

        likelihood: as returned by forward
        x: [B, Tx] observed token ids
        """
        return (likelihood.log_prob(x) * (x != self.pad_idx).float()).sum(-1)
Пример #10
0
    def log_prob(self, likelihood: Distribution, y):
        """
        Return log-probability of observation.

        likelihood: as returned by forward
        y: [B, Ty] observed token ids
        """
        return (likelihood.log_prob(y) *
                (y != self.tgt_embedder.padding_idx).float()).sum(-1)
Пример #11
0
    def _support_check_wrapper(dist: Distribution):

        old_log_prob = dist.log_prob

        def new_log_prob_method(value):
            result = old_log_prob(value)

            # Out of support values
            result[torch.isnan(result)] = -float("Inf")
            result[~dist.support.check(value)] = -float("Inf")

            return result

        dist.log_prob = new_log_prob_method
Пример #12
0
def psis_diagnostics(
        potential_function: Callable,
        q: Distribution,
        proposal: Optional[Distribution] = None,
        N: int = int(5e4),
) -> float:
    r"""This will evaluate the posteriors quality by investingating its importance
    weights. If q is a perfect posterior approximation then $q(\theta) \propto
    p(\theta, x_o)$ thus $\log w(\theta) = \log \frac{p(\theta, x_o)}{\log q(\theta)} =
    \log p(x_o)$ is constant. This function will fit a Generalized Paretto
    distribution to the tails of w. The shape parameter k serves as metric as detailed
    in [1]. In short it is related to the variance of a importance sampling estimate,
    especially for k > 1 the variance will be infinite.

    NOTE: In our experience this metric does distinguish "very bad" from "ok", but
    becomes less sensitive to distinguish "ok" from "good".

    Args:
        potential_function: Potential function of target.
        q: Variational distribution, should be proportional to the potential_function
        proposal: Proposal for samples. Typically this is q.
        N: Number of samples involved in the test.

    Returns:
        float: Quality metric

    Reference:
        [1] _Yes, but Did It Work?: Evaluating Variational Inference_, Yuling Yao, Aki
        Vehtari, Daniel Simpson, Andrew Gelman, 2018, https://arxiv.org/abs/1802.02538

    """
    M = int(min(N / 5, 3 * np.sqrt(N)))
    with torch.no_grad():
        if proposal is None:
            samples = q.sample(Size((N, )))
        else:
            samples = proposal.sample(Size((N, )))
        log_q = q.log_prob(samples)
        log_potential = potential_function(samples)
        logweights = log_potential - log_q
        logweights = logweights[torch.isfinite(logweights)]
        logweights_max = logweights.max()
        weights = torch.exp(logweights -
                            logweights_max)  # Thus will only affect scale
        vals, _ = weights.sort()
        largest_weigths = vals[-M:]
        k, _ = gpdfit(largest_weigths)
    return k
Пример #13
0
def loss(distr: Distribution, actions: Tensor, critic_value: Tensor,
         c_entropy: float) -> Tensor:
    """Computes A2C actor loss, i.e. -C(s, a) log pi(a | s) - c_entropy H(pi(.|s)).

    :args distr: distribution on actions, accepting actions of the same size as actions
    :args actions: actions performed
    :args critic_value: advantage corresponding to the actions performed
    :args c_entropy: entropy loss weighting

    :return: loss
    """
    logp_action = distr.log_prob(actions)
    entropy = distr.entropy()

    loss_critic = (-logp_action * critic_value.detach()).mean()
    return loss_critic - c_entropy * entropy.mean()
Пример #14
0
def plot_dist(xx, yy, dist: td.Distribution, data=None):
    xlim = [xx.min(), xx.max()]
    ylim = [yy.min(), yy.max()]
    with torch.no_grad():
        if len(dist.batch_shape) > 0:
            zz = eval_grid(xx, yy, lambda xy: dist.log_prob(xy)[..., 0])
        else:
            zz = eval_grid(xx, yy, dist.log_prob)
    plt.imshow(zz.exp().T,
               interpolation='bilinear',
               origin='lower',
               extent=[*xlim, *ylim])
    # plt.contourf(xx, yy, zz.exp(), cmap='viridis')
    if data is not None:
        plt.scatter(*data.T, c='k', s=4)
    plt.xlim(xlim)
    plt.ylim(ylim)
Пример #15
0
def proportional_to_joint_diagnostics(
        potential_function: Callable,
        q: Distribution,
        proposal: Optional[Distribution] = None,
        N: int = int(5e4),
) -> float:
    r"""This will evaluate the posteriors quality by investingating its importance
    weights. If q is a perfect posterior approximation then $q(\theta) \propto
    p(\theta, x_o)$. Thus we should be able to fit a line to $(q(\theta),
    p(\theta, x_o))$, whereas the slope will be proportional to the normalizing
    constant. The quality of a linear fit is hence a direct metric for the quality of q.
    We use R2 statistic.

    NOTE: In our experience this metric does distinguish "good" from "ok", but
    becomes less sensitive to distinguish "very bad" from "ok".

    Args:
        potential_function: Potential function of target.
        q: Variational distribution, should be proportional to the potential_function
        proposal: Proposal for samples. Typically this is q.
        N: Number of samples involved in the test.

    Returns:
        float: Quality metric

    """

    with torch.no_grad():
        if proposal is None:
            samples = q.sample(Size((N, )))
        else:
            samples = proposal.sample(Size((N, )))
        log_q = q.log_prob(samples)
        log_potential = potential_function(samples)

        X = log_q.exp().unsqueeze(-1)
        Y = log_potential.exp().unsqueeze(-1)
        w = torch.linalg.solve(X.T @ X, X.T @ Y)  # Linear regression

        residuals = Y - w * X
        var_res = torch.sum(residuals**2)
        var_tot = torch.sum((Y - Y.mean())**2)
        r2 = 1 - var_res / var_tot  # R2 statistic to evaluate fit
    return r2.item()
Пример #16
0
def loss(distr: Distribution,
         actions: Tensor,
         critic_value: Tensor,
         c_entropy: float,
         eps_clamp: float,
         c_kl: float,
         old_logp: Tensor,
         old_distr: Optional[Distribution] = None) -> None:
    """Computes PPO actor loss. See
    https://spinningup.openai.com/en/latest/algorithms/ppo.html
    for a detailled explanation

    :args distr: current distribution of actions,
        accepting actions of the same size as actions
    :args actions: actions performed
    :args critic_value: advantage corresponding to the actions performed
    :args c_entropy: entropy loss weighting
    :args eps_clamp: clamping parameter
    :args c_kl: kl penalty coefficient
    :args old_logp: log probabilities of the old distribution of actions
    :args old_distr: old ditribution of actions

    :return: loss
    """
    logp_action = distr.log_prob(actions)
    logr = (logp_action - old_logp)

    r_clipped = torch.where(critic_value.detach() > 0,
                            torch.clamp(logr, max=np.log(1 + eps_clamp)),
                            torch.clamp(logr,
                                        min=np.log(1 - eps_clamp))).exp()

    loss = -r_clipped * critic_value.detach()
    if c_entropy != 0.:
        loss -= c_entropy * distr.entropy()

    if c_kl != 0.:
        if old_distr is None:
            raise ValueError(
                "Optional argument old_distr is required if c_kl > 0")
        loss += c_kl * kl_divergence(old_distr, distr)

    return loss.mean()
Пример #17
0
    def compute_loss(
            self, x: Tensor, z_e: Tensor, z: Tensor, z_q: Tensor,
            x_posterior: Distribution) -> Tuple[Tensor, Dict[str, Tensor]]:
        """TODO docstring"""
        # ELBO = E[log p(x|z)] - KL(q(z)||p(z))
        log_likelihood = x_posterior.log_prob(x).mean(0)
        elbo = log_likelihood - self._kl

        stats = {
            'elbo': elbo.detach().clone(),
            'log_likelihood': log_likelihood.detach().clone(),
            'kl': self._kl.clone()
        }

        vq_loss, vq_stats = self.vq.compute_loss(z_e, z, z_q)
        loss = -elbo + vq_loss
        stats.update(vq_stats)

        return loss, stats
Пример #18
0
def importance_resampling(
    num_samples: int,
    potential_fn: Callable,
    proposal: Distribution,
    K: int = 32,
    num_samples_batch: int = 10000,
    **kwargs,
) -> Tensor:
    """Perform sequential importance resampling (SIR).

    Args:
        num_samples: Number of samples to draw.
        potential_fn: Potential function, this may be used to debias the proposal.
        proposal: Proposal distribution to propose samples.
        K: Number of proposed samples form which only one is selected based on its
            importance weight.
        num_samples_batch: Number of samples processed in parallel. For large K you may
            want to reduce this, depending on your memory capabilities.

    Returns:
        Tensor: Samples of shape (num_samples, event_shape)

    """
    final_samples = []
    num_samples_batch = min(num_samples, num_samples_batch)
    iters = int(num_samples / num_samples_batch)
    for _ in range(iters):
        batch_size = min(num_samples_batch, num_samples - len(final_samples))
        with torch.no_grad():
            thetas = proposal.sample(torch.Size((batch_size * K, )))
            logp = potential_fn(thetas)
            logq = proposal.log_prob(thetas)
            weights = (logp - logq).reshape(batch_size,
                                            K).softmax(-1).cumsum(-1)
            u = torch.rand(batch_size, 1, device=thetas.device)
            mask = torch.cumsum(weights >= u, -1) == 1
            samples = thetas.reshape(batch_size, K, -1)[mask]
            final_samples.append(samples)
    return torch.vstack(final_samples)
Пример #19
0
    def decode(
        self,
        distr: Distribution,
        joint_log_probs: Optional[storch.Tensor],
        parents: [storch.Tensor],
        orig_distr_plates: [storch.Plate],
    ) -> (storch.Tensor, storch.Tensor, storch.Tensor):
        """
        Decode given the input arguments
        :param distribution: The distribution to decode
        :param joint_log_probs: The log probabilities of the samples so far. prev_plates x amt_samples
        :param parents: List of parents of this tensor
        :param orig_distr_plates: List of plates from the distribution. Can include the self plate k.
        :return: 3-tuple of `storch.Tensor`. 1: The sampled value. 2: The new joint log probabilities of the samples.
        3: How the samples index the parent samples. Can just be a range if there is no choosing happening.
        """
        ancestral_distrplate_index = -1
        is_conditional_sample = False

        multi_dim_distr_plates = []
        multi_dim_index = 0
        for plate in orig_distr_plates:
            if plate.n > 1:
                if plate.name == self.plate_name:
                    ancestral_distrplate_index = multi_dim_index
                    is_conditional_sample = True
                else:
                    multi_dim_distr_plates.append(plate)
                multi_dim_index += 1
        # plates? x k x events x

        # TODO: This doesn't properly combine two ancestral plates with the same name but different variable index
        #  (they should merge).
        all_multi_dim_plates = multi_dim_distr_plates.copy()
        if self.variable_index > 0:
            # Previous variables have been sampled. add the prev_plates to all_plates
            for plate in self.joint_log_probs.multi_dim_plates():
                if plate not in multi_dim_distr_plates:
                    all_multi_dim_plates.append(plate)

        amt_multi_dim_plates = len(all_multi_dim_plates)
        amt_multi_dim_distr_plates = len(multi_dim_distr_plates)
        amt_multi_dim_orig_distr_plates = amt_multi_dim_distr_plates + (
            1 if is_conditional_sample else 0
        )
        amt_multi_dim_prev_plates = amt_multi_dim_plates - amt_multi_dim_distr_plates
        if not distr.has_enumerate_support:
            raise ValueError("Can only decode distributions with enumerable support.")

        with storch.ignore_wrapping():
            # |D_yv| x (|distr_plates| + |k?| + |event_dims|) * (1,) x |D_yv|
            support_non_expanded: torch.Tensor = distr.enumerate_support(expand=False)
            # Compute the log-probability of the different events
            # |D_yv| x distr_plate[0] x ... k? ... x distr_plate[n-1] x events
            d_log_probs = distr.log_prob(support_non_expanded)

            # Note: Use amt_orig_distr_plates here because it might include k? dimension. amt_distr_plates filters this one.
            # distr_plate[0] x ... k? ... x distr_plate[n-1] x |D_yv| x events
            d_log_probs = storch.Tensor(
                d_log_probs.permute(
                    tuple(range(1, amt_multi_dim_orig_distr_plates + 1))
                    + (0,)
                    + tuple(
                        range(
                            amt_multi_dim_orig_distr_plates + 1, len(d_log_probs.shape)
                        )
                    )
                ),
                [],
                orig_distr_plates,
            )

        # |D_yv| x distr_plate[0] x ... x k? x ... x distr_plate[n-1] x events x event_shape
        support = distr.enumerate_support(expand=True)

        if is_conditional_sample:
            # Reduce ancestral dimension in the support. As the dimension is just an expanded version, this should
            # not change the underlying data.
            # |D_yv| x distr_plates x events x event_shape
            support = support[(slice(None),) * (ancestral_distrplate_index + 1) + (0,)]

            # Gather the correct log probabilities
            # distr_plate[0] x ... k ... x distr_plate[n-1] x |D_yv| x events
            # TODO: Move this down below to the other scary TODO
            d_log_probs = self.new_plate.on_unwrap_tensor(d_log_probs)
            # Permute the dimensions of d_log_probs st the k dimension is after the plates.
            for i, plate in enumerate(d_log_probs.multi_dim_plates()):
                if plate.name == self.plate_name:
                    d_log_probs.plates.remove(plate)
                    # distr_plates x k x |D_yv| x events
                    d_log_probs._tensor = d_log_probs._tensor.permute(
                        tuple(range(0, i))
                        + tuple(range(i + 1, amt_multi_dim_orig_distr_plates))
                        + (i,)
                        + tuple(
                            range(
                                amt_multi_dim_orig_distr_plates, len(d_log_probs.shape)
                            )
                        )
                    )
                    break

        # Equal to event_shape
        element_shape = distr.event_shape
        support_permutation = (
            tuple(range(1, amt_multi_dim_distr_plates + 1))
            + (0,)
            + tuple(range(amt_multi_dim_distr_plates + 1, len(support.shape)))
        )
        # distr_plates x |D_yv| x events x event_shape
        support = support.permute(support_permutation)

        if amt_multi_dim_plates != amt_multi_dim_distr_plates:
            # If previous samples had plate dimensions that are not in the distribution plates, add these to the support.
            support = support[
                (slice(None),) * amt_multi_dim_distr_plates
                + (None,) * amt_multi_dim_prev_plates
            ]
            all_plate_dims = tuple(map(lambda _p: _p.n, all_multi_dim_plates))
            # plates x |D_yv| x events x event_shape (where plates = distr_plates x prev_plates)
            support = support.expand(
                all_plate_dims + (-1,) * (len(support.shape) - amt_multi_dim_plates)
            )
        # plates x |D_yv| x events x event_shape
        support = storch.Tensor(support, [], all_multi_dim_plates)

        # Equal to events: Shape for the different conditional independent dimensions
        event_shape = support.shape[
            amt_multi_dim_plates + 1 : -len(element_shape)
            if len(element_shape) > 0
            else None
        ]

        ranges = []
        for size in event_shape:
            ranges.append(list(range(size)))

        amt_samples = 0
        parent_indexing = None
        if joint_log_probs is not None:
            # Initialize a tensor (self.parent_indexing) that keeps track of what samples link to previous choices of samples
            # Note that joint_log_probs.shape[-1] is amt_samples, not k. It's possible that amt_samples < k!
            amt_samples = joint_log_probs.shape[-1]
            # plates x k
            parent_indexing = support.new_zeros(
                size=support.shape[:amt_multi_dim_plates] + (self.k,), dtype=torch.long
            )

            # probably can go wrong if plates are missing.
            parent_indexing[..., :amt_samples] = left_expand_as(
                torch.arange(amt_samples), parent_indexing
            )
        # plates x k x events
        sampled_support_indices = support.new_zeros(
            size=support.shape[:amt_multi_dim_plates]  # plates
            + (self.k,)
            + support.shape[
                amt_multi_dim_plates + 1 : -len(element_shape)
                if len(element_shape) > 0
                else None
            ],  # events
            dtype=torch.long,
        )
        # Sample independent tensors in sequence
        # Iterate over the different (conditionally) independent samples being taken (the events)
        for indices in itertools.product(*ranges):
            # Log probabilities of the different options for this sample step (event)
            # distr_plates x k? x |D_yv|
            yv_log_probs = d_log_probs[(...,) + indices]
            (
                sampled_support_indices,
                joint_log_probs,
                parent_indexing,
                amt_samples,
            ) = self.decode_step(
                indices,
                yv_log_probs,
                joint_log_probs,
                sampled_support_indices,
                parent_indexing,
                is_conditional_sample,
                amt_multi_dim_plates,
                amt_samples,
            )
        # Finally, index the support using the sampled indices to get the sample!
        if amt_samples < self.k:
            # plates x amt_samples x events
            sampled_support_indices = sampled_support_indices[
                (...,) + (slice(amt_samples),) + (slice(None),) * len(ranges)
            ]
        expanded_indices = right_expand_as(sampled_support_indices, support)
        sample = support.gather(dim=amt_multi_dim_plates, index=expanded_indices)
        return sample, joint_log_probs, parent_indexing
Пример #20
0
 def get_log_prob(self, dist: Distribution,
                  action: torch.tensor) -> torch.tensor:
     if self.action_space.is_discrete:
         return dist.log_prob(action.squeeze(-1)).unsqueeze(-1)
     else:
         return dist.log_prob(action).sum(axis=-1, keepdims=True)
Пример #21
0
 def _weight_with_kernel(self, y: torch.Tensor, x_dist: Distribution,
                         x_new: TimeseriesState,
                         kernel: Distribution) -> torch.Tensor:
     y_dist = self._model.build_density(x_new)
     return y_dist.log_prob(y) + x_dist.log_prob(
         x_new.values) - kernel.log_prob(x_new.values)
Пример #22
0
def run_pmmh(
        context: ParameterContext,
        state: FilterAlgorithmState,
        proposal: BaseProposal,
        proposal_kernel: Distribution,
        proposal_filter: BaseFilter,
        proposal_context: ParameterContext,
        y: torch.Tensor,
        size=torch.Size([]),
        mutate_kernel=False,
) -> torch.BoolTensor:
    r"""
    Runs one iteration of the PMMH update step in which we sample a candidate :math:`\theta^*` from the proposal
    kernel, run a filter for the considered dataset with :math:`\theta^*`, and accept based on the acceptance
    probability given by the article.

    Args:
        context: the parameter context of the main algorithm.
        state: the latest algorithm state.
        proposal: the proposal to use when generating the candidate sample :math:`\theta^*`.
        proposal_kernel: the kernel from which to draw the candidate sample :math:`\theta^*`. To clarify, ``proposal``
            corresponds to the ``BaseProposal`` class that was used when generating ``prop_kernel``.
        proposal_filter: the proposal filter to use.
        proposal_context: the parameter context of the proposal filter.
        y: see ``pyfilter.inference.base.BaseAlgorithm``.
        size: optional parameter specifying the number of num_samples to draw from ``proposal_kernel``. Should be empty
            if we draw from an independent kernel.
        mutate_kernel: optional parameter specifying whether to update ``proposal_kernel`` with the newly accepted
            candidate sample.

    Returns:
        Returns the candidate sample(s) that were accepted.
    """

    constrained = False

    rvs = proposal_kernel.sample(size)
    proposal_context.unstack_parameters(rvs, constrained=constrained)

    new_res = proposal_filter.batch_filter(y, bar=False)

    diff_logl = new_res.loglikelihood - state.filter_state.loglikelihood
    diff_prior = proposal_context.eval_priors(
        constrained=constrained) - context.eval_priors(constrained=constrained)

    new_prop_kernel = proposal.build(proposal_context,
                                     state.replicate(new_res), proposal_filter,
                                     y)
    params_as_tensor = context.stack_parameters(constrained=constrained)

    diff_prop = new_prop_kernel.log_prob(
        params_as_tensor) - proposal_kernel.log_prob(rvs)

    log_acc_prob = diff_prop + diff_prior.squeeze(-1) + diff_logl
    accepted: torch.BoolTensor = torch.empty_like(
        log_acc_prob).uniform_().log() < log_acc_prob

    state.filter_state.exchange(new_res, accepted)
    context.exchange(proposal_context, accepted)

    if mutate_kernel:
        proposal.exchange(proposal_kernel, new_prop_kernel, accepted)

    return accepted
    def compute_divergence(self, sample: Tensor,
                           posterior: td.Distribution) -> Tensor:
        log_p = self.prior.log_prob(sample)
        log_q = posterior.log_prob(sample)

        return (log_q - log_p).sum()