Ejemplo n.º 1
0
def check_cycle_generator():
    """Checks the output and number of parameters of the CycleGenerator class.
    """
    state = torch.load('checker_files/cycle_generator.pt')

    G_XtoY = CycleGenerator(conv_dim=32, init_zero_weights=False)
    G_XtoY.load_state_dict(state['state_dict'])
    images = state['input']
    cycle_generator_expected = state['output']

    output = G_XtoY(images)
    output_np = output.data.cpu().numpy()

    if np.allclose(output_np, cycle_generator_expected):
        print('CycleGenerator output: EQUAL')
    else:
        print('CycleGenerator output: NOT EQUAL')

    num_params = count_parameters(G_XtoY)
    expected_params = 105856

    print('CycleGenerator #params = {}, expected #params = {}, {}'.format(
        num_params, expected_params,
        'EQUAL' if num_params == expected_params else 'NOT EQUAL'))

    print('-' * 80)
def create_model(opts):
    """Builds the generators and discriminators.
    """
    G_XtoY = CycleGenerator(conv_dim=opts.g_conv_dim, init_zero_weights=opts.init_zero_weights)
    G_YtoX = CycleGenerator(conv_dim=opts.g_conv_dim, init_zero_weights=opts.init_zero_weights)
    D_X = DCDiscriminator(conv_dim=opts.d_conv_dim)
    D_Y = DCDiscriminator(conv_dim=opts.d_conv_dim)

    return G_XtoY, G_YtoX, D_X, D_Y
Ejemplo n.º 3
0
def load_checkpoint(opts):
    """Loads the generator and discriminator models from checkpoints.
    """
    G_XtoY_path = os.path.join(opts.load, 'G_XtoY.pkl')
    G_YtoX_path = os.path.join(opts.load, 'G_YtoX.pkl')
    D_X_path = os.path.join(opts.load, 'D_X.pkl')
    D_Y_path = os.path.join(opts.load, 'D_Y.pkl')

    G_XtoY = CycleGenerator(conv_dim=opts.g_conv_dim,
                            init_zero_weights=opts.init_zero_weights)
    G_YtoX = CycleGenerator(conv_dim=opts.g_conv_dim,
                            init_zero_weights=opts.init_zero_weights)
    D_X = DCDiscriminator(conv_dim=opts.d_conv_dim)
    D_Y = DCDiscriminator(conv_dim=opts.d_conv_dim)

    G_XtoY.load_state_dict(
        torch.load(G_XtoY_path, map_location=lambda storage, loc: storage))
    G_YtoX.load_state_dict(
        torch.load(G_YtoX_path, map_location=lambda storage, loc: storage))
    D_X.load_state_dict(
        torch.load(D_X_path, map_location=lambda storage, loc: storage))
    D_Y.load_state_dict(
        torch.load(D_Y_path, map_location=lambda storage, loc: storage))

    # if torch.cuda.is_available():
    #     G_XtoY.cuda()
    #     G_YtoX.cuda()
    #     D_X.cuda()
    #     D_Y.cuda()
    #     print('Models moved to GPU.')

    return G_XtoY, G_YtoX, D_X, D_Y
Ejemplo n.º 4
0
def check_cycle_generator():
    """Checks the output and number of parameters of the CycleGenerator class.
    """
    set_random_seeds(RANDOM_SEED)
    G_XtoY = CycleGenerator(conv_dim=32, init_zero_weights=False)

    dataloader_X, test_dataloader_X = get_emoji_loader(emoji_type='Apple')
    images, labels = iter(dataloader_X).next()
    images = Variable(images)

    output = G_XtoY(images)
    output_np = output.data.cpu().numpy()

    # np.save('checker_files/cycle_generator.npy', output_np)
    cycle_generator_expected = np.load('checker_files/cycle_generator.npy')

    if np.allclose(output_np, cycle_generator_expected):
        print('CycleGenerator output: EQUAL')
    else:
        print('CycleGenerator output: NOT EQUAL')

    num_params = count_parameters(G_XtoY)
    expected_params = 105856

    print('CycleGenerator #params = {}, expected #params = {}, {}'.format(
          num_params, expected_params, 'EQUAL' if num_params == expected_params else 'NOT EQUAL'))

    print('-' * 80)
Ejemplo n.º 5
0
def create_model(opts):
    """Builds the generators and discriminators.
    """
    G_XtoY = CycleGenerator(conv_dim=opts.g_conv_dim,
                            init_zero_weights=opts.init_zero_weights)
    G_YtoX = CycleGenerator(conv_dim=opts.g_conv_dim,
                            init_zero_weights=opts.init_zero_weights)
    D_X = DCDiscriminator(conv_dim=opts.d_conv_dim)
    D_Y = DCDiscriminator(conv_dim=opts.d_conv_dim)

    print_models(G_XtoY, G_YtoX, D_X, D_Y)

    # if torch.cuda.is_available():
    #     G_XtoY.cuda()
    #     G_YtoX.cuda()
    #     D_X.cuda()
    #     D_Y.cuda()
    #     print('Models moved to GPU.')

    return G_XtoY, G_YtoX, D_X, D_Y
