def inference():

    # Inference Path #
    make_dirs(config.inference_path)

    # Prepare Data Loader #
    test_loader_selfie, test_loader_anime = get_selfie2anime_loader('test', config.batch_size)

    # Prepare Generator #
    G_A2B = Generator(image_size=config.crop_size, num_blocks=config.num_blocks).to(device)

    G_A2B.load_state_dict(torch.load(os.path.join(config.weights_path, 'U-GAT-IT_G_A2B_Epoch_{}.pkl'.format(config.num_epochs))))

    # Inference #
    print("U-GAT-IT | Generating Selfie2Anime images started...")
    with torch.no_grad():
        for i, (selfie, anime) in enumerate(zip(test_loader_selfie, test_loader_anime)):

            # Prepare Data #
            real_A = selfie.to(device)

            # Generate Fake Images #
            fake_B = G_A2B(real_A)[0]

            # Save Images (Selfie -> Anime) #
            result = torch.cat((real_A, fake_B), dim=0)
            save_image(denorm(result.data),
                       os.path.join(config.inference_path, 'U-GAT-IT_Selfie2Anime_Results_%03d.png' % (i + 1))
                       )

    # Make a GIF file #
    make_gifs_test("U-GAT-IT", "Selfie2Anime", config.inference_path)
def inference():

    # Inference Path #
    make_dirs(config.inference_path)

    # Prepare Data Loader #
    val_loader = get_edges2shoes_loader(purpose='val',
                                        batch_size=config.val_batch_size)

    # Prepare Generator #
    G_A2B = Generator().to(device)
    G_B2A = Generator().to(device)

    G_A2B.load_state_dict(
        torch.load(
            os.path.join(
                config.weights_path,
                'DiscoGAN_Generator_A2B_Epoch_{}.pkl'.format(
                    config.num_epochs))))
    G_B2A.load_state_dict(
        torch.load(
            os.path.join(
                config.weights_path,
                'DiscoGAN_Generator_B2A_Epoch_{}.pkl'.format(
                    config.num_epochs))))

    # Test #
    print("DiscoGAN | Generating Edges2Shoes images started...")
    for i, (real_A, real_B) in enumerate(val_loader):

        # Prepare Data #
        real_A = real_A.to(device)
        real_B = real_B.to(device)

        # Generate Fake Images #
        fake_B = G_A2B(real_A)
        fake_A = G_B2A(real_B)

        # Generated Reconstructed Images #
        fake_ABA = G_B2A(fake_B)
        fake_BAB = G_A2B(fake_A)

        # Save Images #
        result = torch.cat(
            (real_A, fake_A, fake_BAB, real_B, fake_B, fake_ABA), dim=0)
        save_image(denorm(result.data),
                   os.path.join(
                       config.inference_path,
                       'DiscoGAN_Edges2Shoes_Results_%03d.png' % (i + 1)),
                   nrow=8,
                   normalize=True)

    # Make a GIF file #
    make_gifs_test("DiscoGAN", config.inference_path)
예제 #3
0
def inference():

    # Inference Path #
    make_dirs(config.inference_path)

    # Prepare Data Loader #
    val_loader = get_edges2handbags_loader('val', config.val_batch_size)

    # Prepare Generator #
    G = Generator(z_dim=config.z_dim).to(device)
    G.load_state_dict(
        torch.load(
            os.path.join(
                config.weights_path,
                'BicycleGAN_Generator_Epoch_{}.pkl'.format(
                    config.num_epochs))))
    G.eval()

    # Fixed Noise #
    fixed_noise = torch.randn(config.test_size, config.num_images,
                              config.z_dim).to(device)

    # Test #
    print("BiCycleGAN | Generating Edges2Handbags Images started...")
    for iters, (sketch, ground_truth) in enumerate(val_loader):

        # Prepare Data #
        N = sketch.size(0)
        sketch = sketch.to(device)
        results = torch.FloatTensor(N * (1 + config.num_images), 3,
                                    config.crop_size, config.crop_size)

        # Generate Fake Images #
        for i in range(N):
            results[i * (1 + config.num_images)] = sketch[i].data

            for j in range(config.num_images):
                image = sketch[i].unsqueeze(dim=0)
                noise_to_generator = fixed_noise[i, j, :].unsqueeze(dim=0)

                out = G(image, noise_to_generator)
                results[i * (1 + config.num_images) + j + 1] = out

            # Save Images #
            save_image(
                denorm(results.data),
                os.path.join(
                    config.inference_path,
                    'BicycleGAN_Edges2Handbags_Results_%03d.png' %
                    (iters + 1)),
                nrow=(1 + config.num_images),
            )

    make_gifs_test("BicycleGAN", config.inference_path)
