def test_dist_normal_multivariate_batched(self): dist_sample_shape_correct = [2, 3] dist_means_correct = [[0, 2, 0], [2, 0, 2]] dist_stddevs_correct = [[1, 3, 1], [3, 1, 3]] dist_log_probs_correct = [[sum([-0.918939, -2.01755, -0.918939])], [sum([-2.01755, -0.918939, -2.01755])]] dist = Normal(dist_means_correct, dist_stddevs_correct) dist_sample_shape = list(dist.sample().size()) dist_empirical = Empirical( [dist.sample() for i in range(empirical_samples)]) dist_means = util.to_numpy(dist.mean) dist_means_empirical = util.to_numpy(dist_empirical.mean) dist_stddevs = util.to_numpy(dist.stddev) dist_stddevs_empirical = util.to_numpy(dist_empirical.stddev) dist_log_probs = util.to_numpy(dist.log_prob(dist_means_correct)) util.debug('dist_sample_shape', 'dist_sample_shape_correct', 'dist_means', 'dist_means_empirical', 'dist_means_correct', 'dist_stddevs', 'dist_stddevs_empirical', 'dist_stddevs_correct', 'dist_log_probs', 'dist_log_probs_correct') self.assertEqual(dist_sample_shape, dist_sample_shape_correct) self.assertTrue(np.allclose(dist_means, dist_means_correct, atol=0.1)) self.assertTrue( np.allclose(dist_means_empirical, dist_means_correct, atol=0.1)) self.assertTrue( np.allclose(dist_stddevs, dist_stddevs_correct, atol=0.1)) self.assertTrue( np.allclose(dist_stddevs_empirical, dist_stddevs_correct, atol=0.1)) self.assertTrue( np.allclose(dist_log_probs, dist_log_probs_correct, atol=0.1))
def test_dist_normal(self): dist_sample_shape_correct = [1] dist_means_correct = [0] dist_stddevs_correct = [1] dist_log_probs_correct = [-0.918939] dist = Normal(dist_means_correct, dist_stddevs_correct) dist_sample_shape = list(dist.sample().size()) dist_empirical = Empirical( [dist.sample() for i in range(empirical_samples)]) dist_means = util.to_numpy(dist.mean) dist_means_empirical = util.to_numpy(dist_empirical.mean) dist_stddevs = util.to_numpy(dist.stddev) dist_stddevs_empirical = util.to_numpy(dist_empirical.stddev) dist_log_probs = util.to_numpy(dist.log_prob(dist_means_correct)) util.debug('dist_sample_shape', 'dist_sample_shape_correct', 'dist_means', 'dist_means_empirical', 'dist_means_correct', 'dist_stddevs', 'dist_stddevs_empirical', 'dist_stddevs_correct', 'dist_log_probs', 'dist_log_probs_correct') self.assertEqual(dist_sample_shape, dist_sample_shape_correct) self.assertTrue(np.allclose(dist_means, dist_means_correct, atol=0.1)) self.assertTrue( np.allclose(dist_means_empirical, dist_means_correct, atol=0.1)) self.assertTrue( np.allclose(dist_stddevs, dist_stddevs_correct, atol=0.1)) self.assertTrue( np.allclose(dist_stddevs_empirical, dist_stddevs_correct, atol=0.1)) self.assertTrue( np.allclose(dist_log_probs, dist_log_probs_correct, atol=0.1))
def test_dist_empirical_resample(self): dist_means_correct = [2] dist_stddevs_correct = [5] dist = Normal(dist_means_correct, dist_stddevs_correct) dist_empirical = Empirical([dist.sample() for i in range(empirical_samples)]) dist_empirical = dist_empirical.resample(int(empirical_samples/2)) dist_means_empirical = util.to_numpy(dist_empirical.mean) dist_stddevs_empirical = util.to_numpy(dist_empirical.stddev) util.debug('dist_means_empirical', 'dist_means_correct', 'dist_stddevs_empirical', 'dist_stddevs_correct') self.assertTrue(np.allclose(dist_means_empirical, dist_means_correct, atol=0.25)) self.assertTrue(np.allclose(dist_stddevs_empirical, dist_stddevs_correct, atol=0.25))
def test_dist_empirical_combine_uniform_weights(self): dist1_mean_correct = 1 dist1_stddev_correct = 3 dist2_mean_correct = 5 dist2_stddev_correct = 2 dist3_mean_correct = -2.5 dist3_stddev_correct = 1.2 dist_combined_mean_correct = 1.16667 dist_combined_stddev_correct = 3.76858 dist1 = Normal(dist1_mean_correct, dist1_stddev_correct) dist1_empirical = Empirical([dist1.sample() for i in range(empirical_samples)]) dist1_empirical_mean = float(dist1_empirical.mean) dist1_empirical_stddev = float(dist1_empirical.stddev) dist2 = Normal(dist2_mean_correct, dist2_stddev_correct) dist2_empirical = Empirical([dist2.sample() for i in range(empirical_samples)]) dist2_empirical_mean = float(dist2_empirical.mean) dist2_empirical_stddev = float(dist2_empirical.stddev) dist3 = Normal(dist3_mean_correct, dist3_stddev_correct) dist3_empirical = Empirical([dist3.sample() for i in range(empirical_samples)]) dist3_empirical_mean = float(dist3_empirical.mean) dist3_empirical_stddev = float(dist3_empirical.stddev) dist_combined_empirical = Empirical.combine([dist1_empirical, dist2_empirical, dist3_empirical]) dist_combined_empirical_mean = float(dist_combined_empirical.mean) dist_combined_empirical_stddev = float(dist_combined_empirical.stddev) util.debug('dist1_empirical_mean', 'dist1_empirical_stddev', 'dist1_mean_correct', 'dist1_stddev_correct', 'dist2_empirical_mean', 'dist2_empirical_stddev', 'dist2_mean_correct', 'dist2_stddev_correct', 'dist3_empirical_mean', 'dist3_empirical_stddev', 'dist3_mean_correct', 'dist3_stddev_correct', 'dist_combined_empirical_mean', 'dist_combined_empirical_stddev', 'dist_combined_mean_correct', 'dist_combined_stddev_correct') self.assertAlmostEqual(dist1_empirical_mean, dist1_mean_correct, places=1) self.assertAlmostEqual(dist1_empirical_stddev, dist1_stddev_correct, places=1) self.assertAlmostEqual(dist2_empirical_mean, dist2_mean_correct, places=1) self.assertAlmostEqual(dist2_empirical_stddev, dist2_stddev_correct, places=1) self.assertAlmostEqual(dist3_empirical_mean, dist3_mean_correct, places=1) self.assertAlmostEqual(dist3_empirical_stddev, dist3_stddev_correct, places=1) self.assertAlmostEqual(dist_combined_empirical_mean, dist_combined_mean_correct, places=1) self.assertAlmostEqual(dist_combined_empirical_stddev, dist_combined_stddev_correct, places=1)