def test_kl_divergence(self): model1 = beer.NormalGammaPrior(self.mean, self.precision, self.prior_count) model2 = beer.NormalGammaPrior(self.mean, self.precision, self.prior_count) div = beer.kl_div(model1, model2) self.assertAlmostEqual(div, 0.)
def create_modelset_diag(ncomps, dim, type_t): posts = [ beer.NormalGammaPrior( torch.zeros(dim).type(type_t), torch.ones(dim).type(type_t), 1.) for _ in range(ncomps) ] modelset = beer.NormalDiagonalCovarianceSet( beer.NormalGammaPrior( torch.zeros(dim).type(type_t), torch.ones(dim).type(type_t), 1.), posts) return modelset
def create_normalgamma(t_type): dim = int(1 + torch.randint(100, (1, 1)).item()) mean = torch.randn(dim).type(t_type) scale = (1 + torch.randn(dim)**2).type(t_type) shape = (1 + torch.randn(dim)**2).type(t_type) rate = (1 + torch.randn(dim)**2).type(t_type) return beer.NormalGammaPrior(mean, scale, shape, rate)
def test_log_norm(self): model = beer.NormalGammaPrior(self.mean, self.precision, self.prior_count) model_log_norm = model.log_norm.numpy() natural_params = model.natural_params.numpy() log_norm = normalgamma_log_norm(natural_params) self.assertAlmostEqual(model_log_norm, log_norm)
def setUp(self): self.dim = int(1 + torch.randint(100, (1, 1)).item()) self.mean = torch.randn(self.dim).type(self.type) self.scale = (1 + torch.randn(self.dim)**2).type(self.type) self.shape = (1 + torch.randn(self.dim)**2).type(self.type) self.rate = (1 + torch.randn(self.dim)**2).type(self.type) self.model = beer.NormalGammaPrior(self.mean, self.scale, self.shape, self.rate)
def test_exp_sufficient_statistics(self): model = beer.NormalGammaPrior(self.mean, self.precision, self.prior_count) model_s_stats = model.expected_sufficient_statistics.numpy() natural_params = model.natural_params.numpy() s_stats = normalgamma_grad_log_norm(natural_params) self.assertTrue(np.allclose(model_s_stats, s_stats, rtol=TOL, atol=TOL))
def test_create(self): model = beer.NormalGammaPrior(self.mean, self.precision, self.prior_count) self.assertTrue(isinstance(model, beer.ExpFamilyDensity))