elif args.which_latent == 'w_tied':
        args.latent_full = args.latent
    else:
        raise NotImplementedError

    args.start_iter = 0
    util.set_log_dir(args)
    util.print_args(parser, args)

    generator = Generator(
        args.size,
        args.latent,
        args.n_mlp,
        channel_multiplier=args.channel_multiplier).to(device)
    discriminator = Discriminator(args.size,
                                  channel_multiplier=args.channel_multiplier,
                                  n_head=args.n_head_d).to(device)
    g_ema = Generator(args.size,
                      args.latent,
                      args.n_mlp,
                      channel_multiplier=args.channel_multiplier).to(device)
    g_ema.eval()
    accumulate(g_ema, generator, 0)

    g_reg_ratio = args.g_reg_every / (args.g_reg_every +
                                      1) if args.g_reg_every > 0 else 1.
    d_reg_ratio = args.d_reg_every / (args.d_reg_every +
                                      1) if args.d_reg_every > 0 else 1.

    g_optim = optim.Adam(
        generator.parameters(),
Exemple #2
0
    def train(self, src_data, tgt_data):
        params = self.params
        print(params)
        penalty = 10.0  # penalty on cosine similarity
        print('Subword penalty {}'.format(penalty))
        # Load data
        if not os.path.exists(params.data_dir):
            raise "Data path doesn't exists: %s" % params.data_dir

        src_lang = params.src_lang
        tgt_lang = params.tgt_lang
        self.suffix_str = src_lang + '_' + tgt_lang

        evaluator = Evaluator(params, src_data=src_data, tgt_data=tgt_data)
        monitor = Monitor(params, src_data=src_data, tgt_data=tgt_data)

        # Initialize subword embedding transformer
        # print('Initializing subword embedding transformer...')
        # src_data['F'].eval()
        # src_optimizer = optim.SGD(src_data['F'].parameters())
        # for _ in trange(128):
        #     indices = np.random.permutation(src_data['seqs'].size(0))
        #     indices = torch.LongTensor(indices)
        #     if torch.cuda.is_available():
        #         indices = indices.cuda()
        #     total_loss = 0
        #     for batch in indices.split(params.mini_batch_size):
        #         src_optimizer.zero_grad()
        #         vecs0 = src_data['vecs'][batch]  # original
        #         vecs = src_data['F'](src_data['seqs'][batch], src_data['E'])
        #         loss = F.mse_loss(vecs0, vecs)
        #         loss.backward()
        #         total_loss += float(loss)
        #         src_optimizer.step()
        # print('Done: final loss = {:.2f}'.format(total_loss))

        src_optimizer = optim.SGD(src_data['F'].parameters(),
                                  lr=params.sw_learning_rate,
                                  momentum=0.9)
        print('Src optim: {}'.format(src_optimizer))
        # Loss function
        loss_fn = torch.nn.BCELoss()

        # Create models
        g = Generator(input_size=params.g_input_size,
                      hidden_size=params.g_hidden_size,
                      output_size=params.g_output_size)

        if self.params.model_file:
            print('Load a model from ' + self.params.model_file)
            g.load(self.params.model_file)

        d = Discriminator(input_size=params.d_input_size,
                          hidden_size=params.d_hidden_size,
                          output_size=params.d_output_size,
                          hyperparams=get_hyperparams(params, disc=True))
        seed = params.seed
        self.initialize_exp(seed)

        if not params.disable_cuda and torch.cuda.is_available():
            print('Use GPU')
            # Move the network and the optimizer to the GPU
            g.cuda()
            d.cuda()
            loss_fn = loss_fn.cuda()

        if self.params.model_file is None:
            print('Initializing G based on distribution')
            # if the relative change of loss values is smaller than tol, stop iteration
            topn = 10000
            tol = 1e-5
            prev_loss, loss = None, None
            g_optimizer = optim.SGD(g.parameters(), lr=0.01, momentum=0.9)

            batches = src_data['seqs'][:topn].split(params.mini_batch_size)
            src_emb = torch.cat([
                src_data['F'](batch, src_data['E']).detach()
                for batch in batches
            ])
            tgt_emb = tgt_data['E'].emb.weight[:topn]
            if not params.disable_cuda and torch.cuda.is_available():
                src_emb = src_emb.cuda()
                tgt_emb = tgt_emb.cuda()
            src_emb = F.normalize(src_emb)
            tgt_emb = F.normalize(tgt_emb)
            src_mean = src_emb.mean(dim=0).detach()
            tgt_mean = tgt_emb.mean(dim=0).detach()
            # src_std = src_emb.std(dim=0).deatch()
            # tgt_std = tgt_emb.std(dim=0).deatch()

            for _ in trange(1000):  # at most 1000 iterations
                prev_loss = loss
                g_optimizer.zero_grad()
                mapped_src_mean = g(src_mean)
                loss = F.mse_loss(mapped_src_mean, tgt_mean)
                loss.backward()
                g_optimizer.step()
                # Orthogonalize
                self.orthogonalize(g.map1.weight.data)
                loss = float(loss)
                if type(prev_loss) is float and abs(prev_loss -
                                                    loss) / prev_loss <= tol:
                    break
            print('Done: final loss = {}'.format(float(loss)))
        evaluator.precision(g, src_data, tgt_data)
        sim = monitor.cosine_similarity(g, src_data, tgt_data)
        print('Cos sim.: {:3f} (+/-{:.3})'.format(sim.mean(), sim.std()))

        d_acc_epochs, g_loss_epochs = [], []

        # Define optimizers
        d_optimizer = optim.SGD(d.parameters(), lr=params.d_learning_rate)
        g_optimizer = optim.SGD(g.parameters(), lr=params.g_learning_rate)
        for epoch in range(params.num_epochs):
            d_losses, g_losses = [], []
            hit = 0
            total = 0
            start_time = timer()

            for mini_batch in range(
                    0, params.iters_in_epoch // params.mini_batch_size):
                for d_index in range(params.d_steps):
                    d_optimizer.zero_grad()  # Reset the gradients
                    d.train()

                    X, y, _ = self.get_batch_data(src_data, tgt_data, g)
                    pred = d(X)
                    d_loss = loss_fn(pred, y)
                    d_loss.backward()
                    d_optimizer.step()

                    d_losses.append(d_loss.data.cpu().numpy())
                    discriminator_decision = pred.data.cpu().numpy()
                    hit += np.sum(
                        discriminator_decision[:params.mini_batch_size] >= 0.5)
                    hit += np.sum(
                        discriminator_decision[params.mini_batch_size:] < 0.5)

                    sys.stdout.write("[%d/%d] :: Discriminator Loss: %f \r" %
                                     (mini_batch, params.iters_in_epoch //
                                      params.mini_batch_size,
                                      np.asscalar(np.mean(d_losses))))
                    sys.stdout.flush()

                    total += 2 * params.mini_batch_size * params.d_steps

                for g_index in range(params.g_steps):
                    # 2. Train G on D's response (but DO NOT train D on these labels)
                    g_optimizer.zero_grad()
                    src_optimizer.zero_grad()
                    d.eval()

                    X, y, src_vecs = self.get_batch_data(src_data, tgt_data, g)
                    pred = d(X)
                    g_loss = loss_fn(pred, 1 - y)
                    src_loss = F.mse_loss(*src_vecs)
                    if g_loss.is_cuda:
                        src_loss = src_loss.cuda()
                    loss = g_loss + penalty * src_loss
                    loss.backward()
                    g_optimizer.step()  # Only optimizes G's parameters
                    src_optimizer.step()

                    g_losses.append(g_loss.data.cpu().numpy())

                    # Orthogonalize
                    self.orthogonalize(g.map1.weight.data)

                    sys.stdout.write(
                        "[%d/%d] ::                                     Generator Loss: %f \r"
                        % (mini_batch,
                           params.iters_in_epoch // params.mini_batch_size,
                           np.asscalar(np.mean(g_losses))))
                    sys.stdout.flush()

                d_acc_epochs.append(hit / total)
                g_loss_epochs.append(np.asscalar(np.mean(g_losses)))
            print(
                "Epoch {} : Discriminator Loss: {:.5f}, Discriminator Accuracy: {:.5f}, Generator Loss: {:.5f}, Time elapsed {:.2f} mins"
                .format(epoch, np.asscalar(np.mean(d_losses)), hit / total,
                        np.asscalar(np.mean(g_losses)),
                        (timer() - start_time) / 60))

            filename = path.join(params.model_dir, 'g_e{}.pth'.format(epoch))
            print('Save a generator to ' + filename)
            g.save(filename)
            filename = path.join(params.model_dir, 's_e{}.pth'.format(epoch))
            print('Save a subword transformer to ' + filename)
            src_data['F'].save(filename)
            if (epoch + 1) % params.print_every == 0:
                evaluator.precision(g, src_data, tgt_data)
                sim = monitor.cosine_similarity(g, src_data, tgt_data)
                print('Cos sim.: {:3f} (+/-{:.3})'.format(
                    sim.mean(), sim.std()))

        return g
Exemple #3
0
def main():
    # Load the data
    data = GANstronomyDataset(opts.DATA_PATH, split=opts.TVT_SPLIT)
    data.set_split_index(0)
    data_loader = torch.utils.data.DataLoader(data,
                                              batch_size=opts.BATCH_SIZE,
                                              shuffle=True)

    # Make the output directory
    util.create_dir(opts.RUN_PATH)
    util.create_dir(opts.IMG_OUT_PATH)
    util.create_dir(opts.MODEL_OUT_PATH)

    # Copy opts.py and model.py to opts.RUN_PATH as a record
    shutil.copy2('opts.py', opts.RUN_PATH)
    shutil.copy2('model.py', opts.RUN_PATH)
    shutil.copy2('train.py', opts.RUN_PATH)
    
    # Instantiate the models
    G = Generator(opts.LATENT_SIZE, opts.EMBED_SIZE).to(opts.DEVICE)
    G_optimizer = torch.optim.Adam(G.parameters(), lr=opts.ADAM_LR, betas=opts.ADAM_B)

    D = Discriminator(opts.EMBED_SIZE).to(opts.DEVICE)
    D_optimizer = torch.optim.Adam(D.parameters(), lr=opts.ADAM_LR, betas=opts.ADAM_B)

    if opts.MODEL_PATH is None:
        start_iepoch, start_ibatch = 0, 0
    else:
        print('Attempting to resume training using model in %s...' % opts.MODEL_PATH)
        start_iepoch, start_ibatch = load_state_dicts(opts.MODEL_PATH, G, G_optimizer, D, D_optimizer)
    
    for iepoch in range(opts.NUM_EPOCHS):
        for ibatch, data_batch in enumerate(data_loader):
            # To try to resume training, just continue if iepoch and ibatch are less than their starts
            if iepoch < start_iepoch or (iepoch == start_iepoch and ibatch < start_ibatch):
                if iepoch % opts.INTV_PRINT_LOSS == 0 and not ibatch:
                    print('Skipping epoch %d...' % iepoch)
                continue
            
            recipe_ids, recipe_embs, img_ids, imgs, classes, noisy_real, noisy_fake = data_batch
            noisy_real, noisy_fake = util.get_variables2(noisy_real, noisy_fake)

            # Make sure we're not training on validation or test data!
            if opts.SAFETY_MODE:
                for recipe_id in recipe_ids:
                    assert data.get_recipe_split_index(recipe_id) == 0

            batch_size, recipe_embs, imgs = util.get_variables3(recipe_ids, recipe_embs, img_ids, imgs)

            # Adversarial ground truths
            all_real = Variable(FloatTensor(batch_size, 1).fill_(1.0), requires_grad=False).to(opts.DEVICE)
            # all_fake = Variable(FloatTensor(batch_size, 1).fill_(0.0), requires_grad=False).to(opts.DEVICE)

            z = torch.randn(batch_size, opts.LATENT_SIZE).to(opts.DEVICE)
            # D_loss, G_loss = wasserstein_loss(G, D, imgs, recipe_embs)
            
            # Train Discriminator
            z = torch.randn(batch_size, opts.LATENT_SIZE).to(opts.DEVICE)
            imgs_gen = G(z, recipe_embs)
            for _ in range(opts.NUM_UPDATE_D):
                D_optimizer.zero_grad()
                fake_probs = D(imgs_gen.detach(), recipe_embs)
                real_probs = D(imgs, recipe_embs)
                D_loss = BCELoss(fake_probs, noisy_fake) + BCELoss(real_probs, noisy_real)
                D_loss.backward(retain_graph=True)
                D_optimizer.step()

            # Train Generator
            G_optimizer.zero_grad()
            fake_probs = D(imgs_gen, recipe_embs)
            G_loss = BCELoss(fake_probs, all_real)
            G_loss.backward()
            G_optimizer.step()
            
            if iepoch % opts.INTV_PRINT_LOSS == 0 and not ibatch:
                print_loss(G_loss, D_loss, iepoch)
            if iepoch % opts.INTV_SAVE_IMG == 0 and not ibatch:
                # Save a training image
                get_img_gen(data, 0, G, iepoch, opts.IMG_OUT_PATH)
                # Save a validation image
                get_img_gen(data, 1, G, iepoch, opts.IMG_OUT_PATH)
            if iepoch % opts.INTV_SAVE_MODEL == 0 and not ibatch:
                print('Saving model...')
                save_model(G, G_optimizer, D, D_optimizer, iepoch, opts.MODEL_OUT_PATH)

    save_model(G, G_optimizer, D, D_optimizer, 'FINAL', opts.MODEL_OUT_PATH)
    print('\a') # Ring the bell to alert the human
Exemple #4
0
def train(args):
    print(args)

    # net
    netG = Generator()
    netG = netG.cuda()
    netD = Discriminator()
    netD = netD.cuda()

    # loss
    l1_loss = nn.L1Loss().cuda()
    l2_loss = nn.MSELoss().cuda()
    bce_loss = nn.BCELoss().cuda()

    # opt
    optimizerG = optim.Adam(netG.parameters(), lr=args.glr)
    optimizerD = optim.Adam(netD.parameters(), lr=args.dlr)

    # lr
    schedulerG = lr_scheduler.StepLR(optimizerG, args.lr_step_size,
                                     args.lr_gamma)
    schedulerD = lr_scheduler.StepLR(optimizerD, args.lr_step_size,
                                     args.lr_gamma)

    # utility for saving models, parameters and logs
    save = SaveData(args.save_dir, args.exp, True)
    save.save_params(args)

    # netG, _ = save.load_model(netG)

    dataset = MyDataset(args.data_dir, is_train=True)
    dataloader = torch.utils.data.DataLoader(dataset,
                                             batch_size=args.batch_size,
                                             shuffle=True,
                                             num_workers=int(args.n_threads))

    real_label = Variable(
        torch.ones([1, 1, args.patch_gan, args.patch_gan],
                   dtype=torch.float)).cuda()
    fake_label = Variable(
        torch.zeros([1, 1, args.patch_gan, args.patch_gan],
                    dtype=torch.float)).cuda()

    image_pool = ImagePool(args.pool_size)

    vgg = Vgg16(requires_grad=False)
    vgg.cuda()

    for epoch in range(args.epochs):
        print("* Epoch {}/{}".format(epoch + 1, args.epochs))

        schedulerG.step()
        schedulerD.step()

        d_total_real_loss = 0
        d_total_fake_loss = 0
        d_total_loss = 0

        g_total_res_loss = 0
        g_total_per_loss = 0
        g_total_gan_loss = 0
        g_total_loss = 0

        netG.train()
        netD.train()

        for batch, images in tqdm(enumerate(dataloader)):
            input_image, target_image = images
            input_image = Variable(input_image.cuda())
            target_image = Variable(target_image.cuda())
            output_image = netG(input_image)

            # Update D
            netD.requires_grad(True)
            netD.zero_grad()

            ## real image
            real_output = netD(target_image)
            d_real_loss = bce_loss(real_output, real_label)
            d_real_loss.backward()
            d_real_loss = d_real_loss.data.cpu().numpy()
            d_total_real_loss += d_real_loss

            ## fake image
            fake_image = output_image.detach()
            fake_image = Variable(image_pool.query(fake_image.data))
            fake_output = netD(fake_image)
            d_fake_loss = bce_loss(fake_output, fake_label)
            d_fake_loss.backward()
            d_fake_loss = d_fake_loss.data.cpu().numpy()
            d_total_fake_loss += d_fake_loss

            ## loss
            d_total_loss += d_real_loss + d_fake_loss

            optimizerD.step()

            # Update G
            netD.requires_grad(False)
            netG.zero_grad()

            ## reconstruction loss
            g_res_loss = l1_loss(output_image, target_image)
            g_res_loss.backward(retain_graph=True)
            g_res_loss = g_res_loss.data.cpu().numpy()
            g_total_res_loss += g_res_loss

            ## perceptual loss
            g_per_loss = args.p_factor * l2_loss(vgg(output_image),
                                                 vgg(target_image))
            g_per_loss.backward(retain_graph=True)
            g_per_loss = g_per_loss.data.cpu().numpy()
            g_total_per_loss += g_per_loss

            ## gan loss
            output = netD(output_image)
            g_gan_loss = args.g_factor * bce_loss(output, real_label)
            g_gan_loss.backward()
            g_gan_loss = g_gan_loss.data.cpu().numpy()
            g_total_gan_loss += g_gan_loss

            ## loss
            g_total_loss += g_res_loss + g_per_loss + g_gan_loss

            optimizerG.step()

        d_total_real_loss = d_total_real_loss / (batch + 1)
        d_total_fake_loss = d_total_fake_loss / (batch + 1)
        d_total_loss = d_total_loss / (batch + 1)
        save.add_scalar('D/real', d_total_real_loss, epoch)
        save.add_scalar('D/fake', d_total_fake_loss, epoch)
        save.add_scalar('D/total', d_total_loss, epoch)

        g_total_res_loss = g_total_res_loss / (batch + 1)
        g_total_per_loss = g_total_per_loss / (batch + 1)
        g_total_gan_loss = g_total_gan_loss / (batch + 1)
        g_total_loss = g_total_loss / (batch + 1)
        save.add_scalar('G/res', g_total_res_loss, epoch)
        save.add_scalar('G/per', g_total_per_loss, epoch)
        save.add_scalar('G/gan', g_total_gan_loss, epoch)
        save.add_scalar('G/total', g_total_loss, epoch)

        if epoch % args.period == 0:
            log = "Train d_loss: {:.5f} \t g_loss: {:.5f}".format(
                d_total_loss, g_total_loss)
            print(log)
            save.save_log(log)
            save.save_model(netG, epoch)
transforms = transforms.Compose([
    transforms.Resize(IMAGE_SIZE),
    transforms.ToTensor(),
    transforms.Normalize([0.5 for _ in range(CHANNELS_IMG)],
                         [0.5 for _ in range(CHANNELS_IMG)]),
])

dataset = datasets.MNIST(root="dataset/", transform=transforms, download=True)
#comment mnist and uncomment below if you want to train on CelebA dataset
#dataset = datasets.ImageFolder(root="celeb_dataset", transform=transforms)
loader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)

# initialize gen and disc/critic
gen = Generator(Z_DIM, CHANNELS_IMG, FEATURES_GEN).to(device)
critic = Discriminator(CHANNELS_IMG, FEATURES_CRITIC).to(device)
initialize_weights(gen)
initialize_weights(critic)

# initializate optimizer
opt_gen = optim.RMSprop(gen.parameters(), lr=LEARNING_RATE)
opt_critic = optim.RMSprop(critic.parameters(), lr=LEARNING_RATE)

# for tensorboard plotting
fixed_noise = torch.randn(32, Z_DIM, 1, 1).to(device)
writer_real = SummaryWriter(f"logs/real")
writer_fake = SummaryWriter(f"logs/fake")
step = 0

gen.train()
critic.train()
Exemple #6
0
from torch.autograd import Variable
from data import dataloader

# Loss function
adversarial_loss = torch.nn.BCELoss()

# Initialize generator and discriminator
EPOCH = 10  # 训练整批数据多少次
LR = 0.0002  # 学习率
DOWNLOAD_MNIST = False  # 已经下载好的话,会自动跳过的
len_Z = 100  # random input.channal for Generator
g_hidden_channal = 64
d_hidden_channal = 64
image_channal = 1  # mnist数据为黑白的只有一维
generator = Generator(len_Z, g_hidden_channal, image_channal)
discriminator = Discriminator(image_channal, g_hidden_channal)
BATCH_SIZE = 32
cuda = True if torch.cuda.is_available() else False
if cuda:
    generator.cuda()
    discriminator.cuda()
    adversarial_loss.cuda()
import torchvision





# Optimizers
optimizer_G = torch.optim.Adam(generator.parameters(), lr=0.0002, betas=(0.5, 0.999))
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999))
Exemple #7
0
    ]))

