def test_dist_truncated_normal_clamped_batched(self): dist_sample_shape_correct = [2, 1] dist_means_non_truncated = [[0], [2]] dist_means_non_truncated_correct = [[0.5], [1]] dist_stddevs_non_truncated = [[1], [3]] dist_means_correct = [[0.744836], [-0.986679]] dist_stddevs_correct = [[0.143681], [1.32416]] dist_lows_correct = [[0.5], [-4]] dist_highs_correct = [[1], [1]] dist_log_prob_arguments = [[0.75], [-3]] dist_log_probs_correct = [[0.702875], [-2.11283]] dist = TruncatedNormal(dist_means_non_truncated, dist_stddevs_non_truncated, dist_lows_correct, dist_highs_correct, clamp_mean_between_low_high=True) dist_sample_shape = list(dist.sample().size()) dist_empirical = Empirical([dist.sample() for i in range(empirical_samples)]) dist_means_non_truncated = util.to_numpy(dist._mean_non_truncated) dist_means = util.to_numpy(dist.mean) dist_means_empirical = util.to_numpy(dist_empirical.mean) dist_stddevs = util.to_numpy(dist.stddev) dist_stddevs_empirical = util.to_numpy(dist_empirical.stddev) dist_log_probs = util.to_numpy(dist.log_prob(dist_log_prob_arguments)) util.debug('dist_sample_shape', 'dist_sample_shape_correct', 'dist_means_non_truncated', 'dist_means_non_truncated_correct', 'dist_means', 'dist_means_empirical', 'dist_means_correct', 'dist_stddevs', 'dist_stddevs_empirical', 'dist_stddevs_correct', 'dist_log_probs', 'dist_log_probs_correct') self.assertEqual(dist_sample_shape, dist_sample_shape_correct) self.assertTrue(np.allclose(dist_means_non_truncated, dist_means_non_truncated_correct, atol=0.1)) self.assertTrue(np.allclose(dist_means, dist_means_correct, atol=0.1)) self.assertTrue(np.allclose(dist_means_empirical, dist_means_correct, atol=0.1)) self.assertTrue(np.allclose(dist_stddevs, dist_stddevs_correct, atol=0.1)) self.assertTrue(np.allclose(dist_stddevs_empirical, dist_stddevs_correct, atol=0.1)) self.assertTrue(np.allclose(dist_log_probs, dist_log_probs_correct, atol=0.1))
def test_dist_truncated_normal_batched(self): dist_sample_shape_correct = [2, 1] dist_means_non_truncated_correct = [[0], [2]] dist_stddevs_non_truncated_correct = [[1], [3]] dist_means_correct = [[0], [0.901189]] dist_stddevs_correct = [[0.53956], [1.95118]] dist_lows_correct = [[-1], [-4]] dist_highs_correct = [[1], [4]] dist_log_probs_correct = [[-0.537223], [-1.69563]] dist = TruncatedNormal(dist_means_non_truncated_correct, dist_stddevs_non_truncated_correct, dist_lows_correct, dist_highs_correct) dist_sample_shape = list(dist.sample().size()) dist_empirical = Empirical([dist.sample() for i in range(empirical_samples)]) dist_means = util.to_numpy(dist.mean) dist_means_empirical = util.to_numpy(dist_empirical.mean) dist_stddevs = util.to_numpy(dist.stddev) dist_stddevs_empirical = util.to_numpy(dist_empirical.stddev) dist_log_probs = util.to_numpy(dist.log_prob(dist_means_non_truncated_correct)) util.debug('dist_sample_shape', 'dist_sample_shape_correct', 'dist_means', 'dist_means_empirical', 'dist_means_correct', 'dist_stddevs', 'dist_stddevs_empirical', 'dist_stddevs_correct', 'dist_log_probs', 'dist_log_probs_correct') self.assertEqual(dist_sample_shape, dist_sample_shape_correct) self.assertTrue(np.allclose(dist_means, dist_means_correct, atol=0.1)) self.assertTrue(np.allclose(dist_means_empirical, dist_means_correct, atol=0.1)) self.assertTrue(np.allclose(dist_stddevs, dist_stddevs_correct, atol=0.1)) self.assertTrue(np.allclose(dist_stddevs_empirical, dist_stddevs_correct, atol=0.1)) self.assertTrue(np.allclose(dist_log_probs, dist_log_probs_correct, atol=0.1))