Пример #1
0
def test_Ensemble_fit_generator():
    tf.keras.backend.clear_session()

    graph = example_graph_1(feature_size=10)

    # base_model, keras_model, generator, train_gen
    gnn_models = [
        create_graphSAGE_model(graph),
        create_HinSAGE_model(graph),
        create_GCN_model(graph),
        create_GAT_model(graph),
    ]

    for gnn_model in gnn_models:
        keras_model = gnn_model[1]
        generator = gnn_model[2]
        train_gen = gnn_model[3]

        ens = Ensemble(keras_model, n_estimators=2, n_predictions=1)

        ens.compile(optimizer=Adam(),
                    loss=categorical_crossentropy,
                    weighted_metrics=["acc"])

        ens.fit_generator(train_gen, epochs=1, verbose=0, shuffle=False)

        with pytest.raises(ValueError):
            ens.fit_generator(
                generator=generator,  # wrong type
                epochs=10,
                validation_data=train_gen,
                verbose=0,
                shuffle=False,
            )
Пример #2
0
def test_fit_generator():

    train_data = np.array([1, 2])
    train_targets = np.array([[1, 0], [0, 1]])

    graph = example_graph_1(feature_size=10)

    # base_model, keras_model, generator, train_gen
    gnn_models = [
        create_graphSAGE_model(graph),
        create_HinSAGE_model(graph),
        create_graphSAGE_model(graph, link_prediction=True),
        create_HinSAGE_model(graph, link_prediction=True),
        create_GCN_model(graph),
        create_GAT_model(graph),
    ]

    for gnn_model in gnn_models:
        keras_model = gnn_model[1]
        generator = gnn_model[2]
        train_gen = gnn_model[3]

        ens = Ensemble(keras_model, n_estimators=2, n_predictions=1)

        ens.compile(optimizer=Adam(),
                    loss=categorical_crossentropy,
                    weighted_metrics=["acc"])

        # Specifying train_data and train_targets, implies the use of bagging so train_gen would
        # be of the wrong type for this call to fit_generator.
        with pytest.raises(ValueError):
            ens.fit_generator(
                train_gen,
                train_data=train_data,
                train_targets=train_targets,
                epochs=10,
                validation_generator=train_gen,
                verbose=0,
                shuffle=False,
            )

        with pytest.raises(ValueError):
            ens.fit_generator(
                generator=generator,
                train_data=train_data,
                train_targets=None,  # Should not be None
                epochs=10,
                validation_generator=train_gen,
                verbose=0,
                shuffle=False,
            )

        with pytest.raises(ValueError):
            ens.fit_generator(
                generator=generator,
                train_data=None,
                train_targets=None,
                epochs=10,
                validation_generator=None,
                verbose=0,
                shuffle=False,
            )

        with pytest.raises(ValueError):
            ens.fit_generator(
                generator=generator,
                train_data=train_data,
                train_targets=train_targets,
                epochs=10,
                validation_generator=None,
                verbose=0,
                shuffle=False,
                bag_size=
                -1,  # should be positive integer smaller than or equal to len(train_data) or None
            )

        with pytest.raises(ValueError):
            ens.fit_generator(
                generator=generator,
                train_data=train_data,
                train_targets=train_targets,
                epochs=10,
                validation_generator=None,
                verbose=0,
                shuffle=False,
                bag_size=10,  # larger than the number of training points
            )