示例#1
0
def create_eval_engine(model, device):

    process_function = get_process_function(model, device)

    eval_engine = Engine(process_function)

    accuracy = Accuracy()
    accuracy.attach(eval_engine, "accuracy")
    recall = Recall(average=False)
    recall.attach(eval_engine, "recall")
    precision = Precision(average=False)
    precision.attach(eval_engine, "precision")
    f1 = (precision * recall * 2 / (precision + recall))
    f1.attach(eval_engine, "f1")
    f2 = (precision * recall * 5 / ((4 * precision) + recall))
    f2.attach(eval_engine, "f2")

    def Fbeta(r, p, beta):
        return torch.mean(
            (1 + beta**2) * p * r / (beta**2 * p + r + 1e-20)).item()

    avg_f1 = MetricsLambda(Fbeta, recall, precision, 1)
    avg_f1.attach(eval_engine, "average f1")
    avg_f2 = MetricsLambda(Fbeta, recall, precision, 2)
    avg_f2.attach(eval_engine, "average f2")
    avg_recall = Recall(average=True)
    avg_recall.attach(eval_engine, "average recall")
    avg_precision = Precision(average=True)
    avg_precision.attach(eval_engine, "average precision")

    return eval_engine
示例#2
0
    def create_generalized_dice_metric(self, cm: ConfusionMatrix,
                                       weight: torch.Tensor):
        """
        Computes the Sørensen–Dice Coefficient (https://en.wikipedia.org/wiki/Sørensen–Dice_coefficient)
        Args:
            cm (:obj:`ignite.metrics.ConfusionMatrix`): A confusion matrix representing the classification of data.
            weight (:obj:`torch.Tensor`): A weight vector which length equals to the number of classes.
        Returns:
            ignite.Metric: The Generalized Dice Coefficient Metric object.
        """

        # Increase floating point precision
        cm = cm.type(torch.float64)
        dice = 2 * (cm.diag() * weight) / ((
            (cm.sum(dim=1) + cm.sum(dim=0)) * weight) + EPSILON)

        if self._ignore_index != -100:

            def remove_index(dice_vector):
                try:
                    indices = list(range(len(dice_vector)))
                    indices.remove(self._ignore_index)
                    return dice_vector[indices]
                except ValueError as e:
                    raise IndexError(
                        "'ignore_index' must be non-negative, and lower than the number of classes in confusion matrix, but {} was given. "
                        .format(self._ignore_index))

            return MetricsLambda(remove_index, dice)
        else:
            return dice
示例#3
0
def test_metrics_lambda_reset():
    m0 = ListGatherMetric(0)
    m1 = ListGatherMetric(1)
    m2 = ListGatherMetric(2)
    m0.update([1, 10, 100])
    m1.update([1, 10, 100])
    m2.update([1, 10, 100])

    def fn(x, y, z, t):
        return 1

    m = MetricsLambda(fn, m0, m1, z=m2, t=0)

    # initiating a new instance of MetricsLambda must reset
    # its argument metrics
    assert m0.list_ is None
    assert m1.list_ is None
    assert m2.list_ is None

    m0.update([1, 10, 100])
    m1.update([1, 10, 100])
    m2.update([1, 10, 100])
    m.reset()
    assert m0.list_ is None
    assert m1.list_ is None
    assert m2.list_ is None
示例#4
0
def test_state_metrics_ingredients_not_attached():

    y_pred = torch.randint(0, 2, size=(15, 10, 4)).float()
    y = torch.randint(0, 2, size=(15, 10, 4)).long()

    def update_fn(engine, batch):
        y_pred, y = batch
        return y_pred, y

    evaluator = Engine(update_fn)

    precision = Precision(average=False)
    recall = Recall(average=False)
    F1 = precision * recall * 2 / (precision + recall + 1e-20)
    F1 = MetricsLambda(lambda t: torch.mean(t).item(), F1)

    F1.attach(evaluator, "F1")

    def data(y_pred, y):
        for i in range(y_pred.shape[0]):
            yield (y_pred[i], y[i])

    d = data(y_pred, y)
    state = evaluator.run(d, max_epochs=1)

    assert set(state.metrics.keys()) == set(["F1"])
示例#5
0
def IoU(cm, ignore_background=True):
    assert isinstance(cm, ConfusionMatrix)
    iou = cm.diag() / (cm.sum(dim=1) + cm.sum(dim=0) - cm.diag())
    if ignore_background:
        return MetricsLambda(lambda res: res[1:], iou)
    else:
        return iou
示例#6
0
def iou_pytorch(cm, ignore_index=None):
    if not isinstance(cm, ConfusionMatrixPytorch):
        raise TypeError("Argument cm should be instance of ConfusionMatrix, "
                        "but given {}".format(type(cm)))

    if ignore_index is not None:
        if (not (isinstance(ignore_index, numbers.Integral)
                 and 0 <= ignore_index < cm.num_classes)):
            raise ValueError("ignore_index should be non-negative integer, "
                             "but given {}".format(ignore_index))

    # Increase floating point precision and pass to CPU
    cm = cm.type(torch.DoubleTensor)
    iou = cm.diag() / (cm.sum(dim=1) + cm.sum(dim=0) - cm.diag() + 1e-15)
    if ignore_index is not None:

        def ignore_index_fn(iou_vector):
            if ignore_index >= len(iou_vector):
                raise ValueError("ignore_index {} is larger than the length "
                                 "of IoU vector {}".format(
                                     ignore_index, len(iou_vector)))
            indices = list(range(len(iou_vector)))
            indices.remove(ignore_index)
            return iou_vector[indices]

        return MetricsLambda(ignore_index_fn, iou)
    else:
        return iou
示例#7
0
def test_integration():
    np.random.seed(1)

    n_iters = 10
    batch_size = 10
    n_classes = 10

    y_true = np.arange(0, n_iters * batch_size) % n_classes
    y_pred = 0.2 * np.random.rand(n_iters * batch_size, n_classes)
    for i in range(n_iters * batch_size):
        if np.random.rand() > 0.4:
            y_pred[i, y_true[i]] = 1.0
        else:
            j = np.random.randint(0, n_classes)
            y_pred[i, j] = 0.7

    y_true_batch_values = iter(y_true.reshape(n_iters, batch_size))
    y_pred_batch_values = iter(y_pred.reshape(n_iters, batch_size, n_classes))

    def update_fn(engine, batch):
        y_true_batch = next(y_true_batch_values)
        y_pred_batch = next(y_pred_batch_values)
        return torch.from_numpy(y_pred_batch), torch.from_numpy(y_true_batch)

    evaluator = Engine(update_fn)

    precision = Precision(average=False)
    recall = Recall(average=False)

    def Fbeta(r, p, beta):
        return torch.mean((1 + beta**2) * p * r / (beta**2 * p + r)).item()

    F1 = MetricsLambda(Fbeta, recall, precision, 1)

    precision.attach(evaluator, "precision")
    recall.attach(evaluator, "recall")
    F1.attach(evaluator, "f1")

    data = list(range(n_iters))
    state = evaluator.run(data, max_epochs=1)

    precision_true = precision_score(y_true,
                                     np.argmax(y_pred, axis=-1),
                                     average=None)
    recall_true = recall_score(y_true,
                               np.argmax(y_pred, axis=-1),
                               average=None)
    f1_true = f1_score(y_true, np.argmax(y_pred, axis=-1), average='macro')

    precision = state.metrics['precision'].numpy()
    recall = state.metrics['recall'].numpy()

    assert precision_true == approx(precision), "{} vs {}".format(
        precision_true, precision)
    assert recall_true == approx(recall), "{} vs {}".format(
        recall_true, recall)
    assert f1_true == approx(state.metrics['f1']), "{} vs {}".format(
        f1_true, state.metrics['f1'])
