Ejemplo n.º 1
0
def test_decapitate_model_lazy_input():
    """Test an error is raised when the model has a lazy input layer initialization"""
    # Raise warning when model has lazy input layer initialization
    error_model = Sequential(
        [Dense(40, input_shape=(100, )),
         Dense(20),
         Activation('softmax')])

    with warnings.catch_warnings(record=True) as warning_check:
        _decapitate_model(error_model, 1)
        assert len(warning_check) == 1
        assert "depth issues" in str(warning_check[-1].message)
Ejemplo n.º 2
0
def test_decapitate_model():
    """
    This test creates a toy network, and checks that it calls the right errors
    and checks that it decapitates the network correctly:
    """
    # Create test model
    test_model = _decapitate_model(CHECK_MODEL, 5)

    # Make checks for all of the necessary features: the model outputs, the
    # last layer, the last layer's connections, and the last layer's shape
    assert test_model.layers[-1] == test_model.layers[3]
    assert test_model.layers[3].outbound_nodes == []
    assert test_model.outputs == [test_model.layers[3].output]
    assert test_model.layers[-1].output_shape == (None, 20)
Ejemplo n.º 3
0
def test_decapitate_model_too_deep():
    """Test error raised when model is decapitated too deep"""
    # Check for Value Error when passed a depth >= (# of layers in network) - 1
    with pytest.raises(ValueError):
        _decapitate_model(CHECK_MODEL, 8)