예제 #4
0
def inference():

    # Inference Path #
    make_dirs(config.inference_path)

    # Prepare Data Loader #
    test_loader = get_celeba_loader('test', config.batch_size,
                                    config.selected_attrs)

    # Prepare Generator #
    G = Generator(num_classes=len(config.selected_attrs)).to(device)
    G.load_state_dict(
        torch.load(
            os.path.join(
                config.weights_path,
                'StarGAN_Generator_Epoch_{}.pkl'.format(config.num_epochs))))

    # Test #
    print("StarGAN | Generating Aligned CelebA Images started...")
    for i, (image, label) in enumerate(test_loader):

        # Prepare Data #
        image = image.to(device)
        fixed_labels = create_labels(label,
                                     selected_attrs=config.selected_attrs)

        # Generate Fake Images #
        x_fake_list = [image]

        for c_fixed in fixed_labels:
            x_fake_list.append(G(image, c_fixed))
        x_concat = torch.cat(x_fake_list, dim=3)

        # Save Images #
        save_image(denorm(x_concat.data.cpu()),
                   os.path.join(
                       config.inference_path,
                       'StarGAN_Aligned_CelebA_Results_%04d.png' % (i + 1)),
                   nrow=1,
                   padding=0)

    make_gifs_test("StarGAN", config.inference_path)
예제 #5
0
def inference():

    # Inference Path #
    make_dirs(config.inference_path)

    # Prepare Data Loader #
    test_loader = get_facades_loader('test', config.test_batch_size)

    # Prepare Generator #
    G = Generator().to(device)
    G.load_state_dict(
        torch.load(
            os.path.join(
                config.weights_path,
                'Pix2Pix_Generator_Epochs_{}.pkl'.format(config.num_epochs))))
    G.eval()

    # Test #
    print("Pix2Pix | Generating facades images started...")
    for i, (input, target) in enumerate(test_loader):

        # Prepare Data #
        input = input.to(device)
        target = target.to(device)

        # Generate Fake Image #
        generated = G(input)

        # Save Images #
        result = torch.cat((target, input, generated), dim=0)
        save_image(result,
                   os.path.join(config.inference_path,
                                'Pix2Pix_Results_%03d.png' % (i + 1)),
                   nrow=8,
                   normalize=True)

    make_gifs_test("Pix2Pix", config.inference_path)
def inference():

    # Inference Path #
    paths = [config.inference_path_H2Z, config.inference_path_Z2H]
    paths = [make_dirs(path) for path in paths]

    # Prepare Data Loader #
    test_horse_loader, test_zebra_loader = get_horse2zebra_loader(
        'test', config.val_batch_size)

    # Prepare Attention and Generator #
    Attn_A = Attention().to(device)
    Attn_B = Attention().to(device)

    G_A2B = Generator().to(device)
    G_B2A = Generator().to(device)

    Attn_A.load_state_dict(
        torch.load(
            os.path.join(
                config.weights_path,
                'UAG-GAN_Attention_A_Epoch_{}.pkl'.format(config.num_epochs))))
    Attn_B.load_state_dict(
        torch.load(
            os.path.join(
                config.weights_path,
                'UAG-GAN_Attention_B_Epoch_{}.pkl'.format(config.num_epochs))))

    G_A2B.load_state_dict(
        torch.load(
            os.path.join(
                config.weights_path,
                'UAG-GAN_Generator_A2B_Epoch_{}.pkl'.format(
                    config.num_epochs))))
    G_B2A.load_state_dict(
        torch.load(
            os.path.join(
                config.weights_path,
                'UAG-GAN_Generator_B2A_Epoch_{}.pkl'.format(
                    config.num_epochs))))

    # Test #
    print("UAG-GAN | Generating Horse2Zebra images started...")
    for i, (horse,
            zebra) in enumerate(zip(test_horse_loader, test_zebra_loader)):

        # Prepare Data #
        real_A = horse.to(device)
        real_B = zebra.to(device)

        # Generate Attention Images #
        attn_A = Attn_A(real_A.detach())
        attn_A = attn_A.repeat(1, 3, 1, 1)
        attn_A = 2 * attn_A - 1

        attn_B = Attn_B(real_B.detach())
        attn_B = attn_B.repeat(1, 3, 1, 1)
        attn_B = 2 * attn_B - 1

        # Generated Fake Images #
        fake_B = G_A2B(real_A.detach())
        fake_A = G_B2A(real_B.detach())

        # Save Images (Horse -> Zebra) #
        result = torch.cat((real_A, attn_A, fake_B), dim=0)
        save_image(
            denorm(result.data),
            os.path.join(config.inference_path_H2Z,
                         'UAG-GAN_Horse2Zebra_Results_%03d.png' % (i + 1)))

        # Save Images (Zebra -> Horse) #
        result = torch.cat((real_B, attn_B, fake_A), dim=0)
        save_image(
            denorm(result.data),
            os.path.join(config.inference_path_Z2H,
                         'UAG-GAN_Zebra2Horse_Results_%03d.png' % (i + 1)))

    # Make a GIF file #
    make_gifs_test("UAG-GAN", "Horse2Zebra", config.inference_path_H2Z)
    make_gifs_test("UAG-GAN", "Zebra2Horse", config.inference_path_Z2H)
