Exemplo n.º 1
0
def test_post_rescaling_with_logit_update_bounds(reparam):
    """Assert an error is raised if using logit and update bounds"""
    reparam._update_bounds = True
    rescaling = 'logit'
    with pytest.raises(RuntimeError) as excinfo:
        RescaleToBounds.configure_post_rescaling(reparam, rescaling)
    assert 'Cannot use logit with update bounds' in str(excinfo.value)
Exemplo n.º 2
0
def test_post_rescaling_with_functions(reparam):
    """Assert that specifying functions works as intended"""
    rescaling = (np.exp, np.log)
    RescaleToBounds.configure_post_rescaling(reparam, rescaling)
    assert reparam.has_post_rescaling is True
    assert reparam.has_prime_prior is False
    assert reparam.post_rescaling is np.exp
    assert reparam.post_rescaling_inv is np.log
Exemplo n.º 3
0
def test_post_rescaling_with_str(reparam):
    """Assert that specifying a str works as intended.

    Also test the config for the logit
    """
    reparam._update_bounds = False
    reparam.parameters = ['x']
    from nessai.utils.rescaling import rescaling_functions
    rescaling = 'logit'
    RescaleToBounds.configure_post_rescaling(reparam, rescaling)
    assert reparam.has_post_rescaling is True
    assert reparam.has_prime_prior is False
    assert reparam.post_rescaling is rescaling_functions['logit'][0]
    assert reparam.post_rescaling_inv is rescaling_functions['logit'][1]
    assert reparam.rescale_bounds == {'x': [0, 1]}
Exemplo n.º 4
0
def test_post_rescaling_invalid_input(reparam):
    """Assert an error is raised if the input isn't a str or tuple"""
    with pytest.raises(RuntimeError) as excinfo:
        RescaleToBounds.configure_post_rescaling(reparam, (np.exp, ))
    assert 'Post-rescaling must be a str or tuple' in str(excinfo.value)
Exemplo n.º 5
0
def test_post_rescaling_with_invalid_str(reparam):
    """Assert an error is raised if the rescaling is not recognised"""
    rescaling = 'not_a_rescaling'
    with pytest.raises(RuntimeError) as excinfo:
        RescaleToBounds.configure_post_rescaling(reparam, rescaling)
    assert 'Unknown rescaling function: not_a_rescaling' in str(excinfo.value)
Exemplo n.º 6
0
def test_configure_post_rescaling_none(reparam):
    """Test the configuration of the post-rescaling if it is None"""
    RescaleToBounds.configure_post_rescaling(reparam, None)
    assert reparam.has_post_rescaling is False