示例#8
0
    def forward(self, inputs: torch.Tensor, targets: torch.Tensor):
        """
           Computes the Tversky loss based on https://arxiv.org/pdf/1706.05721.pdf
           Note that PyTorch optimizers minimize a loss. In this case, we would like to maximize the dice loss so we
           return the negated dice loss.
           Args:
               inputs (:obj:`torch.Tensor`) : A tensor of shape (B, C, ..). The model prediction on which the loss has to
                be computed.
               targets (:obj:`torch.Tensor`) : A tensor of shape (B, C, ..). The ground truth.
           Returns:
               :obj:`torch.Tensor`: The Tversky loss for each class or reduced according to reduction method.
           """
        if not inputs.size() == targets.size():
            raise ValueError(
                "'Inputs' and 'Targets' must have the same shape.")

        inputs = flatten(inputs)
        targets = flatten(targets).float()
        ones = torch.Tensor().new_ones((inputs.size()),
                                       dtype=torch.float,
                                       device=inputs.device)

        P_G = (inputs * targets).sum(-1)
        if self.weight is not None:
            P_G = self.weight * P_G

        P_NG = (inputs * (ones - targets)).sum(-1)
        NP_G = ((ones - inputs) * targets).sum(-1)

        ones = torch.Tensor().new_ones((inputs.size(0), ),
                                       dtype=torch.float,
                                       device=inputs.device)
        tversky = P_G / (P_G + self._alpha * P_NG + self._beta * NP_G +
                         EPSILON)

        tversky_loss = ones - tversky

        if self._ignore_index != -100:

            def ignore_index_fn(tversky_vector):
                try:
                    indices = list(range(len(tversky_vector)))
                    indices.remove(self._ignore_index)
                    return tversky_vector[indices]
                except ValueError as e:
                    raise IndexError(
                        "'ignore_index' must be non-negative, and lower than the number of classes in confusion matrix, but {} was given. "
                        .format(self._ignore_index))

            tversky_loss = MetricsLambda(ignore_index_fn,
                                         tversky_loss).compute()

        if self.reduction == "mean":
            tversky_loss = tversky_loss.mean()

        return tversky_loss
示例#9
0
def test_metrics_lambda():
    m0 = ListGatherMetric(0)
    m1 = ListGatherMetric(1)
    m2 = ListGatherMetric(2)

    def process_function(engine, data):
        return data

    engine = Engine(process_function)

    def plus(this, other):
        return this + other

    m0_plus_m1 = MetricsLambda(plus, m0, other=m1)
    m2_plus_2 = MetricsLambda(plus, m2, 2)
    m0_plus_m1.attach(engine, "m0_plus_m1")
    m2_plus_2.attach(engine, "m2_plus_2")

    engine.run([[1, 10, 100]])
    assert engine.state.metrics["m0_plus_m1"] == 11
    assert engine.state.metrics["m2_plus_2"] == 102
    engine.run([[2, 20, 200]])
    assert engine.state.metrics["m0_plus_m1"] == 22
    assert engine.state.metrics["m2_plus_2"] == 202

    # metrics are partially attached
    assert not m0.is_attached(engine)
    assert not m1.is_attached(engine)
    assert not m2.is_attached(engine)

    # a dependency is detached
    m0.detach(engine)
    # so the lambda metric is too
    assert not m0_plus_m1.is_attached(engine)
    # the lambda is attached again
    m0_plus_m1.attach(engine, "m0_plus_m1")
    assert m0_plus_m1.is_attached(engine)
    # metrics are always partially attached
    assert not m0.is_attached(engine)
    m0_plus_m1.detach(engine)
    assert not m0_plus_m1.is_attached(engine)
    # detached (and no longer partially attached)
    assert not m0.is_attached(engine)
示例#10
0
def run(args):
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    print(colored("Using device: ", "white") + colored(device, "green"))

    print(colored("Initializing test dataset...", color="white"))
    _, _, test_dataset = get_datasets(args.data)
    test_loader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=True)
    model_factory = {
        'fcn-resnet50': lambda: torchvision.models.segmentation.fcn_resnet50(num_classes=NUM_CLASSES,
                                                                             pretrained=False),
        'fcn-resnet101': lambda: torchvision.models.segmentation.fcn_resnet101(num_classes=NUM_CLASSES,
                                                                               pretrained=False),
        'deeplab-resnet50': lambda: torchvision.models.segmentation.deeplabv3_resnet50(num_classes=NUM_CLASSES,
                                                                                       pretrained=False),
        'deeplab-resnet101': lambda: torchvision.models.segmentation.deeplabv3_resnet101(num_classes=NUM_CLASSES,
                                                                                         pretrained=False)
    }
    model = model_factory[args.model]()
    model.load_state_dict(torch.load(args.weights))
    model.to(device)

    cm_metric = ConfusionMatrix(num_classes=NUM_CLASSES, output_transform=output_transform_seg)
    metrics = {'dice': MetricsLambda(lambda x: torch.mean(x).item(), DiceCoefficient(cm_metric)),
               'iou': MetricsLambda(lambda x: torch.mean(x).item(), IoU(cm_metric)),
               'dice_background': MetricsLambda(lambda x: x[0].item(), DiceCoefficient(cm_metric)),
               'dice_head': MetricsLambda(lambda x: x[1].item(), DiceCoefficient(cm_metric)),
               'dice_mid': MetricsLambda(lambda x: x[2].item(), DiceCoefficient(cm_metric)),
               'dice_tail': MetricsLambda(lambda x: x[3].item(), DiceCoefficient(cm_metric)),
               'iou_background': MetricsLambda(lambda x: x[0].item(), IoU(cm_metric)),
               'iou_head': MetricsLambda(lambda x: x[1].item(), IoU(cm_metric)),
               'iou_mid': MetricsLambda(lambda x: x[2].item(), IoU(cm_metric)),
               'iou_tail': MetricsLambda(lambda x: x[3].item(), IoU(cm_metric))
               }

    print(colored("Evaluating...\n", color="white"))
    test_evaluator = create_supervised_evaluator(model, metrics=metrics, device=device, prepare_batch=prepare_batch)

    @test_evaluator.on(Events.COMPLETED)
    def log_training_loss(engine):
        for k, v in engine.state.metrics.items():
            print(f"{k}: {v:.4f}")

    test_evaluator.run(test_loader)
