def test_extract_regularizer_mutliple_raises(): callbacks = [SplineRegularizer(), SplineRegularizer()] with pytest.raises(NotImplementedError) as excinfo: extract_regularizer(callbacks) err_msg = 'Multiple regularizer callbacks where provided.' assert err_msg == str(excinfo.value)
def test_extract_regularizer(callback): callbacks = [Logger(), callback] assert extract_regularizer(callbacks) == callback
def test_extract_regularizer_no_regularizer(): callbacks = [Logger()] assert extract_regularizer(callbacks) is None