def test_metrics_lambda_reset(): m0 = ListGatherMetric(0) m1 = ListGatherMetric(1) m2 = ListGatherMetric(2) m0.update([1, 10, 100]) m1.update([1, 10, 100]) m2.update([1, 10, 100]) def fn(x, y, z, t): return 1 m = MetricsLambda(fn, m0, m1, z=m2, t=0) # initiating a new instance of MetricsLambda must reset # its argument metrics assert m0.list_ is None assert m1.list_ is None assert m2.list_ is None m0.update([1, 10, 100]) m1.update([1, 10, 100]) m2.update([1, 10, 100]) m.reset() assert m0.list_ is None assert m1.list_ is None assert m2.list_ is None
def test_metrics_lambda_update_and_attach_together(): y_pred = torch.randint(0, 2, size=(15, 10, 4)).float() y = torch.randint(0, 2, size=(15, 10, 4)).long() def update_fn(engine, batch): y_pred, y = batch return y_pred, y engine = Engine(update_fn) precision = Precision(average=False) recall = Recall(average=False) def Fbeta(r, p, beta): return torch.mean((1 + beta**2) * p * r / (beta**2 * p + r)).item() F1 = MetricsLambda(Fbeta, recall, precision, 1) F1.attach(engine, "f1") with pytest.raises( ValueError, match=r"MetricsLambda is already attached to an engine"): F1.update((y_pred, y)) y_pred = torch.randint(0, 2, size=(15, 10, 4)).float() y = torch.randint(0, 2, size=(15, 10, 4)).long() F1 = MetricsLambda(Fbeta, recall, precision, 1) F1.update((y_pred, y)) engine = Engine(update_fn) with pytest.raises(ValueError, match=r"The underlying metrics are already updated"): F1.attach(engine, "f1") F1.reset() F1.attach(engine, "f1")
def test_metrics_lambda_update(): """ Test if the underlying metrics are updated """ y_pred = torch.randint(0, 2, size=(15, 10, 4)).float() y = torch.randint(0, 2, size=(15, 10, 4)).long() precision = Precision(average=False) recall = Recall(average=False) def Fbeta(r, p, beta): return torch.mean((1 + beta ** 2) * p * r / (beta ** 2 * p + r)).item() F1 = MetricsLambda(Fbeta, recall, precision, 1) F1.update((y_pred, y)) assert precision._updated assert recall._updated F1.reset() assert not precision._updated assert not recall._updated """ Test multiple updates and if the inputs of the underlying metrics are updated multiple times """ y_pred1 = torch.randint(0, 2, size=(15,)) y1 = torch.randint(0, 2, size=(15,)) y_pred2 = torch.randint(0, 2, size=(15,)) y2 = torch.randint(0, 2, size=(15,)) F1.update((y_pred1, y1)) F1.update((y_pred2, y2)) # Compute true_positives and positives for precision correct1 = y1 * y_pred1 all_positives1 = y_pred1.sum(dim=0) if correct1.sum() == 0: true_positives1 = torch.zeros_like(all_positives1) else: true_positives1 = correct1.sum(dim=0) correct2 = y2 * y_pred2 all_positives2 = y_pred2.sum(dim=0) if correct2.sum() == 0: true_positives2 = torch.zeros_like(all_positives2) else: true_positives2 = correct2.sum(dim=0) true_positives = true_positives1 + true_positives2 positives = all_positives1 + all_positives2 assert precision._type == "binary" assert precision._true_positives == true_positives assert precision._positives == positives # Computing positivies for recall is different positives1 = y1.sum(dim=0) positives2 = y2.sum(dim=0) positives = positives1 + positives2 assert recall._type == "binary" assert recall._true_positives == true_positives assert recall._positives == positives """ Test compute """ F1.reset() F1.update((y_pred1, y1)) F1_metrics_lambda = F1.compute() F1_sklearn = f1_score(y1.numpy(), y_pred1.numpy()) assert pytest.approx(F1_metrics_lambda) == F1_sklearn