dataloader = torch.utils.data.DataLoader(dataset,
                                         batch_size=batch_size,
                                         shuffle=True,
                                         num_workers=workers)

# choose cuda
device = torch.device("cuda:0" if (
    torch.cuda.is_available() and ngpu > 0) else "cpu")

# generate G_model
G = Generator(latent_depth=nz, feature_depth=ngf).to(device)

# generate D_model
D = Discriminator(feature_depth=ndf).to(device)

if (device.type == 'cuda') and (ngpu > 1):
    G = nn.DataParallel(G, list(range(ngpu)))
    D = nn.DataParallel(D, list(range(ngpu)))

G.apply(weights_init)
D.apply(weights_init)

criterion = nn.BCELoss()

fixed_noise = torch.randn(64, nz, 1, 1, device=device)

real_label = lambda: random.uniform(0.7, 1.2)
fake_label = lambda: random.uniform(0.0, 0.3)
Exemple #8
0
def main(args):
    print("Loading data")
    dataset = args.data.rstrip('/').split('/')[-1]
    if dataset in ['mnist']:
        train_loader, test_loader = get_mnist(args.batch_size, args.data)
    elif dataset in ['cifar']:
        train_loader, test_loader, classes = get_cifar(args.batch_size,
                                                       args.data)
    elif dataset in ['svhn']:
        train_loader, test_loader, extra_loader = get_svhn(
            args.batch_size, args.data)
    elif dataset in ['fashion']:
        train_loader, test_loader = get_fashion_mnist(args.batch_size,
                                                      args.data)
    elif dataset in ['stl10']:
        train_loader, test_loader, unlabeled_loader = get_stl10(
            args.batch_size, args.data)
    elif dataset in ['yale']:
        train_loader = get_yale(args.batch_size, args.data)
    else:
        raise NotImplementedError
    torch.cuda.set_device(args.device_id)
    for _, (batch, _) in enumerate(train_loader):
        size = batch.size()
        break

    model = VAE(size, args.code_dim, args.batch_size,
                data=dataset).to(args.device)
    D = Discriminator(args.code_dim, 4).to(args.device)
    optimizer = torch.optim.Adam(model.parameters(), args.lr)
    optimizer_D = torch.optim.Adam(model.parameters(), args.lr)

    start_epoch = 1
    print('\nStarting Training')
    if args.base:
        try:
            for epoch in range(start_epoch, args.epochs):
                nll, re_loss, kl_divergence, d_loss = baseline(args,
                                                               train_loader,
                                                               model,
                                                               optimizer,
                                                               epoch,
                                                               train=True)
                print('-' * 90)
                meta = "| epoch {:2d} ".format(epoch)
                print(
                    meta +
                    "| Train NLL: {:5.2f} | Train loss: {:5.2f} ({:5.2f}) | D loss {:5.2f} |"
                    .format(nll, re_loss, kl_divergence, d_loss))

                nll, re_loss, kl_divergence, d_loss = baseline(args,
                                                               test_loader,
                                                               model,
                                                               optimizer,
                                                               1,
                                                               train=False)
                print(
                    len(meta) * " " +
                    "| Test NLL: {:5.2f} | Test loss: {:5.2f} ({:5.2f}) | D loss {:5.2f} |"
                    .format(nll, re_loss, kl_divergence, d_loss))

        except KeyboardInterrupt:
            print('-' * 50)
            print('Quit Training')

        nll, re_loss, kl_divergence, d_loss = baseline(args,
                                                       test_loader,
                                                       model,
                                                       optimizer,
                                                       epoch,
                                                       train=False)
        print('=' * 90)
        print(
            "| Train NLL: {:5.2f} | Train loss: {:5.2f} ({:5.2f}) | D loss {:5.2f} |"
            .format(nll, re_loss, kl_divergence, d_loss))

    else:
        try:
            for epoch in range(start_epoch, args.epochs):
                nll, re_loss, kl_divergence, d_loss = run(args,
                                                          train_loader,
                                                          model,
                                                          D,
                                                          optimizer,
                                                          optimizer_D,
                                                          epoch,
                                                          train=True)
                print('-' * 90)
                meta = "| epoch {:2d} ".format(epoch)
                print(
                    meta +
                    "| Train NLL: {:5.2f} | Train loss: {:5.2f} ({:5.2f}) | D loss {:5.2f} |"
                    .format(nll, re_loss, kl_divergence, d_loss))

                nll, re_loss, kl_divergence, d_loss = run(args,
                                                          test_loader,
                                                          model,
                                                          D,
                                                          optimizer,
                                                          optimizer_D,
                                                          1,
                                                          train=False)
                print(
                    len(meta) * " " +
                    "| Test NLL: {:5.2f} | Test loss: {:5.2f} ({:5.2f}) | D loss {:5.2f} |"
                    .format(nll, re_loss, kl_divergence, d_loss))

        except KeyboardInterrupt:
            print('-' * 50)
            print('Quit Training')

        nll, re_loss, kl_divergence, d_loss = run(args,
                                                  test_loader,
                                                  model,
                                                  D,
                                                  optimizer,
                                                  optimizer_D,
                                                  epoch,
                                                  train=False)
        print('=' * 90)
        print(
            "| Train NLL: {:5.2f} | Train loss: {:5.2f} ({:5.2f}) | D loss {:5.2f} |"
            .format(nll, re_loss, kl_divergence, d_loss))
    np.random.seed(seed_num)
    torch.manual_seed(seed_num)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

    loss_D_h2l_log = []
    loss_D_l2h_log = []
    loss_G_h2l_log = []
    loss_G_l2h_log = []
    loss_cycle_log = []

    max_epoch = int(args.epoch)
    learn_rate = 1e-4

    G_h2l = High2Low().cuda()
    D_h2l = Discriminator(16).cuda()
    G_l2h = Low2High().cuda()
    D_l2h = Discriminator(64).cuda()
    mse = nn.MSELoss()
    TVloss = TVLoss()
    L1 = nn.L1Loss()
    upsample = nn.Upsample(scale_factor=4, mode='bicubic')

    optim_D_h2l = optim.Adam(filter(lambda p: p.requires_grad,
                                    D_h2l.parameters()),
                             lr=learn_rate,
                             betas=(0.0, 0.9))
    optim_G_h2l = optim.Adam(G_h2l.parameters(),
                             lr=learn_rate,
                             betas=(0.0, 0.9))
    optim_D_l2h = optim.Adam(filter(lambda p: p.requires_grad,
Exemple #10
0
def main(args):

    #initialize models and load mnist dataset
    G = Generator()
    D = Discriminator()
    x = load_dataset()

    #build optimizer of generator
    opt_generator = chainer.optimizers.Adam().setup(G)
    opt_generator.use_cleargrads()

    #build optimizer of discriminator
    opt_discriminator = chainer.optimizers.Adam().setup(D)
    opt_generator.use_cleargrads()

    #make the output folder
    if not os.path.exists(args.output):
        os.makedirs(args.output, exist_ok=True)

    #list of loss
    Glosses = []
    Dlosses = []

    print("Now starting training loop...")

    #begin training process
    for train_iter in range(1, args.num_epochs + 1):

        for i in range(0, len(x), 100):

            #Clears all gradient arrays.
            #The following should be called before the backward computation at every iteration of the optimization.
            G.cleargrads()
            D.cleargrads()

            #Train the generator
            noise_samples = sample(100)
            Gloss = 0.5 * F.sum(F.square(D(G(np.asarray(noise_samples))) - 1))
            Gloss.backward()
            opt_generator.update()

            #As above
            G.cleargrads()
            D.cleargrads()

            #Train the discriminator
            noise_samples = sample(100)
            Dreal = D(np.asarray(x[i:i + 100]))
            Dgen = D(G(np.asarray(noise_samples)))
            Dloss = 0.5 * F.sum(F.square(
                (Dreal - 1.0))) + 0.5 * F.sum(F.square(Dgen))
            Dloss.backward()
            opt_discriminator.update()

        #save loss from each batch
        Glosses.append(Gloss.data)
        Dlosses.append(Dloss.data)

        if train_iter % 10 == 0:

            print("epoch {0:04d}".format(train_iter), end=", ")
            print("Gloss: {}".format(Gloss.data), end=", ")
            print("Dloss: {}".format(Dloss.data))

            noise_samples = sample(100)
            print_sample(
                os.path.join(args.output,
                             "epoch_{0:04}.png".format(train_iter)),
                noise_samples, G)

    print("The training process is finished.")

    plotLoss(train_iter, Dlosses, Glosses)
Exemple #11
0
train_loader = DataLoader(dataset=train_data,
                          batch_size=config['batch_size'],
                          shuffle=True,
                          pin_memory=True)
sample_loader = DataLoader(dataset=sample_data,
                           batch_size=config['batch_size'],
                           shuffle=False,
                           pin_memory=True)

generator = Generator(window_size=window_size,
                      node_num=node_num,
                      in_features=config['in_features'],
                      out_features=config['out_features'],
                      lstm_features=config['lstm_features'])

discriminator = Discriminator(input_size=node_num * node_num,
                              hidden_size=config['disc_hidden'])

generator = generator.cuda()
discriminator = discriminator.cuda()

mse = nn.MSELoss(reduction='sum')

pretrain_optimizer = optim.RMSprop(generator.parameters(),
                                   lr=config['pretrain_learning_rate'])
generator_optimizer = optim.RMSprop(generator.parameters(),
                                    lr=config['g_learning_rate'])
discriminator_optimizer = optim.RMSprop(discriminator.parameters(),
                                        lr=config['d_learning_rate'])
#
print('pretrain generator')
Exemple #12
0
    'sum_0', 'sum_1', 'sum_2', 'sum_3', 'sum_4', 'sum_5', 'sum_6', 'sum_7'
]

generator = Generator()
generator.to_gpu()
gen_opt = set_optimizer(generator, alpha=0.0002)

#segmentation_generator = Generator()
#segmentation_generator.to_gpu()
#seg_gen_opt = set_optimizer(segmentation_generator)

#key_point_detector = KeyPointDetector()
#key_point_detector.to_gpu()
#key_opt = set_optimizer(key_point_detector)

discriminator = Discriminator()
discriminator.to_gpu()
dis_opt = set_optimizer(discriminator, alpha=0.0002)

#segmentation_discriminator = SimpleDiscriminator()
#segmentation_discriminator.to_gpu()
#seg_dis_opt = set_optimizer(segmentation_discriminator)

#keypoint_discriminator = SimpleDiscriminator()
#keypoint_discriminator.to_gpu()
#key_dis_opt = set_optimizer(keypoint_discriminator)

ztest = chainer.as_variable(
    xp.random.uniform(-1, 1, (batchsize, 256)).astype(xp.float32))

for epoch in range(epochs):
    transform.append(T.ToTensor())
    transform.append(T.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)))
    transforms.Lambda(lambda x: x + 1. / 128 * torch.rand(x.size()))
    transform = T.Compose(transform)

    dataset = torchvision.datasets.CIFAR10(root='./cifar10',
                                           transform=transform,
                                           download=True)
    data_loader = torch.utils.data.DataLoader(dataset=dataset,
                                              batch_size=64,
                                              shuffle=True,
                                              drop_last=True,
                                              num_workers=8)

    G = Generator().cuda()
    D = Discriminator().cuda()
    D.load_state_dict(torch.load('./checkpoint/25000_D.pth'))
    G.load_state_dict(torch.load('./checkpoint/25000_G.pth'))
    optimizer_G = torch.optim.RMSprop(G.parameters(),
                                      lr=1e-4,
                                      alpha=0.99,
                                      eps=1e-8)
    optimizer_D = torch.optim.RMSprop(D.parameters(),
                                      lr=1e-4,
                                      alpha=0.99,
                                      eps=1e-8)
    print("start...")
    dataiter = iter(data_loader)
    for idx in range(2000000):
        time_start = datetime.datetime.now()
        try:
