Esempio n. 1
0
def test_weighted_waic():
    a = 1 + torch.rand(10)
    b = 1 + torch.rand(10)
    c = 1 + torch.rand(10)
    expanded_x = torch.stack([a, b, c, a, b, a, c, a, c]).log()
    x = torch.stack([a, b, c]).log()
    log_weights = torch.tensor([4.0, 2, 3]).log()
    # assume weights are unnormalized
    log_weights = log_weights - torch.randn(1)

    w1, p1 = waic(x, log_weights)
    w2, p2 = waic(expanded_x)

    # test lpd
    lpd1 = -0.5 * w1 + p1
    lpd2 = -0.5 * w2 + p2
    assert_equal(lpd1, lpd2)

    # test p_waic (also test for weighted_variance)
    unbiased_p1 = p1 * 2 / 3
    unbiased_p2 = p2 * 8 / 9
    assert_equal(unbiased_p1, unbiased_p2)

    # test correctness for dim=-1
    w3, p3 = waic(x.t(), log_weights, dim=-1)
    assert_equal(w1, w3)
    assert_equal(p1, p3)
Esempio n. 2
0
def test_waic():
    x = -torch.arange(1.0, 101).log().reshape(25, 4)
    w_pw, p_pw = waic(x, pointwise=True)
    w, p = waic(x)
    w1, p1 = waic(x.t(), dim=1)

    # test against loo package: http://mc-stan.org/loo/reference/waic.html
    assert_equal(w_pw, torch.tensor([7.49, 7.75, 7.86, 7.92]), prec=0.01)
    assert_equal(p_pw, torch.tensor([1.14, 0.91, 0.79, 0.70]), prec=0.01)

    assert_equal(w, w_pw.sum())
    assert_equal(p, p_pw.sum())

    assert_equal(w, w1)
    assert_equal(p, p1)
Esempio n. 3
0
    def information_criterion(self, pointwise=False):
        """
        Computes information criterion of the model. Currently, returns only "Widely
        Applicable/Watanabe-Akaike Information Criterion" (WAIC) and the corresponding
        effective number of parameters.

        Reference:

        [1] `Practical Bayesian model evaluation using leave-one-out cross-validation and WAIC`,
        Aki Vehtari, Andrew Gelman, and Jonah Gabry

        :param bool pointwise: a flag to decide if we want to get a vectorized WAIC or not. When
            ``pointwise=False``, returns the sum.
        :returns: a dictionary containing values of WAIC and its effective number of
            parameters.
        :rtype: :class:`OrderedDict`
        """
        if not self.exec_traces:
            return {}
        obs_node = None
        log_likelihoods = []
        for trace in self.exec_traces:
            obs_nodes = trace.observation_nodes
            if len(obs_nodes) > 1:
                raise ValueError(
                    "Infomation criterion calculation only works for models "
                    "with one observation node.")
            if obs_node is None:
                obs_node = obs_nodes[0]
            elif obs_node != obs_nodes[0]:
                raise ValueError(
                    "Observation node has been changed, expected {} but got {}"
                    .format(obs_node, obs_nodes[0]))

            log_likelihoods.append(trace.nodes[obs_node]["fn"].log_prob(
                trace.nodes[obs_node]["value"]))

        ll = torch.stack(log_likelihoods, dim=0)
        waic_value, p_waic = waic(
            ll, torch.tensor(self.log_weights, device=ll.device), pointwise)
        return OrderedDict([("waic", waic_value), ("p_waic", p_waic)])