示例#1
0
def test_override_required_output_keys():
    # https://discuss.pytorch.org/t/how-access-inputs-in-custom-ignite-metric/91221/5
    import torch.nn as nn

    from ignite.engine import create_supervised_evaluator

    counter = [0]

    class CustomMetric(Metric):
        required_output_keys = ("y_pred", "y", "x")

        def __init__(self, *args, **kwargs):
            super().__init__(*args, **kwargs)

        def update(self, output):
            y_pred, y, x = output
            assert y_pred.shape == (4, 3)
            assert y.shape == (4, )
            assert x.shape == (4, 10)
            assert x.equal(data[counter[0]][0])
            assert y.equal(data[counter[0]][1])
            counter[0] += 1

        def reset(self):
            pass

        def compute(self):
            pass

    model = nn.Linear(10, 3)

    metrics = {"Precision": Precision(), "CustomMetric": CustomMetric()}

    evaluator = create_supervised_evaluator(
        model,
        metrics=metrics,
        output_transform=lambda x, y, y_pred: {
            "x": x,
            "y": y,
            "y_pred": y_pred
        })

    data = [
        (torch.rand(4, 10), torch.randint(0, 3, size=(4, ))),
        (torch.rand(4, 10), torch.randint(0, 3, size=(4, ))),
        (torch.rand(4, 10), torch.randint(0, 3, size=(4, ))),
    ]
    evaluator.run(data)
示例#2
0
def test_override_required_output_keys():
    # https://github.com/pytorch/ignite/issues/1415
    from ignite.engine import create_supervised_evaluator

    counter = [0]

    class DummyLoss2(Loss):
        def __init__(self, *args, **kwargs):
            super().__init__(*args, **kwargs)

        def update(self, output):
            y_pred, y, criterion_kwargs = output
            assert y_pred.shape == (4, 3)
            assert y.shape == (4, )
            assert criterion_kwargs == c_kwargs
            assert y.equal(data[counter[0]][1])
            counter[0] += 1

        def reset(self):
            pass

        def compute(self):
            pass

    model = nn.Linear(10, 3)

    metrics = {"Precision": Precision(), "DummyLoss2": DummyLoss2(nll_loss)}

    # global criterion kwargs
    c_kwargs = {"reduction": "sum"}

    evaluator = create_supervised_evaluator(
        model,
        metrics=metrics,
        output_transform=lambda x, y, y_pred: {
            "x": x,
            "y": y,
            "y_pred": y_pred,
            "criterion_kwargs": c_kwargs
        },
    )

    data = [
        (torch.rand(4, 10), torch.randint(0, 3, size=(4, ))),
        (torch.rand(4, 10), torch.randint(0, 3, size=(4, ))),
        (torch.rand(4, 10), torch.randint(0, 3, size=(4, ))),
    ]
    evaluator.run(data)
示例#3
0
    def _test():
        y_true = np.arange(0, n_iters * batch_size * dist.get_world_size()) % n_classes
        y_pred = 0.2 * np.random.rand(
            n_iters * batch_size * dist.get_world_size(), n_classes
        )
        for i in range(n_iters * batch_size * dist.get_world_size()):
            if np.random.rand() > 0.4:
                y_pred[i, y_true[i]] = 1.0
            else:
                j = np.random.randint(0, n_classes)
                y_pred[i, j] = 0.7

        y_true = y_true.reshape(n_iters * dist.get_world_size(), batch_size)
        y_pred = y_pred.reshape(n_iters * dist.get_world_size(), batch_size, n_classes)

        def update_fn(engine, i):
            y_true_batch = y_true[i + rank * n_iters, ...]
            y_pred_batch = y_pred[i + rank * n_iters, ...]
            return torch.from_numpy(y_pred_batch), torch.from_numpy(y_true_batch)

        evaluator = Engine(update_fn)

        precision = Precision(average=False, device=device)
        recall = Recall(average=False, device=device)

        def Fbeta(r, p, beta):
            return torch.mean((1 + beta ** 2) * p * r / (beta ** 2 * p + r)).item()

        F1 = MetricsLambda(Fbeta, recall, precision, 1)
        F1.attach(evaluator, "f1")

        another_f1 = (
            (1.0 + precision * recall * 2 / (precision + recall + 1e-20)).mean().item()
        )
        another_f1.attach(evaluator, "ff1")

        data = list(range(n_iters))
        state = evaluator.run(data, max_epochs=1)

        assert "f1" in state.metrics
        assert "ff1" in state.metrics
        f1_true = f1_score(
            y_true.ravel(),
            np.argmax(y_pred.reshape(-1, n_classes), axis=-1),
            average="macro",
        )
        assert f1_true == approx(state.metrics["f1"])
        assert 1.0 + f1_true == approx(state.metrics["ff1"])
示例#4
0
def _test_distrib_integration(device):

    rank = idist.get_rank()
    torch.manual_seed(12)

    def _test(p, r, average, n_epochs):
        n_iters = 60
        s = 16
        n_classes = 7

        offset = n_iters * s
        y_true = torch.randint(0, n_classes, size=(offset * idist.get_world_size(),)).to(device)
        y_preds = torch.rand(offset * idist.get_world_size(), n_classes).to(device)

        def update(engine, i):
            return (
                y_preds[i * s + rank * offset : (i + 1) * s + rank * offset, :],
                y_true[i * s + rank * offset : (i + 1) * s + rank * offset],
            )

        engine = Engine(update)

        fbeta = Fbeta(beta=2.5, average=average, device=device)
        fbeta.attach(engine, "f2.5")

        data = list(range(n_iters))
        engine.run(data=data, max_epochs=n_epochs)

        assert "f2.5" in engine.state.metrics
        res = engine.state.metrics["f2.5"]
        if isinstance(res, torch.Tensor):
            res = res.cpu().numpy()

        true_res = fbeta_score(
            y_true.cpu().numpy(),
            torch.argmax(y_preds, dim=1).cpu().numpy(),
            beta=2.5,
            average="macro" if average else None,
        )

        assert pytest.approx(res) == true_res

    _test(None, None, average=True, n_epochs=1)
    _test(None, None, average=True, n_epochs=2)
    precision = Precision(average=False)
    recall = Recall(average=False)
    _test(precision, recall, average=False, n_epochs=1)
    _test(precision, recall, average=False, n_epochs=2)
示例#5
0
def test_integration():

    def _test(p, r, average):
        np.random.seed(1)

        n_iters = 10
        batch_size = 10
        n_classes = 10

        y_true = np.arange(0, n_iters * batch_size) % n_classes
        y_pred = 0.2 * np.random.rand(n_iters * batch_size, n_classes)
        for i in range(n_iters * batch_size):
            if np.random.rand() > 0.4:
                y_pred[i, y_true[i]] = 1.0
            else:
                j = np.random.randint(0, n_classes)
                y_pred[i, j] = 0.7

        y_true_batch_values = iter(y_true.reshape(n_iters, batch_size))
        y_pred_batch_values = iter(y_pred.reshape(n_iters, batch_size, n_classes))

        def update_fn(engine, batch):
            y_true_batch = next(y_true_batch_values)
            y_pred_batch = next(y_pred_batch_values)
            return torch.from_numpy(y_pred_batch), torch.from_numpy(y_true_batch)

        evaluator = Engine(update_fn)

        f2 = Fbeta(beta=2.0, average=average, precision=p, recall=r)
        f2.attach(evaluator, "f2")

        data = list(range(n_iters))
        state = evaluator.run(data, max_epochs=1)

        f2_true = fbeta_score(y_true, np.argmax(y_pred, axis=-1),
                              average='macro' if average else None, beta=2.0)
        if isinstance(state.metrics['f2'], torch.Tensor):
            assert (f2_true == state.metrics['f2'].numpy()).all(), "{} vs {}".format(f2_true, state.metrics['f2'])
        else:
            assert f2_true == pytest.approx(state.metrics['f2']), "{} vs {}".format(f2_true, state.metrics['f2'])

    _test(None, None, False)
    _test(None, None, True)
    precision = Precision(average=False)
    recall = Recall(average=False)
    _test(precision, recall, False)
    _test(precision, recall, True)
示例#6
0
def build_metrics():
    metrics = {
        'precision':
        Precision(lambda x: (x[0].argmax(dim=1), x[1])),
        'recall':
        Recall(lambda x: (x[0].argmax(dim=1), x[1])),
        'accuracy':
        Accuracy(lambda x: (x[0].argmax(dim=1), x[1])),
        'confusion_matrix':
        ConfusionMatrix(2, output_transform=prepare_confusion_matrix),
        'metric_output_results':
        MetricOutputResults(lambda x: (x[0].argmax(dim=1), x[1], x[0], x[2])),
        'metric_last_layer':
        MetricLastLayer(lambda x: (x[3])),
        #'auroc':AUROC(lambda x: (x[0], x[1])),
    }
    return metrics
