Exemplo n.º 1
0
def sample_images(batches_done):
    """Saves a generated sample from the test set"""
    imgs = next(iter(val_dataloader))
    G_AB.eval()
    G_BA.eval()
    Trans_AB.eval()
    Trans_BA.eval()

    real_A_full = Variable(imgs["A"].type(Tensor))
    pyr_A  = pyramid.pyramid_decom(img=real_A_full, max_levels=2)
    real_A = pyr_A[-1]
    fake_B = G_AB(real_A)
    pyr_A_trans = pyramid_transform(pyr_A, real_A, fake_B, Trans_AB)
    fake_B_full = pyramid.pyramid_recons(pyr_A_trans)

    real_B_full = Variable(imgs["B"].type(Tensor))
    pyr_B  = pyramid.pyramid_decom(img=real_B_full, max_levels=2)
    real_B = pyr_B[-1]
    fake_A = G_BA(real_B)
    pyr_B_trans = pyramid_transform(pyr_B, real_B, fake_A, Trans_BA)
    fake_A_full = pyramid.pyramid_recons(pyr_B_trans)

    # Arange images along x-axis
    real_A = make_grid(real_A_full.cpu().data, nrow=5, normalize=True)
    real_B = make_grid(real_B_full.cpu().data, nrow=5, normalize=True)
    fake_A = make_grid(fake_A_full.cpu().data, nrow=5, normalize=True)
    fake_B = make_grid(fake_B_full.cpu().data, nrow=5, normalize=True)
    # Arange images along y-axis
    image_grid = torch.cat((real_A, fake_B, real_B, fake_A), 1)
    save_image(image_grid, "images/%s/%s.png" % (opt.dataset_name, batches_done), normalize=False)
Exemplo n.º 2
0
def sample_images(batches_done):
    """Saves a generated sample from the test set"""
    imgs = next(iter(val_dataloader))
    G.eval()
    Trans.eval()

    real_full = Variable(imgs["A"].type(Tensor))
    real_B_full = Variable(imgs["B"].type(Tensor))
    start = time.time()
    pyr  = pyramid.pyramid_decom(img=real_full, max_levels=opt.levels)
    fake = G(pyr[-1])
    pyr_trans = pyramid_transform(pyr, pyr[-1], fake, Trans, if_conv)
    fake_full = pyramid.pyramid_recons(pyr_trans)
    cost = (time.time() - start) / opt.validate_size
    print('time cost for one image: {:.4f}'.format(cost))

    # Arange images along x-axis
    real_A = make_grid(torch.clamp(real_full, -1, 1).cpu().data, nrow=5, normalize=True)
    fake_B = make_grid(torch.clamp(fake_full, -1, 1).cpu().data, nrow=5, normalize=True)
    real_B = make_grid(torch.clamp(real_B_full, -1, 1).cpu().data, nrow=5, normalize=True)
    # Arange images along y-axis
    image_grid = torch.cat((real_A, fake_B, real_B), 1)
    if not os.path.exists("../images/%s" % (opt.dataset_name)):
        os.makedirs("../images/%s" % (opt.dataset_name))

    # mse = torch.nn.functional.mse_loss(torch.clamp(fake_full[0], -1, 1), torch.clamp(real_B_full[0], -1, 1))
    # print(fake_full[0].shape, real_B_full[0].shape)
    # psnr = 10 * log10(1 / mse.item())
    # print('psnr: {:.4f}'.format(psnr))

    save_image(image_grid, "../images/%s/%s.png" % (opt.dataset_name, batches_done), normalize=False)
