def test_load_external(): """ This function tests if a model that has been trained on a different computer can be loaded and used on a different computer. """ x = np.linspace(-10.0, 10.0, 2000) y = x**2 x = np.reshape(x, (x.shape[0], 1)) estimator = MRMP() estimator.load_nn("saved_model") score_after_loading = estimator.score(x, y) score_on_other_machine = -24.101043 assert np.isclose(score_after_loading, score_on_other_machine)
def test_save_local(): """ This function tests the saving and the loading of a trained model. """ x = np.linspace(-10.0, 10.0, 2000) y = x**2 x = np.reshape(x, (x.shape[0], 1)) estimator = MRMP() estimator.fit(x=x, y=y) score_after_training = estimator.score(x, y) estimator.save_nn(save_dir="saved_test_model") estimator.load_nn(save_dir="saved_test_model") score_after_loading = estimator.score(x, y) assert score_after_loading == score_after_training shutil.rmtree("./saved_test_model")