示例#7
0
    def _test(average):
        pr = Precision(average=average)

        y_pred = torch.softmax(torch.rand(4, 4), dim=1)
        y = torch.ones(4).type(torch.LongTensor)
        pr.update((y_pred, y))

        y_pred = torch.rand(4, 1)
        y = torch.ones(4).type(torch.LongTensor)

        with pytest.raises(RuntimeError):
            pr.update((y_pred, y))
示例#8
0
    def _test(average):
        pr = Precision(average=average)

        y_pred = torch.softmax(torch.rand(4, 4), dim=1)
        y = torch.ones(4).long()
        pr.update((y_pred, y))

        y_pred = torch.randint(0, 2, size=(4,))
        y = torch.ones(4).long()

        with pytest.raises(RuntimeError):
            pr.update((y_pred, y))
示例#9
0
def test_multiclass_wrong_inputs():
    pr = Precision()

    with pytest.raises(ValueError):
        # incompatible shapes
        pr.update(
            (torch.rand(10, 5,
                        4), torch.randint(0, 2,
                                          size=(10, )).type(torch.LongTensor)))

    with pytest.raises(ValueError):
        # incompatible shapes
        pr.update((torch.rand(10, 5, 6),
                   torch.randint(0, 5, size=(10, 5)).type(torch.LongTensor)))

    with pytest.raises(ValueError):
        # incompatible shapes
        pr.update((torch.rand(10),
                   torch.randint(0, 5,
                                 size=(10, 5, 6)).type(torch.LongTensor)))
示例#10
0
def test_incorrect_shape():
    precision = Precision()

    y_pred = torch.zeros(2, 3, 2, 2)
    y = torch.zeros(2, 3)

    with pytest.raises(ValueError):
        precision.update((y_pred, y))

    y_pred = torch.zeros(2, 3, 2, 2)
    y = torch.zeros(2, 3, 4, 4)

    with pytest.raises(ValueError):
        precision.update((y_pred, y))
示例#11
0
def test_integration_ingredients_not_attached():
    np.random.seed(1)

    n_iters = 10
    batch_size = 10
    n_classes = 10

    y_true = np.arange(0, n_iters * batch_size) % n_classes
    y_pred = 0.2 * np.random.rand(n_iters * batch_size, n_classes)
    for i in range(n_iters * batch_size):
        if np.random.rand() > 0.4:
            y_pred[i, y_true[i]] = 1.0
        else:
            j = np.random.randint(0, n_classes)
            y_pred[i, j] = 0.7

    y_true_batch_values = iter(y_true.reshape(n_iters, batch_size))
    y_pred_batch_values = iter(y_pred.reshape(n_iters, batch_size, n_classes))

    def update_fn(engine, batch):
        y_true_batch = next(y_true_batch_values)
        y_pred_batch = next(y_pred_batch_values)
        return torch.from_numpy(y_pred_batch), torch.from_numpy(y_true_batch)

    evaluator = Engine(update_fn)

    precision = Precision(average=False)
    recall = Recall(average=False)

    def Fbeta(r, p, beta):
        return torch.mean((1 + beta ** 2) * p * r / (beta ** 2 * p + r)).item()

    F1 = MetricsLambda(Fbeta, recall, precision, 1)
    F1.attach(evaluator, "f1")

    data = list(range(n_iters))
    state = evaluator.run(data, max_epochs=1)
    f1_true = f1_score(y_true, np.argmax(y_pred, axis=-1), average="macro")
    assert f1_true == approx(state.metrics["f1"]), "{} vs {}".format(
        f1_true, state.metrics["f1"]
    )
示例#12
0
def Fbeta(beta, output_transform=lambda x: x, average=True, precision=None, recall=None, device=None):
    """Calculates F-beta score

    Args:
        beta (float): weight of precision in harmonic mean
        output_transform (callable, optional): a callable that is used to transform the
            :class:`~ignite.engine.Engine`'s `process_function`'s output into the
            form expected by the metric. This can be useful if, for example, you have a multi-output model and
            you want to compute the metric with respect to one of the outputs.
        average (bool, optional): if True, F-beta score is computed as the unweighted average (across all classes
            in multiclass case), otherwise, returns a tensor with F-beta score for each class in multiclass case.
        precision (Precision, optional): precision object metric with `average=False` to compute F-beta score
        recall (Precision, optional): recall object metric with `average=False` to compute F-beta score
        device (str of torch.device, optional): device specification in case of distributed computation usage.
            In most of the cases, it can be defined as "cuda:local_rank" or "cuda"
            if already set `torch.cuda.set_device(local_rank)`. By default, if a distributed process group is
            initialized and available, device is set to `cuda`.

    Returns:
        MetricsLambda, F-beta metric
    """
    if not (beta > 0):
        raise ValueError("Beta should be a positive integer, but given {}".format(beta))

    if precision is None:
        precision = Precision(output_transform=output_transform, average=False, device=device)
    elif precision._average:
        raise ValueError("Input precision metric should have average=False")

    if recall is None:
        recall = Recall(output_transform=output_transform, average=False, device=device)
    elif recall._average:
        raise ValueError("Input recall metric should have average=False")

    fbeta = (1.0 + beta ** 2) * precision * recall / (beta ** 2 * precision + recall)

    if average:
        fbeta = fbeta.mean().item()

    return fbeta
示例#13
0
def test_metrics_lambda_update_and_attach_together():

    y_pred = torch.randint(0, 2, size=(15, 10, 4)).float()
    y = torch.randint(0, 2, size=(15, 10, 4)).long()

    def update_fn(engine, batch):
        y_pred, y = batch
        return y_pred, y

    engine = Engine(update_fn)

    precision = Precision(average=False)
    recall = Recall(average=False)

    def Fbeta(r, p, beta):
        return torch.mean((1 + beta**2) * p * r / (beta**2 * p + r)).item()

    F1 = MetricsLambda(Fbeta, recall, precision, 1)

    F1.attach(engine, "f1")
    with pytest.raises(
            ValueError,
            match=r"MetricsLambda is already attached to an engine"):
        F1.update((y_pred, y))

    y_pred = torch.randint(0, 2, size=(15, 10, 4)).float()
    y = torch.randint(0, 2, size=(15, 10, 4)).long()

    F1 = MetricsLambda(Fbeta, recall, precision, 1)
    F1.update((y_pred, y))

    engine = Engine(update_fn)

    with pytest.raises(ValueError,
                       match=r"The underlying metrics are already updated"):
        F1.attach(engine, "f1")

    F1.reset()
    F1.attach(engine, "f1")
示例#14
0
def test_predict(model,dataloader_test,use_cuda):
    if use_cuda:
        model = model.cuda()
    
    precision = Precision()
    recall = Recall()
    f1 = Fbeta(beta=1.0, average=True, precision=precision, recall=recall)
    
    for i,(img, label) in enumerate(dataloader_test):
        img, labels = Variable(img),Variable(label)
        if use_cuda:
            img = img.cuda()
            label = label.cuda()
            pred = model(img)
            _,my_label = torch.max(label, dim=1)
            precision.update((pred, my_label))
            recall.update((pred, my_label))
            f1.update((pred, my_label))
            
    precision.compute()
    recall.compute()
    print("\tF1 Score: {:0.2f}".format(f1.compute()*100))
示例#15
0
def test_multilabel_wrong_inputs():
    pr = Precision(average=True, is_multilabel=True)

    with pytest.raises(ValueError):
        # incompatible shapes
        pr.update((torch.randint(0, 2, size=(10, )),
                   torch.randint(0, 2, size=(10, )).long()))

    with pytest.raises(ValueError):
        # incompatible y_pred
        pr.update((torch.rand(10, 5), torch.randint(0, 2,
                                                    size=(10, 5)).long()))

    with pytest.raises(ValueError):
        # incompatible y
        pr.update((torch.randint(0, 5, size=(10, 5, 6)), torch.rand(10)))

    with pytest.raises(ValueError):
        # incompatible shapes between two updates
        pr.update((torch.randint(0, 2, size=(20, 5)),
                   torch.randint(0, 2, size=(20, 5)).long()))
        pr.update((torch.randint(0, 2, size=(20, 6)),
                   torch.randint(0, 2, size=(20, 6)).long()))
