예제 #1
0
def test_produce():

    mlp = MlpClassifierPrimitive(
        hyperparams=mlp_hp(mlp_hp.defaults(),
                           weights_filepath='/scratch_dir/model_weights.pth',
                           all_confidences=False))
    mlp.set_params(params=mlp_params)

    preds = mlp.produce(inputs=features).value
    assert preds.shape == (features.shape[0], 2)
    assert (preds.columns == ['target', 'confidence']).all()
예제 #2
0
def test_produce_explanations_all_classes():

    mlp = MlpClassifierPrimitive(
        hyperparams=mlp_hp(mlp_hp.defaults(),
                           weights_filepath='/scratch_dir/model_weights.pth',
                           explain_all_classes=True))
    mlp.set_params(params=mlp_params)

    explanations = mlp.produce_explanations(inputs=features).value
    assert explanations.shape == (features.shape[0], mlp._nclasses)
    assert np.array(explanations.iloc[0, 0]).shape == (120, 120)
예제 #3
0
def test_fit():

    mlp = MlpClassifierPrimitive(
        hyperparams=mlp_hp(mlp_hp.defaults(),
                           epochs=1,
                           weights_filepath='/scratch_dir/model_weights.pth'))

    mlp.set_training_data(inputs=features, outputs=labels)
    mlp.fit()
    assert mlp._clf_model[-1].weight.shape[0] == mlp._nclasses
    global mlp_params
    mlp_params = mlp.get_params()
예제 #4
0
def test_produce_all_confidences():

    mlp = MlpClassifierPrimitive(hyperparams=mlp_hp(
        mlp_hp.defaults(),
        weights_filepath='/scratch_dir/model_weights.pth',
    ))
    mlp.set_params(params=mlp_params)

    preds = mlp.produce(inputs=features).value
    nc = mlp._nclasses
    assert preds.shape == (features.shape[0] * nc, 2)
    if nc > 2:
        for i in range(0, features.shape[0]):
            assert round(preds['confidence'][i * nc:(i + 1) * nc].sum()) == 1
    assert (preds.columns == ['target', 'confidence']).all()