def inference():

    # Inference Path #
    paths = [config.inference_path_H2Z, config.inference_path_Z2H]
    paths = [make_dirs(path) for path in paths]

    # Prepare Data Loader #
    test_horse_loader, test_zebra_loader = get_horse2zebra_loader(
        'test', config.val_batch_size)

    # Prepare Generator #
    G_A2B = Generator().to(device)
    G_B2A = Generator().to(device)

    G_A2B.load_state_dict(
        torch.load(
            os.path.join(
                config.weights_path,
                'CycleGAN_Generator_A2B_Epoch_{}.pkl'.format(
                    config.num_epochs))))
    G_B2A.load_state_dict(
        torch.load(
            os.path.join(
                config.weights_path,
                'CycleGAN_Generator_B2A_Epoch_{}.pkl'.format(
                    config.num_epochs))))

    # Test #
    print("CycleGAN | Generating Horse2Zebra images started...")
    for i, (horse,
            zebra) in enumerate(zip(test_horse_loader, test_zebra_loader)):

        # Prepare Data #
        real_A = horse.to(device)
        real_B = zebra.to(device)

        # Generate Fake Images #
        fake_B = G_A2B(real_A)
        fake_A = G_B2A(real_B)

        # Generated Reconstructed Images #
        fake_ABA = G_B2A(fake_B)
        fake_BAB = G_A2B(fake_A)

        # Save Images (Horse -> Zebra) #
        result = torch.cat((real_A, fake_B, fake_ABA), dim=0)
        save_image(denorm(result.data),
                   os.path.join(
                       config.inference_path_H2Z,
                       'CycleGAN_Horse2Zebra_Results_%03d.png' % (i + 1)),
                   nrow=3,
                   normalize=True)

        # Save Images (Zebra -> Horse) #
        result = torch.cat((real_B, fake_A, fake_BAB), dim=0)
        save_image(denorm(result.data),
                   os.path.join(
                       config.inference_path_Z2H,
                       'CycleGAN_Zebra2Horse_Results_%03d.png' % (i + 1)),
                   nrow=3,
                   normalize=True)

    # Make a GIF file #
    make_gifs_test("CycleGAN", "Horse2Zebra", config.inference_path_H2Z)
    make_gifs_test("CycleGAN", "Zebra2Horse", config.inference_path_Z2H)
예제 #8
0
def inference():

    # Inference Path #
    make_dirs(config.inference_random_path)
    make_dirs(config.inference_ex_guided_path)

    # Prepare Data Loader #
    test_loader_A, test_loader_B = get_edges2shoes_loader(
        'test', config.val_batch_size)

    # Prepare Generator #
    G_A2B = AdaIN_Generator().to(device)
    G_B2A = AdaIN_Generator().to(device)

    G_A2B.load_state_dict(
        torch.load(
            os.path.join(
                config.weights_path,
                'MUNIT_Generator_A2B_Epoch_{}.pkl'.format(config.num_epochs))))
    G_B2A.load_state_dict(
        torch.load(
            os.path.join(
                config.weights_path,
                'MUNIT_Generator_B2A_Epoch_{}.pkl'.format(config.num_epochs))))

    G_A2B.eval()
    G_B2A.eval()

    # Test #
    print("MUNIT | Generating Edges2Shoes images started...")

    for i, (real_A, real_B) in enumerate(zip(test_loader_A, test_loader_B)):

        # Prepare Data #
        real_A = real_A.to(device)
        real_B = real_B.to(device)

        if config.style == "Random":
            random_style = torch.randn(real_A.size(0), config.style_dim, 1,
                                       1).to(device)
            style = random_style
            results = [real_A]

        elif config.style == "Ex_Guided":
            _, style = G_A2B.encode(real_B)
            results = [real_A, real_B]

        else:
            raise NotImplementedError

        for j in range(config.num_inference):

            content, _ = G_B2A.encode(real_A[j].unsqueeze(dim=0))
            results.append(G_A2B.decode(content, style[j].unsqueeze(0)))

            # Save Images #
            result = torch.cat(results, dim=0)

            if config.style == "Random":
                title = 'MUNIT_Edges2Shoes_%s_Results_%03d.png' % (
                    config.style, i + 1)
                path = os.path.join(config.inference_random_path, title)

            elif config.style == "Ex_Guided":
                title = 'MUNIT_Edges2Shoes_%s_Results_%03d.png' % (
                    config.style, i + 1)
                path = os.path.join(config.inference_ex_guided_path, title)

            else:
                raise NotImplementedError

            save_image(result.data,
                       path,
                       nrow=config.num_inference,
                       normalize=True)

    # Make a GIF file #
    if config.style == "Random":
        make_gifs_test("MUNIT", config.style, config.inference_random_path)

    elif config.style == "Ex_Guided":
        make_gifs_test("MUNIT", config.style, config.inference_ex_guided_path)

    else:
        raise NotImplementedError