示例#16
0
    def _test(average):
        pr = Precision(average=average)
        y_pred = torch.randint(0, 2, size=(10,))
        y = torch.randint(0, 2, size=(10,)).long()
        pr.update((y_pred, y))
        np_y = y.numpy().ravel()
        np_y_pred = y_pred.numpy().ravel()
        assert pr._type == "binary"
        assert isinstance(pr.compute(), float if average else torch.Tensor)
        pr_compute = pr.compute() if average else pr.compute().numpy()
        assert precision_score(np_y, np_y_pred, average="binary") == pytest.approx(pr_compute)

        pr.reset()
        y_pred = torch.randint(0, 2, size=(10,))
        y = torch.randint(0, 2, size=(10,)).long()
        pr.update((y_pred, y))
        np_y = y.numpy().ravel()
        np_y_pred = y_pred.numpy().ravel()
        assert pr._type == "binary"
        assert isinstance(pr.compute(), float if average else torch.Tensor)
        pr_compute = pr.compute() if average else pr.compute().numpy()
        assert precision_score(np_y, np_y_pred, average="binary") == pytest.approx(pr_compute)

        pr.reset()
        y_pred = torch.Tensor([0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.51])
        y_pred = torch.round(y_pred)
        y = torch.randint(0, 2, size=(10,)).long()
        pr.update((y_pred, y))
        np_y = y.numpy().ravel()
        np_y_pred = y_pred.numpy().ravel()
        assert pr._type == "binary"
        assert isinstance(pr.compute(), float if average else torch.Tensor)
        pr_compute = pr.compute() if average else pr.compute().numpy()
        assert precision_score(np_y, np_y_pred, average="binary") == pytest.approx(pr_compute)

        # Batched Updates
        pr.reset()
        y_pred = torch.randint(0, 2, size=(100,))
        y = torch.randint(0, 2, size=(100,)).long()

        batch_size = 16
        n_iters = y.shape[0] // batch_size + 1

        for i in range(n_iters):
            idx = i * batch_size
            pr.update((y_pred[idx : idx + batch_size], y[idx : idx + batch_size]))

        np_y = y.numpy().ravel()
        np_y_pred = y_pred.numpy().ravel()
        assert pr._type == "binary"
        assert isinstance(pr.compute(), float if average else torch.Tensor)
        pr_compute = pr.compute() if average else pr.compute().numpy()
        assert precision_score(np_y, np_y_pred, average="binary") == pytest.approx(pr_compute)
示例#17
0
    def _test(average):
        pr = Precision(average=average, is_multilabel=True)

        y_pred = torch.randint(0, 2, size=(10, 5, 18, 16))
        y = torch.randint(0, 2, size=(10, 5, 18, 16)).long()
        pr.update((y_pred, y))
        np_y_pred = to_numpy_multilabel(y_pred)
        np_y = to_numpy_multilabel(y)
        assert pr._type == "multilabel"
        pr_compute = pr.compute() if average else pr.compute().mean().item()
        with warnings.catch_warnings():
            warnings.simplefilter("ignore", category=UndefinedMetricWarning)
            assert precision_score(np_y, np_y_pred, average="samples") == pytest.approx(pr_compute)

        pr.reset()
        y_pred = torch.randint(0, 2, size=(10, 4, 20, 23))
        y = torch.randint(0, 2, size=(10, 4, 20, 23)).long()
        pr.update((y_pred, y))
        np_y_pred = to_numpy_multilabel(y_pred)
        np_y = to_numpy_multilabel(y)
        assert pr._type == "multilabel"
        pr_compute = pr.compute() if average else pr.compute().mean().item()
        with warnings.catch_warnings():
            warnings.simplefilter("ignore", category=UndefinedMetricWarning)
            assert precision_score(np_y, np_y_pred, average="samples") == pytest.approx(pr_compute)

        # Batched Updates
        pr.reset()
        y_pred = torch.randint(0, 2, size=(100, 5, 12, 14))
        y = torch.randint(0, 2, size=(100, 5, 12, 14)).long()

        batch_size = 16
        n_iters = y.shape[0] // batch_size + 1

        for i in range(n_iters):
            idx = i * batch_size
            pr.update((y_pred[idx : idx + batch_size], y[idx : idx + batch_size]))

        np_y = to_numpy_multilabel(y)
        np_y_pred = to_numpy_multilabel(y_pred)
        assert pr._type == "multilabel"
        pr_compute = pr.compute() if average else pr.compute().mean().item()
        with warnings.catch_warnings():
            warnings.simplefilter("ignore", category=UndefinedMetricWarning)
            assert precision_score(np_y, np_y_pred, average="samples") == pytest.approx(pr_compute)
示例#18
0
def test_binary_wrong_inputs():
    pr = Precision()

    with pytest.raises(ValueError):
        # y has not only 0 or 1 values
        pr.update((torch.randint(0, 2, size=(10,)).long(), torch.arange(0, 10).long()))

    with pytest.raises(ValueError):
        # y_pred values are not thresholded to 0, 1 values
        pr.update((torch.rand(10,), torch.randint(0, 2, size=(10,)).long()))

    with pytest.raises(ValueError):
        # incompatible shapes
        pr.update((torch.randint(0, 2, size=(10,)).long(), torch.randint(0, 2, size=(10, 5)).long()))

    with pytest.raises(ValueError):
        # incompatible shapes
        pr.update((torch.randint(0, 2, size=(10, 5, 6)).long(), torch.randint(0, 2, size=(10,)).long()))

    with pytest.raises(ValueError):
        # incompatible shapes
        pr.update((torch.randint(0, 2, size=(10,)).long(), torch.randint(0, 2, size=(10, 5, 6)).long()))
示例#19
0
    def _test(average):
        pr = Precision(average=average)

        y_pred = torch.rand(10, 5, 18, 16)
        y = torch.randint(0, 5, size=(10, 18, 16)).long()
        pr.update((y_pred, y))
        num_classes = y_pred.shape[1]
        np_y_pred = y_pred.argmax(dim=1).numpy().ravel()
        np_y = y.numpy().ravel()
        assert pr._type == "multiclass"
        assert isinstance(pr.compute(), float if average else torch.Tensor)
        pr_compute = pr.compute() if average else pr.compute().numpy()
        sk_average_parameter = "macro" if average else None
        with warnings.catch_warnings():
            warnings.simplefilter("ignore", category=UndefinedMetricWarning)
            sk_compute = precision_score(np_y, np_y_pred, labels=range(0, num_classes), average=sk_average_parameter)
            assert sk_compute == pytest.approx(pr_compute)

        pr.reset()
        y_pred = torch.rand(10, 7, 20, 12)
        y = torch.randint(0, 7, size=(10, 20, 12)).long()
        pr.update((y_pred, y))
        num_classes = y_pred.shape[1]
        np_y_pred = y_pred.argmax(dim=1).numpy().ravel()
        np_y = y.numpy().ravel()
        assert pr._type == "multiclass"
        assert isinstance(pr.compute(), float if average else torch.Tensor)
        pr_compute = pr.compute() if average else pr.compute().numpy()
        sk_average_parameter = "macro" if average else None
        with warnings.catch_warnings():
            warnings.simplefilter("ignore", category=UndefinedMetricWarning)
            sk_compute = precision_score(np_y, np_y_pred, labels=range(0, num_classes), average=sk_average_parameter)
            assert sk_compute == pytest.approx(pr_compute)

        # Batched Updates
        pr.reset()
        y_pred = torch.rand(100, 8, 12, 14)
        y = torch.randint(0, 8, size=(100, 12, 14)).long()

        batch_size = 16
        n_iters = y.shape[0] // batch_size + 1

        for i in range(n_iters):
            idx = i * batch_size
            pr.update((y_pred[idx : idx + batch_size], y[idx : idx + batch_size]))

        num_classes = y_pred.shape[1]
        np_y = y.numpy().ravel()
        np_y_pred = y_pred.argmax(dim=1).numpy().ravel()
        assert pr._type == "multiclass"
        assert isinstance(pr.compute(), float if average else torch.Tensor)
        pr_compute = pr.compute() if average else pr.compute().numpy()
        sk_average_parameter = "macro" if average else None
        with warnings.catch_warnings():
            warnings.simplefilter("ignore", category=UndefinedMetricWarning)
            sk_compute = precision_score(np_y, np_y_pred, labels=range(0, num_classes), average=sk_average_parameter)
            assert sk_compute == pytest.approx(pr_compute)
