D = []
for i in range(len(targets_dir)):
    tmpD = networks.discriminator(args.in_ndc, args.out_ndc, args.ndf)
    if args.latest_discriminator_model != '':
        if torch.cuda.is_available():
            tmpD.load_state_dict(
                torch.load(targets_dir[i] + args.latest_discriminator_model))
        else:
            tmpD.load_state_dict(
                torch.load(targets_dir[i] + args.latest_discriminator_model,
                           map_location=lambda storage, loc: storage))
    tmpD.to(device)
    tmpD.train()
    D.append(tmpD)

VGG = networks.VGG19(init_weights=args.vgg_model, feature_mode=True)
VGG.to(device)
VGG.eval()
print('---------- Networks initialized -------------')
utils.print_network(G)
utils.print_network(D[0])
utils.print_network(VGG)
print('-----------------------------------------------')

# loss
BCE_loss = nn.BCELoss().to(device)
L1_loss = nn.L1Loss().to(device)

# Adam optimizer
G_optimizer = optim.Adam(G.parameters(),
                         lr=args.lrG,
示例#2
0
def main():
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    if torch.backends.cudnn.enabled:
        torch.backends.cudnn.benchmark = True

    prepare_result()
    make_edge_promoting_img()

    # data_loader
    landscape_dataloader = CreateTrainDataLoader(args, "landscape")
    anime_dataloader = CreateTrainDataLoader(args, "anime")
    landscape_test_dataloader = CreateTestDataLoader(args, "landscape")
    anime_test_dataloader = CreateTestDataLoader(args, "anime")

    generator = networks.Generator(args.ngf)
    if args.latest_generator_model != '':
        if torch.cuda.is_available():
            generator.load_state_dict(torch.load(args.latest_generator_model))
        else:
            # cpu mode
            generator.load_state_dict(
                torch.load(args.latest_generator_model,
                           map_location=lambda storage, loc: storage))

    discriminator = networks.Discriminator(args.in_ndc, args.out_ndc, args.ndf)
    if args.latest_discriminator_model != '':
        if torch.cuda.is_available():
            discriminator.load_state_dict(
                torch.load(args.latest_discriminator_model))
        else:
            discriminator.load_state_dict(
                torch.load(args.latest_discriminator_model,
                           map_location=lambda storage, loc: storage))

    VGG = networks.VGG19(init_weights=args.vgg_model, feature_mode=True)

    generator.to(device)
    discriminator.to(device)
    VGG.to(device)

    generator.train()
    discriminator.train()

    VGG.eval()

    G_optimizer = optim.Adam(generator.parameters(),
                             lr=args.lrG,
                             betas=(args.beta1, args.beta2))
    D_optimizer = optim.Adam(discriminator.parameters(),
                             lr=args.lrD,
                             betas=(args.beta1, args.beta2))
    # G_scheduler = optim.lr_scheduler.MultiStepLR(optimizer=G_optimizer, milestones=[args.train_epoch // 2, args.train_epoch // 4 * 3], gamma=0.1)
    # D_scheduler = optim.lr_scheduler.MultiStepLR(optimizer=D_optimizer, milestones=[args.train_epoch // 2, args.train_epoch // 4 * 3], gamma=0.1)

    print('---------- Networks initialized -------------')
    utils.print_network(generator)
    utils.print_network(discriminator)
    utils.print_network(VGG)
    print('-----------------------------------------------')

    BCE_loss = nn.BCELoss().to(device)
    Hinge_loss = nn.HingeEmbeddingLoss().to(device)
    L1_loss = nn.L1Loss().to(device)
    MSELoss = nn.MSELoss().to(device)

    Adv_loss = BCE_loss

    pre_train_hist = {}
    pre_train_hist['Recon_loss'] = []
    pre_train_hist['per_epoch_time'] = []
    pre_train_hist['total_time'] = []
    """ Pre-train reconstruction """
    if args.latest_generator_model == '':
        print('Pre-training start!')
        start_time = time.time()
        for epoch in range(args.pre_train_epoch):
            epoch_start_time = time.time()
            Recon_losses = []
            for lcimg, lhimg, lsimg in landscape_dataloader:
                lcimg, lhimg, lsimg = lcimg.to(device), lhimg.to(
                    device), lsimg.to(device)

                # train generator G
                G_optimizer.zero_grad()

                x_feature = VGG((lcimg + 1) / 2)

                mask = mask_gen()
                hint = torch.cat((lhimg * mask, mask), 1)
                gen_img = generator(lsimg, hint)
                G_feature = VGG((gen_img + 1) / 2)

                Recon_loss = 10 * L1_loss(G_feature, x_feature.detach())
                Recon_losses.append(Recon_loss.item())
                pre_train_hist['Recon_loss'].append(Recon_loss.item())

                Recon_loss.backward()
                G_optimizer.step()

            per_epoch_time = time.time() - epoch_start_time
            pre_train_hist['per_epoch_time'].append(per_epoch_time)
            print('[%d/%d] - time: %.2f, Recon loss: %.3f' %
                  ((epoch + 1), args.pre_train_epoch, per_epoch_time,
                   torch.mean(torch.FloatTensor(Recon_losses))))

            # Save
            if (epoch + 1) % 5 == 0:
                with torch.no_grad():
                    generator.eval()
                    for n, (lcimg, lhimg,
                            lsimg) in enumerate(landscape_dataloader):
                        lcimg, lhimg, lsimg = lcimg.to(device), lhimg.to(
                            device), lsimg.to(device)
                        mask = mask_gen()
                        hint = torch.cat((lhimg * mask, mask), 1)
                        g_recon = generator(lsimg, hint)
                        result = torch.cat((lcimg[0], g_recon[0]), 2)
                        path = os.path.join(
                            args.name + '_results', 'Reconstruction',
                            args.name + '_train_recon_' + f'epoch_{epoch}_' +
                            str(n + 1) + '.png')
                        plt.imsave(
                            path,
                            (result.cpu().numpy().transpose(1, 2, 0) + 1) / 2)
                        if n == 4:
                            break

                    for n, (lcimg, lhimg,
                            lsimg) in enumerate(landscape_test_dataloader):
                        lcimg, lhimg, lsimg = lcimg.to(device), lhimg.to(
                            device), lsimg.to(device)
                        mask = mask_gen()
                        hint = torch.cat((lhimg * mask, mask), 1)
                        g_recon = generator(lsimg, hint)
                        result = torch.cat((lcimg[0], g_recon[0]), 2)
                        path = os.path.join(
                            args.name + '_results', 'Reconstruction',
                            args.name + '_test_recon_' + f'epoch_{epoch}_' +
                            str(n + 1) + '.png')
                        plt.imsave(
                            path,
                            (result.cpu().numpy().transpose(1, 2, 0) + 1) / 2)
                        if n == 4:
                            break

        total_time = time.time() - start_time
        pre_train_hist['total_time'].append(total_time)
        with open(os.path.join(args.name + '_results', 'pre_train_hist.pkl'),
                  'wb') as f:
            pickle.dump(pre_train_hist, f)
        torch.save(
            generator.state_dict(),
            os.path.join(args.name + '_results', 'generator_pretrain.pkl'))

    else:
        print('Load the latest generator model, no need to pre-train')

    train_hist = {}
    train_hist['Disc_loss'] = []
    train_hist['Gen_loss'] = []
    train_hist['Con_loss'] = []
    train_hist['per_epoch_time'] = []
    train_hist['total_time'] = []
    print('training start!')
    start_time = time.time()

    real = torch.ones(args.batch_size, 1, args.input_size // 4,
                      args.input_size // 4).to(device)
    fake = torch.zeros(args.batch_size, 1, args.input_size // 4,
                       args.input_size // 4).to(device)
    for epoch in range(args.train_epoch):
        epoch_start_time = time.time()
        generator.train()
        Disc_losses = []
        Gen_losses = []
        Con_losses = []
        for i, ((acimg, ac_smooth_img, _), (lcimg, lhimg, lsimg)) in enumerate(
                zip(anime_dataloader, landscape_dataloader)):
            acimg, ac_smooth_img, lcimg, lhimg, lsimg = acimg.to(
                device), ac_smooth_img.to(device), lcimg.to(device), lhimg.to(
                    device), lsimg.to(device)

            if i % args.n_dis == 0:
                # train G
                G_optimizer.zero_grad()

                mask = mask_gen()
                hint = torch.cat((lhimg * mask, mask), 1)
                gen_img = generator(lsimg, hint)
                D_fake = discriminator(gen_img)
                D_fake_loss = Adv_loss(D_fake, real)

                x_feature = VGG((lcimg + 1) / 2)
                G_feature = VGG((gen_img + 1) / 2)
                Con_loss = args.con_lambda * L1_loss(G_feature,
                                                     x_feature.detach())

                Gen_loss = D_fake_loss + Con_loss
                Gen_losses.append(D_fake_loss.item())
                train_hist['Gen_loss'].append(D_fake_loss.item())
                Con_losses.append(Con_loss.item())
                train_hist['Con_loss'].append(Con_loss.item())

                Gen_loss.backward()
                G_optimizer.step()
                # G_scheduler.step()

            # train D
            D_optimizer.zero_grad()

            D_real = discriminator(acimg)
            D_real_loss = Adv_loss(D_real, real)  # Hinge Loss (?)

            mask = mask_gen()
            hint = torch.cat((lhimg * mask, mask), 1)

            gen_img = generator(lsimg, hint)
            D_fake = discriminator(gen_img)
            D_fake_loss = Adv_loss(D_fake, fake)

            D_edge = discriminator(ac_smooth_img)
            D_edge_loss = Adv_loss(D_edge, fake)

            Disc_loss = D_real_loss + D_fake_loss + D_edge_loss
            # Disc_loss = D_real_loss + D_fake_loss
            Disc_losses.append(Disc_loss.item())
            train_hist['Disc_loss'].append(Disc_loss.item())

            Disc_loss.backward()
            D_optimizer.step()

    #     G_scheduler.step()
    #     D_scheduler.step()

        per_epoch_time = time.time() - epoch_start_time
        train_hist['per_epoch_time'].append(per_epoch_time)
        print(
            '[%d/%d] - time: %.2f, Disc loss: %.3f, Gen loss: %.3f, Con loss: %.3f'
            % ((epoch + 1), args.train_epoch, per_epoch_time,
               torch.mean(torch.FloatTensor(Disc_losses)),
               torch.mean(torch.FloatTensor(Gen_losses)),
               torch.mean(torch.FloatTensor(Con_losses))))

        if epoch % 2 == 1 or epoch == args.train_epoch - 1:
            with torch.no_grad():
                generator.eval()
                for n, (lcimg, lhimg,
                        lsimg) in enumerate(landscape_dataloader):
                    lcimg, lhimg, lsimg = lcimg.to(device), lhimg.to(
                        device), lsimg.to(device)
                    mask = mask_gen()
                    hint = torch.cat((lhimg * mask, mask), 1)
                    g_recon = generator(lsimg, hint)
                    result = torch.cat((lcimg[0], g_recon[0]), 2)
                    path = os.path.join(
                        args.name + '_results', 'Transfer',
                        str(epoch + 1) + '_epoch_' + args.name + '_train_' +
                        str(n + 1) + '.png')
                    plt.imsave(path,
                               (result.cpu().numpy().transpose(1, 2, 0) + 1) /
                               2)
                    if n == 4:
                        break

                for n, (lcimg, lhimg,
                        lsimg) in enumerate(landscape_test_dataloader):
                    lcimg, lhimg, lsimg = lcimg.to(device), lhimg.to(
                        device), lsimg.to(device)
                    mask = mask_gen()
                    hint = torch.cat((lhimg * mask, mask), 1)
                    g_recon = generator(lsimg, hint)
                    result = torch.cat((lcimg[0], g_recon[0]), 2)
                    path = os.path.join(
                        args.name + '_results', 'Transfer',
                        str(epoch + 1) + '_epoch_' + args.name + '_test_' +
                        str(n + 1) + '.png')
                    plt.imsave(path,
                               (result.cpu().numpy().transpose(1, 2, 0) + 1) /
                               2)
                    if n == 4:
                        break

                torch.save(
                    generator.state_dict(),
                    os.path.join(args.name + '_results',
                                 'generator_latest.pkl'))
                torch.save(
                    generator.state_dict(),
                    os.path.join(args.name + '_results',
                                 'discriminator_latest.pkl'))

    total_time = time.time() - start_time
    train_hist['total_time'].append(total_time)

    print("Avg one epoch time: %.2f, total %d epochs time: %.2f" %
          (torch.mean(torch.FloatTensor(
              train_hist['per_epoch_time'])), args.train_epoch, total_time))
    print("Training finish!... save training results")

    torch.save(generator.state_dict(),
               os.path.join(args.name + '_results', 'generator_param.pkl'))
    torch.save(discriminator.state_dict(),
               os.path.join(args.name + '_results', 'discriminator_param.pkl'))
    with open(os.path.join(args.name + '_results', 'train_hist.pkl'),
              'wb') as f:
        pickle.dump(train_hist, f)
示例#3
0
def main():
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    if torch.backends.cudnn.enabled:
        torch.backends.cudnn.benchmark = True

    prepare_result()
    make_edge_promoting_img()

    # data_loader
    src_transform = transforms.Compose([
        transforms.Resize((args.input_size, args.input_size)),
        transforms.ToTensor(),
        transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
    ])

    tgt_transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
    ])
    train_loader_src = utils.data_load(os.path.join('data', args.src_data),
                                       'train',
                                       src_transform,
                                       args.batch_size,
                                       shuffle=True,
                                       drop_last=True)
    train_loader_tgt = utils.data_load(os.path.join('data', args.tgt_data),
                                       'pair',
                                       tgt_transform,
                                       args.batch_size,
                                       shuffle=True,
                                       drop_last=True)
    test_loader_src = utils.data_load(os.path.join('data', args.src_data),
                                      'test',
                                      src_transform,
                                      1,
                                      shuffle=True,
                                      drop_last=True)

    # network
    G = networks.generator(args.in_ngc, args.out_ngc, args.ngf, args.nb)
    if args.latest_generator_model != '':
        if torch.cuda.is_available():
            G.load_state_dict(torch.load(args.latest_generator_model))
        else:
            # cpu mode
            G.load_state_dict(
                torch.load(args.latest_generator_model,
                           map_location=lambda storage, loc: storage))

    D = networks.discriminator(args.in_ndc, args.out_ndc, args.ndf)
    if args.latest_discriminator_model != '':
        if torch.cuda.is_available():
            D.load_state_dict(torch.load(args.latest_discriminator_model))
        else:
            D.load_state_dict(
                torch.load(args.latest_discriminator_model,
                           map_location=lambda storage, loc: storage))
    VGG = networks.VGG19(init_weights=args.vgg_model, feature_mode=True)

    G.to(device)
    D.to(device)
    VGG.to(device)

    G.train()
    D.train()

    VGG.eval()

    print('---------- Networks initialized -------------')
    utils.print_network(G)
    utils.print_network(D)
    utils.print_network(VGG)
    print('-----------------------------------------------')

    # loss
    BCE_loss = nn.BCELoss().to(device)
    L1_loss = nn.L1Loss().to(device)

    # Adam optimizer
    G_optimizer = optim.Adam(G.parameters(),
                             lr=args.lrG,
                             betas=(args.beta1, args.beta2))
    D_optimizer = optim.Adam(D.parameters(),
                             lr=args.lrD,
                             betas=(args.beta1, args.beta2))
    G_scheduler = optim.lr_scheduler.MultiStepLR(
        optimizer=G_optimizer,
        milestones=[args.train_epoch // 2, args.train_epoch // 4 * 3],
        gamma=0.1)
    D_scheduler = optim.lr_scheduler.MultiStepLR(
        optimizer=D_optimizer,
        milestones=[args.train_epoch // 2, args.train_epoch // 4 * 3],
        gamma=0.1)

    pre_train_hist = {}
    pre_train_hist['Recon_loss'] = []
    pre_train_hist['per_epoch_time'] = []
    pre_train_hist['total_time'] = []
    """ Pre-train reconstruction """
    if args.latest_generator_model == '':
        print('Pre-training start!')
        start_time = time.time()
        for epoch in range(args.pre_train_epoch):
            epoch_start_time = time.time()
            Recon_losses = []
            for x, _ in train_loader_src:
                x = x.to(device)

                # train generator G
                G_optimizer.zero_grad()

                x_feature = VGG((x + 1) / 2)
                G_ = G(x)
                G_feature = VGG((G_ + 1) / 2)

                Recon_loss = 10 * L1_loss(G_feature, x_feature.detach())
                Recon_losses.append(Recon_loss.item())
                pre_train_hist['Recon_loss'].append(Recon_loss.item())

                Recon_loss.backward()
                G_optimizer.step()

            per_epoch_time = time.time() - epoch_start_time
            pre_train_hist['per_epoch_time'].append(per_epoch_time)
            print('[%d/%d] - time: %.2f, Recon loss: %.3f' %
                  ((epoch + 1), args.pre_train_epoch, per_epoch_time,
                   torch.mean(torch.FloatTensor(Recon_losses))))

        total_time = time.time() - start_time
        pre_train_hist['total_time'].append(total_time)
        with open(os.path.join(args.name + '_results', 'pre_train_hist.pkl'),
                  'wb') as f:
            pickle.dump(pre_train_hist, f)

        with torch.no_grad():
            G.eval()
            for n, (x, _) in enumerate(train_loader_src):
                x = x.to(device)
                G_recon = G(x)
                result = torch.cat((x[0], G_recon[0]), 2)
                path = os.path.join(
                    args.name + '_results', 'Reconstruction',
                    args.name + '_train_recon_' + str(n + 1) + '.png')
                plt.imsave(path,
                           (result.cpu().numpy().transpose(1, 2, 0) + 1) / 2)
                if n == 4:
                    break

            for n, (x, _) in enumerate(test_loader_src):
                x = x.to(device)
                G_recon = G(x)
                result = torch.cat((x[0], G_recon[0]), 2)
                path = os.path.join(
                    args.name + '_results', 'Reconstruction',
                    args.name + '_test_recon_' + str(n + 1) + '.png')
                plt.imsave(path,
                           (result.cpu().numpy().transpose(1, 2, 0) + 1) / 2)
                if n == 4:
                    break
    else:
        print('Load the latest generator model, no need to pre-train')

    train_hist = {}
    train_hist['Disc_loss'] = []
    train_hist['Gen_loss'] = []
    train_hist['Con_loss'] = []
    train_hist['per_epoch_time'] = []
    train_hist['total_time'] = []
    print('training start!')
    start_time = time.time()
    real = torch.ones(args.batch_size, 1, args.input_size // 4,
                      args.input_size // 4).to(device)
    fake = torch.zeros(args.batch_size, 1, args.input_size // 4,
                       args.input_size // 4).to(device)
    for epoch in range(args.train_epoch):
        epoch_start_time = time.time()
        G.train()
        Disc_losses = []
        Gen_losses = []
        Con_losses = []
        for (x, _), (y, _) in zip(train_loader_src, train_loader_tgt):
            e = y[:, :, :, args.input_size:]
            y = y[:, :, :, :args.input_size]
            x, y, e = x.to(device), y.to(device), e.to(device)

            # train D
            D_optimizer.zero_grad()

            D_real = D(y)
            D_real_loss = BCE_loss(D_real, real)

            G_ = G(x)
            D_fake = D(G_)
            D_fake_loss = BCE_loss(D_fake, fake)

            D_edge = D(e)
            D_edge_loss = BCE_loss(D_edge, fake)

            Disc_loss = D_real_loss + D_fake_loss + D_edge_loss
            Disc_losses.append(Disc_loss.item())
            train_hist['Disc_loss'].append(Disc_loss.item())

            Disc_loss.backward()
            D_optimizer.step()

            # train G
            G_optimizer.zero_grad()

            G_ = G(x)
            D_fake = D(G_)
            D_fake_loss = BCE_loss(D_fake, real)

            x_feature = VGG((x + 1) / 2)
            G_feature = VGG((G_ + 1) / 2)
            Con_loss = args.con_lambda * L1_loss(G_feature, x_feature.detach())

            Gen_loss = D_fake_loss + Con_loss
            Gen_losses.append(D_fake_loss.item())
            train_hist['Gen_loss'].append(D_fake_loss.item())
            Con_losses.append(Con_loss.item())
            train_hist['Con_loss'].append(Con_loss.item())

            Gen_loss.backward()
            G_optimizer.step()

        G_scheduler.step()
        D_scheduler.step()

        per_epoch_time = time.time() - epoch_start_time
        train_hist['per_epoch_time'].append(per_epoch_time)
        print(
            '[%d/%d] - time: %.2f, Disc loss: %.3f, Gen loss: %.3f, Con loss: %.3f'
            % ((epoch + 1), args.train_epoch, per_epoch_time,
               torch.mean(torch.FloatTensor(Disc_losses)),
               torch.mean(torch.FloatTensor(Gen_losses)),
               torch.mean(torch.FloatTensor(Con_losses))))

        if epoch % 2 == 1 or epoch == args.train_epoch - 1:
            with torch.no_grad():
                G.eval()
                for n, (x, _) in enumerate(train_loader_src):
                    x = x.to(device)
                    G_recon = G(x)
                    result = torch.cat((x[0], G_recon[0]), 2)
                    path = os.path.join(
                        args.name + '_results', 'Transfer',
                        str(epoch + 1) + '_epoch_' + args.name + '_train_' +
                        str(n + 1) + '.png')
                    plt.imsave(path,
                               (result.cpu().numpy().transpose(1, 2, 0) + 1) /
                               2)
                    if n == 4:
                        break

                for n, (x, _) in enumerate(test_loader_src):
                    x = x.to(device)
                    G_recon = G(x)
                    result = torch.cat((x[0], G_recon[0]), 2)
                    path = os.path.join(
                        args.name + '_results', 'Transfer',
                        str(epoch + 1) + '_epoch_' + args.name + '_test_' +
                        str(n + 1) + '.png')
                    plt.imsave(path,
                               (result.cpu().numpy().transpose(1, 2, 0) + 1) /
                               2)
                    if n == 4:
                        break

                torch.save(
                    G.state_dict(),
                    os.path.join(args.name + '_results',
                                 'generator_latest.pkl'))
                torch.save(
                    D.state_dict(),
                    os.path.join(args.name + '_results',
                                 'discriminator_latest.pkl'))

    total_time = time.time() - start_time
    train_hist['total_time'].append(total_time)

    print("Avg one epoch time: %.2f, total %d epochs time: %.2f" %
          (torch.mean(torch.FloatTensor(
              train_hist['per_epoch_time'])), args.train_epoch, total_time))
    print("Training finish!... save training results")

    torch.save(G.state_dict(),
               os.path.join(args.name + '_results', 'generator_param.pkl'))
    torch.save(D.state_dict(),
               os.path.join(args.name + '_results', 'discriminator_param.pkl'))
    with open(os.path.join(args.name + '_results', 'train_hist.pkl'),
              'wb') as f:
        pickle.dump(train_hist, f)