Ejemplo n.º 6
0
def create_model(opts):
    G_XtoY = CycleGenerator(init_zero_weights=opts.init_zero_weights)
    G_YtoX = CycleGenerator(init_zero_weights=opts.init_zero_weights)
    D_X = PatchGANDiscriminator()
    D_Y = PatchGANDiscriminator()

    if torch.cuda.is_available():
        G_XtoY.cuda()
        G_YtoX.cuda()
        D_X.cuda()
        D_Y.cuda()
        print('Models moved to GPU.')

    return G_XtoY, G_YtoX, D_X, D_Y
Ejemplo n.º 7
0
def create_model(opts):
    """Builds the generators and discriminators.
    """
    G_XtoY = CycleGenerator(conv_dim=opts.g_conv_dim,
                            init_zero_weights=opts.init_zero_weights)
    G_YtoX = CycleGenerator(conv_dim=opts.g_conv_dim,
                            init_zero_weights=opts.init_zero_weights)
    D_X = DCDiscriminator(conv_dim=opts.d_conv_dim)
    D_Y = DCDiscriminator(conv_dim=opts.d_conv_dim)

    if torch.cuda.is_available():
        G_XtoY.cuda()
        G_YtoX.cuda()
        D_X.cuda()
        D_Y.cuda()

    return G_XtoY, G_YtoX, D_X, D_Y
