Ejemplo n.º 1
0
def check_persistence_via_pickle(Estimator):
    # Check that we can pickle all estimators
    estimator = _construct_instance(Estimator)
    set_random_state(estimator)
    fit_args = _make_args(estimator, "fit")
    estimator.fit(*fit_args)

    # Generate results before pickling
    results = dict()
    args = dict()
    for method in NON_STATE_CHANGING_METHODS:
        if hasattr(estimator, method):
            args[method] = _make_args(estimator, method)
            results[method] = getattr(estimator, method)(*args[method])

    # Pickle and unpickle
    pickled_estimator = pickle.dumps(estimator)
    unpickled_estimator = pickle.loads(pickled_estimator)

    # Compare against results after pickling
    for method in results:
        unpickled_result = getattr(unpickled_estimator, method)(*args[method])
        _assert_array_almost_equal(
            results[method],
            unpickled_result,
            decimal=6,
            err_msg="Results are not the same after pickling",
        )
Ejemplo n.º 2
0
def check_fit_idempotent(Estimator):
    # Check that calling fit twice is equivalent to calling it once
    estimator = _construct_instance(Estimator)
    set_random_state(estimator)

    # Fit for the first time
    fit_args = _make_args(estimator, "fit")
    estimator.fit(*fit_args)

    results = dict()
    args = dict()
    for method in NON_STATE_CHANGING_METHODS:
        if hasattr(estimator, method):
            args[method] = _make_args(estimator, method)
            results[method] = getattr(estimator, method)(*args[method])

    # Fit again
    set_random_state(estimator)
    estimator.fit(*fit_args)

    for method in NON_STATE_CHANGING_METHODS:
        if hasattr(estimator, method):
            new_result = getattr(estimator, method)(*args[method])
            _assert_array_almost_equal(
                results[method],
                new_result,
                # err_msg=f"Idempotency check failed for method {method}",
            )
Ejemplo n.º 3
0
def test_load_UCR_UEA_dataset_download(tmpdir):
    # tmpdir is a pytest fixture
    extract_path = tmpdir.mkdtemp()
    name = "ArrowHead"
    actual_X, actual_y = load_UCR_UEA_dataset(name,
                                              return_X_y=True,
                                              extract_path=extract_path)
    data_path = os.path.join(extract_path, name)
    assert os.path.exists(data_path)

    # check files
    files = [
        f"{name}.txt",
        f"{name}_TEST.arff",
        f"{name}_TEST.ts",
        f"{name}_TEST.txt",
        f"{name}_TRAIN.arff",
        f"{name}_TRAIN.ts",
        f"{name}_TRAIN.txt",
        # "README.md",
    ]

    for file in os.listdir(data_path):
        assert file in files
        files.remove(file)
    assert len(files) == 0

    # check data
    expected_X, expected_y = load_arrow_head(return_X_y=True)
    _assert_array_almost_equal(actual_X, expected_X, decimal=4)
    np.testing.assert_array_equal(expected_y, actual_y)
Ejemplo n.º 4
0
def check_transform_inverse_transform_equivalent(Estimator):
    estimator = _construct_instance(Estimator)
    X = _make_args(estimator, "fit")[0]
    Xt = estimator.fit_transform(X)
    Xit = estimator.inverse_transform(Xt)
    _assert_array_almost_equal(X, Xit)