def test_decapitate_model_lazy_input(): """Test an error is raised when the model has a lazy input layer initialization""" # Raise warning when model has lazy input layer initialization error_model = Sequential( [Dense(40, input_shape=(100, )), Dense(20), Activation('softmax')]) with warnings.catch_warnings(record=True) as warning_check: _decapitate_model(error_model, 1) assert len(warning_check) == 1 assert "depth issues" in str(warning_check[-1].message)
def test_decapitate_model(): """ This test creates a toy network, and checks that it calls the right errors and checks that it decapitates the network correctly: """ # Create test model test_model = _decapitate_model(CHECK_MODEL, 5) # Make checks for all of the necessary features: the model outputs, the # last layer, the last layer's connections, and the last layer's shape assert test_model.layers[-1] == test_model.layers[3] assert test_model.layers[3].outbound_nodes == [] assert test_model.outputs == [test_model.layers[3].output] assert test_model.layers[-1].output_shape == (None, 20)
def test_decapitate_model_too_deep(): """Test error raised when model is decapitated too deep""" # Check for Value Error when passed a depth >= (# of layers in network) - 1 with pytest.raises(ValueError): _decapitate_model(CHECK_MODEL, 8)