def train(unused_argv):
    # image size must be divisible by 2 ** L
    if FLAGS.dataset == "CIFAR":
        K = 32
        L = 3
        num_channels = 3
    elif FLAGS.dataset == "COCO":
        K = 32
        L = 3
        num_channels = 3

    flow = Flow(K, L, num_channels).cuda()
    optimizer = optim.Adam(flow.parameters(), lr=FLAGS.lr)
    scheduler = optim.lr_scheduler.LambdaLR(optimizer,
                                            lambda step: min(1.0, step / 100))
    epochs_done = 1
    if FLAGS.resume_from:
        ckpt = torch.load(FLAGS.resume_from)
        flow.load_state_dict(ckpt["flow_state_dict"])
        optimizer.load_state_dict(ckpt["optimizer_state_dict"])
        epochs_done = ckpt["epoch"]
        flow.eval()

    timestamp = datetime.now().strftime("%Y%m%d-%H%M%S")
    log_folder = Path(FLAGS.logs) / timestamp
    writer = SummaryWriter(str(log_folder))

    zs_shapes = None
    bpd = None

    for epoch in range(epochs_done, FLAGS.num_epochs + epochs_done):
        trainloader = get_trainloader(FLAGS.crop_size)
        pbar = tqdm(trainloader, desc="epoch %d" % epoch, leave=False)
        for n, image_batch in enumerate(pbar):
            if not isinstance(image_batch, torch.Tensor):
                image_batch = image_batch[0]

            image_batch = image_batch.to(cuda_device)
            batch_size = image_batch.shape[0]

            optimizer.zero_grad()
            zs, log_det = flow(image_batch)
            image_dim = image_batch.numel() / batch_size
            nll = -image_dim * torch.log(1 / 256 * one)
            nll = nll + nll_loss(zs, log_det)
            bpd = nll / (image_dim * torch.log(2 * one))

            pbar.set_postfix(nll=nll.item(), bpd=bpd.item())

            nll.backward()
            optimizer.step()

            if zs_shapes is None:
                zs_shapes = [z.size() for z in zs]

            step = (epoch - 1) * len(trainloader) + n
            writer.add_scalar("nll", nll, step)
            writer.add_scalar("bpd", bpd, step)
            scheduler.step(step)
        pbar.close()
        if epoch % FLAGS.test_every == 0:
            torch.save(
                {
                    "epoch": epoch,
                    "flow_state_dict": flow.state_dict(),
                    "optimizer_state_dict": optimizer.state_dict(),
                    "zs_shapes": zs_shapes,
                    "bpd": bpd,
                }, str(log_folder / ("ckpt-%d.pth" % epoch)))
            image_samples, _ = sample(flow, zs_shapes)
            image_samples = make_grid(image_samples,
                                      nrow=10,
                                      normalize=True,
                                      range=(0, 1),
                                      pad_value=10)
            writer.add_image("samples", image_samples, epoch)
    writer.close()
Exemplo n.º 2
0
 for epoch in range(epochs):
   for b, x in enumerate(train_loader, 0):
     model.train()
     optimizer.zero_grad()
     x = x.to(device)
     # show(x[0].detach().permute(1, 2, 0), os.path.join(root, "input/1.png"), 0)
     z, sldj = model(x)
     # x = model.reverse(z)
     # show(x[0].detach().permute(1, 2, 0), os.path.join(root, "input/1.png"), 0)
     loss = loss_fn(z, sldj)
     loss.backward()
     grad_clipping(optimizer)
     optimizer.step()
     if loss.item() < old_loss:
       is_best = 1
       save_checkpoint(model.state_dict(), is_best, root, f"./models/checkpoint.pth.tar")
       old_loss = loss
     print(f"loss at batch {b} epoch {epoch}: {loss.item()}")
     if b == 0: #len(train_loader) - 1:
       z_sample = torch.randn((16, 3, 32, 32), dtype=torch.float32).to(device)
       model.eval()
       if use_cuda:
         x = model.module.reverse(z_sample)
       else:
         x = model.reverse(z_sample)
       images = torch.sigmoid(x)
       images_concat = utils.make_grid(images, nrow=int(16 ** 0.5), padding=2, pad_value=255)
       utils.save_image(images_concat, os.path.join(root, f"./samples/{epoch}_{b}_{str((torch.round(loss * 10**3)/10**3).item())}.png"))
       # show(x.squeeze().detach().cpu(), os.path.join(root, f"./samples/{epoch}_{b}_{str((torch.round(loss * 10**3)/10**3).item())}.png"), 1)
       # show(x.squeeze().permute(1, 2, 0).detach().cpu(), os.path.join(root, f"./samples/{epoch}_{b}_{str((torch.round(loss * 10**3)/10**3).item())}.png"), 1)
 best_loss = old_loss