Ejemplo n.º 8
0
def train_loop(opts):

    if opts.image_height == 128:
        res_blocks = 6
    elif opts.image_height >= 256:
        res_blocks = 9

    # Create networks
    G_AB = CycleGenerator(opts.a_channels, opts.b_channels,
                          res_blocks).to(device)
    G_BA = CycleGenerator(opts.b_channels, opts.a_channels,
                          res_blocks).to(device)
    D_A = Discriminator(opts.a_channels, opts.d_conv_dim).to(device)
    D_B = Discriminator(opts.b_channels, opts.d_conv_dim).to(device)

    # Print network architecture
    print("                 G_AtoB                ")
    print("---------------------------------------")
    print(G_AB)
    print("---------------------------------------")

    print("                 G_BtoA                ")
    print("---------------------------------------")
    print(G_BA)
    print("---------------------------------------")

    print("                  D_A                  ")
    print("---------------------------------------")
    print(D_A)
    print("---------------------------------------")

    print("                  D_B                  ")
    print("---------------------------------------")
    print(D_B)
    print("---------------------------------------")

    # Create losses
    criterion_gan = torch.nn.MSELoss()
    criterion_cycle = torch.nn.L1Loss()
    criterion_identity = torch.nn.L1Loss()

    if opts.load:
        #TODO
        pass

    # Weights cycle loss and identity loss
    lambda_cycle = 10
    lambda_id = 0.5 * lambda_cycle

    # Create optimizers
    g_optimizer = torch.optim.Adam(itertools.chain(G_AB.parameters(),
                                                   G_BA.parameters()),
                                   lr=opts.lr,
                                   betas=(opts.beta1, opts.beta2))
    d_a_optimizer = torch.optim.Adam(D_A.parameters(),
                                     lr=opts.lr,
                                     betas=(opts.beta1, opts.beta2))
    d_b_optimizer = torch.optim.Adam(D_B.parameters(),
                                     lr=opts.lr,
                                     betas=(opts.beta1, opts.beta2))

    # Create learning rate update schedulers
    LambdaLR = get_lambda_rule(opts)
    g_lr_scheduler = torch.optim.lr_scheduler.LambdaLR(g_optimizer,
                                                       lr_lambda=LambdaLR)
    d_a_lr_scheduler = torch.optim.lr_scheduler.LambdaLR(d_a_optimizer,
                                                         lr_lambda=LambdaLR)
    d_b_lr_scheduler = torch.optim.lr_scheduler.LambdaLR(d_b_optimizer,
                                                         lr_lambda=LambdaLR)

    # Image transformations
    transform = transforms.Compose([
        transforms.Resize(int(opts.image_height * 1.12), Image.BICUBIC),
        transforms.RandomCrop((opts.image_height, opts.image_width)),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])

    train_dataloader = DataLoader(CycleGANDataset(opts.dataroot_dir,
                                                  opts.dataset_name,
                                                  transform),
                                  batch_size=opts.batch_size,
                                  shuffle=True,
                                  num_workers=opts.n_cpu)
    test_dataloader = DataLoader(CycleGANDataset(opts.dataroot_dir,
                                                 opts.dataset_name,
                                                 transform,
                                                 mode='test'),
                                 batch_size=5,
                                 shuffle=False,
                                 num_workers=1)

    end_epoch = opts.epochs + opts.start_epoch
    total_batch = len(train_dataloader) * opts.epochs
    for epoch in range(opts.start_epoch, end_epoch):
        for index, batch in enumerate(train_dataloader):
            # Create adversarial target
            real_A = Variable(batch['A'].to(device))
            real_B = Variable(batch['B'].to(device))
            fake_A, fake_B = G_BA(real_B), G_AB(real_A)

            # Train discriminator A
            d_a_optimizer.zero_grad()

            patch_real = D_A(real_A)
            loss_a_real = criterion_gan(
                patch_real,
                torch.tensor(1.0).expand_as(patch_real).to(device))
            patch_fake = D_A(fake_A)
            loss_a_fake = criterion_gan(
                patch_fake,
                torch.tensor(0.0).expand_as(patch_fake).to(device))
            loss_d_a = (loss_a_real + loss_a_fake) / 2
            loss_d_a.backward()
            d_a_optimizer.step()

            # Train discriminator B
            d_b_optimizer.zero_grad()

            patch_real = D_B(real_B)
            loss_b_real = criterion_gan(
                patch_real,
                torch.tensor(1.0).expand_as(patch_real).to(device))
            patch_fake = D_B(fake_B)
            loss_b_fake = criterion_gan(
                patch_fake,
                torch.tensor(0.0).expand_as(patch_fake).to(device))
            loss_d_b = (loss_b_real + loss_b_fake) / 2
            loss_d_b.backward()
            d_b_optimizer.step()

            # Train generator

            g_optimizer.zero_grad()
            fake_A, fake_B = G_BA(real_B), G_AB(real_A)
            reconstructed_A, reconstructed_B = G_BA(fake_B), G_AB(fake_A)
            # GAN loss
            patch_a = D_A(fake_A)
            loss_gan_ba = criterion_gan(
                patch_a,
                torch.tensor(1.0).expand_as(patch_a).to(device))
            patch_b = D_B(fake_B)
            loss_gan_ab = criterion_gan(
                patch_b,
                torch.tensor(1.0).expand_as(patch_b).to(device))
            loss_gan = (loss_gan_ab + loss_gan_ba) / 2

            # Cycle loss
            loss_cycle_a = criterion_cycle(reconstructed_A, real_A)
            loss_cycle_b = criterion_cycle(reconstructed_B, real_B)
            loss_cycle = (loss_cycle_a + loss_cycle_b) / 2

            # Identity loss
            loss_id_a = criterion_identity(G_BA(real_A), real_A)
            loss_id_b = criterion_identity(G_AB(real_B), real_B)
            loss_identity = (loss_id_a + loss_id_b) / 2

            # Total loss
            loss_g = loss_gan + lambda_cycle * loss_cycle + lambda_id * loss_identity
            loss_g.backward()
            g_optimizer.step()

            current_batch = epoch * len(train_dataloader) + index
            sys.stdout.write(
                f"\r[Epoch {epoch+1}/{opts.epochs-opts.start_epoch}] [Index {index}/{len(train_dataloader)}] [D_A loss: {loss_d_a.item():.4f}] [D_B loss: {loss_d_b.item():.4f}] [G loss: adv: {loss_gan.item():.4f}, cycle: {loss_cycle.item():.4f}, identity: {loss_identity.item():.4f}]"
            )

            if current_batch % opts.sample_every == 0:
                save_sample(G_AB, G_BA, current_batch, opts, test_dataloader)

        # Update learning reate
        g_lr_scheduler.step()
        d_a_lr_scheduler.step()
        d_b_lr_scheduler.step()
        if epoch % opts.checkpoint_every == 0:
            torch.save(
                G_AB.state_dict(),
                f'{opts.checkpoint_dir}/{opts.dataset_name}/G_AB_{epoch}.pth')
            torch.save(
                G_BA.state_dict(),
                f'{opts.checkpoint_dir}/{opts.dataset_name}/G_BA_{epoch}.pth')
            torch.save(
                D_A.state_dict(),
                f'{opts.checkpoint_dir}/{opts.dataset_name}/D_A_{epoch}.pth')
            torch.save(
                D_B.state_dict(),
                f'{opts.checkpoint_dir}/{opts.dataset_name}/D_B_{epoch}.pth')
Ejemplo n.º 9
0
opt.B_nc = 3
opt.cuda = 1
opt.lr = 0.0002
opt.epoch_count = 0
opt.n_epochs = 100 
opt.n_epochs_decay = 100 
opt.port = 35850
opt.test = 1
print(opt)

if torch.cuda.is_available() and not opt.cuda:
    print("WARNING: You have a CUDA device, so you should probably run with --cuda")

###### Definition of variables ######
# Networks
netG_A2B = CycleGenerator(opt.A_nc, opt.B_nc)
netG_B2A = CycleGenerator(opt.B_nc, opt.A_nc)

if opt.cuda:
    netG_A2B.cuda()
    netG_B2A.cuda()

# Load state dicts
netG_A2B.load_state_dict(torch.load(opt.generator_A2B))
netG_B2A.load_state_dict(torch.load(opt.generator_B2A))

# Set model's test mode
netG_A2B.eval()
netG_B2A.eval()

# Inputs & targets memory allocation