示例#11
0
def test_metrics_lambda():
    m0 = ListGatherMetric(0)
    m1 = ListGatherMetric(1)
    m2 = ListGatherMetric(2)

    def process_function(engine, data):
        return data

    engine = Engine(process_function)
    m0_plus_m1 = MetricsLambda(lambda x, y: x + y, m0, m1)
    m2_plus_2 = MetricsLambda(lambda x, y: x + y, m2, 2)
    m0_plus_m1.attach(engine, 'm0_plus_m1')
    m2_plus_2.attach(engine, 'm2_plus_2')

    engine.run([[1, 10, 100]])
    assert engine.state.metrics['m0_plus_m1'] == 11
    assert engine.state.metrics['m2_plus_2'] == 102
    engine.run([[2, 20, 200]])
    assert engine.state.metrics['m0_plus_m1'] == 22
    assert engine.state.metrics['m2_plus_2'] == 202
示例#12
0
def create_evaluator(model, criterion, cfg):
    def _validation_step(_, batch):
        model.eval()
        with torch.no_grad():
            x, y = batch_to_tensor(batch, cfg)
            x, y = x.to(cfg.device), y.to(cfg.device)

            y_pred, hidden = model(x)
            loss = criterion(y_pred, y)

            if cfg.multi_label:
                y_pred = (y_pred > 0).float()

            return y_pred, y, loss, hidden

    evaluator = Engine(_validation_step)

    accuracy = Accuracy(lambda x: x[0:2], is_multilabel=cfg.multi_label)
    accuracy.attach(evaluator, "acc")

    precision = Precision(lambda x: x[0:2],
                          average=False,
                          is_multilabel=cfg.multi_label)
    precision.attach(evaluator, 'precision')
    MetricsLambda(lambda t: torch.mean(t).item(),
                  precision).attach(evaluator, "MP")

    recall = Recall(lambda x: x[0:2],
                    average=False,
                    is_multilabel=cfg.multi_label)
    recall.attach(evaluator, 'recall')
    MetricsLambda(lambda t: torch.mean(t).item(),
                  recall).attach(evaluator, "MR")

    F1 = 2. * precision * recall / (precision + recall + 1e-20)
    f1 = MetricsLambda(lambda t: torch.mean(t).item(), F1)
    f1.attach(evaluator, "F1")

    Average(lambda x: x[2]).attach(evaluator, 'loss')

    return evaluator
示例#13
0
def IoU(cm: ConfusionMatrix,
        ignore_index: Optional[int] = None) -> MetricsLambda:
    """Calculates Intersection over Union using :class:`~ignite.metrics.ConfusionMatrix` metric.

    Args:
        cm (ConfusionMatrix): instance of confusion matrix metric
        ignore_index (int, optional): index to ignore, e.g. background index

    Returns:
        MetricsLambda

    Examples:

    .. code-block:: python

        train_evaluator = ...

        cm = ConfusionMatrix(num_classes=num_classes)
        IoU(cm, ignore_index=0).attach(train_evaluator, 'IoU')

        state = train_evaluator.run(train_dataset)
        # state.metrics['IoU'] -> tensor of shape (num_classes - 1, )

    """
    if not isinstance(cm, ConfusionMatrix):
        raise TypeError(
            "Argument cm should be instance of ConfusionMatrix, but given {}".
            format(type(cm)))

    if ignore_index is not None:
        if not (isinstance(ignore_index, numbers.Integral)
                and 0 <= ignore_index < cm.num_classes):
            raise ValueError(
                "ignore_index should be non-negative integer, but given {}".
                format(ignore_index))

    # Increase floating point precision and pass to CPU
    cm = cm.type(torch.DoubleTensor)
    iou = cm.diag() / (cm.sum(dim=1) + cm.sum(dim=0) - cm.diag() + 1e-15)
    if ignore_index is not None:

        def ignore_index_fn(iou_vector):
            if ignore_index >= len(iou_vector):
                raise ValueError(
                    "ignore_index {} is larger than the length of IoU vector {}"
                    .format(ignore_index, len(iou_vector)))
            indices = list(range(len(iou_vector)))
            indices.remove(ignore_index)
            return iou_vector[indices]

        return MetricsLambda(ignore_index_fn, iou)
    else:
        return iou
示例#14
0
def test_metrics_lambda_update_and_attach_together():

    y_pred = torch.randint(0, 2, size=(15, 10, 4)).float()
    y = torch.randint(0, 2, size=(15, 10, 4)).long()

    def update_fn(engine, batch):
        y_pred, y = batch
        return y_pred, y

    engine = Engine(update_fn)

    precision = Precision(average=False)
    recall = Recall(average=False)

    def Fbeta(r, p, beta):
        return torch.mean((1 + beta**2) * p * r / (beta**2 * p + r)).item()

    F1 = MetricsLambda(Fbeta, recall, precision, 1)

    F1.attach(engine, "f1")
    with pytest.raises(
            ValueError,
            match=r"MetricsLambda is already attached to an engine"):
        F1.update((y_pred, y))

    y_pred = torch.randint(0, 2, size=(15, 10, 4)).float()
    y = torch.randint(0, 2, size=(15, 10, 4)).long()

    F1 = MetricsLambda(Fbeta, recall, precision, 1)
    F1.update((y_pred, y))

    engine = Engine(update_fn)

    with pytest.raises(ValueError,
                       match=r"The underlying metrics are already updated"):
        F1.attach(engine, "f1")

    F1.reset()
    F1.attach(engine, "f1")
示例#15
0
    def forward(self, inputs: torch.Tensor, targets: torch.Tensor):
        """
        Computes the Sørensen–Dice loss.
        Note that PyTorch optimizers minimize a loss. In this case, we would like to maximize the dice loss so we
        return the negated dice loss.
        Args:
            inputs (:obj:`torch.Tensor`) : A tensor of shape (B, C, ..). The model prediction on which the loss has to
             be computed.
            targets (:obj:`torch.Tensor`) : A tensor of shape (B, C, ..). The ground truth.
        Returns:
            :obj:`torch.Tensor`: The Sørensen–Dice loss for each class or reduced according to reduction method.
        """
        if not inputs.size() == targets.size():
            raise ValueError(
                "'Inputs' and 'Targets' must have the same shape.")

        inputs = flatten(inputs)
        targets = flatten(targets).float()

        # Compute per channel Dice Coefficient
        intersection = (inputs * targets).sum(-1)

        if self.weight is not None:
            intersection = self.weight * intersection

        cardinality = (inputs + targets).sum(-1)

        ones = torch.Tensor().new_ones((inputs.size(0), ),
                                       dtype=torch.float,
                                       device=inputs.device)

        dice = ones - (2.0 * intersection / cardinality.clamp(min=EPSILON))

        if self._ignore_index != -100:

            def ignore_index_fn(dice_vector):
                try:
                    indices = list(range(len(dice_vector)))
                    indices.remove(self._ignore_index)
                    return dice_vector[indices]
                except ValueError as e:
                    raise IndexError(
                        "'ignore_index' must be non-negative, and lower than the number of classes in confusion matrix, but {} was given. "
                        .format(self._ignore_index))

            dice = MetricsLambda(ignore_index_fn, dice).compute()

        if self.reduction == "mean":
            dice = dice.mean()

        return dice
