예제 #1
0
    def test_dist_normal_multivariate_batched(self):
        dist_sample_shape_correct = [2, 3]
        dist_means_correct = [[0, 2, 0], [2, 0, 2]]
        dist_stddevs_correct = [[1, 3, 1], [3, 1, 3]]
        dist_log_probs_correct = [[sum([-0.918939, -2.01755, -0.918939])],
                                  [sum([-2.01755, -0.918939, -2.01755])]]

        dist = Normal(dist_means_correct, dist_stddevs_correct)
        dist_sample_shape = list(dist.sample().size())
        dist_empirical = Empirical(
            [dist.sample() for i in range(empirical_samples)])
        dist_means = util.to_numpy(dist.mean)
        dist_means_empirical = util.to_numpy(dist_empirical.mean)
        dist_stddevs = util.to_numpy(dist.stddev)
        dist_stddevs_empirical = util.to_numpy(dist_empirical.stddev)
        dist_log_probs = util.to_numpy(dist.log_prob(dist_means_correct))

        util.debug('dist_sample_shape', 'dist_sample_shape_correct',
                   'dist_means', 'dist_means_empirical', 'dist_means_correct',
                   'dist_stddevs', 'dist_stddevs_empirical',
                   'dist_stddevs_correct', 'dist_log_probs',
                   'dist_log_probs_correct')

        self.assertEqual(dist_sample_shape, dist_sample_shape_correct)
        self.assertTrue(np.allclose(dist_means, dist_means_correct, atol=0.1))
        self.assertTrue(
            np.allclose(dist_means_empirical, dist_means_correct, atol=0.1))
        self.assertTrue(
            np.allclose(dist_stddevs, dist_stddevs_correct, atol=0.1))
        self.assertTrue(
            np.allclose(dist_stddevs_empirical, dist_stddevs_correct,
                        atol=0.1))
        self.assertTrue(
            np.allclose(dist_log_probs, dist_log_probs_correct, atol=0.1))
예제 #2
0
    def test_dist_normal(self):
        dist_sample_shape_correct = [1]
        dist_means_correct = [0]
        dist_stddevs_correct = [1]
        dist_log_probs_correct = [-0.918939]

        dist = Normal(dist_means_correct, dist_stddevs_correct)
        dist_sample_shape = list(dist.sample().size())
        dist_empirical = Empirical(
            [dist.sample() for i in range(empirical_samples)])
        dist_means = util.to_numpy(dist.mean)
        dist_means_empirical = util.to_numpy(dist_empirical.mean)
        dist_stddevs = util.to_numpy(dist.stddev)
        dist_stddevs_empirical = util.to_numpy(dist_empirical.stddev)
        dist_log_probs = util.to_numpy(dist.log_prob(dist_means_correct))

        util.debug('dist_sample_shape', 'dist_sample_shape_correct',
                   'dist_means', 'dist_means_empirical', 'dist_means_correct',
                   'dist_stddevs', 'dist_stddevs_empirical',
                   'dist_stddevs_correct', 'dist_log_probs',
                   'dist_log_probs_correct')

        self.assertEqual(dist_sample_shape, dist_sample_shape_correct)
        self.assertTrue(np.allclose(dist_means, dist_means_correct, atol=0.1))
        self.assertTrue(
            np.allclose(dist_means_empirical, dist_means_correct, atol=0.1))
        self.assertTrue(
            np.allclose(dist_stddevs, dist_stddevs_correct, atol=0.1))
        self.assertTrue(
            np.allclose(dist_stddevs_empirical, dist_stddevs_correct,
                        atol=0.1))
        self.assertTrue(
            np.allclose(dist_log_probs, dist_log_probs_correct, atol=0.1))
예제 #3
0
    def test_dist_empirical_resample(self):
        dist_means_correct = [2]
        dist_stddevs_correct = [5]

        dist = Normal(dist_means_correct, dist_stddevs_correct)
        dist_empirical = Empirical([dist.sample() for i in range(empirical_samples)])
        dist_empirical = dist_empirical.resample(int(empirical_samples/2))
        dist_means_empirical = util.to_numpy(dist_empirical.mean)
        dist_stddevs_empirical = util.to_numpy(dist_empirical.stddev)

        util.debug('dist_means_empirical', 'dist_means_correct', 'dist_stddevs_empirical', 'dist_stddevs_correct')

        self.assertTrue(np.allclose(dist_means_empirical, dist_means_correct, atol=0.25))
        self.assertTrue(np.allclose(dist_stddevs_empirical, dist_stddevs_correct, atol=0.25))
예제 #4
0
    def test_dist_empirical_combine_uniform_weights(self):
        dist1_mean_correct = 1
        dist1_stddev_correct = 3
        dist2_mean_correct = 5
        dist2_stddev_correct = 2
        dist3_mean_correct = -2.5
        dist3_stddev_correct = 1.2
        dist_combined_mean_correct = 1.16667
        dist_combined_stddev_correct = 3.76858

        dist1 = Normal(dist1_mean_correct, dist1_stddev_correct)
        dist1_empirical = Empirical([dist1.sample() for i in range(empirical_samples)])
        dist1_empirical_mean = float(dist1_empirical.mean)
        dist1_empirical_stddev = float(dist1_empirical.stddev)
        dist2 = Normal(dist2_mean_correct, dist2_stddev_correct)
        dist2_empirical = Empirical([dist2.sample() for i in range(empirical_samples)])
        dist2_empirical_mean = float(dist2_empirical.mean)
        dist2_empirical_stddev = float(dist2_empirical.stddev)
        dist3 = Normal(dist3_mean_correct, dist3_stddev_correct)
        dist3_empirical = Empirical([dist3.sample() for i in range(empirical_samples)])
        dist3_empirical_mean = float(dist3_empirical.mean)
        dist3_empirical_stddev = float(dist3_empirical.stddev)
        dist_combined_empirical = Empirical.combine([dist1_empirical, dist2_empirical, dist3_empirical])
        dist_combined_empirical_mean = float(dist_combined_empirical.mean)
        dist_combined_empirical_stddev = float(dist_combined_empirical.stddev)

        util.debug('dist1_empirical_mean', 'dist1_empirical_stddev', 'dist1_mean_correct', 'dist1_stddev_correct', 'dist2_empirical_mean', 'dist2_empirical_stddev', 'dist2_mean_correct', 'dist2_stddev_correct', 'dist3_empirical_mean', 'dist3_empirical_stddev', 'dist3_mean_correct', 'dist3_stddev_correct', 'dist_combined_empirical_mean', 'dist_combined_empirical_stddev', 'dist_combined_mean_correct', 'dist_combined_stddev_correct')

        self.assertAlmostEqual(dist1_empirical_mean, dist1_mean_correct, places=1)
        self.assertAlmostEqual(dist1_empirical_stddev, dist1_stddev_correct, places=1)
        self.assertAlmostEqual(dist2_empirical_mean, dist2_mean_correct, places=1)
        self.assertAlmostEqual(dist2_empirical_stddev, dist2_stddev_correct, places=1)
        self.assertAlmostEqual(dist3_empirical_mean, dist3_mean_correct, places=1)
        self.assertAlmostEqual(dist3_empirical_stddev, dist3_stddev_correct, places=1)
        self.assertAlmostEqual(dist_combined_empirical_mean, dist_combined_mean_correct, places=1)
        self.assertAlmostEqual(dist_combined_empirical_stddev, dist_combined_stddev_correct, places=1)