def test_add_tie_specify_name(self): tied = prior.Uniform(-5, 5) sphere = Sphere(n=prior.Uniform(1, 2), r=prior.Uniform(1, 2), center=[tied, tied, 10]) model = AlphaModel(sphere) model.add_tie(['r', 'n'], new_name='dummy') expected = ['dummy', 'center.0'] self.assertEqual(model._parameter_names, expected)
def test_add_tie_updates_parameter_names(self): tied = prior.Uniform(-5, 5) sphere = Sphere(n=prior.Uniform(1, 2), r=prior.Uniform(1, 2), center=[tied, tied, 10]) model = AlphaModel(sphere) model.add_tie(['r', 'n']) expected = ['n', 'center.0'] self.assertEqual(model._parameter_names, expected)
def test_add_tie_updates_map(self): tied = prior.Uniform(-5, 5) sphere = Sphere(n=prior.Uniform(1, 2), r=prior.Uniform(1, 2), center=[tied, tied, 10]) model = AlphaModel(sphere) model.add_tie(['r', 'n']) expected = [dict, [[['n', '_parameter_0'], ['r', '_parameter_0'], ['center', ['_parameter_1', '_parameter_1', 10]]]]] self.assertEqual(model._maps['scatterer'], expected)
def test_yaml_preserves_parameter_ties(self): tied = prior.Uniform(0, 1) sphere = Sphere(n=tied, r=prior.Uniform(0.6, 1, name='radius'), center=[prior.Uniform(0.6, 1), tied, 10]) alpha = {'r': 0.6, 'g': prior.Uniform(0.8, 0.9)} model = AlphaModel(sphere, alpha=alpha) model.add_tie(['radius', 'center.0']) post_model = take_yaml_round_trip(model) self.assertEqual(model.parameters, post_model.parameters)
def test_add_3_way_tie(self): tied = prior.Uniform(-5, 5) n = prior.ComplexPrior(prior.Uniform(1, 2), prior.Uniform(0, 1)) sphere = Sphere(n=n, r=prior.Uniform(0.5, 1), center=[prior.Uniform(0, 1), prior.Uniform(0, 1), prior.Uniform(0, 10)]) model = AlphaModel(sphere) model.add_tie(['center.0', 'n.imag', 'center.1']) expected_map = [ dict, [[['n', [transformed_prior, [complex, ['_parameter_0', '_parameter_1']]]], ['r', '_parameter_2'], ['center', ['_parameter_1', '_parameter_1', '_parameter_3']]]]] expected_parameters = [prior.Uniform(1, 2), prior.Uniform(0, 1), prior.Uniform(0.5, 1), prior.Uniform(0, 10)] expected_names = ['n.real', 'n.imag', 'r', 'center.2'] self.assertEqual(model._maps['scatterer'], expected_map) self.assertEqual(model._parameters, expected_parameters) self.assertEqual(model._parameter_names, expected_names)