示例#16
0
    def create_dice_metric(self, cm: ConfusionMatrix):
        """
        Computes the Sørensen–Dice Coefficient (https://en.wikipedia.org/wiki/Sørensen–Dice_coefficient)
        Args:
            cm (:obj:`ignite.metrics.ConfusionMatrix`): A confusion matrix representing the classification of data.
        Returns:
            array or float: The Sørensen–Dice Coefficient for each class or the mean Sørensen–Dice Coefficient.
        """
        # Increase floating point precision
        cm = cm.type(torch.float64)
        dice = 2 * cm.diag() / (cm.sum(dim=1) + cm.sum(dim=0) + EPSILON)

        if self._ignore_index != -100:

            def remove_index(dice_vector):
                try:
                    indices = list(range(len(dice_vector)))
                    indices.remove(self._ignore_index)
                    return dice_vector[indices]
                except ValueError as e:
                    raise IndexError(
                        "'ignore_index' must be non-negative, and lower than the number of classes in confusion matrix, but {} was given. "
                        .format(self._ignore_index))

            dice = MetricsLambda(remove_index, dice)

        if self._weight is not None:

            def multiply_weights(dice_vector):
                return self._weight * dice_vector

            dice = MetricsLambda(multiply_weights, dice)

        if self._reduction == "mean":
            dice = dice.mean()

        return dice
示例#17
0
 def metrics(self):
     """ Metrics.
         nll     negative log-likelihood
         acc     classification accuracy for next-response selection
         ppl     perplexity
     """
     metrics = {
         "nll":
         Loss(torch.nn.CrossEntropyLoss(ignore_index=-1),
              output_transform=lambda x: (x[0][0], x[1][0])),
         "acc":
         Accuracy(output_transform=lambda x: (x[0][1], x[1][1]))
     }
     metrics["ppl"] = MetricsLambda(math.exp, metrics["nll"])
     return metrics
 def __load_metrics(self):
     precision = Precision(average=False)
     recall = Recall(average=False)
     F1 = precision * recall * 2 / (precision + recall + 1e-20)
     F1 = MetricsLambda(lambda t: torch.mean(t).item(), F1)
     confusion_matrix = ConfusionMatrix(self.n_class, average="recall")
     # TODO: Add metric by patient
     self.metrics = {
         'accuracy': Accuracy(),
         "f1": F1,
         "confusion_matrix": confusion_matrix,
         "precision": precision.mean(),
         "recall": recall.mean(),
         'loss': Loss(self.loss)
     }
示例#19
0
def test_metrics_lambda():
    m0 = ListGatherMetric(0)
    m1 = ListGatherMetric(1)
    m2 = ListGatherMetric(2)

    def process_function(engine, data):
        return data

    engine = Engine(process_function)

    def plus(this, other):
        return this + other

    m0_plus_m1 = MetricsLambda(plus, m0, other=m1)
    m2_plus_2 = MetricsLambda(plus, m2, 2)
    m0_plus_m1.attach(engine, "m0_plus_m1")
    m2_plus_2.attach(engine, "m2_plus_2")

    engine.run([[1, 10, 100]])
    assert engine.state.metrics["m0_plus_m1"] == 11
    assert engine.state.metrics["m2_plus_2"] == 102
    engine.run([[2, 20, 200]])
    assert engine.state.metrics["m0_plus_m1"] == 22
    assert engine.state.metrics["m2_plus_2"] == 202
示例#20
0
    def _test():
        y_true = np.arange(0, n_iters * batch_size * dist.get_world_size()) % n_classes
        y_pred = 0.2 * np.random.rand(
            n_iters * batch_size * dist.get_world_size(), n_classes
        )
        for i in range(n_iters * batch_size * dist.get_world_size()):
            if np.random.rand() > 0.4:
                y_pred[i, y_true[i]] = 1.0
            else:
                j = np.random.randint(0, n_classes)
                y_pred[i, j] = 0.7

        y_true = y_true.reshape(n_iters * dist.get_world_size(), batch_size)
        y_pred = y_pred.reshape(n_iters * dist.get_world_size(), batch_size, n_classes)

        def update_fn(engine, i):
            y_true_batch = y_true[i + rank * n_iters, ...]
            y_pred_batch = y_pred[i + rank * n_iters, ...]
            return torch.from_numpy(y_pred_batch), torch.from_numpy(y_true_batch)

        evaluator = Engine(update_fn)

        precision = Precision(average=False, device=device)
        recall = Recall(average=False, device=device)

        def Fbeta(r, p, beta):
            return torch.mean((1 + beta ** 2) * p * r / (beta ** 2 * p + r)).item()

        F1 = MetricsLambda(Fbeta, recall, precision, 1)
        F1.attach(evaluator, "f1")

        another_f1 = (
            (1.0 + precision * recall * 2 / (precision + recall + 1e-20)).mean().item()
        )
        another_f1.attach(evaluator, "ff1")

        data = list(range(n_iters))
        state = evaluator.run(data, max_epochs=1)

        assert "f1" in state.metrics
        assert "ff1" in state.metrics
        f1_true = f1_score(
            y_true.ravel(),
            np.argmax(y_pred.reshape(-1, n_classes), axis=-1),
            average="macro",
        )
        assert f1_true == approx(state.metrics["f1"])
        assert 1.0 + f1_true == approx(state.metrics["ff1"])
