def test_weighted_waic(): a = 1 + torch.rand(10) b = 1 + torch.rand(10) c = 1 + torch.rand(10) expanded_x = torch.stack([a, b, c, a, b, a, c, a, c]).log() x = torch.stack([a, b, c]).log() log_weights = torch.tensor([4.0, 2, 3]).log() # assume weights are unnormalized log_weights = log_weights - torch.randn(1) w1, p1 = waic(x, log_weights) w2, p2 = waic(expanded_x) # test lpd lpd1 = -0.5 * w1 + p1 lpd2 = -0.5 * w2 + p2 assert_equal(lpd1, lpd2) # test p_waic (also test for weighted_variance) unbiased_p1 = p1 * 2 / 3 unbiased_p2 = p2 * 8 / 9 assert_equal(unbiased_p1, unbiased_p2) # test correctness for dim=-1 w3, p3 = waic(x.t(), log_weights, dim=-1) assert_equal(w1, w3) assert_equal(p1, p3)
def test_waic(): x = -torch.arange(1.0, 101).log().reshape(25, 4) w_pw, p_pw = waic(x, pointwise=True) w, p = waic(x) w1, p1 = waic(x.t(), dim=1) # test against loo package: http://mc-stan.org/loo/reference/waic.html assert_equal(w_pw, torch.tensor([7.49, 7.75, 7.86, 7.92]), prec=0.01) assert_equal(p_pw, torch.tensor([1.14, 0.91, 0.79, 0.70]), prec=0.01) assert_equal(w, w_pw.sum()) assert_equal(p, p_pw.sum()) assert_equal(w, w1) assert_equal(p, p1)
def information_criterion(self, pointwise=False): """ Computes information criterion of the model. Currently, returns only "Widely Applicable/Watanabe-Akaike Information Criterion" (WAIC) and the corresponding effective number of parameters. Reference: [1] `Practical Bayesian model evaluation using leave-one-out cross-validation and WAIC`, Aki Vehtari, Andrew Gelman, and Jonah Gabry :param bool pointwise: a flag to decide if we want to get a vectorized WAIC or not. When ``pointwise=False``, returns the sum. :returns: a dictionary containing values of WAIC and its effective number of parameters. :rtype: :class:`OrderedDict` """ if not self.exec_traces: return {} obs_node = None log_likelihoods = [] for trace in self.exec_traces: obs_nodes = trace.observation_nodes if len(obs_nodes) > 1: raise ValueError( "Infomation criterion calculation only works for models " "with one observation node.") if obs_node is None: obs_node = obs_nodes[0] elif obs_node != obs_nodes[0]: raise ValueError( "Observation node has been changed, expected {} but got {}" .format(obs_node, obs_nodes[0])) log_likelihoods.append(trace.nodes[obs_node]["fn"].log_prob( trace.nodes[obs_node]["value"])) ll = torch.stack(log_likelihoods, dim=0) waic_value, p_waic = waic( ll, torch.tensor(self.log_weights, device=ll.device), pointwise) return OrderedDict([("waic", waic_value), ("p_waic", p_waic)])