Exemplo n.º 1
0
def test_Ensemble_fit_generator():
    tf.keras.backend.clear_session()

    graph = example_graph_1(feature_size=10)

    # base_model, keras_model, generator, train_gen
    gnn_models = [
        create_graphSAGE_model(graph),
        create_HinSAGE_model(graph),
        create_GCN_model(graph),
        create_GAT_model(graph),
    ]

    for gnn_model in gnn_models:
        keras_model = gnn_model[1]
        generator = gnn_model[2]
        train_gen = gnn_model[3]

        ens = Ensemble(keras_model, n_estimators=2, n_predictions=1)

        ens.compile(optimizer=Adam(),
                    loss=categorical_crossentropy,
                    weighted_metrics=["acc"])

        ens.fit_generator(train_gen, epochs=1, verbose=0, shuffle=False)

        with pytest.raises(ValueError):
            ens.fit_generator(
                generator=generator,  # wrong type
                epochs=10,
                validation_data=train_gen,
                verbose=0,
                shuffle=False,
            )
Exemplo n.º 2
0
def test_deprecated_methods():
    tf.keras.backend.clear_session()

    train_data = np.array([1, 2])
    train_targets = np.array([[1, 0], [0, 1]])

    graph = example_graph_1(feature_size=2)

    _, keras_model, gen, train_gen = create_GAT_model(graph)

    ensemble = Ensemble(keras_model, n_estimators=1, n_predictions=1)
    bagging = BaggingEnsemble(keras_model, n_estimators=1, n_predictions=1)
    models = [ensemble, bagging]
    for model in models:
        model.compile(optimizer=Adam(), loss=binary_crossentropy)

    # check that each of the generator methods gives a warning, and also seems to behave like the
    # non-deprecated method
    with pytest.warns(DeprecationWarning, match="'fit_generator' .* 'fit'"):
        ens_history = ensemble.fit_generator(train_gen, epochs=2, verbose=0)
    assert len(ens_history) == 1
    assert len(ens_history[0].history["loss"]) == 2

    with pytest.warns(DeprecationWarning, match="'fit_generator' .* 'fit'"):
        bag_history = bagging.fit_generator(
            gen, train_data, train_targets, epochs=2, verbose=0
        )
    assert len(bag_history) == 1
    assert len(bag_history[0].history["loss"]) == 2

    for model in models:
        with pytest.warns(
            DeprecationWarning, match="'evaluate_generator' .* 'evaluate'"
        ):
            eval_result = model.evaluate_generator(train_gen, verbose=0)
        np.testing.assert_array_equal(eval_result, model.evaluate(train_gen, verbose=0))

        with pytest.warns(DeprecationWarning, match="'predict_generator' .* 'predict'"):
            pred_result = model.predict_generator(train_gen, verbose=0)
        np.testing.assert_array_equal(pred_result, model.predict(train_gen, verbose=0))