Exemple #14
0
    g_optimizer = gluon.Trainer(generator.collect_params(),
                                optimizer='adam',
                                optimizer_params={
                                    'learning_rate': args.lr_default,
                                    'beta1': 0.0,
                                    'beta2': 0.99
                                },
                                kvstore='local')

    # Set a different learning rate for style by setting the lr_mult of 0.01
    for k in generator.collect_params().keys():
        if k.startswith('hybridsequential2'):
            generator.collect_params()[k].lr_mult = 0.01

    discriminator = Discriminator(
        from_rgb_activate=not args.no_from_rgb_activate)
    discriminator.initialize(ctx=context)
    discriminator.collect_params().reset_ctx(context)

    d_optimizer = gluon.Trainer(discriminator.collect_params(),
                                optimizer='adam',
                                optimizer_params={
                                    'learning_rate': args.lr_default,
                                    'beta1': 0.0,
                                    'beta2': 0.99
                                },
                                kvstore='local')

    g_running = StyledGenerator(code_size)
    g_running.initialize(ctx=mx.gpu(0))
    g_running.collect_params().reset_ctx(mx.gpu(0))
Exemple #15
0
def train(resume=False):

    writer = SummaryWriter('../runs/' + hparams.exp_name)

    for k in hparams.__dict__.keys():
        writer.add_text(str(k), str(hparams.__dict__[k]))

    train_dataset = ChestData(
        data_csv=hparams.train_csv,
        data_dir=hparams.train_dir,
        augment=hparams.augment,
        transform=transforms.Compose([
            transforms.Resize(hparams.image_shape),
            transforms.ToTensor(),
            #                             transforms.Normalize((0.5027, 0.5027, 0.5027), (0.2915, 0.2915, 0.2915))
        ]))

    validation_dataset = ChestData(
        data_csv=hparams.valid_csv,
        data_dir=hparams.valid_dir,
        transform=transforms.Compose([
            transforms.Resize(hparams.image_shape),
            transforms.ToTensor(),
            #                             transforms.Normalize((0.5027, 0.5027, 0.5027), (0.2915, 0.2915, 0.2915))
        ]))

    # train_sampler = WeightedRandomSampler()

    train_loader = DataLoader(train_dataset,
                              batch_size=hparams.batch_size,
                              shuffle=True,
                              num_workers=2)

    validation_loader = DataLoader(validation_dataset,
                                   batch_size=hparams.batch_size,
                                   shuffle=True,
                                   num_workers=2)

    print('loaded train data of length : {}'.format(len(train_dataset)))

    adversarial_loss = torch.nn.CrossEntropyLoss().to(hparams.gpu_device)
    discriminator = Discriminator().to(hparams.gpu_device)

    if hparams.cuda:
        discriminator = nn.DataParallel(discriminator,
                                        device_ids=hparams.device_ids)

    params_count = 0
    for param in discriminator.parameters():
        params_count += np.prod(param.size())
    print('Model has {0} trainable parameters'.format(params_count))

    if not hparams.pretrained:
        #         discriminator.apply(weights_init_normal)
        pass

    optimizer_D = torch.optim.Adam(discriminator.parameters(),
                                   lr=hparams.learning_rate,
                                   betas=(0.9, 0.999))

    scheduler_D = ReduceLROnPlateau(optimizer_D,
                                    mode='min',
                                    factor=0.1,
                                    patience=1,
                                    verbose=True,
                                    cooldown=0)

    Tensor = torch.cuda.FloatTensor if hparams.cuda else torch.FloatTensor

    def validation(discriminator, send_stats=False, epoch=0):
        print('Validating model on {0} examples. '.format(
            len(validation_dataset)))
        discriminator_ = discriminator.eval()

        with torch.no_grad():
            pred_logits_list = []
            pred_labels_list = []
            labels_list = []
            for infer_num in range(hparams.repeat_infer):
                for (img, labels, imgs_names) in tqdm(validation_loader):
                    img = Variable(img.float(), requires_grad=False)
                    labels = Variable(labels.long(), requires_grad=False)

                    img_ = img.to(hparams.gpu_device)
                    labels = labels.to(hparams.gpu_device)

                    pred_logits = discriminator_(img_)
                    _, pred_labels = torch.max(pred_logits, axis=1)

                    pred_logits_list.append(pred_logits)
                    pred_labels_list.append(pred_labels)
                    labels_list.append(labels)

            pred_logits = torch.cat(pred_logits_list, dim=0)
            pred_labels = torch.cat(pred_labels_list, dim=0)
            labels = torch.cat(labels_list, dim=0)

            val_loss = adversarial_loss(pred_logits, labels)

        return accuracy_metrics(labels.long(), pred_labels.long()), val_loss

    print('Starting training.. (log saved in:{})'.format(hparams.exp_name))
    start_time = time.time()
    best_valid_f1 = 0

    # print(model)
    for epoch in range(hparams.num_epochs):
        for batch, (imgs, labels, imgs_name) in enumerate(tqdm(train_loader)):

            imgs = Variable(imgs.float(), requires_grad=False)
            labels = Variable(labels.long(), requires_grad=False)

            imgs_ = imgs.to(hparams.gpu_device)
            labels = labels.to(hparams.gpu_device)

            # ---------------------
            #  Train Discriminator
            # ---------------------
            optimizer_D.zero_grad()

            pred_logits = discriminator(imgs_)

            d_loss = adversarial_loss(pred_logits, labels)

            d_loss.backward()
            optimizer_D.step()

            writer.add_scalar('d_loss',
                              d_loss.item(),
                              global_step=batch + epoch * len(train_loader))

            _, pred_labels = torch.max(pred_logits, axis=1)
            pred_labels = pred_labels.float()

            # if batch % hparams.print_interval == 0:
            #     auc, f1, acc, _, _ = accuracy_metrics(pred_labels, labels.long(), pred_logits)
            #     print('[Epoch - {0:.1f}, batch - {1:.3f}, d_loss - {2:.6f}, acc - {3:.4f}, f1 - {4:.5f}, auc - {5:.4f}]'.\
            #     format(1.0*epoch, 100.0*batch/len(train_loader), d_loss.item(), acc['avg'], f1[hparams.avg_mode], auc[hparams.avg_mode]))
        (val_f1, val_acc, val_conf_mat), val_loss = validation(discriminator,
                                                               epoch=epoch)

        if val_conf_mat is not None:
            fig = plot_cf(val_conf_mat)
            writer.add_figure('val_conf', fig, global_step=epoch)
            plt.close(fig)
        writer.add_scalar('val_loss', val_loss, global_step=epoch)
        writer.add_scalar('val_f1', val_f1, global_step=epoch)
        writer.add_scalar('val_acc', val_acc, global_step=epoch)
        scheduler_D.step(val_loss)
        writer.add_scalar('learning_rate',
                          optimizer_D.param_groups[0]['lr'],
                          global_step=epoch)

        torch.save(
            {
                'epoch': epoch,
                'discriminator_state_dict': discriminator.state_dict(),
                'optimizer_D_state_dict': optimizer_D.state_dict(),
            }, hparams.model + '.' + str(epoch))
        if best_valid_f1 <= val_f1:
            best_valid_f1 = val_f1
            if val_conf_mat is not None:
                fig = plot_cf(val_conf_mat)
                writer.add_figure('val_conf', fig, global_step=epoch)
                plt.close(fig)
            torch.save(
                {
                    'epoch': epoch,
                    'discriminator_state_dict': discriminator.state_dict(),
                    'optimizer_D_state_dict': optimizer_D.state_dict(),
                }, hparams.model + '.best')
            print('best model on validation set saved.')
        print('[Epoch - {0:.1f} ---> val_f1 - {1:.4f}, val_acc - {2:.4f}, val_loss - {3:.4f}, best_val_f1 - {4:.4f}, curr_lr - {5:.4f}] - time - {6:.1f}'\
            .format(1.0*epoch, val_f1, val_acc, val_loss, best_valid_f1, optimizer_D.param_groups[0]['lr'], time.time()-start_time))
        start_time = time.time()