示例#20
0
def test_indexing_metric():
    def _test(ignite_metric,
              sklearn_metic,
              sklearn_args,
              index,
              num_classes=5):
        y_pred = torch.rand(15, 10, num_classes).float()
        y = torch.randint(0, num_classes, size=(15, 10)).long()

        def update_fn(engine, batch):
            y_pred, y = batch
            return y_pred, y

        metrics = {
            "metric": ignite_metric[index],
            "metric_wo_index": ignite_metric
        }

        validator = Engine(update_fn)

        for name, metric in metrics.items():
            metric.attach(validator, name)

        def data(y_pred, y):
            for i in range(y_pred.shape[0]):
                yield (y_pred[i], y[i])

        d = data(y_pred, y)
        state = validator.run(d, max_epochs=1, epoch_length=y_pred.shape[0])

        sklearn_output = sklearn_metic(
            y.view(-1).numpy(),
            y_pred.view(-1, num_classes).argmax(dim=1).numpy(), **sklearn_args)

        assert (state.metrics["metric_wo_index"][index] ==
                state.metrics["metric"]).all()
        assert np.allclose(state.metrics["metric"].numpy(), sklearn_output)

    num_classes = 5

    labels = list(range(0, num_classes, 2))
    _test(Precision(),
          precision_score, {
              "labels": labels,
              "average": None
          },
          index=labels)
    labels = list(range(num_classes - 1, 0, -2))
    _test(Precision(),
          precision_score, {
              "labels": labels,
              "average": None
          },
          index=labels)
    labels = [1]
    _test(Precision(),
          precision_score, {
              "labels": labels,
              "average": None
          },
          index=labels)

    labels = list(range(0, num_classes, 2))
    _test(Recall(),
          recall_score, {
              "labels": labels,
              "average": None
          },
          index=labels)
    labels = list(range(num_classes - 1, 0, -2))
    _test(Recall(),
          recall_score, {
              "labels": labels,
              "average": None
          },
          index=labels)
    labels = [1]
    _test(Recall(),
          recall_score, {
              "labels": labels,
              "average": None
          },
          index=labels)

    # np.ix_ is used to allow for a 2D slice of a matrix. This is required to get accurate result from
    # ConfusionMatrix. ConfusionMatrix must be sliced the same row-wise and column-wise.
    labels = list(range(0, num_classes, 2))
    _test(ConfusionMatrix(num_classes),
          confusion_matrix, {"labels": labels},
          index=np.ix_(labels, labels))
    labels = list(range(num_classes - 1, 0, -2))
    _test(ConfusionMatrix(num_classes),
          confusion_matrix, {"labels": labels},
          index=np.ix_(labels, labels))
    labels = [1]
    _test(ConfusionMatrix(num_classes),
          confusion_matrix, {"labels": labels},
          index=np.ix_(labels, labels))
示例#21
0
def test_multiclass_wrong_inputs():
    pr = Precision()

    with pytest.raises(ValueError):
        # incompatible shapes
        pr.update((torch.rand(10, 5, 4), torch.randint(0, 2, size=(10,)).long()))

    with pytest.raises(ValueError):
        # incompatible shapes
        pr.update((torch.rand(10, 5, 6), torch.randint(0, 5, size=(10, 5)).long()))

    with pytest.raises(ValueError):
        # incompatible shapes
        pr.update((torch.rand(10), torch.randint(0, 5, size=(10, 5, 6)).long()))

    pr = Precision(average=True)

    with pytest.raises(ValueError):
        # incompatible shapes between two updates
        pr.update((torch.rand(10, 5), torch.randint(0, 5, size=(10,)).long()))
        pr.update((torch.rand(10, 6), torch.randint(0, 5, size=(10,)).long()))

    with pytest.raises(ValueError):
        # incompatible shapes between two updates
        pr.update((torch.rand(10, 5, 12, 14), torch.randint(0, 5, size=(10, 12, 14)).long()))
        pr.update((torch.rand(10, 6, 12, 14), torch.randint(0, 5, size=(10, 12, 14)).long()))

    pr = Precision(average=False)

    with pytest.raises(ValueError):
        # incompatible shapes between two updates
        pr.update((torch.rand(10, 5), torch.randint(0, 5, size=(10,)).long()))
        pr.update((torch.rand(10, 6), torch.randint(0, 5, size=(10,)).long()))

    with pytest.raises(ValueError):
        # incompatible shapes between two updates
        pr.update((torch.rand(10, 5, 12, 14), torch.randint(0, 5, size=(10, 12, 14)).long()))
        pr.update((torch.rand(10, 6, 12, 14), torch.randint(0, 5, size=(10, 12, 14)).long()))
示例#22
0
def test_pytorch_operators():
    def _test(composed_metric, metric_name, compute_true_value_fn):

        metrics = {
            metric_name: composed_metric,
        }

        y_pred = torch.rand(15, 10, 5).float()
        y = torch.randint(0, 5, size=(15, 10)).long()

        def update_fn(engine, batch):
            y_pred, y = batch
            return y_pred, y

        validator = Engine(update_fn)

        for name, metric in metrics.items():
            metric.attach(validator, name)

        def data(y_pred, y):
            for i in range(y_pred.shape[0]):
                yield (y_pred[i], y[i])

        d = data(y_pred, y)
        state = validator.run(d, max_epochs=1, epoch_length=y_pred.shape[0])

        assert set(state.metrics.keys()) == set([metric_name])
        np_y_pred = np.argmax(y_pred.numpy(), axis=-1).ravel()
        np_y = y.numpy().ravel()
        assert state.metrics[metric_name] == approx(
            compute_true_value_fn(np_y_pred, np_y))

    precision_1 = Precision(average=False)
    precision_2 = Precision(average=False)
    norm_summed_precision = (precision_1 + precision_2).norm(p=10)

    def compute_true_norm_summed_precision(y_pred, y):
        p1 = precision_score(y, y_pred, average=None)
        p2 = precision_score(y, y_pred, average=None)
        return np.linalg.norm(p1 + p2, ord=10)

    _test(norm_summed_precision,
          "mean summed precision",
          compute_true_value_fn=compute_true_norm_summed_precision)

    precision = Precision(average=False)
    recall = Recall(average=False)
    sum_precision_recall = (precision + recall).sum()

    def compute_sum_precision_recall(y_pred, y):
        p = precision_score(y, y_pred, average=None)
        r = recall_score(y, y_pred, average=None)
        return np.sum(p + r)

    _test(sum_precision_recall,
          "sum precision recall",
          compute_true_value_fn=compute_sum_precision_recall)

    precision = Precision(average=False)
    recall = Recall(average=False)
    f1 = (precision * recall * 2 / (precision + recall + 1e-20)).mean()

    def compute_f1(y_pred, y):
        f1 = f1_score(y, y_pred, average="macro")
        return f1

    _test(f1, "f1", compute_true_value_fn=compute_f1)
示例#23
0
def test_binary_input(average):

    pr = Precision(average=average)

    def _test(y_pred, y, batch_size):
        pr.reset()
        pr.update((y_pred, y))

        np_y = y.numpy().ravel()
        np_y_pred = y_pred.numpy().ravel()

        if batch_size > 1:
            n_iters = y.shape[0] // batch_size + 1
            for i in range(n_iters):
                idx = i * batch_size
                pr.update(
                    (y_pred[idx:idx + batch_size], y[idx:idx + batch_size]))

        assert pr._type == "binary"
        assert isinstance(pr.compute(), float if average else torch.Tensor)
        pr_compute = pr.compute() if average else pr.compute().numpy()
        assert precision_score(np_y, np_y_pred,
                               average="binary") == pytest.approx(pr_compute)

    def get_test_cases():

        test_cases = [
            # Binary accuracy on input of shape (N, 1) or (N, )
            (torch.randint(0, 2, size=(10, )), torch.randint(0, 2,
                                                             size=(10, )), 1),
            (torch.randint(0, 2,
                           size=(10, 1)), torch.randint(0, 2,
                                                        size=(10, 1)), 1),
            # updated batches
            (torch.randint(0, 2, size=(50, )), torch.randint(0, 2,
                                                             size=(50, )), 16),
            (torch.randint(0, 2,
                           size=(50, 1)), torch.randint(0, 2,
                                                        size=(50, 1)), 16),
            # Binary accuracy on input of shape (N, L)
            (torch.randint(0, 2,
                           size=(10, 5)), torch.randint(0, 2,
                                                        size=(10, 5)), 1),
            (torch.randint(0, 2, size=(10, 1, 5)),
             torch.randint(0, 2, size=(10, 1, 5)), 1),
            # updated batches
            (torch.randint(0, 2,
                           size=(50, 5)), torch.randint(0, 2,
                                                        size=(50, 5)), 16),
            (torch.randint(0, 2, size=(50, 1, 5)),
             torch.randint(0, 2, size=(50, 1, 5)), 16),
            # Binary accuracy on input of shape (N, H, W)
            (torch.randint(0, 2, size=(10, 12, 10)),
             torch.randint(0, 2, size=(10, 12, 10)), 1),
            (torch.randint(0, 2, size=(10, 1, 12, 10)),
             torch.randint(0, 2, size=(10, 1, 12, 10)), 1),
            # updated batches
            (torch.randint(0, 2, size=(50, 12, 10)),
             torch.randint(0, 2, size=(50, 12, 10)), 16),
            (torch.randint(0, 2, size=(50, 1, 12, 10)),
             torch.randint(0, 2, size=(50, 1, 12, 10)), 16),
        ]

        return test_cases

    for _ in range(5):
        # check multiple random inputs as random exact occurencies are rare
        test_cases = get_test_cases()
        for y_pred, y, batch_size in test_cases:
            _test(y, y_pred, batch_size)
