def test_dist_poisson_multivariate_batched(self): dist_sample_shape_correct = [2, 3] dist_means_correct = [[1, 2, 15], [100, 200, 300]] dist_stddevs_correct = [[math.sqrt(1), math.sqrt(2), math.sqrt(15)], [math.sqrt(100), math.sqrt(200), math.sqrt(300)]] dist_rates_correct = [[1, 2, 15], [100, 200, 300]] dist_log_probs_correct = [[sum([-1, -1.30685, -2.27852])], [sum([-3.22236, -3.56851, -3.77110])]] dist = Poisson(dist_rates_correct) dist_sample_shape = list(dist.sample().size()) dist_empirical = Empirical([dist.sample() for i in range(empirical_samples)]) dist_rates = util.to_numpy(dist.rate) dist_means = util.to_numpy(dist.mean) dist_means_empirical = util.to_numpy(dist_empirical.mean) dist_stddevs = util.to_numpy(dist.stddev) dist_stddevs_empirical = util.to_numpy(dist_empirical.stddev) dist_log_probs = util.to_numpy(dist.log_prob(dist_means_correct)) util.debug('dist_sample_shape', 'dist_sample_shape_correct', 'dist_rates', 'dist_rates_correct', 'dist_means', 'dist_means_empirical', 'dist_means_correct', 'dist_stddevs', 'dist_stddevs_empirical', 'dist_stddevs_correct', 'dist_log_probs', 'dist_log_probs_correct') self.assertEqual(dist_sample_shape, dist_sample_shape_correct) self.assertTrue(np.allclose(dist_means, dist_means_correct, atol=0.1)) self.assertTrue(np.allclose(dist_means_empirical, dist_means_correct, atol=0.25)) self.assertTrue(np.allclose(dist_stddevs, dist_stddevs_correct, atol=0.1)) self.assertTrue(np.allclose(dist_stddevs_empirical, dist_stddevs_correct, atol=0.25)) self.assertTrue(np.allclose(dist_rates, dist_rates_correct, atol=0.1)) self.assertTrue(np.allclose(dist_log_probs, dist_log_probs_correct, atol=0.1))
def test_dist_poisson_batched(self): dist_sample_shape_correct = [2, 1] dist_means_correct = [[4], [100]] dist_stddevs_correct = [[math.sqrt(4)], [math.sqrt(100)]] dist_rates_correct = [[4], [100]] dist_log_probs_correct = [[-1.63288], [-3.22236]] dist = Poisson(dist_rates_correct) dist_sample_shape = list(dist.sample().size()) dist_empirical = Empirical([dist.sample() for i in range(empirical_samples)]) dist_rates = util.to_numpy(dist.rate) dist_means = util.to_numpy(dist.mean) dist_means_empirical = util.to_numpy(dist_empirical.mean) dist_stddevs = util.to_numpy(dist.stddev) dist_stddevs_empirical = util.to_numpy(dist_empirical.stddev) dist_log_probs = util.to_numpy(dist.log_prob(dist_means_correct)) util.debug('dist_sample_shape', 'dist_sample_shape_correct', 'dist_rates', 'dist_rates_correct', 'dist_means', 'dist_means_empirical', 'dist_means_correct', 'dist_stddevs', 'dist_stddevs_empirical', 'dist_stddevs_correct', 'dist_log_probs', 'dist_log_probs_correct') self.assertEqual(dist_sample_shape, dist_sample_shape_correct) self.assertTrue(np.allclose(dist_means, dist_means_correct, atol=0.1)) self.assertTrue(np.allclose(dist_means_empirical, dist_means_correct, atol=0.1)) self.assertTrue(np.allclose(dist_stddevs, dist_stddevs_correct, atol=0.1)) self.assertTrue(np.allclose(dist_stddevs_empirical, dist_stddevs_correct, atol=0.1)) self.assertTrue(np.allclose(dist_rates, dist_rates_correct, atol=0.1)) self.assertTrue(np.allclose(dist_log_probs, dist_log_probs_correct, atol=0.1))
def test_dist_poisson_multivariate_from_flat_params(self): dist_sample_shape_correct = [1, 3] dist_means_correct = [[1, 2, 15]] dist_stddevs_correct = [[math.sqrt(1), math.sqrt(2), math.sqrt(15)]] dist_rates_correct = [[1, 2, 15]] dist_log_probs_correct = [sum([-1, -1.30685, -2.27852])] dist = Poisson(dist_rates_correct[0]) dist_sample_shape = list(dist.sample().size()) dist_empirical = Empirical([dist.sample() for i in range(empirical_samples)]) dist_rates = util.to_numpy(dist.rate) dist_means = util.to_numpy(dist.mean) dist_means_empirical = util.to_numpy(dist_empirical.mean) dist_stddevs = util.to_numpy(dist.stddev) dist_stddevs_empirical = util.to_numpy(dist_empirical.stddev) dist_log_probs = util.to_numpy(dist.log_prob(dist_means_correct)) util.debug('dist_sample_shape', 'dist_sample_shape_correct', 'dist_rates', 'dist_rates_correct', 'dist_means', 'dist_means_empirical', 'dist_means_correct', 'dist_stddevs', 'dist_stddevs_empirical', 'dist_stddevs_correct', 'dist_log_probs', 'dist_log_probs_correct') self.assertEqual(dist_sample_shape, dist_sample_shape_correct) self.assertTrue(np.allclose(dist_means, dist_means_correct, atol=0.1)) self.assertTrue(np.allclose(dist_means_empirical, dist_means_correct, atol=0.1)) self.assertTrue(np.allclose(dist_stddevs, dist_stddevs_correct, atol=0.1)) self.assertTrue(np.allclose(dist_stddevs_empirical, dist_stddevs_correct, atol=0.1)) self.assertTrue(np.allclose(dist_rates, dist_rates_correct, atol=0.1)) self.assertTrue(np.allclose(dist_log_probs, dist_log_probs_correct, atol=0.1))
def test_dist_uniform(self): dist_sample_shape_correct = [1] dist_means_correct = [0.5] dist_stddevs_correct = [0.288675] dist_lows_correct = [0] dist_highs_correct = [1] dist_log_probs_correct = [0] dist = Uniform(dist_lows_correct, dist_highs_correct) dist_sample_shape = list(dist.sample().size()) dist_empirical = Empirical([dist.sample() for i in range(empirical_samples)]) dist_lows = util.to_numpy(dist.low) dist_highs = util.to_numpy(dist.high) dist_means = util.to_numpy(dist.mean) dist_means_empirical = util.to_numpy(dist_empirical.mean) dist_stddevs = util.to_numpy(dist.stddev) dist_stddevs_empirical = util.to_numpy(dist_empirical.stddev) dist_log_probs = util.to_numpy(dist.log_prob(dist_means_correct)) util.debug('dist_sample_shape', 'dist_sample_shape_correct', 'dist_lows', 'dist_lows_correct', 'dist_highs', 'dist_highs_correct', 'dist_means', 'dist_means_empirical', 'dist_means_correct', 'dist_stddevs', 'dist_stddevs_empirical', 'dist_stddevs_correct', 'dist_log_probs', 'dist_log_probs_correct') self.assertEqual(dist_sample_shape, dist_sample_shape_correct) self.assertTrue(np.allclose(dist_means, dist_means_correct, atol=0.1)) self.assertTrue(np.allclose(dist_means_empirical, dist_means_correct, atol=0.1)) self.assertTrue(np.allclose(dist_stddevs, dist_stddevs_correct, atol=0.1)) self.assertTrue(np.allclose(dist_stddevs_empirical, dist_stddevs_correct, atol=0.1)) self.assertTrue(np.allclose(dist_lows, dist_lows_correct, atol=0.1)) self.assertTrue(np.allclose(dist_highs, dist_highs_correct, atol=0.1)) self.assertTrue(np.allclose(dist_log_probs, dist_log_probs_correct, atol=0.1))
def test_dist_uniform_batched(self): dist_sample_shape_correct = [4, 1] dist_means_correct = [[0.5], [7.5], [0.5], [0.5]] dist_stddevs_correct = [[0.288675], [1.44338], [0.288675], [0.288675]] dist_lows_correct = [[0], [5], [0], [0]] dist_highs_correct = [[1], [10], [1], [1]] dist_values = [[0.5], [7.5], [0], [1]] dist_log_probs_correct = [[0], [-1.60944], [float('-inf')], [float('-inf')]] dist = Uniform(dist_lows_correct, dist_highs_correct) dist_sample_shape = list(dist.sample().size()) dist_empirical = Empirical([dist.sample() for i in range(empirical_samples)]) dist_lows = util.to_numpy(dist.low) dist_highs = util.to_numpy(dist.high) dist_means = util.to_numpy(dist.mean) dist_means_empirical = util.to_numpy(dist_empirical.mean) dist_stddevs = util.to_numpy(dist.stddev) dist_stddevs_empirical = util.to_numpy(dist_empirical.stddev) dist_log_probs = util.to_numpy(dist.log_prob(dist_values)) util.debug('dist_sample_shape', 'dist_sample_shape_correct', 'dist_lows', 'dist_lows_correct', 'dist_highs', 'dist_highs_correct', 'dist_means', 'dist_means_empirical', 'dist_means_correct', 'dist_stddevs', 'dist_stddevs_empirical', 'dist_stddevs_correct', 'dist_values', 'dist_log_probs', 'dist_log_probs_correct') self.assertEqual(dist_sample_shape, dist_sample_shape_correct) self.assertTrue(np.allclose(dist_means, dist_means_correct, atol=0.1)) self.assertTrue(np.allclose(dist_means_empirical, dist_means_correct, atol=0.1)) self.assertTrue(np.allclose(dist_stddevs, dist_stddevs_correct, atol=0.1)) self.assertTrue(np.allclose(dist_stddevs_empirical, dist_stddevs_correct, atol=0.1)) self.assertTrue(np.allclose(dist_lows, dist_lows_correct, atol=0.1)) self.assertTrue(np.allclose(dist_highs, dist_highs_correct, atol=0.1)) self.assertTrue(np.allclose(dist_log_probs, dist_log_probs_correct, atol=0.1))
def test_dist_truncated_normal_batched(self): dist_sample_shape_correct = [2, 1] dist_means_non_truncated_correct = [[0], [2]] dist_stddevs_non_truncated_correct = [[1], [3]] dist_means_correct = [[0], [0.901189]] dist_stddevs_correct = [[0.53956], [1.95118]] dist_lows_correct = [[-1], [-4]] dist_highs_correct = [[1], [4]] dist_log_probs_correct = [[-0.537223], [-1.69563]] dist = TruncatedNormal(dist_means_non_truncated_correct, dist_stddevs_non_truncated_correct, dist_lows_correct, dist_highs_correct) dist_sample_shape = list(dist.sample().size()) dist_empirical = Empirical([dist.sample() for i in range(empirical_samples)]) dist_means = util.to_numpy(dist.mean) dist_means_empirical = util.to_numpy(dist_empirical.mean) dist_stddevs = util.to_numpy(dist.stddev) dist_stddevs_empirical = util.to_numpy(dist_empirical.stddev) dist_log_probs = util.to_numpy(dist.log_prob(dist_means_non_truncated_correct)) util.debug('dist_sample_shape', 'dist_sample_shape_correct', 'dist_means', 'dist_means_empirical', 'dist_means_correct', 'dist_stddevs', 'dist_stddevs_empirical', 'dist_stddevs_correct', 'dist_log_probs', 'dist_log_probs_correct') self.assertEqual(dist_sample_shape, dist_sample_shape_correct) self.assertTrue(np.allclose(dist_means, dist_means_correct, atol=0.1)) self.assertTrue(np.allclose(dist_means_empirical, dist_means_correct, atol=0.1)) self.assertTrue(np.allclose(dist_stddevs, dist_stddevs_correct, atol=0.1)) self.assertTrue(np.allclose(dist_stddevs_empirical, dist_stddevs_correct, atol=0.1)) self.assertTrue(np.allclose(dist_log_probs, dist_log_probs_correct, atol=0.1))
def test_dist_truncated_normal_clamped_batched(self): dist_sample_shape_correct = [2, 1] dist_means_non_truncated = [[0], [2]] dist_means_non_truncated_correct = [[0.5], [1]] dist_stddevs_non_truncated = [[1], [3]] dist_means_correct = [[0.744836], [-0.986679]] dist_stddevs_correct = [[0.143681], [1.32416]] dist_lows_correct = [[0.5], [-4]] dist_highs_correct = [[1], [1]] dist_log_prob_arguments = [[0.75], [-3]] dist_log_probs_correct = [[0.702875], [-2.11283]] dist = TruncatedNormal(dist_means_non_truncated, dist_stddevs_non_truncated, dist_lows_correct, dist_highs_correct, clamp_mean_between_low_high=True) dist_sample_shape = list(dist.sample().size()) dist_empirical = Empirical([dist.sample() for i in range(empirical_samples)]) dist_means_non_truncated = util.to_numpy(dist._mean_non_truncated) dist_means = util.to_numpy(dist.mean) dist_means_empirical = util.to_numpy(dist_empirical.mean) dist_stddevs = util.to_numpy(dist.stddev) dist_stddevs_empirical = util.to_numpy(dist_empirical.stddev) dist_log_probs = util.to_numpy(dist.log_prob(dist_log_prob_arguments)) util.debug('dist_sample_shape', 'dist_sample_shape_correct', 'dist_means_non_truncated', 'dist_means_non_truncated_correct', 'dist_means', 'dist_means_empirical', 'dist_means_correct', 'dist_stddevs', 'dist_stddevs_empirical', 'dist_stddevs_correct', 'dist_log_probs', 'dist_log_probs_correct') self.assertEqual(dist_sample_shape, dist_sample_shape_correct) self.assertTrue(np.allclose(dist_means_non_truncated, dist_means_non_truncated_correct, atol=0.1)) self.assertTrue(np.allclose(dist_means, dist_means_correct, atol=0.1)) self.assertTrue(np.allclose(dist_means_empirical, dist_means_correct, atol=0.1)) self.assertTrue(np.allclose(dist_stddevs, dist_stddevs_correct, atol=0.1)) self.assertTrue(np.allclose(dist_stddevs_empirical, dist_stddevs_correct, atol=0.1)) self.assertTrue(np.allclose(dist_log_probs, dist_log_probs_correct, atol=0.1))
def test_dist_empirical(self): values = Variable(util.Tensor([1, 2, 3])) log_weights = Variable(util.Tensor([1, 2, 3])) dist_mean_correct = 2.5752103328704834 dist_stddev_correct = 0.6514633893966675 dist_expectation_sin_correct = 0.3921678960323334 dist_map_sin_mean_correct = 0.3921678960323334 dist_min_correct = 1 dist_max_correct = 3 # dist_sample_shape_correct = [] dist = Empirical(values, log_weights) dist_empirical = Empirical([dist.sample() for i in range(empirical_samples)]) dist_mean = float(dist.mean) dist_mean_empirical = float(dist_empirical.mean) dist_stddev = float(dist.stddev) dist_stddev_empirical = float(dist_empirical.stddev) dist_expectation_sin = float(dist.expectation(torch.sin)) dist_map_sin_mean = float(dist.map(torch.sin).mean) dist_min = float(dist.min) dist_max = float(dist.max) dist_sample_shape = list(dist.sample().size()) util.debug('dist_mean', 'dist_mean_empirical', 'dist_mean_correct', 'dist_stddev', 'dist_stddev_empirical', 'dist_stddev_correct', 'dist_expectation_sin', 'dist_expectation_sin_correct', 'dist_map_sin_mean', 'dist_map_sin_mean_correct', 'dist_min', 'dist_min_correct', 'dist_max', 'dist_max_correct') # self.assertEqual(dist_sample_shape, dist_sample_shape_correct) self.assertAlmostEqual(dist_mean, dist_mean_correct, places=1) self.assertAlmostEqual(dist_mean_empirical, dist_mean_correct, places=1) self.assertAlmostEqual(dist_stddev, dist_stddev_correct, places=1) self.assertAlmostEqual(dist_stddev_empirical, dist_stddev_correct, places=1) self.assertAlmostEqual(dist_expectation_sin, dist_expectation_sin_correct, places=1) self.assertAlmostEqual(dist_map_sin_mean, dist_map_sin_mean_correct, places=1) self.assertAlmostEqual(dist_min, dist_min_correct, places=1) self.assertAlmostEqual(dist_max, dist_max_correct, places=1)
def test_dist_mixture_batched(self): dist_sample_shape_correct = [2, 1] dist_1 = Normal([[0], [1]], [[0.1], [1]]) dist_2 = Normal([[2], [5]], [[0.1], [1]]) dist_3 = Normal([[3], [10]], [[0.1], [1]]) dist_means_correct = [[0.7], [8.1]] dist_stddevs_correct = [[1.10454], [3.23883]] dist_log_probs_correct = [[-23.473], [-3.06649]] dist = Mixture([dist_1, dist_2, dist_3], probs=[[0.7, 0.2, 0.1], [0.1, 0.2, 0.7]]) dist_sample_shape = list(dist.sample().size()) dist_empirical = Empirical([dist.sample() for i in range(empirical_samples)]) dist_means = util.to_numpy(dist.mean) dist_means_empirical = util.to_numpy(dist_empirical.mean) dist_stddevs = util.to_numpy(dist.stddev) dist_stddevs_empirical = util.to_numpy(dist_empirical.stddev) dist_log_probs = util.to_numpy(dist.log_prob(dist_means_correct)) util.debug('dist_sample_shape', 'dist_sample_shape_correct', 'dist_means', 'dist_means_empirical', 'dist_means_correct', 'dist_stddevs', 'dist_stddevs_empirical', 'dist_stddevs_correct', 'dist_log_probs', 'dist_log_probs_correct') self.assertEqual(dist_sample_shape, dist_sample_shape_correct) self.assertTrue(np.allclose(dist_means, dist_means_correct, atol=0.1)) self.assertTrue(np.allclose(dist_means_empirical, dist_means_correct, atol=0.1)) self.assertTrue(np.allclose(dist_stddevs, dist_stddevs_correct, atol=0.1)) self.assertTrue(np.allclose(dist_stddevs_empirical, dist_stddevs_correct, atol=0.1)) self.assertTrue(np.allclose(dist_log_probs, dist_log_probs_correct, atol=0.1))
def test_dist_mixture(self): dist_sample_shape_correct = [1] dist_1 = Normal(0, 0.1) dist_2 = Normal(2, 0.1) dist_3 = Normal(3, 0.1) dist_means_correct = [0.7] dist_stddevs_correct = [1.10454] dist_log_probs_correct = [-23.473] dist = Mixture([dist_1, dist_2, dist_3], probs=[0.7, 0.2, 0.1]) dist_sample_shape = list(dist.sample().size()) dist_empirical = Empirical([dist.sample() for i in range(empirical_samples)]) dist_means = util.to_numpy(dist.mean) dist_means_empirical = util.to_numpy(dist_empirical.mean) dist_stddevs = util.to_numpy(dist.stddev) dist_stddevs_empirical = util.to_numpy(dist_empirical.stddev) dist_log_probs = util.to_numpy(dist.log_prob(dist_means_correct)) # print(dist.log_prob([2,2])) util.debug('dist_sample_shape', 'dist_sample_shape_correct', 'dist_means', 'dist_means_empirical', 'dist_means_correct', 'dist_stddevs', 'dist_stddevs_empirical', 'dist_stddevs_correct', 'dist_log_probs', 'dist_log_probs_correct') self.assertEqual(dist_sample_shape, dist_sample_shape_correct) self.assertTrue(np.allclose(dist_means, dist_means_correct, atol=0.1)) self.assertTrue(np.allclose(dist_means_empirical, dist_means_correct, atol=0.1)) self.assertTrue(np.allclose(dist_stddevs, dist_stddevs_correct, atol=0.1)) self.assertTrue(np.allclose(dist_stddevs_empirical, dist_stddevs_correct, atol=0.1)) self.assertTrue(np.allclose(dist_log_probs, dist_log_probs_correct, atol=0.1))
def test_dist_kumaraswamy(self): dist_sample_shape_correct = [1] dist_shape1s_correct = [2] dist_shape2s_correct = [5] dist_means_correct = [0.369408] dist_stddevs_correct = [0.173793] dist_log_probs_correct = [0.719861] dist = Kumaraswamy(dist_shape1s_correct, dist_shape2s_correct) dist_sample_shape = list(dist.sample().size()) dist_empirical = Empirical([dist.sample() for i in range(empirical_samples)]) dist_means = util.to_numpy(dist.mean) dist_means_empirical = util.to_numpy(dist_empirical.mean) dist_stddevs = util.to_numpy(dist.stddev) dist_stddevs_empirical = util.to_numpy(dist_empirical.stddev) dist_log_probs = util.to_numpy(dist.log_prob(dist_means_correct)) util.debug('dist_sample_shape', 'dist_sample_shape_correct', 'dist_means', 'dist_means_empirical', 'dist_means_correct', 'dist_stddevs', 'dist_stddevs_empirical', 'dist_stddevs_correct', 'dist_log_probs', 'dist_log_probs_correct') self.assertEqual(dist_sample_shape, dist_sample_shape_correct) self.assertTrue(np.allclose(dist_means, dist_means_correct, atol=0.1)) self.assertTrue(np.allclose(dist_means_empirical, dist_means_correct, atol=0.1)) self.assertTrue(np.allclose(dist_stddevs, dist_stddevs_correct, atol=0.1)) self.assertTrue(np.allclose(dist_stddevs_empirical, dist_stddevs_correct, atol=0.1)) self.assertTrue(np.allclose(dist_log_probs, dist_log_probs_correct, atol=0.1))
def test_dist_empirical_save_load(self): file_name = os.path.join(tempfile.mkdtemp(), str(uuid.uuid4())) values = Variable(util.Tensor([1, 2, 3])) log_weights = Variable(util.Tensor([1, 2, 3])) dist_mean_correct = 2.5752103328704834 dist_stddev_correct = 0.6514633893966675 dist_expectation_sin_correct = 0.3921678960323334 dist_map_sin_mean_correct = 0.3921678960323334 dist_on_file = Empirical(values, log_weights) dist_on_file.save(file_name) dist = Distribution.load(file_name) os.remove(file_name) dist_empirical = Empirical([dist.sample() for i in range(empirical_samples)]) dist_mean = float(dist.mean) dist_mean_empirical = float(dist_empirical.mean) dist_stddev = float(dist.stddev) dist_stddev_empirical = float(dist_empirical.stddev) dist_expectation_sin = float(dist.expectation(torch.sin)) dist_map_sin_mean = float(dist.map(torch.sin).mean) util.debug('file_name', 'dist_mean', 'dist_mean_empirical', 'dist_mean_correct', 'dist_stddev', 'dist_stddev_empirical', 'dist_stddev_correct', 'dist_expectation_sin', 'dist_expectation_sin_correct', 'dist_map_sin_mean', 'dist_map_sin_mean_correct') self.assertAlmostEqual(dist_mean, dist_mean_correct, places=1) self.assertAlmostEqual(dist_mean_empirical, dist_mean_correct, places=1) self.assertAlmostEqual(dist_stddev, dist_stddev_correct, places=1) self.assertAlmostEqual(dist_stddev_empirical, dist_stddev_correct, places=1) self.assertAlmostEqual(dist_expectation_sin, dist_expectation_sin_correct, places=1) self.assertAlmostEqual(dist_map_sin_mean, dist_map_sin_mean_correct, places=1)
def test_dist_normal_multivariate_batched(self): dist_sample_shape_correct = [2, 3] dist_means_correct = [[0, 2, 0], [2, 0, 2]] dist_stddevs_correct = [[1, 3, 1], [3, 1, 3]] dist_log_probs_correct = [[sum([-0.918939, -2.01755, -0.918939])], [sum([-2.01755, -0.918939, -2.01755])]] dist = Normal(dist_means_correct, dist_stddevs_correct) dist_sample_shape = list(dist.sample().size()) dist_empirical = Empirical( [dist.sample() for i in range(empirical_samples)]) dist_means = util.to_numpy(dist.mean) dist_means_empirical = util.to_numpy(dist_empirical.mean) dist_stddevs = util.to_numpy(dist.stddev) dist_stddevs_empirical = util.to_numpy(dist_empirical.stddev) dist_log_probs = util.to_numpy(dist.log_prob(dist_means_correct)) util.debug('dist_sample_shape', 'dist_sample_shape_correct', 'dist_means', 'dist_means_empirical', 'dist_means_correct', 'dist_stddevs', 'dist_stddevs_empirical', 'dist_stddevs_correct', 'dist_log_probs', 'dist_log_probs_correct') self.assertEqual(dist_sample_shape, dist_sample_shape_correct) self.assertTrue(np.allclose(dist_means, dist_means_correct, atol=0.1)) self.assertTrue( np.allclose(dist_means_empirical, dist_means_correct, atol=0.1)) self.assertTrue( np.allclose(dist_stddevs, dist_stddevs_correct, atol=0.1)) self.assertTrue( np.allclose(dist_stddevs_empirical, dist_stddevs_correct, atol=0.1)) self.assertTrue( np.allclose(dist_log_probs, dist_log_probs_correct, atol=0.1))
def test_model_remote_with_replacement_trace_length_statistics(self): samples = 2000 trace_length_mean_correct = 2 trace_length_stddev_correct = 0 trace_length_min_correct = 2 trace_length_max_correct = 2 trace_length_mean = float(self._model.trace_length_mean(samples)) trace_length_stddev = float(self._model.trace_length_stddev(samples)) trace_length_min = float(self._model.trace_length_min(samples)) trace_length_max = float(self._model.trace_length_max(samples)) util.debug('samples', 'trace_length_mean', 'trace_length_mean_correct', 'trace_length_stddev', 'trace_length_stddev_correct', 'trace_length_min', 'trace_length_min_correct', 'trace_length_max', 'trace_length_max_correct') self.assertAlmostEqual(trace_length_mean, trace_length_mean_correct, places=0) self.assertAlmostEqual(trace_length_stddev, trace_length_stddev_correct, places=0) self.assertAlmostEqual(trace_length_min, trace_length_min_correct, places=0) self.assertAlmostEqual(trace_length_max, trace_length_max_correct, places=0)
def test_model_remote_trace_length_statistics(self): samples = 2000 trace_length_mean_correct = 2.5630438327789307 trace_length_stddev_correct = 1.2081329822540283 trace_length_min_correct = 2 trace_length_mean = float(self._model.trace_length_mean(samples)) trace_length_stddev = float(self._model.trace_length_stddev(samples)) trace_length_min = float(self._model.trace_length_min(samples)) trace_length_max = float(self._model.trace_length_max(samples)) util.debug('samples', 'trace_length_mean', 'trace_length_mean_correct', 'trace_length_stddev', 'trace_length_stddev_correct', 'trace_length_min', 'trace_length_min_correct', 'trace_length_max') self.assertAlmostEqual(trace_length_mean, trace_length_mean_correct, places=0) self.assertAlmostEqual(trace_length_stddev, trace_length_stddev_correct, places=0) self.assertAlmostEqual(trace_length_min, trace_length_min_correct, places=0)
def test_inference_gum_marsaglia_posterior_inference_compilation(self): observation = [8, 9] posterior_mean_correct = 7.25 posterior_stddev_correct = math.sqrt(1 / 1.2) self._model.learn_inference_network(observation=[1, 1], early_stop_traces=training_traces) posterior = self._model.posterior_distribution( samples, use_inference_network=True, observation=observation) posterior_mean = float(posterior.mean) posterior_mean_unweighted = float(posterior.mean_unweighted) posterior_stddev = float(posterior.stddev) posterior_stddev_unweighted = float(posterior.stddev_unweighted) kl_divergence = float( util.kl_divergence_normal(posterior_mean_correct, posterior_stddev_correct, posterior.mean, posterior_stddev)) util.debug('training_traces', 'samples', 'posterior_mean_unweighted', 'posterior_mean', 'posterior_mean_correct', 'posterior_stddev_unweighted', 'posterior_stddev', 'posterior_stddev_correct', 'kl_divergence') add_perf_score_inference_compilation(kl_divergence) self.assertAlmostEqual(posterior_mean, posterior_mean_correct, places=0) self.assertAlmostEqual(posterior_stddev, posterior_stddev_correct, places=0) self.assertLess(kl_divergence, 0.25)
def test_dist_kumaraswamy_batched(self): dist_sample_shape_correct = [4, 1] dist_shape1s_correct = [[0.5], [7.5], [7.5], [7.5]] dist_shape2s_correct = [[0.75], [2.5], [2.5], [2.5]] dist_means_correct = [[0.415584], [0.807999], [0.807999], [0.807999]] dist_stddevs_correct = [[0.327509], [0.111605], [0.111605], [0.111605]] dist_values = [[0.415584], [0.807999], [0.], [1.]] dist_log_probs_correct = [[-0.283125], [1.20676], [float('-inf')], [float('-inf')]] dist = Kumaraswamy(dist_shape1s_correct, dist_shape2s_correct) dist_sample_shape = list(dist.sample().size()) dist_empirical = Empirical([dist.sample() for i in range(empirical_samples)]) dist_means = util.to_numpy(dist.mean) dist_means_empirical = util.to_numpy(dist_empirical.mean) dist_stddevs = util.to_numpy(dist.stddev) dist_stddevs_empirical = util.to_numpy(dist_empirical.stddev) dist_log_probs = util.to_numpy(dist.log_prob(dist_values)) util.debug('dist_sample_shape', 'dist_sample_shape_correct', 'dist_means', 'dist_means_empirical', 'dist_means_correct', 'dist_stddevs', 'dist_stddevs_empirical', 'dist_stddevs_correct', 'dist_values', 'dist_log_probs', 'dist_log_probs_correct') self.assertEqual(dist_sample_shape, dist_sample_shape_correct) self.assertTrue(np.allclose(dist_means, dist_means_correct, atol=0.1)) self.assertTrue(np.allclose(dist_means_empirical, dist_means_correct, atol=0.1)) self.assertTrue(np.allclose(dist_stddevs, dist_stddevs_correct, atol=0.1)) self.assertTrue(np.allclose(dist_stddevs_empirical, dist_stddevs_correct, atol=0.1)) self.assertTrue(np.allclose(dist_log_probs, dist_log_probs_correct, atol=0.1))
def test_model_remote_gum_marsaglia_posterior_random_walk_metropolis_hastings( self): observation = [8, 9] posterior_mean_correct = 7.25 posterior_stddev_correct = math.sqrt(1 / 1.2) posterior = self._model.posterior_distribution( samples, inference_engine=pyprob.InferenceEngine. RANDOM_WALK_METROPOLIS_HASTINGS, observation=observation) posterior_mean = float(posterior.mean) posterior_mean_unweighted = float(posterior.unweighted().mean) posterior_stddev = float(posterior.stddev) posterior_stddev_unweighted = float(posterior.unweighted().stddev) kl_divergence = float( util.kl_divergence_normal( Normal(posterior_mean_correct, posterior_stddev_correct), Normal(posterior.mean, posterior_stddev))) util.debug('samples', 'posterior_mean_unweighted', 'posterior_mean', 'posterior_mean_correct', 'posterior_stddev_unweighted', 'posterior_stddev', 'posterior_stddev_correct', 'kl_divergence') self.assertAlmostEqual(posterior_mean, posterior_mean_correct, places=0) self.assertAlmostEqual(posterior_stddev, posterior_stddev_correct, places=0) self.assertLess(kl_divergence, 0.25)
def test_dist_normal(self): dist_sample_shape_correct = [1] dist_means_correct = [0] dist_stddevs_correct = [1] dist_log_probs_correct = [-0.918939] dist = Normal(dist_means_correct, dist_stddevs_correct) dist_sample_shape = list(dist.sample().size()) dist_empirical = Empirical( [dist.sample() for i in range(empirical_samples)]) dist_means = util.to_numpy(dist.mean) dist_means_empirical = util.to_numpy(dist_empirical.mean) dist_stddevs = util.to_numpy(dist.stddev) dist_stddevs_empirical = util.to_numpy(dist_empirical.stddev) dist_log_probs = util.to_numpy(dist.log_prob(dist_means_correct)) util.debug('dist_sample_shape', 'dist_sample_shape_correct', 'dist_means', 'dist_means_empirical', 'dist_means_correct', 'dist_stddevs', 'dist_stddevs_empirical', 'dist_stddevs_correct', 'dist_log_probs', 'dist_log_probs_correct') self.assertEqual(dist_sample_shape, dist_sample_shape_correct) self.assertTrue(np.allclose(dist_means, dist_means_correct, atol=0.1)) self.assertTrue( np.allclose(dist_means_empirical, dist_means_correct, atol=0.1)) self.assertTrue( np.allclose(dist_stddevs, dist_stddevs_correct, atol=0.1)) self.assertTrue( np.allclose(dist_stddevs_empirical, dist_stddevs_correct, atol=0.1)) self.assertTrue( np.allclose(dist_log_probs, dist_log_probs_correct, atol=0.1))
def test_inference_gum_marsaglia_posterior_importance_sampling(self): observation = [8, 9] posterior_mean_correct = 7.25 posterior_stddev_correct = math.sqrt(1 / 1.2) posterior = self._model.posterior_distribution(samples, observation=observation) posterior_mean = float(posterior.mean) posterior_mean_unweighted = float(posterior.mean_unweighted) posterior_stddev = float(posterior.stddev) posterior_stddev_unweighted = float(posterior.stddev_unweighted) kl_divergence = float( util.kl_divergence_normal(posterior_mean_correct, posterior_stddev_correct, posterior.mean, posterior_stddev)) util.debug('samples', 'posterior_mean_unweighted', 'posterior_mean', 'posterior_mean_correct', 'posterior_stddev_unweighted', 'posterior_stddev', 'posterior_stddev_correct', 'kl_divergence') add_perf_score_importance_sampling(kl_divergence) self.assertAlmostEqual(posterior_mean, posterior_mean_correct, places=0) self.assertAlmostEqual(posterior_stddev, posterior_stddev_correct, places=0) self.assertLess(kl_divergence, 0.25)
def test_ObserveEmbeddingConvNet3D4C(self): batch_size = 32 channels = 3 output_dim = 128 input_batch_shape = [batch_size, channels, 16, 16, 16] output_batch_shape_correct = [batch_size, output_dim] input_non_batch_shape = [channels, 16, 16, 16] output_non_batch_shape_correct = [1, output_dim] input_batch = Variable(util.Tensor(torch.Size(input_batch_shape))) input_non_batch = Variable( util.Tensor(torch.Size(input_non_batch_shape))) nn = ObserveEmbeddingConvNet3D4C( input_example_non_batch=input_non_batch, output_dim=output_dim) nn.configure() output_batch_shape = list(nn.forward(input_batch).size()) output_non_batch_shape = list( nn.forward(input_non_batch.unsqueeze(0)).size()) util.debug('batch_size', 'channels', 'output_dim', 'input_batch_shape', 'output_batch_shape', 'output_batch_shape_correct', 'input_non_batch_shape', 'output_non_batch_shape', 'output_non_batch_shape_correct') self.assertEqual(output_batch_shape, output_batch_shape_correct) self.assertEqual(output_non_batch_shape, output_non_batch_shape_correct)
def test_model_remote_gum_marsaglia_prior(self): prior_mean_correct = 1 prior_stddev_correct = math.sqrt(5) prior = self._model.prior_distribution(samples) prior_mean = float(prior.mean) prior_stddev = float(prior.stddev) util.debug('samples', 'prior_mean', 'prior_mean_correct', 'prior_stddev', 'prior_stddev_correct') self.assertAlmostEqual(prior_mean, prior_mean_correct, places=0) self.assertAlmostEqual(prior_stddev, prior_stddev_correct, places=0)
def test_model_remote_set_defaults_and_addresses_prior(self): prior_mean_correct = 1 prior_stddev_correct = 3.882074 # Estimate from 100k samples prior = self._model.prior_distribution(samples) prior_mean = float(prior.mean) prior_stddev = float(prior.stddev) util.debug('samples', 'prior_mean', 'prior_mean_correct', 'prior_stddev', 'prior_stddev_correct') self.assertAlmostEqual(prior_mean, prior_mean_correct, places=0) self.assertAlmostEqual(prior_stddev, prior_stddev_correct, places=0)
def test_model_remote_gum_marsaglia_train_save_load(self): file_name = os.path.join(tempfile.mkdtemp(), str(uuid.uuid4())) self._model.learn_inference_network(observation=[1, 1], num_traces=training_traces) self._model.save_inference_network(file_name) self._model.load_inference_network(file_name) os.remove(file_name) util.debug('training_traces', 'file_name') self.assertTrue(True)
def test_dist_categorical(self): dist_sample_shape_correct = [1] dist_log_probs_correct = [-2.30259] dist = Categorical([0.1, 0.2, 0.7]) dist_sample_shape = list(dist.sample().size()) dist_log_probs = util.to_numpy(dist.log_prob(0)) util.debug('dist_sample_shape', 'dist_sample_shape_correct', 'dist_log_probs', 'dist_log_probs_correct') self.assertEqual(dist_sample_shape, dist_sample_shape_correct) self.assertTrue(np.allclose(dist_log_probs, dist_log_probs_correct, atol=0.1))
def test_dist_empirical_resample(self): dist_means_correct = [2] dist_stddevs_correct = [5] dist = Normal(dist_means_correct, dist_stddevs_correct) dist_empirical = Empirical([dist.sample() for i in range(empirical_samples)]) dist_empirical = dist_empirical.resample(int(empirical_samples/2)) dist_means_empirical = util.to_numpy(dist_empirical.mean) dist_stddevs_empirical = util.to_numpy(dist_empirical.stddev) util.debug('dist_means_empirical', 'dist_means_correct', 'dist_stddevs_empirical', 'dist_stddevs_correct') self.assertTrue(np.allclose(dist_means_empirical, dist_means_correct, atol=0.25)) self.assertTrue(np.allclose(dist_stddevs_empirical, dist_stddevs_correct, atol=0.25))
def test_dist_categorical_batched(self): dist_sample_shape_correct = [2] dist_log_probs_correct = [[-2.30259], [-0.693147]] dist = Categorical([[0.1, 0.2, 0.7], [0.2, 0.5, 0.3]]) dist_sample_shape = list(dist.sample().size()) dist_log_probs = util.to_numpy(dist.log_prob([[0, 1]])) util.debug('dist_sample_shape', 'dist_sample_shape_correct', 'dist_log_probs', 'dist_log_probs_correct') self.assertEqual(dist_sample_shape, dist_sample_shape_correct) self.assertTrue(np.allclose(dist_log_probs, dist_log_probs_correct, atol=0.1))
def test_trace_controlled_uncontrolled_observed(self): controlled_correct = 2 uncontrolled_correct = 3 observed_correct = 4 trace = self._model._traces(1)[0] controlled = len(trace.samples) uncontrolled = len(trace.samples_uncontrolled) observed = len(trace.samples_observed) util.debug('controlled', 'controlled_correct', 'uncontrolled', 'uncontrolled_correct', 'observed', 'observed_correct') self.assertEqual(controlled, controlled_correct) self.assertEqual(uncontrolled, uncontrolled_correct) self.assertEqual(observed, observed_correct)
def test_inference_hmm_posterior_importance_sampling(self): observation = self._observation posterior_mean_correct = self._posterior_mean_correct posterior = self._model.posterior_distribution(samples, observation=observation) posterior_mean_unweighted = posterior.mean_unweighted posterior_mean = posterior.mean l2_distance = float( F.pairwise_distance(posterior_mean, posterior_mean_correct).sum()) util.debug('samples', 'posterior_mean_unweighted', 'posterior_mean', 'posterior_mean_correct', 'l2_distance') add_perf_score_importance_sampling(l2_distance) self.assertLess(l2_distance, 6)
def test_inference_hmm_posterior_inference_compilation(self): observation = self._observation posterior_mean_correct = self._posterior_mean_correct self._model.learn_inference_network(observation=torch.zeros(16, 3), early_stop_traces=training_traces) posterior = self._model.posterior_distribution( samples, use_inference_network=True, observation=observation) posterior_mean_unweighted = posterior.mean_unweighted posterior_mean = posterior.mean l2_distance = float( F.pairwise_distance(posterior_mean, posterior_mean_correct).sum()) util.debug('samples', 'posterior_mean_unweighted', 'posterior_mean', 'posterior_mean_correct', 'l2_distance') add_perf_score_inference_compilation(l2_distance) self.assertLess(l2_distance, 6)