def test_ignores_trivial_names(self): # Should ignore a trivial reference downstream of the real name `z`. dist_names = joint_distribution_sequential._resolve_distribution_names( dist_fn_args=[None, ['z'], ['w', '_']], dist_names=None, leaf_name='y', instance_names=[None, None, None]) self.assertAllEqual(dist_names, ['z', 'w', 'y']) # Trivial reference upstream of the real name `z`. dist_names = joint_distribution_sequential._resolve_distribution_names( dist_fn_args=[None, ['_'], ['w', 'z']], dist_names=None, leaf_name='y', instance_names=[None, None, None]) self.assertAllEqual(dist_names, ['z', 'w', 'y']) # The only direct reference is trivial, but we have an instance name. dist_names = joint_distribution_sequential._resolve_distribution_names( dist_fn_args=[None, ['_']], dist_names=None, leaf_name='y', instance_names=['z', None]) self.assertAllEqual(dist_names, ['z', 'y'])
def test_dummy_names_are_unique(self): dist_names = joint_distribution_sequential._resolve_distribution_names( dist_fn_args=[None, None, None], dist_names=None, leaf_name='x', instance_names=[None, None, None]) self.assertAllEqual(dist_names, ['x2', 'x1', 'x']) dist_names = joint_distribution_sequential._resolve_distribution_names( dist_fn_args=[None, None, None], dist_names=None, leaf_name='x', instance_names=['x', 'x1', None]) self.assertAllEqual(dist_names, ['x', 'x1', 'x2'])
def test_inconsistent_names_raise_error(self): with self.assertRaisesRegexp(ValueError, 'Inconsistent names'): # Refers to first variable as both `z` and `x`. joint_distribution_sequential._resolve_distribution_names( dist_fn_args=[None, ['z'], ['x', 'w']], dist_names=None, leaf_name='y', instance_names=[None, None, None]) with self.assertRaisesRegexp(ValueError, 'Inconsistent names'): # Refers to first variable as `x`, but it was explicitly named `z`. joint_distribution_sequential._resolve_distribution_names( dist_fn_args=[None, ['x']], dist_names=None, leaf_name='y', instance_names=['z', None])