def test_compute_precision(self): file_path = os.path.join(ROOT_DIR, 'datasets/tests/example_keys') eval_tool = EvaluationTool(legit=0) load_tool = LoadingTool() stats = defaultdict(lambda: defaultdict(int)) trues = pd.Series() preds = pd.Series() for chunk in load_tool.load_classifications(file_path, ';', True): chunk_stats = eval_tool.compute_stats(chunk) trues = trues.append(chunk[0]) preds = preds.append(chunk[1]) for label in chunk_stats: stats[label]['FP'] += chunk_stats[label]['FP'] stats[label]['FN'] += chunk_stats[label]['FN'] stats[label]['TP'] += chunk_stats[label]['TP'] labels = [1, 2] prec = [eval_tool.compute_precision(x, stats) for x in labels] prec_sklearn = list( precision_score(y_true=trues, y_pred=preds, labels=labels, average=None)) assert prec == prec_sklearn
def test_compute_precision_unbalanced(self): file_path = os.path.join(ROOT_DIR, 'datasets/tests/example_unbalanced') eval_tool = EvaluationTool() load_tool = LoadingTool() stats = defaultdict(lambda: defaultdict(int)) trues = pd.Series() preds = pd.Series() for chunk in load_tool.load_classifications(file_path, ';'): chunk_stats = eval_tool.compute_stats(chunk) trues = trues.append(chunk[0]) preds = preds.append(chunk[1]) for label in chunk_stats: stats[label]['FP'] += chunk_stats[label]['FP'] stats[label]['FN'] += chunk_stats[label]['FN'] stats[label]['TP'] += chunk_stats[label]['TP'] prec = [ eval_tool.compute_precision(x, stats) for x in eval_tool.labels ] assert np.isnan(prec[4])
def test_get_labels_with_prec_above(self): file_path = os.path.join(ROOT_DIR, 'datasets/tests/example_keys') e_tool = EvaluationTool() l_tool = LoadingTool() stats = defaultdict(lambda: defaultdict(int)) trues = pd.Series() preds = pd.Series() for chunk in l_tool.load_classifications(file_path, ';', True): chunk_stats = e_tool.compute_stats(chunk) for label in chunk_stats: stats[label]['FP'] += chunk_stats[label]['FP'] stats[label]['FN'] += chunk_stats[label]['FN'] stats[label]['TP'] += chunk_stats[label]['TP'] prec = [e_tool.compute_precision(x, stats) for x in e_tool.labels] threshold = 0.3 precs_above_threshold = e_tool.get_labels_with_prec_above_thres( threshold, e_tool.labels, stats) expected = [0, 1] assert expected == precs_above_threshold