def train(unused_argv):
    # image size must be divisible by 2 ** L
    if FLAGS.dataset == "CIFAR":
        K = 32
        L = 3
        num_channels = 3
    elif FLAGS.dataset == "COCO":
        K = 32
        L = 3
        num_channels = 3

    flow = Flow(K, L, num_channels).cuda()
    optimizer = optim.Adam(flow.parameters(), lr=FLAGS.lr)
    scheduler = optim.lr_scheduler.LambdaLR(optimizer,
                                            lambda step: min(1.0, step / 100))
    epochs_done = 1
    if FLAGS.resume_from:
        ckpt = torch.load(FLAGS.resume_from)
        flow.load_state_dict(ckpt["flow_state_dict"])
        optimizer.load_state_dict(ckpt["optimizer_state_dict"])
        epochs_done = ckpt["epoch"]
        flow.eval()

    timestamp = datetime.now().strftime("%Y%m%d-%H%M%S")
    log_folder = Path(FLAGS.logs) / timestamp
    writer = SummaryWriter(str(log_folder))

    zs_shapes = None
    bpd = None

    for epoch in range(epochs_done, FLAGS.num_epochs + epochs_done):
        trainloader = get_trainloader(FLAGS.crop_size)
        pbar = tqdm(trainloader, desc="epoch %d" % epoch, leave=False)
        for n, image_batch in enumerate(pbar):
            if not isinstance(image_batch, torch.Tensor):
                image_batch = image_batch[0]

            image_batch = image_batch.to(cuda_device)
            batch_size = image_batch.shape[0]

            optimizer.zero_grad()
            zs, log_det = flow(image_batch)
            image_dim = image_batch.numel() / batch_size
            nll = -image_dim * torch.log(1 / 256 * one)
            nll = nll + nll_loss(zs, log_det)
            bpd = nll / (image_dim * torch.log(2 * one))

            pbar.set_postfix(nll=nll.item(), bpd=bpd.item())

            nll.backward()
            optimizer.step()

            if zs_shapes is None:
                zs_shapes = [z.size() for z in zs]

            step = (epoch - 1) * len(trainloader) + n
            writer.add_scalar("nll", nll, step)
            writer.add_scalar("bpd", bpd, step)
            scheduler.step(step)
        pbar.close()
        if epoch % FLAGS.test_every == 0:
            torch.save(
                {
                    "epoch": epoch,
                    "flow_state_dict": flow.state_dict(),
                    "optimizer_state_dict": optimizer.state_dict(),
                    "zs_shapes": zs_shapes,
                    "bpd": bpd,
                }, str(log_folder / ("ckpt-%d.pth" % epoch)))
            image_samples, _ = sample(flow, zs_shapes)
            image_samples = make_grid(image_samples,
                                      nrow=10,
                                      normalize=True,
                                      range=(0, 1),
                                      pad_value=10)
            writer.add_image("samples", image_samples, epoch)
    writer.close()
예제 #2
0
        optimizer.zero_grad()
        x = x.to(device)
        # show(x[0].detach().permute(1, 2, 0), os.path.join(root, "input/1.png"), 0)
        z, sldj = model(x)
        # x = model.reverse(z)
        # show(x[0].detach().permute(1, 2, 0), os.path.join(root, "input/1.png"), 0)
        loss = loss_fn(z, sldj)
        loss.backward()
        grad_clipping(optimizer)
        optimizer.step()
        if loss.item() < old_loss:
          is_best = 1
          save_checkpoint(model.state_dict(), is_best, root, f"./models/checkpoint.pth.tar")
          old_loss = loss
        print(f"loss at batch {b} epoch {epoch}: {loss.item()}")
        if b == 0: #len(train_loader) - 1:
          z_sample = torch.randn((16, 3, 32, 32), dtype=torch.float32).to(device)
          model.eval()
          if use_cuda:
            x = model.module.reverse(z_sample)
          else:
            x = model.reverse(z_sample)
          images = torch.sigmoid(x)
          images_concat = utils.make_grid(images, nrow=int(16 ** 0.5), padding=2, pad_value=255)
          utils.save_image(images_concat, os.path.join(root, f"./samples/{epoch}_{b}_{str((torch.round(loss * 10**3)/10**3).item())}.png"))
          # show(x.squeeze().detach().cpu(), os.path.join(root, f"./samples/{epoch}_{b}_{str((torch.round(loss * 10**3)/10**3).item())}.png"), 1)
          # show(x.squeeze().permute(1, 2, 0).detach().cpu(), os.path.join(root, f"./samples/{epoch}_{b}_{str((torch.round(loss * 10**3)/10**3).item())}.png"), 1)
    best_loss = old_loss
    print(f"Best loss at {best_loss}")