Exemplo n.º 1
0
def run(loop: Loop):
    seed_everything(42)
    setup_cudnn_reproducibility(True, False)

    train_ds, valid_ds = get_train_test_datasets("data/cifar")

    model = auto_model(get_model())

    train_loader = auto_dataloader(
        train_ds,
        batch_size=512,
        shuffle=True,
        drop_last=True,
        num_workers=4,
    )

    valid_loader = auto_dataloader(
        valid_ds,
        batch_size=512,
        num_workers=4,
        shuffle=False,
    )

    optim = SGD(model.parameters(), lr=0.4, momentum=0.9)

    scheduler = OneCycleLR(optim,
                           max_lr=1,
                           epochs=NUM_EPOCHS,
                           steps_per_epoch=len(train_loader))
    criterion = CrossEntropyLoss()

    precision = Precision(average=False)
    recall = Recall(average=False)

    # Ignite metrics are combinable
    f1 = (precision * recall * 2 / (precision + recall)).mean()
    accuracy = Accuracy()

    # We are attaching metrics to automatically reset
    loop.attach(
        # Loop manages train/eval modes, device and requires_grad of attached `nn.Module`s
        criterion=criterion,
        # This criterion doesn't have any state or attribute tensors
        # So it's attachment doesn't introduce any behavior
        model=model,
        # Loop saves state of all attached objects having state_dict()/load_state_dict() methods
        # to checkpoints
        optimizer=optim,
        scheduler=scheduler,
    )

    def train(loop: Loop):
        for _ in loop.iterate_epochs(NUM_EPOCHS):
            for x, y in loop.iterate_dataloader(train_loader, mode="train"):
                y_pred_logits = model(x)

                loss: torch.Tensor = criterion(y_pred_logits, y)
                loop.backward(loss)
                # Makes optimizer step and also
                # zeroes grad after (default)
                loop.optimizer_step(optim, zero_grad=True)

                # Here we call scheduler.step() every iteration
                # because we have one-cycle scheduler
                # we also can call it after all dataloader loop
                # if it's som usual scheduler
                scheduler.step()

                # Log learning rate. All metrics are written to tensorboard
                # with specified names
                # If iteration='auto' (default) its determined based on where the call is
                # performed. Here it will be batches
                loop.metrics.log("lr",
                                 scheduler.get_last_lr()[0],
                                 iteration="auto")

            # Loop disables gradients and calls Module.eval() inside loop
            # for all attached modules when mode="valid" (default)
            for x, y in loop.iterate_dataloader(valid_loader, mode="valid"):
                y_pred_logits: torch.Tensor = model(x)

                y_pred = to_onehot(y_pred_logits.argmax(dim=-1),
                                   num_classes=10)

                precision.update((y_pred, y))
                recall.update((y_pred, y))
                accuracy.update((y_pred, y))

            # This metrics will be epoch metrics because they are called outside
            # dataloader loop
            # Here we logging metric without resetting it
            loop.metrics.log("valid/precision", precision.compute().mean())
            loop.metrics.log("valid/recall", recall.compute().mean())

            # .log() method above accepts values (tensors, floats, np.array's)
            # .consume() accepts Metric object. It resets it after logging
            loop.metrics.consume("valid/f1", f1)
            loop.metrics.consume("valid/accuracy", accuracy)

    loop.run(train)
Exemplo n.º 2
0
def train(loop: Loop, config: Config):
    seed_everything(22)
    setup_cudnn_reproducibility(True)

    dataloader, num_channels = get_dataloader(config)

    generator = auto_model(
        Generator(config.z_dim, config.g_filters, num_channels))
    discriminator = auto_model(Discriminator(num_channels, config.d_filters))

    bce = BCEWithLogitsLoss()

    opt_G = Adam(generator.parameters(),
                 lr=config.lr * idist.get_world_size(),
                 betas=(config.beta_1, 0.999))
    opt_D = Adam(discriminator.parameters(),
                 lr=config.lr * idist.get_world_size(),
                 betas=(config.beta_1, 0.999))

    device = idist.device()
    real_labels = torch.ones(config.batch_size, device=device)
    fake_labels = torch.zeros(config.batch_size, device=device)
    fixed_noise = torch.randn(16, config.z_dim, 1, 1, device=device)

    def dump_fake_images_to_tb():
        with loop.mode("valid"):
            fake = make_grid(generator(fixed_noise),
                             normalize=True,
                             range=(-1, 1)).cpu()

        if idist.get_rank() == 0:
            sw: SummaryWriter = get_summary_writer(loop)
            sw.add_image("fake_images",
                         fake,
                         global_step=loop.iterations.current_epoch)

    def get_noise():
        return torch.randn(config.batch_size,
                           config.z_dim,
                           1,
                           1,
                           device=device)

    error_D_avg = Average()
    error_G_avg = Average()

    loop.attach(generator=generator,
                discriminator=discriminator,
                d_opt=opt_D,
                g_opt=opt_G)

    def stage_1(loop: Loop):
        for _ in loop.iterate_epochs(config.epochs):

            for real, _ in loop.iterate_dataloader(dataloader, mode="train"):
                output = discriminator(real)
                error_D_real = bce(output, real_labels)
                loop.backward(error_D_real)

                fake = generator(get_noise())

                # train with fake
                output = discriminator(fake.detach())
                error_D_fake = bce(output, fake_labels)

                loop.backward(error_D_fake)
                loop.optimizer_step(opt_D)

                with torch.no_grad():
                    error_D = error_D_fake + error_D_real
                    error_D_avg.update(error_D)

                with no_grad_for_module(discriminator), module_eval(
                        discriminator):
                    # We don't want to compute grads for
                    # discriminator parameters on
                    # error_G backward pass
                    output = discriminator(fake)

                error_G = bce(output, real_labels)

                simple_gd_step(loop, opt_G, error_G)

                error_G_avg.update(error_G.detach())

                loop.metrics.log("generator/error_batch", error_G.item())
                loop.metrics.log("discriminator/error_batch", error_D.item())

            loop.metrics.consume("generator/error_epoch", error_G_avg)
            loop.metrics.consume("discriminator/error_epoch", error_D_avg)

            dump_fake_images_to_tb()

    loop.run(stage_1)