Exemple #16
0
def main(arguments, neg_labels, pos_labels):
    output_dir_path = Path(arguments.output_dir)
    if not output_dir_path.exists():
        output_dir_path.mkdir()

    # settings
    adam_setting = {
        "alpha": arguments.adam_alpha,
        "beta1": arguments.adam_beta1,
        "beta2": arguments.adam_beta2
    }

    updater_setting = {
        "n_dis": arguments.n_dis,
        "l2_lam": arguments.l2_lam,
        "noise_std": arguments.noise_std
    }
    chainer.config.user_gpu_mode = (arguments.gpu_id >= 0)
    if chainer.config.user_gpu_mode:
        chainer.backends.cuda.get_device_from_id(arguments.gpu_id).use()

    # 訓練用正常データ
    mnist_neg = get_mnist_num(neg_labels)

    # iteratorを作成
    iterator_setting = {
        "batch_size": arguments.batch_size,
        "shuffle": True,
        "repeat": True
    }
    neg_iter = iterators.SerialIterator(mnist_neg, **iterator_setting)

    generator = Generator()
    discriminator = Discriminator()
    if chainer.config.user_gpu_mode:
        generator.to_gpu()
        discriminator.to_gpu()

    opt_g = optimizers.Adam(**adam_setting)
    opt_g.setup(generator)
    opt_d = optimizers.Adam(**adam_setting)
    opt_d.setup(discriminator)
    if arguments.weight_decay > 0.0:
        opt_g.add_hook(chainer.optimizer.WeightDecay(arguments.weight_decay))
        opt_d.add_hook(chainer.optimizer.WeightDecay(arguments.weight_decay))

    updater = GANUpdater(neg_iter, opt_g, opt_d, **updater_setting)
    trainer = Trainer(updater, (arguments.iteration, "iteration"),
                      out=str(output_dir_path))

    # テストデータを取得
    test_neg = get_mnist_num(neg_labels, train=False)
    test_pos = get_mnist_num(pos_labels, train=False)
    # 正常にラベル0,異常にラベル1を付与
    test_neg = chainer.datasets.TupleDataset(
        test_neg, np.zeros(len(test_neg), dtype=np.int32))
    test_pos = chainer.datasets.TupleDataset(
        test_pos, np.ones(len(test_pos), dtype=np.int32))
    test_ds = chainer.datasets.ConcatenatedDataset(test_neg, test_pos)
    test_iter = iterators.SerialIterator(test_ds,
                                         repeat=False,
                                         shuffle=True,
                                         batch_size=500)

    ev_target = EvalModel(generator, discriminator, arguments.noise_std)
    ev_target = ExtendedClassifier(ev_target)
    if chainer.config.user_gpu_mode:
        ev_target.to_gpu()
    evaluator = extensions.Evaluator(
        test_iter,
        ev_target,
        device=arguments.gpu_id if chainer.config.user_gpu_mode else None)
    trainer.extend(evaluator)

    # 訓練経過の表示などの設定
    trigger = (5000, "iteration")
    trainer.extend(extensions.LogReport(trigger=trigger))
    trainer.extend(extensions.PrintReport(
        ["iteration", "generator/loss", "generator/l2", "discriminator/loss"]),
                   trigger=trigger)
    trainer.extend(extensions.ProgressBar())
    trainer.extend(
        extensions.PlotReport(("generator/loss", "discriminator/loss"),
                              "iteration",
                              file_name="loss_plot.eps",
                              trigger=trigger))
    trainer.extend(
        extensions.PlotReport(["generator/l2"],
                              "iteration",
                              file_name="gen_l2_plot.eps",
                              trigger=trigger))
    trainer.extend(
        extensions.PlotReport(
            ("validation/main/F", "validation/main/accuracy"),
            "iteration",
            file_name="acc_plot.eps",
            trigger=trigger))
    trainer.extend(ext_save_img(generator, test_pos, test_neg,
                                output_dir_path / "out_images",
                                arguments.noise_std),
                   trigger=trigger)
    trainer.extend(extensions.snapshot_object(
        generator, "gen_iter_{.updater.iteration:06d}.model"),
                   trigger=trigger)
    trainer.extend(extensions.snapshot_object(
        discriminator, "dis_iter_{.updater.iteration:06d}.model"),
                   trigger=trigger)

    # 訓練開始
    trainer.run()
