def test_dist_mixture_batched(self): dist_sample_shape_correct = [2, 1] dist_1 = Normal([[0], [1]], [[0.1], [1]]) dist_2 = Normal([[2], [5]], [[0.1], [1]]) dist_3 = Normal([[3], [10]], [[0.1], [1]]) dist_means_correct = [[0.7], [8.1]] dist_stddevs_correct = [[1.10454], [3.23883]] dist_log_probs_correct = [[-23.473], [-3.06649]] dist = Mixture([dist_1, dist_2, dist_3], probs=[[0.7, 0.2, 0.1], [0.1, 0.2, 0.7]]) dist_sample_shape = list(dist.sample().size()) dist_empirical = Empirical([dist.sample() for i in range(empirical_samples)]) dist_means = util.to_numpy(dist.mean) dist_means_empirical = util.to_numpy(dist_empirical.mean) dist_stddevs = util.to_numpy(dist.stddev) dist_stddevs_empirical = util.to_numpy(dist_empirical.stddev) dist_log_probs = util.to_numpy(dist.log_prob(dist_means_correct)) util.debug('dist_sample_shape', 'dist_sample_shape_correct', 'dist_means', 'dist_means_empirical', 'dist_means_correct', 'dist_stddevs', 'dist_stddevs_empirical', 'dist_stddevs_correct', 'dist_log_probs', 'dist_log_probs_correct') self.assertEqual(dist_sample_shape, dist_sample_shape_correct) self.assertTrue(np.allclose(dist_means, dist_means_correct, atol=0.1)) self.assertTrue(np.allclose(dist_means_empirical, dist_means_correct, atol=0.1)) self.assertTrue(np.allclose(dist_stddevs, dist_stddevs_correct, atol=0.1)) self.assertTrue(np.allclose(dist_stddevs_empirical, dist_stddevs_correct, atol=0.1)) self.assertTrue(np.allclose(dist_log_probs, dist_log_probs_correct, atol=0.1))
def test_dist_mixture(self): dist_sample_shape_correct = [1] dist_1 = Normal(0, 0.1) dist_2 = Normal(2, 0.1) dist_3 = Normal(3, 0.1) dist_means_correct = [0.7] dist_stddevs_correct = [1.10454] dist_log_probs_correct = [-23.473] dist = Mixture([dist_1, dist_2, dist_3], probs=[0.7, 0.2, 0.1]) dist_sample_shape = list(dist.sample().size()) dist_empirical = Empirical([dist.sample() for i in range(empirical_samples)]) dist_means = util.to_numpy(dist.mean) dist_means_empirical = util.to_numpy(dist_empirical.mean) dist_stddevs = util.to_numpy(dist.stddev) dist_stddevs_empirical = util.to_numpy(dist_empirical.stddev) dist_log_probs = util.to_numpy(dist.log_prob(dist_means_correct)) # print(dist.log_prob([2,2])) util.debug('dist_sample_shape', 'dist_sample_shape_correct', 'dist_means', 'dist_means_empirical', 'dist_means_correct', 'dist_stddevs', 'dist_stddevs_empirical', 'dist_stddevs_correct', 'dist_log_probs', 'dist_log_probs_correct') self.assertEqual(dist_sample_shape, dist_sample_shape_correct) self.assertTrue(np.allclose(dist_means, dist_means_correct, atol=0.1)) self.assertTrue(np.allclose(dist_means_empirical, dist_means_correct, atol=0.1)) self.assertTrue(np.allclose(dist_stddevs, dist_stddevs_correct, atol=0.1)) self.assertTrue(np.allclose(dist_stddevs_empirical, dist_stddevs_correct, atol=0.1)) self.assertTrue(np.allclose(dist_log_probs, dist_log_probs_correct, atol=0.1))