示例#21
0
    def train_model(self, n_epochs, train_loader, val_loader, eval_before_start=True):
        # Attach evaluation to trainer: we evaluate when we start the training and at the end of each epoch
        self.trainer.add_event_handler(Events.EPOCH_COMPLETED, lambda _: self.evaluator.run(val_loader))
        self.trainer.add_event_handler(Events.EPOCH_COMPLETED, lambda _: self.update_epoch())
        if eval_before_start:
            self.trainer.add_event_handler(Events.STARTED, lambda _: self.evaluator.run(val_loader))

        # Linearly decrease the learning rate from lr to zero
        scheduler = PiecewiseLinear(self.optimizer, "lr", [(0, self.lr), (n_epochs * len(train_loader), 0.0)])
        self.trainer.add_event_handler(Events.ITERATION_STARTED, scheduler)

        # Prepare metrics
        RunningAverage(output_transform=lambda x: x).attach(self.trainer, "loss")
        metrics = {"nll": Loss(torch.nn.CrossEntropyLoss(ignore_index=-1), output_transform=lambda x: (x[0][0], x[1][0])),
                   "accuracy": Accuracy(output_transform=lambda x: (x[0][1], x[1][1]))}
        metrics["average_ppl"] = MetricsLambda(math.exp, metrics["nll"])
        for name, metric in metrics.items():
            metric.attach(self.evaluator, name)

        # On the main process: add progress bar, tensorboard, checkpoints and save model
        pbar = ProgressBar(persist=True)
        pbar.attach(self.trainer, metric_names=["loss"])

        if not self.verbose:
            pbar_eval = ProgressBar(persist=False)
            pbar_eval.attach(self.evaluator)

        self.evaluator.add_event_handler(Events.STARTED, lambda _: self.logger.info(f'Beginning validation for epoch {self.epoch}...'))
        self.evaluator.add_event_handler(Events.COMPLETED, lambda _: pbar.log_message("Validation: %s" % pformat(self.evaluator.state.metrics)))

        self.tb_logger.attach(self.trainer, log_handler=OutputHandler(tag="training", metric_names=["loss"]), event_name=Events.ITERATION_COMPLETED)
        self.tb_logger.attach(self.trainer, log_handler=OptimizerParamsHandler(self.optimizer), event_name=Events.ITERATION_STARTED)
        self.tb_logger.attach(self.evaluator, log_handler=OutputHandler(tag="validation", metric_names=list(metrics.keys()), another_engine=self.trainer),
                              event_name=Events.EPOCH_COMPLETED)

        self.trainer.add_event_handler(Events.EPOCH_COMPLETED, self.checkpoint_handler,
                                       {'mymodel': getattr(self.model, 'module', self.model)})  # "getattr" takes care of distributed encapsulation

        # Run the training
        self.trainer.run(train_loader, max_epochs=n_epochs)

        # On the main process: close tensorboard logger and rename the last checkpoint (for easy re-loading with OpenAIGPTModel.from_pretrained method)
        if n_epochs > 0:
            os.rename(self.checkpoint_handler._saved[-1][1][-1], os.path.join(cfg.checkpoint_log_folder, self.name, WEIGHTS_NAME))
            self.tb_logger.close()
示例#22
0
def test_integration_ingredients_not_attached():
    np.random.seed(1)

    n_iters = 10
    batch_size = 10
    n_classes = 10

    y_true = np.arange(0, n_iters * batch_size, dtype="int64") % n_classes
    y_pred = 0.2 * np.random.rand(n_iters * batch_size, n_classes)
    for i in range(n_iters * batch_size):
        if np.random.rand() > 0.4:
            y_pred[i, y_true[i]] = 1.0
        else:
            j = np.random.randint(0, n_classes)
            y_pred[i, j] = 0.7

    y_true_batch_values = iter(y_true.reshape(n_iters, batch_size))
    y_pred_batch_values = iter(y_pred.reshape(n_iters, batch_size, n_classes))

    def update_fn(engine, batch):
        y_true_batch = next(y_true_batch_values)
        y_pred_batch = next(y_pred_batch_values)
        return torch.from_numpy(y_pred_batch), torch.from_numpy(y_true_batch)

    evaluator = Engine(update_fn)

    precision = Precision(average=False)
    recall = Recall(average=False)

    def Fbeta(r, p, beta):
        return torch.mean((1 + beta**2) * p * r / (beta**2 * p + r)).item()

    F1 = MetricsLambda(Fbeta, recall, precision, 1)
    F1.attach(evaluator, "f1")

    data = list(range(n_iters))
    state = evaluator.run(data, max_epochs=1)
    f1_true = f1_score(y_true, np.argmax(y_pred, axis=-1), average="macro")
    assert f1_true == approx(
        state.metrics["f1"]), f"{f1_true} vs {state.metrics['f1']}"
示例#23
0
def DiceCoefficient(cm: ConfusionMatrix,
                    ignore_index: Optional[int] = None) -> MetricsLambda:
    """Calculates Dice Coefficient for a given :class:`~ignite.metrics.ConfusionMatrix` metric.

    Args:
        cm (ConfusionMatrix): instance of confusion matrix metric
        ignore_index (int, optional): index to ignore, e.g. background index
    """

    if not isinstance(cm, ConfusionMatrix):
        raise TypeError(
            "Argument cm should be instance of ConfusionMatrix, but given {}".
            format(type(cm)))

    if ignore_index is not None:
        if not (isinstance(ignore_index, numbers.Integral)
                and 0 <= ignore_index < cm.num_classes):
            raise ValueError(
                "ignore_index should be non-negative integer, but given {}".
                format(ignore_index))

    # Increase floating point precision and pass to CPU
    cm = cm.type(torch.DoubleTensor)
    dice = 2.0 * cm.diag() / (cm.sum(dim=1) + cm.sum(dim=0) + 1e-15)

    if ignore_index is not None:

        def ignore_index_fn(dice_vector: torch.Tensor) -> torch.Tensor:
            if ignore_index >= len(dice_vector):
                raise ValueError(
                    "ignore_index {} is larger than the length of Dice vector {}"
                    .format(ignore_index, len(dice_vector)))
            indices = list(range(len(dice_vector)))
            indices.remove(ignore_index)
            return dice_vector[indices]

        return MetricsLambda(ignore_index_fn, dice)
    else:
        return dice
