def train(unused_argv): # image size must be divisible by 2 ** L if FLAGS.dataset == "CIFAR": K = 32 L = 3 num_channels = 3 elif FLAGS.dataset == "COCO": K = 32 L = 3 num_channels = 3 flow = Flow(K, L, num_channels).cuda() optimizer = optim.Adam(flow.parameters(), lr=FLAGS.lr) scheduler = optim.lr_scheduler.LambdaLR(optimizer, lambda step: min(1.0, step / 100)) epochs_done = 1 if FLAGS.resume_from: ckpt = torch.load(FLAGS.resume_from) flow.load_state_dict(ckpt["flow_state_dict"]) optimizer.load_state_dict(ckpt["optimizer_state_dict"]) epochs_done = ckpt["epoch"] flow.eval() timestamp = datetime.now().strftime("%Y%m%d-%H%M%S") log_folder = Path(FLAGS.logs) / timestamp writer = SummaryWriter(str(log_folder)) zs_shapes = None bpd = None for epoch in range(epochs_done, FLAGS.num_epochs + epochs_done): trainloader = get_trainloader(FLAGS.crop_size) pbar = tqdm(trainloader, desc="epoch %d" % epoch, leave=False) for n, image_batch in enumerate(pbar): if not isinstance(image_batch, torch.Tensor): image_batch = image_batch[0] image_batch = image_batch.to(cuda_device) batch_size = image_batch.shape[0] optimizer.zero_grad() zs, log_det = flow(image_batch) image_dim = image_batch.numel() / batch_size nll = -image_dim * torch.log(1 / 256 * one) nll = nll + nll_loss(zs, log_det) bpd = nll / (image_dim * torch.log(2 * one)) pbar.set_postfix(nll=nll.item(), bpd=bpd.item()) nll.backward() optimizer.step() if zs_shapes is None: zs_shapes = [z.size() for z in zs] step = (epoch - 1) * len(trainloader) + n writer.add_scalar("nll", nll, step) writer.add_scalar("bpd", bpd, step) scheduler.step(step) pbar.close() if epoch % FLAGS.test_every == 0: torch.save( { "epoch": epoch, "flow_state_dict": flow.state_dict(), "optimizer_state_dict": optimizer.state_dict(), "zs_shapes": zs_shapes, "bpd": bpd, }, str(log_folder / ("ckpt-%d.pth" % epoch))) image_samples, _ = sample(flow, zs_shapes) image_samples = make_grid(image_samples, nrow=10, normalize=True, range=(0, 1), pad_value=10) writer.add_image("samples", image_samples, epoch) writer.close()
for epoch in range(epochs): for b, x in enumerate(train_loader, 0): model.train() optimizer.zero_grad() x = x.to(device) # show(x[0].detach().permute(1, 2, 0), os.path.join(root, "input/1.png"), 0) z, sldj = model(x) # x = model.reverse(z) # show(x[0].detach().permute(1, 2, 0), os.path.join(root, "input/1.png"), 0) loss = loss_fn(z, sldj) loss.backward() grad_clipping(optimizer) optimizer.step() if loss.item() < old_loss: is_best = 1 save_checkpoint(model.state_dict(), is_best, root, f"./models/checkpoint.pth.tar") old_loss = loss print(f"loss at batch {b} epoch {epoch}: {loss.item()}") if b == 0: #len(train_loader) - 1: z_sample = torch.randn((16, 3, 32, 32), dtype=torch.float32).to(device) model.eval() if use_cuda: x = model.module.reverse(z_sample) else: x = model.reverse(z_sample) images = torch.sigmoid(x) images_concat = utils.make_grid(images, nrow=int(16 ** 0.5), padding=2, pad_value=255) utils.save_image(images_concat, os.path.join(root, f"./samples/{epoch}_{b}_{str((torch.round(loss * 10**3)/10**3).item())}.png")) # show(x.squeeze().detach().cpu(), os.path.join(root, f"./samples/{epoch}_{b}_{str((torch.round(loss * 10**3)/10**3).item())}.png"), 1) # show(x.squeeze().permute(1, 2, 0).detach().cpu(), os.path.join(root, f"./samples/{epoch}_{b}_{str((torch.round(loss * 10**3)/10**3).item())}.png"), 1) best_loss = old_loss