示例#24
0
def test_recursive_attachment():
    def _test(composed_metric, metric_name, compute_true_value_fn):

        metrics = {
            metric_name: composed_metric,
        }

        y_pred = torch.randint(0, 2, size=(15, 10, 4)).float()
        y = torch.randint(0, 2, size=(15, 10, 4)).long()

        def update_fn(engine, batch):
            y_pred, y = batch
            return y_pred, y

        validator = Engine(update_fn)

        for name, metric in metrics.items():
            metric.attach(validator, name)

        def data(y_pred, y):
            for i in range(y_pred.shape[0]):
                yield (y_pred[i], y[i])

        d = data(y_pred, y)
        state = validator.run(d, max_epochs=1)

        assert set(state.metrics.keys()) == set([
            metric_name,
        ])
        np_y_pred = y_pred.numpy().ravel()
        np_y = y.numpy().ravel()
        assert state.metrics[metric_name] == approx(
            compute_true_value_fn(np_y_pred, np_y))

    precision_1 = Precision()
    precision_2 = Precision()
    summed_precision = precision_1 + precision_2

    def compute_true_summed_precision(y_pred, y):
        p1 = precision_score(y, y_pred)
        p2 = precision_score(y, y_pred)
        return p1 + p2

    _test(summed_precision,
          "summed precision",
          compute_true_value_fn=compute_true_summed_precision)

    precision_1 = Precision()
    precision_2 = Precision()
    mean_precision = (precision_1 + precision_2) / 2

    def compute_true_mean_precision(y_pred, y):
        p1 = precision_score(y, y_pred)
        p2 = precision_score(y, y_pred)
        return (p1 + p2) * 0.5

    _test(mean_precision,
          "mean precision",
          compute_true_value_fn=compute_true_mean_precision)

    precision_1 = Precision()
    precision_2 = Precision()
    some_metric = 2.0 + 0.2 * (precision_1 * precision_2 + precision_1 -
                               precision_2)**0.5

    def compute_true_somemetric(y_pred, y):
        p1 = precision_score(y, y_pred)
        p2 = precision_score(y, y_pred)
        return 2.0 + 0.2 * (p1 * p2 + p1 - p2)**0.5

    _test(some_metric, "some metric", compute_true_somemetric)
示例#25
0
def _test_distrib_integration_multilabel(device):

    from ignite.engine import Engine

    rank = idist.get_rank()
    torch.manual_seed(12)

    def _test(average, n_epochs):
        n_iters = 60
        s = 16
        n_classes = 7

        offset = n_iters * s
        y_true = torch.randint(0, 2, size=(offset * idist.get_world_size(), n_classes, 6, 8)).to(device)
        y_preds = torch.randint(0, 2, size=(offset * idist.get_world_size(), n_classes, 6, 8)).to(device)

        def update(engine, i):
            return (
                y_preds[i * s + rank * offset : (i + 1) * s + rank * offset, ...],
                y_true[i * s + rank * offset : (i + 1) * s + rank * offset, ...],
            )

        engine = Engine(update)

        pr = Precision(average=average, is_multilabel=True)
        pr.attach(engine, "pr")

        data = list(range(n_iters))
        engine.run(data=data, max_epochs=n_epochs)

        assert "pr" in engine.state.metrics
        res = engine.state.metrics["pr"]
        res2 = pr.compute()
        if isinstance(res, torch.Tensor):
            res = res.cpu().numpy()
            res2 = res2.cpu().numpy()
            assert (res == res2).all()
        else:
            assert res == res2

        with warnings.catch_warnings():
            warnings.simplefilter("ignore", category=UndefinedMetricWarning)
            true_res = precision_score(
                to_numpy_multilabel(y_true), to_numpy_multilabel(y_preds), average="samples" if average else None
            )

        assert pytest.approx(res) == true_res

    for _ in range(2):
        _test(average=True, n_epochs=1)
        _test(average=True, n_epochs=2)

    if idist.get_world_size() > 1:
        with pytest.warns(
            RuntimeWarning,
            match="Precision/Recall metrics do not work in distributed setting when "
            "average=False and is_multilabel=True",
        ):
            pr = Precision(average=False, is_multilabel=True)

        y_pred = torch.randint(0, 2, size=(4, 3, 6, 8))
        y = torch.randint(0, 2, size=(4, 3, 6, 8)).long()
        pr.update((y_pred, y))
        pr_compute1 = pr.compute()
        pr_compute2 = pr.compute()
        assert len(pr_compute1) == 4 * 6 * 8
        assert (pr_compute1 == pr_compute2).all()
def create_eval_engine(model, is_multilabel, n_classes, cpu):

  def process_function(engine, batch):
      X, y = batch

      if cpu:
          pred = model(X.cpu())
          gold = y.cpu()
      else:
          pred = model(X.cuda())
          gold = y.cuda()

      return pred, gold

  eval_engine = Engine(process_function)

  if is_multilabel:
      accuracy = MulticlassOverallAccuracy(n_classes=n_classes)
      accuracy.attach(eval_engine, "accuracy")
      per_class_accuracy = MulticlassPerClassAccuracy(n_classes=n_classes)
      per_class_accuracy.attach(eval_engine, "per class accuracy")
      recall = MulticlassRecall(n_classes=n_classes)
      recall.attach(eval_engine, "recall")
      precision = MulticlassPrecision(n_classes=n_classes)
      precision.attach(eval_engine, "precision")
      f1 = MulticlassF(n_classes=n_classes, f_n=1)
      f1.attach(eval_engine, "f1")
      f2= MulticlassF(n_classes=n_classes, f_n=2)
      f2.attach(eval_engine, "f2")

      avg_recall = MulticlassRecall(n_classes=n_classes, average=True)
      avg_recall.attach(eval_engine, "average recall")
      avg_precision = MulticlassPrecision(n_classes=n_classes, average=True)
      avg_precision.attach(eval_engine, "average precision")
      avg_f1 = MulticlassF(n_classes=n_classes, average=True, f_n=1)
      avg_f1.attach(eval_engine, "average f1")
      avg_f2= MulticlassF(n_classes=n_classes, average=True, f_n=2)
      avg_f2.attach(eval_engine, "average f2")
  else:
      accuracy = Accuracy()
      accuracy.attach(eval_engine, "accuracy")
      recall = Recall(average=False)
      recall.attach(eval_engine, "recall")
      precision = Precision(average=False)
      precision.attach(eval_engine, "precision")
      confusion_matrix = ConfusionMatrix(num_classes=n_classes)
      confusion_matrix.attach(eval_engine, "confusion_matrix")
      f1 = (precision * recall * 2 / (precision + recall))
      f1.attach(eval_engine, "f1")
      f2 = (precision * recall * 5 / ((4*precision) + recall))
      f2.attach(eval_engine, "f2")

      def Fbeta(r, p, beta):
          return torch.mean((1 + beta ** 2) * p * r / (beta ** 2 * p + r + 1e-20)).item()

      avg_f1 = MetricsLambda(Fbeta, recall, precision, 1)
      avg_f1.attach(eval_engine, "average f1")
      avg_f2 = MetricsLambda(Fbeta, recall, precision, 2)
      avg_f2.attach(eval_engine, "average f2")

      avg_recall = Recall(average=True)
      avg_recall.attach(eval_engine, "average recall")
      avg_precision = Precision(average=True)
      avg_precision.attach(eval_engine, "average precision")

      if n_classes == 2:
          top_k = TopK(k=10, label_idx_of_interest=0)
          top_k.attach(eval_engine, "top_k")

  return eval_engine
