def test_pre_rescaling_with_functions(reparam): """Assert that specifying functions works as intended""" rescaling = (np.exp, np.log) RescaleToBounds.configure_pre_rescaling(reparam, rescaling) assert reparam.has_pre_rescaling is True assert reparam.pre_rescaling is np.exp assert reparam.pre_rescaling_inv is np.log
def test_pre_rescaling_with_str(reparam): """Assert that specifying a str works as intended""" from nessai.utils.rescaling import rescaling_functions rescaling = 'logit' RescaleToBounds.configure_pre_rescaling(reparam, rescaling) assert reparam.has_pre_rescaling is True assert reparam.pre_rescaling is rescaling_functions['logit'][0] assert reparam.pre_rescaling_inv is rescaling_functions['logit'][1]
def test_pre_rescaling_invalid_input(reparam): """Assert an error is raised if the input isn't a str or tuple""" with pytest.raises(RuntimeError) as excinfo: RescaleToBounds.configure_pre_rescaling(reparam, (np.exp, )) assert 'Pre-rescaling must be a str or tuple' in str(excinfo.value)
def test_pre_rescaling_with_invalid_str(reparam): """Assert an error is raised if the rescaling is not recognised""" rescaling = 'not_a_rescaling' with pytest.raises(RuntimeError) as excinfo: RescaleToBounds.configure_pre_rescaling(reparam, rescaling) assert 'Unknown rescaling function: not_a_rescaling' in str(excinfo.value)
def test_configure_pre_rescaling_none(reparam): """Test the configuration of the pre-rescaling if it is None""" RescaleToBounds.configure_pre_rescaling(reparam, None) assert reparam.has_pre_rescaling is False