def test_produce(): mlp = MlpClassifierPrimitive( hyperparams=mlp_hp(mlp_hp.defaults(), weights_filepath='/scratch_dir/model_weights.pth', all_confidences=False)) mlp.set_params(params=mlp_params) preds = mlp.produce(inputs=features).value assert preds.shape == (features.shape[0], 2) assert (preds.columns == ['target', 'confidence']).all()
def test_produce_explanations_all_classes(): mlp = MlpClassifierPrimitive( hyperparams=mlp_hp(mlp_hp.defaults(), weights_filepath='/scratch_dir/model_weights.pth', explain_all_classes=True)) mlp.set_params(params=mlp_params) explanations = mlp.produce_explanations(inputs=features).value assert explanations.shape == (features.shape[0], mlp._nclasses) assert np.array(explanations.iloc[0, 0]).shape == (120, 120)
def test_fit(): mlp = MlpClassifierPrimitive( hyperparams=mlp_hp(mlp_hp.defaults(), epochs=1, weights_filepath='/scratch_dir/model_weights.pth')) mlp.set_training_data(inputs=features, outputs=labels) mlp.fit() assert mlp._clf_model[-1].weight.shape[0] == mlp._nclasses global mlp_params mlp_params = mlp.get_params()
def test_produce_all_confidences(): mlp = MlpClassifierPrimitive(hyperparams=mlp_hp( mlp_hp.defaults(), weights_filepath='/scratch_dir/model_weights.pth', )) mlp.set_params(params=mlp_params) preds = mlp.produce(inputs=features).value nc = mlp._nclasses assert preds.shape == (features.shape[0] * nc, 2) if nc > 2: for i in range(0, features.shape[0]): assert round(preds['confidence'][i * nc:(i + 1) * nc].sum()) == 1 assert (preds.columns == ['target', 'confidence']).all()