示例#27
0
def run(tb, vb, lr, epochs, writer):
  device = os.environ['main-device']
  logging.info('Training program start!')
  logging.info('Configuration:')
  logging.info('\n'+json.dumps(INFO, indent=2))

  # ------------------------------------
  # 1. Define dataloader
  train_loader, train4val_loader, val_loader = get_dataloaders(tb, vb)
  
  # ------------------------------------
  # 2. Define model
  model = EfficientNet.from_pretrained('efficientnet-b3', num_classes=INFO['dataset-info']['num-of-classes'])
  model = carrier(model)
  
  # ------------------------------------
  # 3. Define optimizer
  optimizer = optim.SGD(model.parameters(), lr=lr, momentum=0.9)
  
  # ------------------------------------
  # 4. Define metrics
  metrics = {
    'accuracy': Accuracy(),
    'loss': Loss(nn.functional.cross_entropy),
    'precision_recall': MetricsLambda(PrecisionRecallTable, Precision(), Recall(), train_loader.dataset.classes),
    'cmatrix': MetricsLambda(CMatrixTable, ConfusionMatrix(7), train_loader.dataset.classes)
  }
  
  # ------------------------------------
  # 5. Create trainer
  trainer = create_supervised_trainer(model, optimizer, nn.functional.cross_entropy, device=device)
  
  # ------------------------------------
  # 6. Create evaluator
  evaluator = create_supervised_evaluator(model, metrics=metrics, device=device)

  desc = 'ITERATION - loss: {:.4f}'
  pbar = tqdm(
    initial=0, leave=False, total=len(train_loader),
    desc=desc.format(0)
  )


  # ------------------------------------
  # 7. Create event hooks
  @trainer.on(Events.ITERATION_COMPLETED)
  def log_training_loss(engine):
    log_interval = 5
    iter = (engine.state.iteration - 1) % len(train_loader) + 1
    if iter % log_interval == 0:
      pbar.desc = desc.format(engine.state.output)
      pbar.update(log_interval)

  @trainer.on(Events.EPOCH_COMPLETED)
  def log_training_results(engine):
    pbar.refresh()
    print ('Checking on training set.')
    evaluator.run(train4val_loader)
    metrics = evaluator.state.metrics
    avg_accuracy = metrics['accuracy']
    avg_loss = metrics['loss']
    precision_recall = metrics['precision_recall']
    cmatrix = metrics['cmatrix']
    prompt = GetTemplate('default-log').format('Training',engine.state.epoch,avg_accuracy,avg_loss,precision_recall['pretty'],cmatrix['pretty'])
    tqdm.write(prompt)
    logging.info('\n'+prompt)
    writer.add_text(os.environ['run-id'], prompt, engine.state.epoch)
    writer.add_scalars('Aggregate/Acc', {'Train Acc': avg_accuracy}, engine.state.epoch)
    writer.add_scalars('Aggregate/Loss', {'Train Loss': avg_loss}, engine.state.epoch)
    # writer.add_scalars('Aggregate/Score', {'Train avg precision': precision_recall['data'][0, -1], 'Train avg recall': precision_recall['data'][1, -1]}, engine.state.epoch)
    # pbar.n = pbar.last_print_n = 0
  
  @trainer.on(Events.EPOCH_COMPLETED)
  def log_validation_results(engine):
    print ('Checking on validation set.')
    evaluator.run(val_loader)
    metrics = evaluator.state.metrics
    avg_accuracy = metrics['accuracy']
    avg_loss = metrics['loss']
    precision_recall = metrics['precision_recall']
    cmatrix = metrics['cmatrix']
    prompt = GetTemplate('default-log').format('Validating',engine.state.epoch,avg_accuracy,avg_loss,precision_recall['pretty'],cmatrix['pretty'])
    tqdm.write(prompt)
    logging.info('\n'+prompt)
    writer.add_text(os.environ['run-id'], prompt, engine.state.epoch)
    writer.add_scalars('Aggregate/Acc', {'Val Acc': avg_accuracy}, engine.state.epoch)
    writer.add_scalars('Aggregate/Loss', {'Val Loss': avg_loss}, engine.state.epoch)
    writer.add_scalars('Aggregate/Score', {'Val avg precision': precision_recall['data'][0, -1], 'Val avg recall': precision_recall['data'][1, -1]}, engine.state.epoch)
    pbar.n = pbar.last_print_n = 0

  # ------------------------------------
  # Run
  trainer.run(train_loader, max_epochs=epochs)
  pbar.close()
def run(tb, vb, lr, epochs, writer):
    device = os.environ['main-device']
    logging.info('Training program start!')
    logging.info('Configuration:')
    logging.info('\n' + json.dumps(INFO, indent=2))

    # ------------------------------------
    # 1. Define dataloader
    train_loader, train4val_loader, val_loader, num_of_images, mapping = get_dataloaders(
        tb, vb)
    # train_loader, train4val_loader, val_loader, num_of_images = get_dataloaders(tb, vb)
    weights = (1 / num_of_images) / ((1 / num_of_images).sum().item())
    # weights = (1/num_of_images)/(1/num_of_images + 1/(num_of_images.sum().item()-num_of_images))
    weights = weights.to(device=device)

    # ------------------------------------
    # 2. Define model
    model = EfficientNet.from_pretrained(
        'efficientnet-b3', num_classes=INFO['dataset-info']['num-of-classes'])
    model = carrier(model)

    # ------------------------------------
    # 3. Define optimizer
    optimizer = optim.SGD(model.parameters(), lr=lr, momentum=0.9)
    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=200)
    ignite_scheduler = LRScheduler(scheduler)

    # ------------------------------------
    # 4. Define metrics

    class DOCLoss(nn.Module):
        def __init__(self, weight):
            super(DOCLoss, self).__init__()
            self.class_weights = weight

        def forward(self, input, target):
            sigmoid = 1 - 1 / (1 + torch.exp(-input))
            sigmoid[range(0, sigmoid.shape[0]),
                    target] = 1 - sigmoid[range(0, sigmoid.shape[0]), target]
            sigmoid = torch.log(sigmoid)
            if self.class_weights is not None:
                loss = -torch.sum(sigmoid * self.class_weights)
            else:
                loss = -torch.sum(sigmoid)
            return loss

    train_metrics = {
        'accuracy':
        Accuracy(),
        'loss':
        Loss(DOCLoss(weight=weights)),
        'precision_recall':
        MetricsLambda(PrecisionRecallTable, Precision(), Recall(),
                      train_loader.dataset.classes),
        'cmatrix':
        MetricsLambda(CMatrixTable,
                      ConfusionMatrix(INFO['dataset-info']['num-of-classes']),
                      train_loader.dataset.classes)
    }

    def val_pred_transform(output):
        y_pred, y = output
        new_y_pred = torch.zeros(
            (y_pred.shape[0],
             INFO['dataset-info']['num-of-classes'] + 1)).to(device=device)
        for ind, c in enumerate(train_loader.dataset.classes):
            new_col = val_loader.dataset.class_to_idx[c]
            new_y_pred[:, new_col] += y_pred[:, ind]
        ukn_ind = val_loader.dataset.class_to_idx['UNKNOWN']
        import math
        new_y_pred[:, ukn_ind] = -math.inf
        return new_y_pred, y

    val_metrics = {
        'accuracy':
        Accuracy(),
        'precision_recall':
        MetricsLambda(PrecisionRecallTable, Precision(val_pred_transform),
                      Recall(val_pred_transform), val_loader.dataset.classes),
        'cmatrix':
        MetricsLambda(
            CMatrixTable,
            ConfusionMatrix(INFO['dataset-info']['num-of-classes'] + 1,
                            output_transform=val_pred_transform),
            val_loader.dataset.classes)
    }

    # ------------------------------------
    # 5. Create trainer
    trainer = create_supervised_trainer(model,
                                        optimizer,
                                        DOCLoss(weight=weights),
                                        device=device)

    # ------------------------------------
    # 6. Create evaluator
    train_evaluator = create_supervised_evaluator(model,
                                                  metrics=train_metrics,
                                                  device=device)
    val_evaluator = create_supervised_evaluator(model,
                                                metrics=val_metrics,
                                                device=device)

    desc = 'ITERATION - loss: {:.4f}'
    pbar = tqdm(initial=0,
                leave=False,
                total=len(train_loader),
                desc=desc.format(0))

    # ------------------------------------
    # 7. Create event hooks

    # Update process bar on each iteration completed.
    @trainer.on(Events.ITERATION_COMPLETED)
    def log_training_loss(engine):
        log_interval = 1
        iter = (engine.state.iteration - 1) % len(train_loader) + 1
        if iter % log_interval == 0:
            pbar.desc = desc.format(engine.state.output)
            pbar.update(log_interval)

    # Compute metrics on train data on each epoch completed.
    @trainer.on(Events.EPOCH_COMPLETED)
    def log_training_results(engine):
        pbar.refresh()
        print('Checking on training set.')
        train_evaluator.run(train4val_loader)
        metrics = train_evaluator.state.metrics
        avg_accuracy = metrics['accuracy']
        avg_loss = metrics['loss']
        precision_recall = metrics['precision_recall']
        cmatrix = metrics['cmatrix']
        prompt = """
      Training Results - Epoch: {}
      Avg accuracy: {:.4f}
      Avg loss: {:.4f}
      precision_recall: \n{}
      confusion matrix: \n{}
      """.format(engine.state.epoch, avg_accuracy, avg_loss,
                 precision_recall['pretty'], cmatrix['pretty'])
        tqdm.write(prompt)
        logging.info('\n' + prompt)
        writer.add_text(os.environ['run-id'], prompt, engine.state.epoch)
        writer.add_scalars('Aggregate/Acc', {'Train Acc': avg_accuracy},
                           engine.state.epoch)
        writer.add_scalars('Aggregate/Loss', {'Train Loss': avg_loss},
                           engine.state.epoch)

    # Compute metrics on val data on each epoch completed.
    @trainer.on(Events.EPOCH_COMPLETED)
    def log_validation_results(engine):
        print('Checking on validation set.')
        val_evaluator.run(val_loader)
        metrics = val_evaluator.state.metrics
        avg_accuracy = metrics['accuracy']
        precision_recall = metrics['precision_recall']
        cmatrix = metrics['cmatrix']
        prompt = """
      Validating Results - Epoch: {}
      Avg accuracy: {:.4f}
      precision_recall: \n{}
      confusion matrix: \n{}
      """.format(engine.state.epoch, avg_accuracy, precision_recall['pretty'],
                 cmatrix['pretty'])
        tqdm.write(prompt)
        logging.info('\n' + prompt)
        writer.add_text(os.environ['run-id'], prompt, engine.state.epoch)
        writer.add_scalars('Aggregate/Acc', {'Val Acc': avg_accuracy},
                           engine.state.epoch)
        writer.add_scalars(
            'Aggregate/Score', {
                'Val avg precision': precision_recall['data'][0, -1],
                'Val avg recall': precision_recall['data'][1, -1]
            }, engine.state.epoch)
        pbar.n = pbar.last_print_n = 0

    # Save model ever N epoch.
    save_model_handler = ModelCheckpoint(os.environ['savedir'],
                                         '',
                                         save_interval=50,
                                         n_saved=2)
    trainer.add_event_handler(Events.EPOCH_COMPLETED, save_model_handler,
                              {'model': model})

    # Update learning-rate due to scheduler.
    trainer.add_event_handler(Events.EPOCH_STARTED, ignite_scheduler)

    # ------------------------------------
    # Run
    trainer.run(train_loader, max_epochs=epochs)
    pbar.close()
