def test_post_rescaling_with_logit_update_bounds(reparam): """Assert an error is raised if using logit and update bounds""" reparam._update_bounds = True rescaling = 'logit' with pytest.raises(RuntimeError) as excinfo: RescaleToBounds.configure_post_rescaling(reparam, rescaling) assert 'Cannot use logit with update bounds' in str(excinfo.value)
def test_post_rescaling_with_functions(reparam): """Assert that specifying functions works as intended""" rescaling = (np.exp, np.log) RescaleToBounds.configure_post_rescaling(reparam, rescaling) assert reparam.has_post_rescaling is True assert reparam.has_prime_prior is False assert reparam.post_rescaling is np.exp assert reparam.post_rescaling_inv is np.log
def test_post_rescaling_with_str(reparam): """Assert that specifying a str works as intended. Also test the config for the logit """ reparam._update_bounds = False reparam.parameters = ['x'] from nessai.utils.rescaling import rescaling_functions rescaling = 'logit' RescaleToBounds.configure_post_rescaling(reparam, rescaling) assert reparam.has_post_rescaling is True assert reparam.has_prime_prior is False assert reparam.post_rescaling is rescaling_functions['logit'][0] assert reparam.post_rescaling_inv is rescaling_functions['logit'][1] assert reparam.rescale_bounds == {'x': [0, 1]}
def test_post_rescaling_invalid_input(reparam): """Assert an error is raised if the input isn't a str or tuple""" with pytest.raises(RuntimeError) as excinfo: RescaleToBounds.configure_post_rescaling(reparam, (np.exp, )) assert 'Post-rescaling must be a str or tuple' in str(excinfo.value)
def test_post_rescaling_with_invalid_str(reparam): """Assert an error is raised if the rescaling is not recognised""" rescaling = 'not_a_rescaling' with pytest.raises(RuntimeError) as excinfo: RescaleToBounds.configure_post_rescaling(reparam, rescaling) assert 'Unknown rescaling function: not_a_rescaling' in str(excinfo.value)
def test_configure_post_rescaling_none(reparam): """Test the configuration of the post-rescaling if it is None""" RescaleToBounds.configure_post_rescaling(reparam, None) assert reparam.has_post_rescaling is False