예제 #1
0
파일: test_stats.py 프로젝트: pyro-ppl/pyro
def test_fit_generalized_pareto(k, sigma, n_samples=5000):
    with warnings.catch_warnings():
        warnings.filterwarnings("ignore", category=RuntimeWarning)
        from scipy.stats import genpareto

    X = genpareto.rvs(c=k, scale=sigma, size=n_samples)
    fit_k, fit_sigma = fit_generalized_pareto(torch.tensor(X))
    assert_equal(k, fit_k, prec=0.02)
    assert_equal(sigma, fit_sigma, prec=0.02)
예제 #2
0
def psis_diagnostic(model, guide, *args, **kwargs):
    """
    Computes the Pareto tail index k for a model/guide pair using the technique
    described in [1], which builds on previous work in [2]. If :math:`0 < k < 0.5`
    the guide is a good approximation to the model posterior, in the sense
    described in [1]. If :math:`0.5 \\le k \\le 0.7`, the guide provides a suboptimal
    approximation to the posterior, but may still be useful in practice. If
    :math:`k > 0.7` the guide program provides a poor approximation to the full
    posterior, and caution should be used when using the guide. Note, however,
    that a guide may be a poor fit to the full posterior while still yielding
    reasonable model predictions. If :math:`k < 0.0` the importance weights
    corresponding to the model and guide appear to be bounded from above; this
    would be a bizarre outcome for a guide trained via ELBO maximization. Please
    see [1] for a more complete discussion of how the tail index k should be
    interpreted.

    Please be advised that a large number of samples may be required for an
    accurate estimate of k.

    Note that we assume that the model and guide are both vectorized and have
    static structure. As is canonical in Pyro, the args and kwargs are passed
    to the model and guide.

    References
    [1] 'Yes, but Did It Work?: Evaluating Variational Inference.'
    Yuling Yao, Aki Vehtari, Daniel Simpson, Andrew Gelman
    [2] 'Pareto Smoothed Importance Sampling.'
    Aki Vehtari, Andrew Gelman, Jonah Gabry

    :param callable model: the model program.
    :param callable guide: the guide program.
    :param int num_particles: the total number of times we run the model and guide in
        order to compute the diagnostic. defaults to 1000.
    :param max_simultaneous_particles: the maximum number of simultaneous samples drawn
        from the model and guide. defaults to `num_particles`. `num_particles` must be
        divisible by `max_simultaneous_particles`. compute the diagnostic. defaults to 1000.
    :param int max_plate_nesting: optional bound on max number of nested :func:`pyro.plate`
        contexts in the model/guide. defaults to 7.
    :returns float: the PSIS diagnostic k
    """

    num_particles = kwargs.pop('num_particles', 1000)
    max_simultaneous_particles = kwargs.pop('max_simultaneous_particles',
                                            num_particles)
    max_plate_nesting = kwargs.pop('max_plate_nesting', 7)

    if num_particles % max_simultaneous_particles != 0:
        raise ValueError(
            "num_particles must be divisible by max_simultaneous_particles.")

    N = num_particles // max_simultaneous_particles
    log_weights = [
        vectorized_importance_weights(model,
                                      guide,
                                      num_samples=max_simultaneous_particles,
                                      max_plate_nesting=max_plate_nesting,
                                      *args,
                                      **kwargs)[0] for _ in range(N)
    ]
    log_weights = torch.cat(log_weights)
    log_weights -= log_weights.max()
    log_weights = torch.sort(log_weights, descending=False)[0]

    cutoff_index = -int(
        math.ceil(min(0.2 * num_particles,
                      3.0 * math.sqrt(num_particles)))) - 1
    lw_cutoff = max(math.log(1.0e-15), log_weights[cutoff_index])
    lw_tail = log_weights[log_weights > lw_cutoff]

    if len(lw_tail) < 10:
        warnings.warn(
            "Not enough tail samples to compute PSIS diagnostic; increase num_particles."
        )
        k = float('inf')
    else:
        k, _ = fit_generalized_pareto(lw_tail.exp() - math.exp(lw_cutoff))

    return k