示例#29
0
def create_trainer(
    train_step,
    output_names,
    model,
    ema_model,
    optimizer,
    lr_scheduler,
    supervised_train_loader,
    test_loader,
    cfg,
    logger,
    cta=None,
    unsup_train_loader=None,
    cta_probe_loader=None,
):

    trainer = Engine(train_step)
    trainer.logger = logger

    output_path = os.getcwd()

    to_save = {
        "model": model,
        "ema_model": ema_model,
        "optimizer": optimizer,
        "trainer": trainer,
        "lr_scheduler": lr_scheduler,
    }
    if cta is not None:
        to_save["cta"] = cta

    common.setup_common_training_handlers(
        trainer,
        train_sampler=supervised_train_loader.sampler,
        to_save=to_save,
        save_every_iters=cfg.solver.checkpoint_every,
        output_path=output_path,
        output_names=output_names,
        lr_scheduler=lr_scheduler,
        with_pbars=False,
        clear_cuda_cache=False,
    )

    ProgressBar(persist=False).attach(
        trainer, metric_names="all", event_name=Events.ITERATION_COMPLETED
    )

    unsupervised_train_loader_iter = None
    if unsup_train_loader is not None:
        unsupervised_train_loader_iter = cycle(unsup_train_loader)

    cta_probe_loader_iter = None
    if cta_probe_loader is not None:
        cta_probe_loader_iter = cycle(cta_probe_loader)

    # Setup handler to prepare data batches
    @trainer.on(Events.ITERATION_STARTED)
    def prepare_batch(e):
        sup_batch = e.state.batch
        e.state.batch = {
            "sup_batch": sup_batch,
        }
        if unsupervised_train_loader_iter is not None:
            unsup_batch = next(unsupervised_train_loader_iter)
            e.state.batch["unsup_batch"] = unsup_batch

        if cta_probe_loader_iter is not None:
            cta_probe_batch = next(cta_probe_loader_iter)
            cta_probe_batch["policy"] = [
                deserialize(p) for p in cta_probe_batch["policy"]
            ]
            e.state.batch["cta_probe_batch"] = cta_probe_batch

    # Setup handler to update EMA model
    @trainer.on(Events.ITERATION_COMPLETED, cfg.ema_decay)
    def update_ema_model(ema_decay):
        # EMA on parametes
        for ema_param, param in zip(ema_model.parameters(), model.parameters()):
            ema_param.data.mul_(ema_decay).add_(param.data, alpha=1.0 - ema_decay)

    # Setup handlers for debugging
    if cfg.debug:

        @trainer.on(Events.STARTED | Events.ITERATION_COMPLETED(every=100))
        @idist.one_rank_only()
        def log_weights_norms():
            wn = []
            ema_wn = []
            for ema_param, param in zip(ema_model.parameters(), model.parameters()):
                wn.append(torch.mean(param.data))
                ema_wn.append(torch.mean(ema_param.data))

            msg = "\n\nWeights norms"
            msg += "\n- Raw model: {}".format(
                to_list_str(torch.tensor(wn[:10] + wn[-10:]))
            )
            msg += "\n- EMA model: {}\n".format(
                to_list_str(torch.tensor(ema_wn[:10] + ema_wn[-10:]))
            )
            logger.info(msg)

            rmn = []
            rvar = []
            ema_rmn = []
            ema_rvar = []
            for m1, m2 in zip(model.modules(), ema_model.modules()):
                if isinstance(m1, nn.BatchNorm2d) and isinstance(m2, nn.BatchNorm2d):
                    rmn.append(torch.mean(m1.running_mean))
                    rvar.append(torch.mean(m1.running_var))
                    ema_rmn.append(torch.mean(m2.running_mean))
                    ema_rvar.append(torch.mean(m2.running_var))

            msg = "\n\nBN buffers"
            msg += "\n- Raw mean: {}".format(to_list_str(torch.tensor(rmn[:10])))
            msg += "\n- Raw var: {}".format(to_list_str(torch.tensor(rvar[:10])))
            msg += "\n- EMA mean: {}".format(to_list_str(torch.tensor(ema_rmn[:10])))
            msg += "\n- EMA var: {}\n".format(to_list_str(torch.tensor(ema_rvar[:10])))
            logger.info(msg)

        # TODO: Need to inspect a bug
        # if idist.get_rank() == 0:
        #     from ignite.contrib.handlers import ProgressBar
        #
        #     profiler = BasicTimeProfiler()
        #     profiler.attach(trainer)
        #
        #     @trainer.on(Events.ITERATION_COMPLETED(every=200))
        #     def log_profiling(_):
        #         results = profiler.get_results()
        #         profiler.print_results(results)

    # Setup validation engine
    metrics = {
        "accuracy": Accuracy(),
    }

    if not (idist.has_xla_support and idist.backend() == idist.xla.XLA_TPU):
        metrics.update({
            "precision": Precision(average=False),
            "recall": Recall(average=False),
        })

    eval_kwargs = dict(
        metrics=metrics,
        prepare_batch=sup_prepare_batch,
        device=idist.device(),
        non_blocking=True,
    )

    evaluator = create_supervised_evaluator(model, **eval_kwargs)
    ema_evaluator = create_supervised_evaluator(ema_model, **eval_kwargs)

    def log_results(epoch, max_epochs, metrics, ema_metrics):
        msg1 = "\n".join(
            ["\t{:16s}: {}".format(k, to_list_str(v)) for k, v in metrics.items()]
        )
        msg2 = "\n".join(
            ["\t{:16s}: {}".format(k, to_list_str(v)) for k, v in ema_metrics.items()]
        )
        logger.info(
            "\nEpoch {}/{}\nRaw:\n{}\nEMA:\n{}\n".format(epoch, max_epochs, msg1, msg2)
        )
        if cta is not None:
            logger.info("\n" + stats(cta))

    @trainer.on(
        Events.EPOCH_COMPLETED(every=cfg.solver.validate_every)
        | Events.STARTED
        | Events.COMPLETED
    )
    def run_evaluation():
        evaluator.run(test_loader)
        ema_evaluator.run(test_loader)
        log_results(
            trainer.state.epoch,
            trainer.state.max_epochs,
            evaluator.state.metrics,
            ema_evaluator.state.metrics,
        )

    # setup TB logging
    if idist.get_rank() == 0:
        tb_logger = common.setup_tb_logging(
            output_path,
            trainer,
            optimizers=optimizer,
            evaluators={"validation": evaluator, "ema validation": ema_evaluator},
            log_every_iters=15,
        )
        if cfg.online_exp_tracking.wandb:
            from ignite.contrib.handlers import WandBLogger

            wb_dir = Path("/tmp/output-fixmatch-wandb")
            if not wb_dir.exists():
                wb_dir.mkdir()

            _ = WandBLogger(
                project="fixmatch-pytorch",
                name=cfg.name,
                config=cfg,
                sync_tensorboard=True,
                dir=wb_dir.as_posix(),
                reinit=True,
            )

    resume_from = cfg.solver.resume_from
    if resume_from is not None:
        resume_from = list(Path(resume_from).rglob("training_checkpoint*.pt*"))
        if len(resume_from) > 0:
            # get latest
            checkpoint_fp = max(resume_from, key=lambda p: p.stat().st_mtime)
            assert checkpoint_fp.exists(), "Checkpoint '{}' is not found".format(
                checkpoint_fp.as_posix()
            )
            logger.info("Resume from a checkpoint: {}".format(checkpoint_fp.as_posix()))
            checkpoint = torch.load(checkpoint_fp.as_posix())
            Checkpoint.load_objects(to_load=to_save, checkpoint=checkpoint)

    @trainer.on(Events.COMPLETED)
    def release_all_resources():
        nonlocal unsupervised_train_loader_iter, cta_probe_loader_iter

        if idist.get_rank() == 0:
            tb_logger.close()

        if unsupervised_train_loader_iter is not None:
            unsupervised_train_loader_iter = None

        if cta_probe_loader_iter is not None:
            cta_probe_loader_iter = None

    return trainer
