def test_load_cv_folds_non_float_ids(): """ Test to check that CV folds with non-float IDs raise error when converted to floats """ # create custom CV folds custom_cv_folds = make_cv_folds_data()[1] # write the generated CV folds to a CSV file fold_file_path = join(_my_dir, 'other', 'custom_folds.csv') with open(fold_file_path, 'wb' if PY2 else 'w') as foldf: w = csv.writer(foldf) w.writerow(['id', 'fold']) for example_id, fold_label in custom_cv_folds.items(): w.writerow([example_id, fold_label]) # now read the CSV file using _load_cv_folds, which should raise ValueError _load_cv_folds(fold_file_path, ids_to_floats=True)
def test_load_cv_folds(): """ Test to check that cross-validation folds are correctly loaded from a CSV file """ # create custom CV folds custom_cv_folds = make_cv_folds_data()[1] # write the generated CV folds to a CSV file fold_file_path = join(_my_dir, 'other', 'custom_folds.csv') with open(fold_file_path, 'wb' if PY2 else 'w') as foldf: w = csv.writer(foldf) w.writerow(['id', 'fold']) for example_id, fold_label in custom_cv_folds.items(): w.writerow([example_id, fold_label]) # now read the CSV file using _load_cv_folds custom_cv_folds_loaded = _load_cv_folds(fold_file_path) eq_(custom_cv_folds_loaded, custom_cv_folds)