Exemple #17
0
def main(_):

    strategy = tf.distribute.MirroredStrategy()
    
    NUM_GPU = len(tf.config.experimental.list_physical_devices('GPU'))
    
    train_ds, ds_info = tfds.load(
        FLAGS.dataset, split='train', shuffle_files=True, with_info=True)
    
    #dataset is very big, don't want to wait long
    if FLAGS.dataset == 'lsun/bedroom':
        train_ds = train_ds.take(300000)
        output_channels = 3
    if FLAGS.dataset == 'cifar10':
        output_channels = 3 
    if FLAGS.dataset == 'mnist':  
        output_channels = 1

    OUTPUT_DIM = FLAGS.image_size * FLAGS.image_size * output_channels

    def preprocess(image):
        """Normalize the images to [-1.0, 1.0]"""
        image = image['image']
        image = tf.image.resize_with_pad(image, FLAGS.image_size,
                                         FLAGS.image_size)

        return (tf.cast(image, tf.float32) - 127.5) / 127.5

    train_ds = train_ds.map(
        preprocess, num_parallel_calls=tf.data.experimental.AUTOTUNE)
    train_ds = train_ds.cache()
    train_ds = train_ds.shuffle(ds_info.splits['train'].num_examples)
    train_ds = train_ds.batch(FLAGS.batch_size)
    train_ds = train_ds.prefetch(tf.data.experimental.AUTOTUNE)
    train_ds = strategy.experimental_distribute_dataset(train_ds)

    with strategy.scope():
    
        generator_optimizer = tf.keras.optimizers.Adam(
            learning_rate=FLAGS.lr, beta_1=0.5, beta_2=0.9)
        discriminator_optimizer = tf.keras.optimizers.Adam(
            learning_rate=FLAGS.lr, beta_1=0.5, beta_2=0.9)

        inputs = tf.keras.Input(
            shape=(FLAGS.latent_vector, ), name="latent_vector")
        outputs = Generator(FLAGS.num_filters, FLAGS.latent_vector,
                            output_channels, OUTPUT_DIM)(inputs)
        generator = tf.keras.Model(inputs=inputs, outputs=outputs)

        inputs = tf.keras.Input(
            shape=(
                FLAGS.image_size * FLAGS.image_size * output_channels),
            name="imgs")
        outputs = Discriminator(FLAGS.num_filters, FLAGS.image_size, output_channels)(inputs)
        discriminator = tf.keras.Model(inputs=inputs, outputs=outputs)

    @tf.function
    def train_gen():

        noise = tf.random.normal([FLAGS.batch_size // NUM_GPU, FLAGS.latent_vector])

        with tf.GradientTape() as gen_tape:

            generated_images = generator(noise, training=True)
            fake_output = discriminator(
                tf.reshape(generated_images, [-1, OUTPUT_DIM]), training=False)
            gen_loss = -tf.reduce_mean(fake_output)

        gradients_of_generator = gen_tape.gradient(
            gen_loss, generator.trainable_variables)

        generator_optimizer.apply_gradients(
            zip(gradients_of_generator, generator.trainable_variables))

        return tf.reduce_mean(gen_loss)

    @tf.function
    def train_disc(images):
        image = tf.reshape(images, [-1, OUTPUT_DIM])
        noise = tf.random.normal([images.shape[0], FLAGS.latent_vector])
        fake_images = generator(noise, training=True)
        with tf.GradientTape() as disc_tape:

            disc_real = discriminator(images, training=True)
            disc_fake = discriminator(fake_images, training=True)
            disc_loss = tf.reduce_mean(disc_fake) - tf.reduce_mean(disc_real)

            alpha = tf.random.uniform(
                shape=[
                    images.shape[0],
                    1,
                ], minval=0., maxval=1.)

            differences = fake_images - image

            interpolates = image + (alpha * differences)
            gradients = tf.gradients(
                discriminator(interpolates), [interpolates])[0]

            slopes = tf.math.sqrt(
                tf.reduce_sum(tf.square(gradients), axis=[1]))
            gradient_penalty = tf.reduce_mean((slopes - 1.)**2)

            disc_loss += 10 * gradient_penalty

        gradients_of_discriminator = disc_tape.gradient(
            disc_loss, discriminator.trainable_variables)

        discriminator_optimizer.apply_gradients(
            zip(gradients_of_discriminator, discriminator.trainable_variables))
        return tf.reduce_mean(disc_loss)

    @tf.function
    def distributed_disc_step(dist_inputs):
        per_replica_disc_loss = strategy.run(train_disc, args=[dist_inputs])
        return strategy.reduce(
            tf.distribute.ReduceOp.SUM, per_replica_disc_loss, axis=None)

    @tf.function
    def distributed_gen_step():
        per_replica_gen_loss = strategy.run(train_gen, args=())
        return strategy.reduce(
            tf.distribute.ReduceOp.SUM, per_replica_gen_loss, axis=None)

    def save_images(model, ep, vector):

        predictions = tf.clip_by_value(model(vector, training=False), -1, 1)
        plt.figure(figsize=(5, 5))

        for i in range(predictions.shape[0]):
            plt.subplot(4, 4, i + 1)
            pred = tf.reshape(
                predictions[i],
                [FLAGS.image_size, FLAGS.image_size, output_channels])
            plt.imshow((pred.numpy() * 127.5 + 127.5).astype(np.uint8))
            plt.axis('off')

        plt.savefig(FLAGS.save_folder +'/image_at_epoch_{:02d}.png'.format(ep))
    
    if not os.path.exists(FLAGS.save_folder):
         os.makedirs(FLAGS.save_folder)
      
    noise_vector = tf.random.normal([FLAGS.num_examples, FLAGS.latent_vector])
    
    for epoch in tqdm(range(FLAGS.epochs)):
        iterator = iter(train_ds)

        gen_loss = 0
        disc_loss = 0
        num_batch = 0
        iterations = 0
        flag = True
        while flag:
            gen_loss += distributed_gen_step()
            iterations += 1
            for _ in range(FLAGS.disc_iters):
                optional = iterator.get_next_as_optional()
                if optional.has_value().numpy() == False:
                    flag = False
                else:
                    data = optional.get_value()
                    d_loss = distributed_disc_step(data)
                    disc_loss += d_loss
                    num_batch += 1
                    
        disc_loss /= num_batch
        gen_loss /= iterations
        print("Epoch {}, gen_loss  {:.5f} \n disc_loss {:.5f}\n".format(
            epoch, gen_loss, disc_loss))

        save_images(generator, epoch, noise_vector)

    save_images(generator, FLAGS.epochs, noise_vector)
def train(dataloader):
    """ Train the model on `num_steps` batches
    Args:
        dataloader : (DataLoader) a torch.utils.data.DataLoader object that fetches training data
        num_steps : (int) # of batches to train on, each of size args.batch_size
    """

    # Define Generator, Discriminator
    G = Generator(out_channel=ch).to(device) # MNIST channel: 1, CIFAR-10 channel: 3
    D = Discriminator(in_channel=ch).to(device)

    # adversarial loss
    loss_fn = nn.BCELoss()

    # Initialize weights
    G.apply(init_weights)
    D.apply(init_weights)

    optimizer_G = optim.Adam(G.parameters(), lr=lr, betas=(b1, b2))
    optimizer_D = optim.Adam(D.parameters(), lr=lr, betas=(b1, b2))

    # Establish convention for real and fake labels during training
    real_label = 1.
    fake_label = 0.

    # -----Training----- #
    for epoch in range(epochs):
        # For each batch in the dataloader

        for i, data in enumerate(dataloader, 0):
            ############################
            # (1) Update D network: maximize log(D(x)) + log(1 - D(G(z)))
            ###########################
            ## Train with all-real batch
            D.zero_grad()
            # Format batch
            real_cpu = data[0].to(device) # load image batch size
            b_size = real_cpu.size(0) # batch size
            label = torch.full((b_size,), real_label, dtype=torch.float, device=device, requires_grad=False) # real batch

            # Forward pass **real batch** through D
            output = D(real_cpu).view(-1)
            # Calculate loss on all-real batch
            errD_real = loss_fn(output, label)
            # Calculate gradients for D in backward pass
            errD_real.backward()

            ## Train with **all-fake** batch
            # Generate noise batch of latent vectors
            noise = torch.randn(b_size, latent_dim, 1, 1, device=device)
            # Generate fake image batch with G
            fake = G(noise)
            label.fill_(fake_label) # fake batch

            # Classify all fake batch with D
            output = D(fake.detach()).view(-1)
            # Calculate D's loss on the all-fake batch
            errD_fake = loss_fn(output, label)
            # Calculate the gradients for this batch
            errD_fake.backward()
            # Add the gradients from the all-real and all-fake batches
            errD = errD_real + errD_fake
            # Update D
            optimizer_D.step()

            ############################
            # (2) Update G network: maximize log(D(G(z)))
            ###########################
            G.zero_grad()
            label.fill_(real_label)  # fake labels are real for generator cost

            # Since we just updated D, perform another forward pass of all-fake batch through D
            output = D(fake).view(-1)
            # Calculate G's loss based on this output
            errG = loss_fn(output, label)
            # Calculate gradients for G
            errG.backward()
            # Update G
            optimizer_G.step()

            # Save fake images generated by Generator
            batches_done = epoch * len(dataloader) + i
            if batches_done % 400 == 0:
                save_image(fake.data[:25], "images/%d.png" % batches_done, nrow=5, normalize=True)

        print(f"[Epoch {epoch + 1}/{epochs}] [D loss: {errD.item():.4f}] [G loss: {errG.item():.4f}]")

        # Save Generator model's parameters
        save_checkpoints(
            {'epoch': i + 1,
             'state_dict': G.state_dict(),
             'optim_dict': optimizer_G.state_dict()},
            checkpoint='./ckpt/',
            is_G=True
        )

        # Save Discriminator model's parameters
        save_checkpoints(
            {'epoch': i + 1,
             'state_dict': D.state_dict(),
             'optim_dict': optimizer_D.state_dict()},
            checkpoint='./ckpt/',
            is_G=False
        )
    return parser.parse_args()


if __name__ == "__main__":
    args = arg_parse()

    args.save_dir = "%s/outs/%s" % (os.getcwd(), args.save_dir)
    if os.path.exists(args.save_dir) is False:
        os.mkdir(args.save_dir)

    CUDA = True if torch.cuda.is_available() else False

    os.environ["CUDA_VISIBLE_DEVICES"] = args.gpus
    torch_device = torch.device("cuda") if CUDA else torch.device('cpu')

    data_path = "../../data/celeba-"  # resolution string will be concatenated in ScalableLoader
    loader = ScalableLoader(data_path,
                            shuffle=True,
                            drop_last=True,
                            num_workers=args.cpus,
                            shuffled_cycle=True)

    g = nn.DataParallel(Generator()).to(torch_device)
    d = nn.DataParallel(Discriminator()).to(torch_device)

    tensorboard = TensorboardLogger("%s/tb" % (args.save_dir))

    pggan = PGGAN(args, g, d, loader, torch_device, args.loss, tensorboard)
    pggan.train()
    state_dict = fill_statedict(state_dict, g_ema.vars, size)

    g.load_state_dict(state_dict)

    latent_avg = torch.from_numpy(g_ema.vars['dlatent_avg'].value().eval())

    ckpt = {'g_ema': state_dict, 'latent_avg': latent_avg}

    if args.gen:
        g_train = Generator(size, 512, 8)
        g_train_state = g_train.state_dict()
        g_train_state = fill_statedict(g_train_state, generator.vars, size)
        ckpt['g'] = g_train_state

    if args.disc:
        disc = Discriminator(size)
        d_state = disc.state_dict()
        d_state = discriminator_fill_statedict(d_state, discriminator.vars,
                                               size)
        ckpt['d'] = d_state

    name = os.path.splitext(os.path.basename(args.path))[0]
    torch.save(ckpt, name + '.pt')

    batch_size = {256: 16, 512: 9, 1024: 4}
    n_sample = batch_size.get(size, 25)

    g = g.to(device)

    z = np.random.RandomState(0).randn(n_sample, 512).astype('float32')
    def build_model(self):

        self.g_net = Generator(max_seq_length=self.data.tags_idx.shape[1],
                               vocab_size=self.vocab_size,
                               embedding_size=self.FLAGS.embedding_dim,
                               hidden_size=self.FLAGS.hidden,
                               img_row=self.img_row,
                               img_col=self.img_col)
        self.d_net = Discriminator(max_seq_length=self.data.tags_idx.shape[1],
                                   vocab_size=self.vocab_size,
                                   embedding_size=self.FLAGS.embedding_dim,
                                   hidden_size=self.FLAGS.hidden,
                                   img_row=self.img_row,
                                   img_col=self.img_col)

        self.seq = tf.placeholder(
            tf.float32,
            [None, len(self.data.eyes_idx) + len(self.data.hair_idx)],
            name="seq")
        self.img = tf.placeholder(tf.float32,
                                  [None, self.img_row, self.img_col, 3],
                                  name="img")
        self.z = tf.placeholder(tf.float32, [None, self.FLAGS.z_dim])

        self.w_seq = tf.placeholder(
            tf.float32,
            [None, len(self.data.eyes_idx) + len(self.data.hair_idx)],
            name="w_seq")
        self.w_img = tf.placeholder(tf.float32,
                                    [None, self.img_row, self.img_col, 3],
                                    name="w_img")

        r_img, r_seq = self.img, self.seq

        self.f_img = self.g_net(r_seq, self.z)

        self.sampler = tf.identity(self.g_net(r_seq,
                                              self.z,
                                              reuse=True,
                                              train=False),
                                   name='sampler')

        # TODO
        """
			r img, r text -> 1
			f img, r text -> 0
			r img, w text -> 0
			w img, r text -> 0
		"""
        self.d = self.d_net(r_seq, r_img, reuse=False)  # r img, r text
        self.d_1 = self.d_net(r_seq, self.f_img)  # f img, r text
        self.d_2 = self.d_net(self.w_seq, self.img)  # r img, w text
        self.d_3 = self.d_net(r_seq, self.w_img)  # w img, r text

        # epsilon = tf.random_uniform([], 0.0, 1.0)
        # img_hat = epsilon * r_img + (1 - epsilon) * self.f_img
        # d_hat = self.d_net(r_seq, img_hat)

        # ddx = tf.gradients(d_hat, img_hat)[0]
        # ddx = tf.reshape(ddx, [-1, self.img_row * self.img_col * 3])
        # ddx = tf.sqrt(tf.reduce_sum(tf.square(ddx), axis=1))
        # ddx = tf.reduce_mean(tf.square(ddx - 1.0) * self.alpha)

        # self.g_loss = -tf.reduce_mean(self.d_1)
        # self.d_loss = tf.reduce_mean(self.d) - (tf.reduce_mean(self.d_1)+tf.reduce_mean(self.d_2)+tf.reduce_mean(self.d_3))/3.
        # self.d_loss = -(self.d_loss - ddx)

        # dcgan
        self.g_loss = tf.reduce_mean(
            tf.nn.sigmoid_cross_entropy_with_logits(logits=self.d_1,
                                                    labels=tf.ones_like(
                                                        self.d_1)))

        self.d_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=self.d, labels=tf.ones_like(self.d))) \
           + (tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=self.d_1, labels=tf.zeros_like(self.d_1))) + \
              tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=self.d_2, labels=tf.zeros_like(self.d_2))) +\
              tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=self.d_3, labels=tf.zeros_like(self.d_3))) ) / 3

        self.global_step = tf.Variable(0,
                                       name='g_global_step',
                                       trainable=False)

        with tf.control_dependencies(tf.get_collection(
                tf.GraphKeys.UPDATE_OPS)):
            self.d_updates = tf.train.AdamOptimizer(
                self.FLAGS.lr, beta1=0.5,
                beta2=0.9).minimize(self.d_loss, var_list=self.d_net.vars)
            self.g_updates = tf.train.AdamOptimizer(
                self.FLAGS.lr, beta1=0.5,
                beta2=0.9).minimize(self.g_loss,
                                    var_list=self.g_net.vars,
                                    global_step=self.global_step)

        self.sess.run(tf.global_variables_initializer())
        self.saver = tf.train.Saver(tf.global_variables())
    state_dict = fill_statedict(state_dict, g_ema.vars, size, n_mlp)

    g.load_state_dict(state_dict)

    latent_avg = torch.from_numpy(g_ema.vars["dlatent_avg"].value().eval())

    ckpt = {"g_ema": state_dict, "latent_avg": latent_avg}

    if args.gen:
        g_train = Generator(size, 512, n_mlp, channel_multiplier=args.channel_multiplier)
        g_train_state = g_train.state_dict()
        g_train_state = fill_statedict(g_train_state, generator.vars, size, n_mlp)
        ckpt["g"] = g_train_state

    if args.disc:
        disc = Discriminator(size, channel_multiplier=args.channel_multiplier)
        d_state = disc.state_dict()
        d_state = discriminator_fill_statedict(d_state, discriminator.vars, size)
        ckpt["d"] = d_state

    name = os.path.splitext(os.path.basename(args.path))[0]
    torch.save(ckpt, name + ".pt")

    batch_size = {256: 16, 512: 9, 1024: 4}
    n_sample = batch_size.get(size, 25)

    g = g.to(device)

    z = np.random.RandomState(0).randn(n_sample, 512).astype("float32")

    with torch.no_grad():