Exemplo n.º 3
0
def sample_images(batches_done):
    """Saves a generated sample from the test set"""
    imgs = next(iter(val_dataloader))
    G_AB.eval()
    G_BA.eval()
    Trans_AB.eval()
    Trans_BA.eval()

    real_A_full = Variable(imgs["A"].type(Tensor))
    pyr_A = pyramid.pyramid_decom(img=real_A_full, max_levels=opt.levels)
    real_A = pyr_A[-1]
    fake_B = G_AB(real_A)
    pyr_A_trans = pyramid_transform(pyr_A, real_A, fake_B, Trans_AB, if_conv)
    fake_B_full = pyramid.pyramid_recons(pyr_A_trans)

    real_B_full = Variable(imgs["B"].type(Tensor))
    start = time.time()
    pyr_B = pyramid.pyramid_decom(img=real_B_full, max_levels=opt.levels)
    real_B = pyr_B[-1]
    fake_A = G_BA(real_B)
    pyr_B_trans = pyramid_transform(pyr_B, real_B, fake_A, Trans_BA, if_conv)
    fake_A_full = pyramid.pyramid_recons(pyr_B_trans)

    cost = (time.time() - start) / opt.validate_size
    print('time cost for one image: {:.4f}'.format(cost))

    # Arange images along x-axis
    real_A = make_grid(torch.clamp(real_A_full, -1, 1).cpu().data,
                       nrow=5,
                       normalize=True)
    real_B = make_grid(torch.clamp(real_B_full, -1, 1).cpu().data,
                       nrow=5,
                       normalize=True)
    fake_A = make_grid(torch.clamp(fake_A_full, -1, 1).cpu().data,
                       nrow=5,
                       normalize=True)
    fake_B = make_grid(torch.clamp(fake_B_full, -1, 1).cpu().data,
                       nrow=5,
                       normalize=True)
    # Arange images along y-axis
    image_grid = torch.cat((real_A, fake_B, real_B, fake_A), 1)
    save_image(image_grid,
               "../images/%s/%s.png" % (opt.dataset_name, batches_done),
               normalize=False)
Exemplo n.º 4
0
        Trans_AB.train()
        Trans_BA.train()

        optimizer_G.zero_grad()

        # Identity loss
        loss_id_A = criterion_identity(G_BA(real_A), real_A)
        loss_id_B = criterion_identity(G_AB(real_B), real_B)

        loss_identity = (loss_id_A + loss_id_B) / 2

        # GAN loss
        fake_B = G_AB(real_A)
        pyr_A_trans = pyramid_transform(pyr_A, real_A, fake_B, Trans_AB,
                                        if_conv)
        fake_B_full = pyramid.pyramid_recons(pyr_A_trans)
        loss_GAN_AB = criterion_GAN(D_B(fake_B_full), valid)

        fake_A = G_BA(real_B)
        pyr_B_trans = pyramid_transform(pyr_B, real_B, fake_A, Trans_BA,
                                        if_conv)
        fake_A_full = pyramid.pyramid_recons(pyr_B_trans)
        loss_GAN_BA = criterion_GAN(D_A(fake_A_full), valid)

        loss_GAN = (loss_GAN_AB + loss_GAN_BA) / 2

        # Cycle loss
        pyr_A_recons = pyramid.pyramid_decom(img=fake_B_full,
                                             max_levels=opt.levels)
        recov_A_low = G_BA(pyr_A_recons[-1])
        pyr_A_recons_trans = pyramid_transform(pyr_A_recons, pyr_A_recons[-1],
Exemplo n.º 5
0
        # ------------------
        #  Train Generators
        # ------------------

        G.train()
        Trans.train()
        optimizer_G.zero_grad()
        optimizer_Trans.zero_grad()

        # Identity loss
        loss_identity = criterion_identity(G(real_A), real_A)

        # GAN loss
        fake_B = G(real_A)
        pyr_A_trans = pyramid_transform(pyr_A, real_A, fake_B, Trans, if_conv)
        fake_B_full = pyramid.pyramid_recons(pyr_A_trans)
        # fake_B_full = torch.clamp(fake_B_full, -1, 1)
        loss_GAN = criterion_GAN(D(fake_B_full), valid)

        # recons loss
        loss_recons = criterion_cycle(fake_B_full, real_A_full)

        # Total loss
        loss_G = opt.lambda_adv * loss_GAN + opt.lambda_cyc * loss_recons + opt.lambda_id * loss_identity

        loss_G.backward()
        optimizer_G.step()
        optimizer_Trans.step()

        # -----------------------
        #  Train Discriminator