def test_dimN(N): cl = CleanLearning(clf=ReshapingLogisticRegression()) size = [100] + [3 for _ in range(N - 1)] X = np.random.normal(size=size) labels = np.random.randint(0, 4, size=100) # ensure that every class is represented labels[0:10] = 0 labels[11:20] = 1 labels[21:30] = 2 labels[31:40] = 3 # just make sure we don't crash... cl.fit(X, labels) cl.predict(X) cl.predict_proba(X) cl.score(X, labels)
def test_pred_and_pred_proba(sparse): data = SPARSE_DATA if sparse else DATA cl = CleanLearning() cl.fit(data["X_train"], data["labels"]) n = np.shape(data["true_labels_test"])[0] m = len(np.unique(data["true_labels_test"])) pred = cl.predict(data["X_test"]) probs = cl.predict_proba(data["X_test"]) # Just check that this functions return what we expect assert np.shape(pred)[0] == n assert np.shape(probs) == (n, m)
def test_aux_inputs(): data = DATA K = len(np.unique(data["labels"])) confident_joint = np.ones(shape=(K, K)) np.fill_diagonal(confident_joint, 10) find_label_issues_kwargs = { "confident_joint": confident_joint, "min_examples_per_class": 2, } cl = CleanLearning( clf=LogisticRegression(multi_class="auto", solver="lbfgs", random_state=SEED), find_label_issues_kwargs=find_label_issues_kwargs, verbose=1, ) label_issues_df = cl.find_label_issues(data["X_train"], data["labels"], clf_kwargs={}) assert isinstance(label_issues_df, pd.DataFrame) FIND_OUTPUT_COLUMNS = [ "is_label_issue", "label_quality", "given_label", "predicted_label" ] assert list(label_issues_df.columns) == FIND_OUTPUT_COLUMNS assert label_issues_df.equals(cl.get_label_issues()) cl.fit( data["X_train"], data["labels"], label_issues=label_issues_df, clf_kwargs={}, clf_final_kwargs={}, ) label_issues_df = cl.get_label_issues() assert isinstance(label_issues_df, pd.DataFrame) assert list(label_issues_df.columns) == (FIND_OUTPUT_COLUMNS + ["sample_weight"]) score = cl.score(data["X_test"], data["true_labels_test"]) # Test a second fit cl.fit(data["X_train"], data["labels"]) # Test cl.find_label_issues with pred_prob input pred_probs_test = cl.predict_proba(data["X_test"]) label_issues_df = cl.find_label_issues(X=None, labels=data["true_labels_test"], pred_probs=pred_probs_test) assert isinstance(label_issues_df, pd.DataFrame) assert list(label_issues_df.columns) == FIND_OUTPUT_COLUMNS assert label_issues_df.equals(cl.get_label_issues()) cl.save_space() assert cl.label_issues_df is None # Verbose off cl = CleanLearning(clf=LogisticRegression(multi_class="auto", solver="lbfgs", random_state=SEED), verbose=0) cl.save_space() # dummy call test cl = CleanLearning(clf=LogisticRegression(multi_class="auto", solver="lbfgs", random_state=SEED), verbose=0) cl.find_label_issues(labels=data["true_labels_test"], pred_probs=pred_probs_test, save_space=True) cl = CleanLearning(clf=LogisticRegression(multi_class="auto", solver="lbfgs", random_state=SEED), verbose=1) # Test with label_issues_mask input label_issues_mask = find_label_issues( labels=data["true_labels_test"], pred_probs=pred_probs_test, ) cl.fit(data["X_test"], data["true_labels_test"], label_issues=label_issues_mask) label_issues_df = cl.get_label_issues() assert isinstance(label_issues_df, pd.DataFrame) assert set(label_issues_df.columns).issubset(FIND_OUTPUT_COLUMNS) # Test with label_issues_indices input label_issues_indices = find_label_issues( labels=data["true_labels_test"], pred_probs=pred_probs_test, return_indices_ranked_by="confidence_weighted_entropy", ) cl.fit(data["X_test"], data["true_labels_test"], label_issues=label_issues_indices) label_issues_df2 = cl.get_label_issues().copy() assert isinstance(label_issues_df2, pd.DataFrame) assert set(label_issues_df2.columns).issubset(FIND_OUTPUT_COLUMNS) assert label_issues_df2["is_label_issue"].equals( label_issues_df["is_label_issue"]) # Test fit() with pred_prob input: cl.fit( data["X_test"], data["true_labels_test"], pred_probs=pred_probs_test, label_issues=label_issues_mask, ) label_issues_df = cl.get_label_issues() assert isinstance(label_issues_df, pd.DataFrame) assert set(label_issues_df.columns).issubset(FIND_OUTPUT_COLUMNS) assert "label_quality" in label_issues_df.columns