Ejemplo n.º 1
0
def test_generate_densenet_model_dropout():
    architecture = generate_densenet_2d_architecture(n_upsampling=4,
                                                     n_filter=48,
                                                     stride=2,
                                                     dropout_rate=0.1)

    input_params = {'shape': [128, 128, 3]}

    model = load_architecture(architecture, input_params)
    model.summary()
Ejemplo n.º 2
0
def test_generate_densenet_model_batchnorm():
    architecture = generate_densenet_2d_architecture(n_upsampling=4,
                                                     n_filter=48,
                                                     stride=2,
                                                     batchnorm=True)

    input_params = {'shape': [128, 128, 3]}

    model = load_architecture(architecture, input_params)
    model.summary()
Ejemplo n.º 3
0
def test_generate_densenet_model_resize():
    architecture = generate_densenet_2d_architecture(n_upsampling=4,
                                                     n_filter=48,
                                                     stride=2)

    input_params = {'shape': [129, 129, 3]}

    model = load_architecture(architecture, input_params)
    model.summary()

    assert np.all(model.input_shape == (None, 129, 129, 3))
    assert np.all(model.output_shape == (None, 129, 129, 1))
Ejemplo n.º 4
0
def test_generate_densenet_model():
    architecture = generate_densenet_2d_architecture(n_upsampling=4,
                                                     n_filter=48,
                                                     stride=2)

    input_params = {'shape': [128, 128, 3]}

    model = load_architecture(architecture, input_params)
    model.summary()

    expected_model = load_architecture(
        load_json_config('tests/json/densenet_architecture.json'),
        input_params)

    assert check_same_models(model, expected_model)
Ejemplo n.º 5
0
def test_create_densenet_json():
    unet = generate_densenet_2d_architecture()
    generate_densenet_2d_json(FILE_NAME)
    actual = load_json_config(FILE_NAME)

    assert np.all(unet == actual)