def train(unused_argv):
    # image size must be divisible by 2 ** L
    if FLAGS.dataset == "CIFAR":
        K = 32
        L = 3
        num_channels = 3
    elif FLAGS.dataset == "COCO":
        K = 32
        L = 3
        num_channels = 3

    flow = Flow(K, L, num_channels).cuda()
    optimizer = optim.Adam(flow.parameters(), lr=FLAGS.lr)
    scheduler = optim.lr_scheduler.LambdaLR(optimizer,
                                            lambda step: min(1.0, step / 100))
    epochs_done = 1
    if FLAGS.resume_from:
        ckpt = torch.load(FLAGS.resume_from)
        flow.load_state_dict(ckpt["flow_state_dict"])
        optimizer.load_state_dict(ckpt["optimizer_state_dict"])
        epochs_done = ckpt["epoch"]
        flow.eval()

    timestamp = datetime.now().strftime("%Y%m%d-%H%M%S")
    log_folder = Path(FLAGS.logs) / timestamp
    writer = SummaryWriter(str(log_folder))

    zs_shapes = None
    bpd = None

    for epoch in range(epochs_done, FLAGS.num_epochs + epochs_done):
        trainloader = get_trainloader(FLAGS.crop_size)
        pbar = tqdm(trainloader, desc="epoch %d" % epoch, leave=False)
        for n, image_batch in enumerate(pbar):
            if not isinstance(image_batch, torch.Tensor):
                image_batch = image_batch[0]

            image_batch = image_batch.to(cuda_device)
            batch_size = image_batch.shape[0]

            optimizer.zero_grad()
            zs, log_det = flow(image_batch)
            image_dim = image_batch.numel() / batch_size
            nll = -image_dim * torch.log(1 / 256 * one)
            nll = nll + nll_loss(zs, log_det)
            bpd = nll / (image_dim * torch.log(2 * one))

            pbar.set_postfix(nll=nll.item(), bpd=bpd.item())

            nll.backward()
            optimizer.step()

            if zs_shapes is None:
                zs_shapes = [z.size() for z in zs]

            step = (epoch - 1) * len(trainloader) + n
            writer.add_scalar("nll", nll, step)
            writer.add_scalar("bpd", bpd, step)
            scheduler.step(step)
        pbar.close()
        if epoch % FLAGS.test_every == 0:
            torch.save(
                {
                    "epoch": epoch,
                    "flow_state_dict": flow.state_dict(),
                    "optimizer_state_dict": optimizer.state_dict(),
                    "zs_shapes": zs_shapes,
                    "bpd": bpd,
                }, str(log_folder / ("ckpt-%d.pth" % epoch)))
            image_samples, _ = sample(flow, zs_shapes)
            image_samples = make_grid(image_samples,
                                      nrow=10,
                                      normalize=True,
                                      range=(0, 1),
                                      pad_value=10)
            writer.add_image("samples", image_samples, epoch)
    writer.close()
Пример #2
0
                    default=False,
                    help='Resume from checkpoint')
args = parser.parse_args()

# ------------------------------
# I. MODEL
# ------------------------------
flow = Flow(width=args.width,
            depth=args.depth,
            n_levels=args.n_levels,
            data_dim=2).to(device)

# ------------------------------
# II. OPTIMIZER
# ------------------------------
optim_flow = torch.optim.Adam(flow.parameters(),
                              lr=args.lr,
                              betas=(args.b1, args.b2))

# ------------------------------
# III. DATA LOADER
# ------------------------------
dataset, dataloader = util.get_data(args)


# ------------------------------
# IV. TRAINING
# ------------------------------
def main(args):
    start_epoch = 0
    if args.resume:
Пример #3
0
  # transform = transforms.Compose([transforms.Resize((img_sz, img_sz)), transforms.ToTensor()])
  # train_dataset = MNIST(root=opt.root, train=True, transform=transform, \
  #   download=True)
  # train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, \
  # **kwargs)

  dataset = CustomDataset(data_root, img_sz)
  train_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

  model = Flow(in_channel, 64, n_block, n_flows, img_sz, num_scales)

  if test_model:
    test(os.path.join(root, model_path), model)
  
  else:
    optimizer = optim.Adam(model.parameters(), lr=lr)
    loss_fn = RealNVPLoss()
    
    if device=="cuda" and torch.cuda.device_count()>1:
      model = DataParallel(model, device_ids=[0, 1, 2, 3])
      model = model.to(device)


    old_loss = 1e6
    best_loss = 0
    is_best = 0
    for epoch in range(epochs):
      for b, x in enumerate(train_loader, 0):
        model.train()
        optimizer.zero_grad()
        x = x.to(device)