示例#24
0
def train():
    parser = ArgumentParser()
    parser.add_argument("--dataset_path", type=str, default="", help="Path or url of the dataset. If empty download from S3.")
    parser.add_argument("--dataset_cache", type=str, default='./dataset_cache', help="Path or url of the dataset cache")
    parser.add_argument("--model_checkpoint", type=str, default="openai-gpt", help="Path, url or short name of the model")
    parser.add_argument("--num_candidates", type=int, default=2, help="Number of candidates for training")
    parser.add_argument("--max_history", type=int, default=2, help="Number of previous exchanges to keep in history")
    parser.add_argument("--train_batch_size", type=int, default=4, help="Batch size for training")
    parser.add_argument("--valid_batch_size", type=int, default=4, help="Batch size for validation")
    parser.add_argument("--gradient_accumulation_steps", type=int, default=8, help="Accumulate gradients on several steps")
    parser.add_argument("--lr", type=float, default=6.25e-5, help="Learning rate")
    parser.add_argument("--lm_coef", type=float, default=1.0, help="LM loss coefficient")
    parser.add_argument("--mc_coef", type=float, default=1.0, help="Multiple-choice loss coefficient")
    parser.add_argument("--max_norm", type=float, default=1.0, help="Clipping gradient norm")
    parser.add_argument("--n_epochs", type=int, default=3, help="Number of training epochs")
    parser.add_argument("--personality_permutations", type=int, default=1, help="Number of permutations of personality sentences")
    parser.add_argument("--eval_before_start", action='store_true', help="If true start with a first evaluation before training")
    parser.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu", help="Device (cuda or cpu)")
    parser.add_argument("--fp16", type=str, default="", help="Set to O0, O1, O2 or O3 for fp16 training (see apex documentation)")
    parser.add_argument("--local_rank", type=int, default=-1, help="Local rank for distributed training (-1: not distributed)")
    args = parser.parse_args()

    # logging is set to INFO (resp. WARN) for main (resp. auxiliary) process. logger.info => log main process only, logger.warning => log all processes
    logging.basicConfig(level=logging.INFO if args.local_rank in [-1, 0] else logging.WARN)
    logger.warning("Running process %d", args.local_rank)  # This is a logger.warning: it will be printed by all distributed processes
    logger.info("Arguments: %s", pformat(args))

    # Initialize distributed training if needed
    args.distributed = (args.local_rank != -1)
    if args.distributed:
        torch.cuda.set_device(args.local_rank)
        args.device = torch.device("cuda", args.local_rank)
        torch.distributed.init_process_group(backend='nccl', init_method='env://')

    logger.info("Prepare tokenizer, pretrained model and optimizer - add special tokens for fine-tuning")
    tokenizer_class = GPT2Tokenizer if "gpt2" in args.model_checkpoint else OpenAIGPTTokenizer
    tokenizer = tokenizer_class.from_pretrained(args.model_checkpoint)
    model_class = GPT2LMHeadModel if "gpt2" in args.model_checkpoint else OpenAIGPTLMHeadModel
    model = model_class.from_pretrained(args.model_checkpoint)
    tokenizer.set_special_tokens(SPECIAL_TOKENS)
    model.set_num_special_tokens(len(SPECIAL_TOKENS))
    model.to(args.device)
    optimizer = OpenAIAdam(model.parameters(), lr=args.lr)

    # Prepare model for FP16 and distributed training if needed (order is important, distributed should be the last)
    if args.fp16:
        from apex import amp  # Apex is only required if we use fp16 training
        model, optimizer = amp.initialize(model, optimizer, opt_level=args.fp16)
    if args.distributed:
        model = DistributedDataParallel(model, device_ids=[args.local_rank], output_device=args.local_rank)

    logger.info("Prepare datasets")
    train_loader, val_loader, train_sampler, valid_sampler = get_data_loaders(args, tokenizer)

    # Training function and trainer
    def update(engine, batch):
        model.train()
        batch = tuple(input_tensor.to(args.device) for input_tensor in batch)
        lm_loss, mc_loss = model(*batch)
        loss = (lm_loss * args.lm_coef + mc_loss * args.mc_coef) / args.gradient_accumulation_steps
        if args.fp16:
            with amp.scale_loss(loss, optimizer) as scaled_loss:
                scaled_loss.backward()
            torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), args.max_norm)
        else:
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_norm)
        if engine.state.iteration % args.gradient_accumulation_steps == 0:
            optimizer.step()
            optimizer.zero_grad()
        return loss.item()
    trainer = Engine(update)

    # Evaluation function and evaluator (evaluator output is the input of the metrics)
    def inference(engine, batch):
        model.eval()
        with torch.no_grad():
            batch = tuple(input_tensor.to(args.device) for input_tensor in batch)
            input_ids, mc_token_ids, lm_labels, mc_labels, token_type_ids = batch
            logger.info(tokenizer.decode(input_ids[0, -1, :].tolist()))
            model_outputs = model(input_ids, mc_token_ids, token_type_ids=token_type_ids)
            lm_logits, mc_logits = model_outputs[0], model_outputs[1]  # So we can also use GPT2 outputs
            lm_logits_flat_shifted = lm_logits[..., :-1, :].contiguous().view(-1, lm_logits.size(-1))
            lm_labels_flat_shifted = lm_labels[..., 1:].contiguous().view(-1)
            return (lm_logits_flat_shifted, mc_logits), (lm_labels_flat_shifted, mc_labels)
    evaluator = Engine(inference)

    # Attach evaluation to trainer: we evaluate when we start the training and at the end of each epoch
    trainer.add_event_handler(Events.EPOCH_COMPLETED, lambda _: evaluator.run(val_loader))
    if args.n_epochs < 1:
        trainer.add_event_handler(Events.COMPLETED, lambda _: evaluator.run(val_loader))
    if args.eval_before_start:
        trainer.add_event_handler(Events.STARTED, lambda _: evaluator.run(val_loader))

    # Make sure distributed data samplers split the dataset nicely between the distributed processes
    if args.distributed:
        trainer.add_event_handler(Events.EPOCH_STARTED, lambda engine: train_sampler.set_epoch(engine.state.epoch))
        evaluator.add_event_handler(Events.EPOCH_STARTED, lambda engine: valid_sampler.set_epoch(engine.state.epoch))

    # Linearly decrease the learning rate from lr to zero
    scheduler = PiecewiseLinear(optimizer, "lr", [(0, args.lr), (args.n_epochs * len(train_loader), 0.0)])
    trainer.add_event_handler(Events.ITERATION_STARTED, scheduler)

    # Prepare metrics - note how we compute distributed metrics 
    RunningAverage(output_transform=lambda x: x).attach(trainer, "loss")
    metrics = {"nll": Loss(torch.nn.CrossEntropyLoss(ignore_index=-1), output_transform=lambda x: (x[0][0], x[1][0])),
               "accuracy": Accuracy(output_transform=lambda x: (x[0][1], x[1][1]))}
    metrics.update({"average_nll": MetricsLambda(average_distributed_scalar, metrics["nll"], args),
                    "average_accuracy": MetricsLambda(average_distributed_scalar, metrics["accuracy"], args)})
    metrics["average_ppl"] = MetricsLambda(math.exp, metrics["average_nll"])
    for name, metric in metrics.items():
        metric.attach(evaluator, name)

    # On the main process: add progress bar, tensorboard, checkpoints and save model, configuration and tokenizer before we start to train
    if args.local_rank in [-1, 0]:
        pbar = ProgressBar(persist=True)
        pbar.attach(trainer, metric_names=["loss"])
        evaluator.add_event_handler(Events.COMPLETED, lambda _: pbar.log_message("Validation: %s" % pformat(evaluator.state.metrics)))

        tb_logger = TensorboardLogger(log_dir=None)
        tb_logger.attach(trainer, log_handler=OutputHandler(tag="training", metric_names=["loss"]), event_name=Events.ITERATION_COMPLETED)
        tb_logger.attach(trainer, log_handler=OptimizerParamsHandler(optimizer), event_name=Events.ITERATION_STARTED)
        tb_logger.attach(evaluator, log_handler=OutputHandler(tag="validation", metric_names=list(metrics.keys()), another_engine=trainer), event_name=Events.EPOCH_COMPLETED)

        checkpoint_handler = ModelCheckpoint(tb_logger.writer.log_dir, 'checkpoint', save_interval=1, n_saved=3)
        trainer.add_event_handler(Events.EPOCH_COMPLETED, checkpoint_handler, {'mymodel': getattr(model, 'module', model)})  # "getattr" take care of distributed encapsulation

        torch.save(args, tb_logger.writer.log_dir + '/model_training_args.bin')
        getattr(model, 'module', model).config.to_json_file(os.path.join(tb_logger.writer.log_dir, CONFIG_NAME))
        tokenizer.save_vocabulary(tb_logger.writer.log_dir)

    # Run the training
    trainer.run(train_loader, max_epochs=args.n_epochs)

    # On the main process: close tensorboard logger and rename the last checkpoint (for easy re-loading with OpenAIGPTModel.from_pretrained method)
    if args.local_rank in [-1, 0] and args.n_epochs > 0:
        os.rename(checkpoint_handler._saved[-1][1][-1], os.path.join(tb_logger.writer.log_dir, WEIGHTS_NAME))  # TODO: PR in ignite to have better access to saved file paths (cleaner)
        tb_logger.close()
