def check_persistence_via_pickle(Estimator): # Check that we can pickle all estimators estimator = _construct_instance(Estimator) set_random_state(estimator) fit_args = _make_args(estimator, "fit") estimator.fit(*fit_args) # Generate results before pickling results = {} args = {} for method in NON_STATE_CHANGING_METHODS: if hasattr(estimator, method): args[method] = _make_args(estimator, method) results[method] = getattr(estimator, method)(*args[method]) # Pickle and unpickle pickled_estimator = pickle.dumps(estimator) # if estimator.__module__.startswith('sktime.'): # assert b"version" in pickled_estimator unpickled_estimator = pickle.loads(pickled_estimator) # Compare against results after pickling for method, value in results.items(): unpickled_result = getattr(unpickled_estimator, method)(*args[method]) _assert_almost_equal(value, unpickled_result)
def check_fit_idempotent(Estimator): # Check that calling fit twice is equivalent to calling it once estimator = _construct_instance(Estimator) set_random_state(estimator) # Fit for the first time fit_args = _make_args(estimator, "fit") estimator.fit(*fit_args) results = {} args = {} for method in NON_STATE_CHANGING_METHODS: if hasattr(estimator, method): args[method] = _make_args(estimator, method) results[method] = getattr(estimator, method)(*args[method]) # Fit again set_random_state(estimator) estimator.fit(*fit_args) for method in NON_STATE_CHANGING_METHODS: if hasattr(estimator, method): new_result = getattr(estimator, method)(*args[method]) _assert_almost_equal( results[method], new_result, err_msg=f"Idempotency check failed for method {method}")