def run(tb, vb, lr, epochs, writer):
    device = os.environ['main-device']
    logging.info('Training program start!')
    logging.info('Configuration:')
    logging.info('\n' + json.dumps(INFO, indent=2))

    # ------------------------------------
    # 1. Define dataloader
    train_loader, train4val_loader, val_loader, num_of_images, mapping = get_dataloaders(
        tb, vb)
    weights = (1 / num_of_images) / ((1 / num_of_images).sum().item())
    weights = weights.to(device=device)

    # ------------------------------------
    # 2. Define model
    model = EfficientNet.from_pretrained(
        'efficientnet-b3', num_classes=INFO['dataset-info']['num-of-classes'])
    model = carrier(model)

    # ------------------------------------
    # 3. Define optimizer
    optimizer = optim.SGD(model.parameters(), lr=lr, momentum=0.9)
    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=200)
    ignite_scheduler = LRScheduler(scheduler)

    # ------------------------------------
    # 4. Define metrics
    train_metrics = {
        'accuracy':
        Accuracy(),
        'loss':
        Loss(nn.CrossEntropyLoss(weight=weights)),
        'precision_recall':
        MetricsLambda(PrecisionRecallTable, Precision(), Recall(),
                      train_loader.dataset.classes),
        'cmatrix':
        MetricsLambda(CMatrixTable,
                      ConfusionMatrix(INFO['dataset-info']['num-of-classes']),
                      train_loader.dataset.classes)
    }

    def val_pred_transform(output):
        y_pred, y = output
        new_y_pred = torch.zeros(
            (y_pred.shape[0],
             len(INFO['dataset-info']['known-classes']) + 1)).to(device=device)
        for c in range(y_pred.shape[1]):
            if c == 0:
                new_y_pred[:, mapping[c]] += y_pred[:, c]
            elif mapping[c] == val_loader.dataset.class_to_idx['UNKNOWN']:
                new_y_pred[:, mapping[c]] = torch.where(
                    new_y_pred[:, mapping[c]] > y_pred[:, c],
                    new_y_pred[:, mapping[c]], y_pred[:, c])
            else:
                new_y_pred[:, mapping[c]] += y_pred[:, c]
        return new_y_pred, y

    val_metrics = {
        'accuracy':
        Accuracy(val_pred_transform),
        'precision_recall':
        MetricsLambda(PrecisionRecallTable, Precision(val_pred_transform),
                      Recall(val_pred_transform), val_loader.dataset.classes),
        'cmatrix':
        MetricsLambda(
            CMatrixTable,
            ConfusionMatrix(len(INFO['dataset-info']['known-classes']) + 1,
                            output_transform=val_pred_transform),
            val_loader.dataset.classes)
    }

    # ------------------------------------
    # 5. Create trainer
    trainer = create_supervised_trainer(model,
                                        optimizer,
                                        nn.CrossEntropyLoss(weight=weights),
                                        device=device)

    # ------------------------------------
    # 6. Create evaluator
    train_evaluator = create_supervised_evaluator(model,
                                                  metrics=train_metrics,
                                                  device=device)
    val_evaluator = create_supervised_evaluator(model,
                                                metrics=val_metrics,
                                                device=device)

    desc = 'ITERATION - loss: {:.4f}'
    pbar = tqdm(initial=0,
                leave=False,
                total=len(train_loader),
                desc=desc.format(0))

    # ------------------------------------
    # 7. Create event hooks
    @trainer.on(Events.ITERATION_COMPLETED)
    def log_training_loss(engine):
        log_interval = 1
        iter = (engine.state.iteration - 1) % len(train_loader) + 1
        if iter % log_interval == 0:
            pbar.desc = desc.format(engine.state.output)
            pbar.update(log_interval)

    @trainer.on(Events.EPOCH_COMPLETED)
    def log_training_results(engine):
        pbar.refresh()
        print('Checking on training set.')
        train_evaluator.run(train4val_loader)
        metrics = train_evaluator.state.metrics
        avg_accuracy = metrics['accuracy']
        avg_loss = metrics['loss']
        precision_recall = metrics['precision_recall']
        cmatrix = metrics['cmatrix']
        prompt = """
      <Training> Results - Epoch: {}
      Avg accuracy: {:.4f}
      Avg loss: {:.4f}
      precision_recall: \n{}
      confusion matrix: \n{}
      """.format(engine.state.epoch, avg_accuracy, avg_loss,
                 precision_recall['pretty'], cmatrix['pretty'])
        tqdm.write(prompt)
        logging.info('\n' + prompt)
        writer.add_text(os.environ['run-id'], prompt, engine.state.epoch)
        writer.add_scalars('Aggregate/Acc', {'Train Acc': avg_accuracy},
                           engine.state.epoch)
        writer.add_scalars('Aggregate/Loss', {'Train Loss': avg_loss},
                           engine.state.epoch)

    @trainer.on(Events.EPOCH_COMPLETED)
    def log_validation_results(engine):
        print('Checking on validation set.')
        val_evaluator.run(val_loader)
        metrics = val_evaluator.state.metrics
        avg_accuracy = metrics['accuracy']
        precision_recall = metrics['precision_recall']
        cmatrix = metrics['cmatrix']
        prompt = """
      <Validating> Results - Epoch: {}
      Avg accuracy: {:.4f}
      precision_recall: \n{}
      confusion matrix: \n{}
      """.format(engine.state.epoch, avg_accuracy, precision_recall['pretty'],
                 cmatrix['pretty'])
        tqdm.write(prompt)
        logging.info('\n' + prompt)
        writer.add_text(os.environ['run-id'], prompt, engine.state.epoch)
        writer.add_scalars('Aggregate/Acc', {'Val Acc': avg_accuracy},
                           engine.state.epoch)
        writer.add_scalars(
            'Aggregate/Score', {
                'Val avg precision': precision_recall['data'][0, -1],
                'Val avg recall': precision_recall['data'][1, -1]
            }, engine.state.epoch)
        pbar.n = pbar.last_print_n = 0

    trainer.add_event_handler(Events.EPOCH_STARTED, ignite_scheduler)

    # ------------------------------------
    # Run
    trainer.run(train_loader, max_epochs=epochs)
    pbar.close()