Exemple #23
0
    parser.add_argument('--mixing',
                        action='store_true',
                        help='use mixing regularization')
    parser.add_argument(
        '--loss',
        type=str,
        default='wgan-gp',
        choices=['wgan-gp', 'r1'],
        help='class of gan loss',
    )

    args = parser.parse_args()

    generator = nn.DataParallel(StyledGenerator(code_size)).cuda()
    discriminator = nn.DataParallel(
        Discriminator(from_rgb_activate=not args.no_from_rgb_activate)).cuda()
    g_running = StyledGenerator(code_size).cuda()
    g_running.train(False)

    class_loss = nn.CrossEntropyLoss()

    g_optimizer = optim.Adam(generator.module.generator.parameters(),
                             lr=args.lr,
                             betas=(0.0, 0.99))
    g_optimizer.add_param_group({
        'params': generator.module.style.parameters(),
        'lr': args.lr * 0.01,
        'mult': 0.01,
    })
    d_optimizer = optim.Adam(discriminator.parameters(),
                             lr=args.lr,
Exemple #24
0
    def build_model(self):
        if self.config.loss_identity:

            from light_cnn import LightCNN_29Layers
            self.l_model = LightCNN_29Layers(num_classes=294)

            self.l_model.eval()
            self.l_model = torch.nn.DataParallel(self.l_model).cuda()
            checkpoint = torch.load("data/lightCNN_160_checkpoint.pth.tar")
            self.l_model.load_state_dict(checkpoint['state_dict'])

        if self.config.mode == 'test':
            feature = True
        else:
            feature = False

        # Define a generator and a discriminator
        if self.config.use_gpb:
            from model import Generator_gpb
            self.G = Generator_gpb(self.config.g_conv_dim, self.config.c_dim,
                                   self.config.g_repeat_num)
            self.G = torch.nn.DataParallel(
                self.G,
                device_ids=[i for i in range(torch.cuda.device_count())
                            ])  # use DataParallel
        else:
            from model import Generator
            self.G = Generator(self.config.g_conv_dim, self.config.c_dim,
                               self.config.g_repeat_num)
        # self.D = Discriminator(self.config.image_size, self.config.d_conv_dim, self.config.c_dim, self.config.d_repeat_num)

        if self.config.loss_id_cls:
            if self.config.id_cls_loss == 'angle':
                if self.config.use_sn:
                    from model import Discriminator_idcls_angle_SN
                    self.D = Discriminator_idcls_angle_SN(
                        self.config.face_crop_size,
                        self.config.d_conv_dim,
                        self.config.c_dim,
                        self.config.d_repeat_num,
                        feature=feature,
                        classnum=self.config.num_id)
                else:
                    from model import Discriminator_idcls_angle
                    self.D = Discriminator_idcls_angle(
                        self.config.face_crop_size,
                        self.config.d_conv_dim,
                        self.config.c_dim,
                        self.config.d_repeat_num,
                        feature=feature,
                        classnum=self.config.num_id)
            elif self.config.id_cls_loss == 'cross':
                if self.config.use_sn:
                    from model import Discriminator_idcls_cross_SN
                    self.D = Discriminator_idcls_cross_SN(
                        self.config.face_crop_size,
                        self.config.d_conv_dim,
                        self.config.c_dim,
                        self.config.d_repeat_num,
                        feature=feature,
                        classnum=self.config.num_id)
                else:
                    from model import Discriminator_idcls_cross
                    self.D = Discriminator_idcls_cross(
                        self.config.face_crop_size,
                        self.config.d_conv_dim,
                        self.config.c_dim,
                        self.config.d_repeat_num,
                        feature=feature,
                        classnum=self.config.num_id)
        else:
            if self.config.use_sn:
                from model import Discriminator_SN
                self.D = Discriminator_SN(self.config.face_crop_size,
                                          self.config.d_conv_dim,
                                          self.config.c_dim,
                                          self.config.d_repeat_num)
            else:
                from model import Discriminator
                self.D = Discriminator(self.config.face_crop_size,
                                       self.config.d_conv_dim,
                                       self.config.c_dim,
                                       self.config.d_repeat_num)

        # Optimizers
        self.g_optimizer = torch.optim.Adam(
            self.G.parameters(), self.config.g_lr,
            [self.config.beta1, self.config.beta2])
        self.d_optimizer = torch.optim.Adam(
            self.D.parameters(), self.config.d_lr,
            [self.config.beta1, self.config.beta2])

        # Print networks
        self.print_network(self.G, 'G')
        self.print_network(self.D, 'D')

        if torch.cuda.is_available():
            self.G.cuda()
            self.D.cuda()
Exemple #25
0
    def init_model_for_testing(self):
        self.generator = Generator()
        self.generator = self.generator.cuda()

        self.discriminator = Discriminator()
        self.discriminator = self.discriminator.cuda()
Exemple #26
0
    filename = "trim_free_" + str(rnd) + ".png"
    line,rnd1,rnd2 = prepare_dataset_line(line_path + filename)
    color = prepare_dataset_color(color_path + filename,rnd1,rnd2)
    line_box.append(line)
    color_box.append(color)

line_test = xp.array(line_box).astype(xp.float32)
line_test = chainer.as_variable(line_test)
color_test = xp.array(color_box).astype(xp.float32)
color_test = chainer.as_variable(color_test)

global_generator=Global_Generator()
global_generator.to_gpu()
gg_opt=set_optimizer(global_generator)

discriminator=Discriminator()
discriminator.to_gpu()
dis_opt=set_optimizer(discriminator)

discriminator_2=Discriminator()
discriminator_2.to_gpu()
dis2_opt=set_optimizer(discriminator_2)

discriminator_4=Discriminator()
discriminator_4.to_gpu()
dis4_opt=set_optimizer(discriminator_4)

for epoch in range(epochs):
    sum_gen_loss=0
    sum_dis_loss=0
    for batch in range(0,iterations,batchsize):
        "../testing/small/testing",
        crop_size=CROP_SIZE,
        upscale_factor=UPSCALE_FACTOR,
        interpolation=INTERPOLATION,
    )
    val_set = ValDatasetFromFolder(
        "../testing/nowe", upscale_factor=UPSCALE_FACTOR, interpolation=INTERPOLATION
    )
    train_loader = DataLoader(
        dataset=train_set, num_workers=4, batch_size=BATCH_SIZE, shuffle=True
    )
    val_loader = DataLoader(dataset=val_set, num_workers=4, batch_size=1, shuffle=False)

    netG = Generator(16, UPSCALE_FACTOR)
    print("# generator parameters:", sum(param.numel() for param in netG.parameters()))
    netD = Discriminator()
    print("# discriminator parameters:", sum(param.numel() for param in netD.parameters()))

    generator_criterion = GeneratorLoss()
    mse_loss = nn.MSELoss()
    if torch.cuda.is_available():
        netG.cuda()
        netD.cuda()
        generator_criterion.cuda()
        mse_loss = mse_loss.cuda()

    optimizerG = optim.Adam(netG.parameters(), lr=0.0001)

    results = {
        "d_loss": [],
        "g_loss": [],
Exemple #28
0
    }

    json = js.dumps(dict)
    f = open(path + "/settings.json", "w")
    f.write(json)
    f.close()

    #Defining the generator and discriminator
    generator = Generator(seq_length,
                          sample_size,
                          hidden_dim=hidden_nodes_g,
                          num_layers=layers,
                          tanh_output=tanh_layer).cuda()
    discriminator = Discriminator(seq_length,
                                  sample_size,
                                  minibatch_normal_init=minibatch_normal_init_,
                                  hidden_dim=hidden_nodes_d,
                                  num_layers=layers,
                                  minibatch=minibatch_layer).cuda()
    d_optimizer = torch.optim.Adam(discriminator.parameters(),
                                   lr=learning_rate)
    g_optimizer = torch.optim.Adam(generator.parameters(), lr=learning_rate)
    #Loss function
    loss = torch.nn.BCELoss()

    generator.train()
    discriminator.train()

    G_losses = []
    D_losses = []
    mmd_list = []
    series_list = np.zeros((1, seq_length))