示例#25
0
文件: metric.py 项目: py361/ignite
 def __rmul__(self, other):
     from ignite.metrics import MetricsLambda
     return MetricsLambda(lambda x, y: x * y, other, self)
示例#26
0
文件: metric.py 项目: py361/ignite
 def __sub__(self, other):
     from ignite.metrics import MetricsLambda
     return MetricsLambda(lambda x, y: x - y, self, other)
示例#27
0
文件: metric.py 项目: py361/ignite
 def __floordiv__(self, other):
     from ignite.metrics import MetricsLambda
     return MetricsLambda(lambda x, y: x // y, self, other)
示例#28
0
文件: metric.py 项目: py361/ignite
 def __rtruediv__(self, other):
     from ignite.metrics import MetricsLambda
     return MetricsLambda(lambda x, y: x.__truediv__(y), other, self)
示例#29
0
def train():
    config_file = "configs/train_daily_dialog_emotion_action_config.json"
    config = Config.from_json_file(config_file)

    # logging is set to INFO (resp. WARN) for main (resp. auxiliary) process. logger.info => log main process only, logger.warning => log all processes
    logging.basicConfig(
        level=logging.INFO if config.local_rank in [-1, 0] else logging.WARN)
    logger.warning(
        "Running process %d", config.local_rank
    )  # This is a logger.warning: it will be printed by all distributed processes
    logger.info("Arguments: %s", pformat(config))

    # Initialize distributed training if needed
    config.distributed = (config.local_rank != -1)
    if config.distributed:
        torch.cuda.set_device(config.local_rank)
        config.device = torch.device("cuda", config.local_rank)
        torch.distributed.init_process_group(backend='nccl',
                                             init_method='env://')

    logger.info(
        "Prepare tokenizer, pretrained model and optimizer - add special tokens for fine-tuning"
    )
    tokenizer_class = GPT2Tokenizer if "gpt2" in config.model_checkpoint else OpenAIGPTTokenizer
    tokenizer = tokenizer_class.from_pretrained(config.model_checkpoint)
    model_class = GPT2DoubleHeadsModel if "gpt2" in config.model_checkpoint else OpenAIGPTDoubleHeadsModel
    model = model_class.from_pretrained(config.model_checkpoint)
    tokenizer.set_special_tokens(SPECIAL_TOKENS)
    model.set_num_special_tokens(len(SPECIAL_TOKENS))
    model.to(config.device)
    optimizer = OpenAIAdam(model.parameters(), lr=config.lr)

    # Prepare model for FP16 and distributed training if needed (order is important, distributed should be the last)
    if config.fp16:
        from apex import amp  # Apex is only required if we use fp16 training
        model, optimizer = amp.initialize(model,
                                          optimizer,
                                          opt_level=config.fp16)
    if config.distributed:
        model = DistributedDataParallel(model,
                                        device_ids=[config.local_rank],
                                        output_device=config.local_rank)

    logger.info("Prepare datasets")
    train_loader, val_loader, train_sampler, valid_sampler = get_data_loaders(
        config, tokenizer)

    # Training function and trainer
    def update(engine, batch):
        model.train()
        input_ids, mc_token_ids, lm_labels, mc_labels, token_type_ids, token_emotion_ids, token_action_ids = tuple(
            input_tensor.to(config.device) for input_tensor in batch)
        lm_loss, mc_loss = model(input_ids, mc_token_ids, lm_labels, mc_labels,
                                 token_type_ids, token_emotion_ids,
                                 token_action_ids)
        loss = (lm_loss * config.lm_coef +
                mc_loss * config.mc_coef) / config.gradient_accumulation_steps
        if config.fp16:
            with amp.scale_loss(loss, optimizer) as scaled_loss:
                scaled_loss.backward()
            torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer),
                                           config.max_norm)
        else:
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), config.max_norm)
        if engine.state.iteration % config.gradient_accumulation_steps == 0:
            optimizer.step()
            optimizer.zero_grad()
        return loss.item()

    trainer = Engine(update)

    # Evaluation function and evaluator (evaluator output is the input of the metrics)
    def inference(engine, batch):
        model.eval()
        with torch.no_grad():
            batch = tuple(
                input_tensor.to(config.device) for input_tensor in batch)
            input_ids, mc_token_ids, lm_labels, mc_labels, token_type_ids, token_emotion_ids, token_action_ids = batch
            #logger.info(tokenizer.decode(input_ids[0, -1, :].tolist()))
            model_outputs = model(input_ids,
                                  mc_token_ids,
                                  token_type_ids=token_type_ids,
                                  token_emotion_ids=token_emotion_ids,
                                  token_action_ids=token_action_ids)
            lm_logits, mc_logits = model_outputs[0], model_outputs[
                1]  # So we can also use GPT2 outputs
            lm_logits_flat_shifted = lm_logits[..., :-1, :].contiguous().view(
                -1, lm_logits.size(-1))
            lm_labels_flat_shifted = lm_labels[..., 1:].contiguous().view(-1)
            return (lm_logits_flat_shifted,
                    mc_logits), (lm_labels_flat_shifted, mc_labels)

    evaluator = Engine(inference)

    # Attach evaluation to trainer: we evaluate when we start the training and at the end of each epoch
    trainer.add_event_handler(Events.EPOCH_COMPLETED,
                              lambda _: evaluator.run(val_loader))
    if config.n_epochs < 1:
        trainer.add_event_handler(Events.COMPLETED,
                                  lambda _: evaluator.run(val_loader))
    if config.eval_before_start:
        trainer.add_event_handler(Events.STARTED,
                                  lambda _: evaluator.run(val_loader))

    # Make sure distributed data samplers split the dataset nicely between the distributed processes
    if config.distributed:
        trainer.add_event_handler(
            Events.EPOCH_STARTED,
            lambda engine: train_sampler.set_epoch(engine.state.epoch))
        evaluator.add_event_handler(
            Events.EPOCH_STARTED,
            lambda engine: valid_sampler.set_epoch(engine.state.epoch))

    # Linearly decrease the learning rate from lr to zero
    scheduler = PiecewiseLinear(optimizer, "lr",
                                [(0, config.lr),
                                 (config.n_epochs * len(train_loader), 0.0)])
    trainer.add_event_handler(Events.ITERATION_STARTED, scheduler)

    # Prepare metrics - note how we compute distributed metrics
    RunningAverage(output_transform=lambda x: x).attach(trainer, "loss")
    metrics = {
        "nll":
        Loss(torch.nn.CrossEntropyLoss(ignore_index=-1),
             output_transform=lambda x: (x[0][0], x[1][0])),
        "accuracy":
        Accuracy(output_transform=lambda x: (x[0][1], x[1][1]))
    }
    metrics.update({
        "average_nll":
        MetricsLambda(average_distributed_scalar, metrics["nll"], config),
        "average_accuracy":
        MetricsLambda(average_distributed_scalar, metrics["accuracy"], config)
    })
    metrics["average_ppl"] = MetricsLambda(math.exp, metrics["average_nll"])
    for name, metric in metrics.items():
        metric.attach(evaluator, name)

    # On the main process: add progress bar, tensorboard, checkpoints and save model, configuration and tokenizer before we start to train
    if config.local_rank in [-1, 0]:
        pbar = ProgressBar(persist=True)
        pbar.attach(trainer, metric_names=["loss"])
        evaluator.add_event_handler(
            Events.COMPLETED, lambda _: pbar.log_message(
                "Validation: %s" % pformat(evaluator.state.metrics)))

        tb_logger = TensorboardLogger(log_dir=config.log_dir)
        tb_logger.attach(trainer,
                         log_handler=OutputHandler(tag="training",
                                                   metric_names=["loss"]),
                         event_name=Events.ITERATION_COMPLETED)
        tb_logger.attach(trainer,
                         log_handler=OptimizerParamsHandler(optimizer),
                         event_name=Events.ITERATION_STARTED)
        tb_logger.attach(evaluator,
                         log_handler=OutputHandler(tag="validation",
                                                   metric_names=list(
                                                       metrics.keys()),
                                                   another_engine=trainer),
                         event_name=Events.EPOCH_COMPLETED)

        checkpoint_handler = ModelCheckpoint(tb_logger.writer.log_dir,
                                             'checkpoint',
                                             save_interval=1,
                                             n_saved=3)
        trainer.add_event_handler(
            Events.EPOCH_COMPLETED, checkpoint_handler,
            {'mymodel': getattr(model, 'module', model)
             })  # "getattr" take care of distributed encapsulation

        torch.save(config,
                   tb_logger.writer.log_dir + '/model_training_args.bin')
        getattr(model, 'module', model).config.to_json_file(
            os.path.join(tb_logger.writer.log_dir, CONFIG_NAME))
        tokenizer.save_vocabulary(tb_logger.writer.log_dir)

    # Run the training
    trainer.run(train_loader, max_epochs=config.n_epochs)

    # On the main process: close tensorboard logger and rename the last checkpoint (for easy re-loading with OpenAIGPTModel.from_pretrained method)
    if config.local_rank in [-1, 0] and config.n_epochs > 0:
        os.rename(
            checkpoint_handler._saved[-1][1][-1],
            os.path.join(tb_logger.writer.log_dir, WEIGHTS_NAME)
        )  # TODO: PR in ignite to have better access to saved file paths (cleaner)
        tb_logger.close()
