Пример #1
0
    def _test(n_epochs, metric_device):
        n_iters = 100
        s = 16
        n_classes = 10

        offset = n_iters * s
        y_true = torch.randint(0, n_classes, size=(offset * idist.get_world_size(),)).to(device)
        y_preds = torch.rand(offset * idist.get_world_size(), n_classes).to(device)

        def update(engine, i):
            return (
                y_preds[i * s + rank * offset : (i + 1) * s + rank * offset, :],
                y_true[i * s + rank * offset : (i + 1) * s + rank * offset],
            )

        engine = Engine(update)

        k = 5
        acc = TopKCategoricalAccuracy(k=k, device=metric_device)
        acc.attach(engine, "acc")

        data = list(range(n_iters))
        engine.run(data=data, max_epochs=n_epochs)

        assert "acc" in engine.state.metrics
        res = engine.state.metrics["acc"]
        if isinstance(res, torch.Tensor):
            res = res.cpu().numpy()

        true_res = top_k_accuracy(y_true.cpu().numpy(), y_preds.cpu().numpy(), k=k)

        assert pytest.approx(res) == true_res
def multiclass_train_lstm(
    model: LstmClassifier,
    dataloader_train: DataLoader,
    dataloader_val: DataLoader,
    filename_prefix: str,
):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    optimizer = torch.optim.Adam(model.parameters(),
                                 lr=1e-4,
                                 weight_decay=1e-3)
    criterion = CrossEntropyLossOneHot()

    def process_function(_engine, batch):
        model.train()
        optimizer.zero_grad()
        x, y = batch
        x = x.to(device)
        y = y.to(device)
        y_pred = model(x)
        loss = criterion(y_pred, y)
        loss.backward()
        optimizer.step()
        return y_pred, y, loss.item(),

    def eval_function(_engine, batch):
        model.eval()
        with torch.no_grad():
            x, y = batch
            y = y.to(device)
            x = x.to(device)
            y_pred = model(x)
            return y_pred, y

    def score_function(engine):
        return engine.state.metrics['top3-accuracy']

    model.to(device)

    trainer = Engine(process_function)
    train_evaluator = Engine(eval_function)
    validation_evaluator = Engine(eval_function)

    accuracy_top1 = Accuracy(output_transform=lambda x: (x[0], x[1]),
                             device=device,
                             is_multilabel=True)
    accuracy_top3 = TopKCategoricalAccuracy(output_transform=lambda x:
                                            (x[0], x[1]),
                                            k=3,
                                            device=device)

    RunningAverage(accuracy_top1).attach(trainer, 'accuracy')
    RunningAverage(accuracy_top3).attach(trainer, 'top3-accuracy')
    RunningAverage(output_transform=lambda x: x[2]).attach(trainer, 'loss')

    accuracy_top1.attach(train_evaluator, 'accuracy')
    accuracy_top3.attach(train_evaluator, 'top3-accuracy')
    Loss(criterion).attach(train_evaluator, 'loss')

    accuracy_top1.attach(validation_evaluator, 'accuracy')
    accuracy_top3.attach(validation_evaluator, 'top3-accuracy')
    Loss(criterion).attach(validation_evaluator, 'loss')

    pbar = ProgressBar(persist=True, bar_format="")
    pbar.attach(engine=trainer, metric_names='all')

    @trainer.on(Events.EPOCH_COMPLETED)
    def log_training_results(engine):
        train_evaluator.run(dataloader_train)
        message = f'Training results - Epoch: {engine.state.epoch}.'
        for metric_name, score in train_evaluator.state.metrics.items():
            message += f' {metric_name}: {score:.2f}.'
        pbar.log_message(message)

    @trainer.on(Events.EPOCH_COMPLETED)
    def log_validation_results(engine):
        validation_evaluator.run(dataloader_val)
        message = f'Validation results - Epoch: {engine.state.epoch}.'
        for metric_name, score in train_evaluator.state.metrics.items():
            message += f' {metric_name}: {score:.2f}.'
        pbar.log_message(message)
        pbar.n = pbar.last_print_n = 0

    validation_evaluator.add_event_handler(
        Events.COMPLETED,
        EarlyStopping(patience=5,
                      score_function=score_function,
                      trainer=trainer))

    checkpointer = ModelCheckpoint(dirname=DIR_MODELS,
                                   filename_prefix=filename_prefix,
                                   score_function=score_function,
                                   score_name='top3-accuracy',
                                   n_saved=2,
                                   create_dir=True,
                                   save_as_state_dict=True,
                                   require_empty=False)
    trainer.add_event_handler(Events.EPOCH_COMPLETED, checkpointer,
                              {'v2': model})

    trainer.run(dataloader_train, max_epochs=20)