Exemple #29
0
        torch.cuda.set_device(args.local_rank)
        torch.distributed.init_process_group(backend='nccl',
                                             init_method='env://')
        synchronize()

    args.latent = 512
    args.n_mlp = 8

    args.start_iter = 0

    generator = Generator(
        args.size,
        args.latent,
        args.n_mlp,
        channel_multiplier=args.channel_multiplier).to(device)
    discriminator = Discriminator(
        args.size, channel_multiplier=args.channel_multiplier).to(device)
    g_ema = Generator(args.size,
                      args.latent,
                      args.n_mlp,
                      channel_multiplier=args.channel_multiplier).to(device)
    g_ema.eval()
    accumulate(g_ema, generator, 0)

    g_reg_ratio = args.g_reg_every / (args.g_reg_every + 1)
    d_reg_ratio = args.d_reg_every / (args.d_reg_every + 1)

    g_optim = optim.Adam(
        generator.parameters(),
        lr=args.lr * g_reg_ratio,
        betas=(0**g_reg_ratio, 0.99**g_reg_ratio),
    )
Exemple #30
0
def main():
	n_epoch_pretrain = 2
	use_tensorboard = True

	parser = argparse.ArgumentParser(description='SRGAN Train')
	parser.add_argument('--crop_size', default=128, type=int, help='training images crop size')
	parser.add_argument('--num_epochs', default=1000, type=int, help='training epoch')
	parser.add_argument('--batch_size', default=64, type=int, help='training batch size')
	parser.add_argument('--train_set', default='data/train', type=str, help='train set path')
	parser.add_argument('--check_point', type=int, default=-1, help="continue with previous check_point")

	opt = parser.parse_args()

	input_size = opt.crop_size
	n_epoch = opt.num_epochs
	batch_size = opt.batch_size
	check_point = opt.check_point

	check_point_path = 'cp/'
	if not os.path.exists(check_point_path):
		os.makedirs(check_point_path)

	train_set = TrainDataset(opt.train_set, crop_size=input_size, upscale_factor=4)
	train_loader = DataLoader(dataset=train_set, num_workers=2, batch_size=batch_size, shuffle=True)

	dev_set = DevDataset('data/dev', upscale_factor=4)
	dev_loader = DataLoader(dataset=dev_set, num_workers=1, batch_size=1, shuffle=False)

	mse = nn.MSELoss()
	bce = nn.BCELoss()
	#tv = TVLoss()
		
	if not torch.cuda.is_available():
		print ('!!!!!!!!!!!!!!USING CPU!!!!!!!!!!!!!')

	netG = Generator()
	print('# generator parameters:', sum(param.numel() for param in netG.parameters()))
	netD = Discriminator()
	print('# discriminator parameters:', sum(param.numel() for param in netD.parameters()))

	if torch.cuda.is_available():
		netG.cuda()
		netD.cuda()
		#tv.cuda()
		mse.cuda()
		bce.cuda()

	if use_tensorboard:
		writer = SummaryWriter()

	# Pre-train generator using only MSE loss
	if check_point == -1:
		optimizerG = optim.Adam(netG.parameters())
		#schedulerG = MultiStepLR(optimizerG, milestones=[20], gamma=0.1)
		for epoch in range(1, n_epoch_pretrain + 1):
			#schedulerG.step()		
			train_bar = tqdm(train_loader)
			
			netG.train()
			
			cache = {'g_loss': 0}
			
			for lowres, real_img_hr in train_bar:
				if torch.cuda.is_available():
					real_img_hr = real_img_hr.cuda()
					
				if torch.cuda.is_available():
					lowres = lowres.cuda()
					
				fake_img_hr = netG(lowres)

				# Train G
				netG.zero_grad()
				
				image_loss = mse(fake_img_hr, real_img_hr)
				cache['g_loss'] += image_loss
				
				image_loss.backward()
				optimizerG.step()

				# Print information by tqdm
				train_bar.set_description(desc='[%d/%d] Loss_G: %.4f' % (epoch, n_epoch_pretrain, image_loss))
				
		# Save model parameters	
		#if torch.cuda.is_available():
		#	torch.save(netG.state_dict(), 'cp/netG_epoch_pre_gpu.pth')
		#else:
		#	torch.save(netG.state_dict(), 'cp/netG_epoch_pre_cpu.pth')
	
	optimizerG = optim.Adam(netG.parameters())
	optimizerD = optim.Adam(netD.parameters())
	
	if check_point != -1:
		if torch.cuda.is_available():
			netG.load_state_dict(torch.load('cp/netG_epoch_' + str(check_point) + '_gpu.pth'))
			netD.load_state_dict(torch.load('cp/netD_epoch_' + str(check_point) + '_gpu.pth'))
			optimizerG.load_state_dict(torch.load('cp/optimizerG_epoch_' + str(check_point) + '_gpu.pth'))
			optimizerD.load_state_dict(torch.load('cp/optimizerD_epoch_' + str(check_point) + '_gpu.pth'))
		else :
			netG.load_state_dict(torch.load('cp/netG_epoch_' + str(check_point) + '_cpu.pth'))
			netD.load_state_dict(torch.load('cp/netD_epoch_' + str(check_point) + '_cpu.pth'))
			optimizerG.load_state_dict(torch.load('cp/optimizerG_epoch_' + str(check_point) + '_cpu.pth'))
			optimizerD.load_state_dict(torch.load('cp/optimizerD_epoch_' + str(check_point) + '_cpu.pth'))
	
	for epoch in range(1 + max(check_point, 0), n_epoch + 1 + max(check_point, 0)):
		train_bar = tqdm(train_loader)
		
		netG.train()
		netD.train()
		
		cache = {'mse_loss': 0, 'tv_loss': 0, 'adv_loss': 0, 'g_loss': 0, 'd_loss': 0, 'ssim': 0, 'psnr': 0, 'd_top_grad' : 0, 'd_bot_grad' : 0, 'g_top_grad' : 0, 'g_bot_grad' : 0}
		
		for lowres, real_img_hr in train_bar:
			#print ('lr size : ' + str(data.size()))
			#print ('hr size : ' + str(target.size()))
			if torch.cuda.is_available():
				real_img_hr = real_img_hr.cuda()
				lowres = lowres.cuda()
			
			# Train D
			
			#if not check_grads(netD, 'D'):
			#	return
			netD.zero_grad()
			
			logits_real = netD(real_img_hr)
			logits_fake = netD(netG(lowres).detach())
			
			# Lable smoothing
			real = torch.tensor(torch.rand(logits_real.size())*0.25 + 0.85)
			fake = torch.tensor(torch.rand(logits_fake.size())*0.15)
			
			# Lable flipping
			prob = (torch.rand(logits_real.size()) < 0.05)
			
			#print ('logits real size : ' + str(logits_real.size()))
			#print ('logits fake size : ' + str(logits_fake.size()))
			
			if torch.cuda.is_available():
				real = real.cuda()
				fake = fake.cuda()
				prob = prob.cuda()
				
			real_clone = real.clone()
			real[prob] = fake[prob]
			fake[prob] = real_clone[prob]
            
			d_loss = bce(logits_real, real) + bce(logits_fake, fake)
			
			cache['d_loss'] += d_loss.item()
			
			d_loss.backward()
			optimizerD.step()
			
			dtg, dbg = get_grads_D(netD)

			cache['d_top_grad'] += dtg
			cache['d_bot_grad'] += dbg

			# Train G
					
			#if not check_grads(netG, 'G'):
			#	return
			netG.zero_grad()
			
			fake_img_hr = netG(lowres)
			image_loss = mse(fake_img_hr, real_img_hr)
			
			logits_fake_new = netD(fake_img_hr)
			adversarial_loss = bce(logits_fake_new, torch.ones_like(logits_fake_new))
			
			#tv_loss = tv(fake_img_hr)
			
			g_loss = image_loss + 1e-2*adversarial_loss

			cache['mse_loss'] += image_loss.item()
			#cache['tv_loss'] += tv_loss.item()
			cache['adv_loss'] += adversarial_loss.item()
			cache['g_loss'] += g_loss.item()

			g_loss.backward()
			optimizerG.step()
			
			gtg, gbg = get_grads_G(netG)

			cache['g_top_grad'] += gtg
			cache['g_bot_grad'] += gbg

			# Print information by tqdm
			train_bar.set_description(desc='[%d/%d] D grads:(%f, %f) G grads:(%f, %f) Loss_D: %.4f Loss_G: %.4f = %.4f + %.4f' % (epoch, n_epoch, dtg, dbg, gtg, gbg, d_loss, g_loss, image_loss, adversarial_loss))
		
		if use_tensorboard:
			writer.add_scalar('d_loss', cache['d_loss']/len(train_loader), epoch)
		
			writer.add_scalar('mse_loss', cache['mse_loss']/len(train_loader), epoch)
			#writer.add_scalar('tv_loss', cache['tv_loss']/len(train_loader), epoch)
			writer.add_scalar('adv_loss', cache['adv_loss']/len(train_loader), epoch)
			writer.add_scalar('g_loss', cache['g_loss']/len(train_loader), epoch)
			
			writer.add_scalar('D top layer gradient', cache['d_top_grad']/len(train_loader), epoch)
			writer.add_scalar('D bot layer gradient', cache['d_bot_grad']/len(train_loader), epoch)
			writer.add_scalar('G top layer gradient', cache['g_top_grad']/len(train_loader), epoch)
			writer.add_scalar('G bot layer gradient', cache['g_bot_grad']/len(train_loader), epoch)
		
		# Save model parameters	
		if torch.cuda.is_available():
			torch.save(netG.state_dict(), 'cp/netG_epoch_%d_gpu.pth' % (epoch))
			if epoch%5 == 0:
				torch.save(netD.state_dict(), 'cp/netD_epoch_%d_gpu.pth' % (epoch))
				torch.save(optimizerG.state_dict(), 'cp/optimizerG_epoch_%d_gpu.pth' % (epoch))
				torch.save(optimizerD.state_dict(), 'cp/optimizerD_epoch_%d_gpu.pth' % (epoch))
		else:
			torch.save(netG.state_dict(), 'cp/netG_epoch_%d_cpu.pth' % (epoch))
			if epoch%5 == 0:
				torch.save(netD.state_dict(), 'cp/netD_epoch_%d_cpu.pth' % (epoch))
				torch.save(optimizerG.state_dict(), 'cp/optimizerG_epoch_%d_cpu.pth' % (epoch))
				torch.save(optimizerD.state_dict(), 'cp/optimizerD_epoch_%d_cpu.pth' % (epoch))
				
		# Visualize results
		with torch.no_grad():
			netG.eval()
			out_path = 'vis/'
			if not os.path.exists(out_path):
				os.makedirs(out_path)
				
			dev_bar = tqdm(dev_loader)
			valing_results = {'mse': 0, 'ssims': 0, 'psnr': 0, 'ssim': 0, 'batch_sizes': 0}
			dev_images = []
			for val_lr, val_hr_restore, val_hr in dev_bar:
				batch_size = val_lr.size(0)
				lr = val_lr
				hr = val_hr
				if torch.cuda.is_available():
					lr = lr.cuda()
					hr = hr.cuda()
				
				sr = netG(lr)
				
				psnr = 10 * log10(1 / ((sr - hr) ** 2).mean().item())
				ssim = pytorch_ssim.ssim(sr, hr).item()
				dev_bar.set_description(desc='[converting LR images to SR images] PSNR: %.4f dB SSIM: %.4f' % (psnr, ssim))
				
				cache['ssim'] += ssim
				cache['psnr'] += psnr
				
				# Avoid out of memory crash on 8G GPU
				if len(dev_images) < 60 :
					dev_images.extend([to_image()(val_hr_restore.squeeze(0)), to_image()(hr.data.cpu().squeeze(0)), to_image()(sr.data.cpu().squeeze(0))])
			
			dev_images = torch.stack(dev_images)
			dev_images = torch.chunk(dev_images, dev_images.size(0) // 3)
			
			dev_save_bar = tqdm(dev_images, desc='[saving training results]')
			index = 1
			for image in dev_save_bar:
				image = utils.make_grid(image, nrow=3, padding=5)
				utils.save_image(image, out_path + 'epoch_%d_index_%d.png' % (epoch, index), padding=5)
				index += 1
		
			if use_tensorboard:			
				writer.add_scalar('ssim', cache['ssim']/len(dev_loader), epoch)
				writer.add_scalar('psnr', cache['psnr']/len(dev_loader), epoch)