def create_eval_engine(model, is_multilabel, n_classes, cpu):

  def process_function(engine, batch):
      X, y = batch

      if cpu:
          pred = model(X.cpu())
          gold = y.cpu()
      else:
          pred = model(X.cuda())
          gold = y.cuda()

      return pred, gold

  eval_engine = Engine(process_function)

  if is_multilabel:
      accuracy = MulticlassOverallAccuracy(n_classes=n_classes)
      accuracy.attach(eval_engine, "accuracy")
      per_class_accuracy = MulticlassPerClassAccuracy(n_classes=n_classes)
      per_class_accuracy.attach(eval_engine, "per class accuracy")
      recall = MulticlassRecall(n_classes=n_classes)
      recall.attach(eval_engine, "recall")
      precision = MulticlassPrecision(n_classes=n_classes)
      precision.attach(eval_engine, "precision")
      f1 = MulticlassF(n_classes=n_classes, f_n=1)
      f1.attach(eval_engine, "f1")
      f2= MulticlassF(n_classes=n_classes, f_n=2)
      f2.attach(eval_engine, "f2")

      avg_recall = MulticlassRecall(n_classes=n_classes, average=True)
      avg_recall.attach(eval_engine, "average recall")
      avg_precision = MulticlassPrecision(n_classes=n_classes, average=True)
      avg_precision.attach(eval_engine, "average precision")
      avg_f1 = MulticlassF(n_classes=n_classes, average=True, f_n=1)
      avg_f1.attach(eval_engine, "average f1")
      avg_f2= MulticlassF(n_classes=n_classes, average=True, f_n=2)
      avg_f2.attach(eval_engine, "average f2")
  else:
      accuracy = Accuracy()
      accuracy.attach(eval_engine, "accuracy")
      recall = Recall(average=False)
      recall.attach(eval_engine, "recall")
      precision = Precision(average=False)
      precision.attach(eval_engine, "precision")
      confusion_matrix = ConfusionMatrix(num_classes=n_classes)
      confusion_matrix.attach(eval_engine, "confusion_matrix")
      f1 = (precision * recall * 2 / (precision + recall))
      f1.attach(eval_engine, "f1")
      f2 = (precision * recall * 5 / ((4*precision) + recall))
      f2.attach(eval_engine, "f2")

      def Fbeta(r, p, beta):
          return torch.mean((1 + beta ** 2) * p * r / (beta ** 2 * p + r + 1e-20)).item()

      avg_f1 = MetricsLambda(Fbeta, recall, precision, 1)
      avg_f1.attach(eval_engine, "average f1")
      avg_f2 = MetricsLambda(Fbeta, recall, precision, 2)
      avg_f2.attach(eval_engine, "average f2")

      avg_recall = Recall(average=True)
      avg_recall.attach(eval_engine, "average recall")
      avg_precision = Precision(average=True)
      avg_precision.attach(eval_engine, "average precision")

      if n_classes == 2:
          top_k = TopK(k=10, label_idx_of_interest=0)
          top_k.attach(eval_engine, "top_k")

  return eval_engine