Esempio n. 1
0
def init_net(depth, dropout, window, cgan):

    input_shape = (1 if not cgan else 2, window)
    # Create the 3 networks
    gen = Generator(depth, dropout, verbose=0)
    discr = Discriminator(depth, dropout, input_shape,
                          verbose=0) if not cgan else ConditionalDiscriminator(
                              depth, dropout, input_shape, verbose=0)
    ae = AutoEncoder(depth, dropout, verbose=0)

    # Put them on cuda if available
    if torch.cuda.is_available():
        gen.cuda()
        discr.cuda()
        ae.cuda()
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    print("Using : " + str(device))
    print("Network initialized\n")

    return gen, discr, ae, device
Esempio n. 2
0
def main():
    G = Generator(args.dim_disc + args.dim_cont)
    D = Discriminator()

    if os.path.isfile(args.model):
        model = torch.load(args.model)
        G.load_state_dict(model[0])
        D.load_state_dict(model[1])

    if use_cuda:
        G.cuda()
        D.cuda()

    if args.mode == "train":
        G, D = train(G, D)
        if args.model:
            torch.save([G.state_dict(), D.state_dict()],
                       args.model,
                       pickle_protocol=4)
    elif args.mode == "gen":
        gen(G)
Esempio n. 3
0
def train(opt, dataloader_m=None):
    # Loss function
    adversarial_loss = torch.nn.BCELoss()

    # Initialize generator and discriminator
    generator = TinyNet()
    discriminator = Discriminator()

    cuda = True if torch.cuda.is_available() else False
    if cuda:
        generator.cuda()
        discriminator.cuda()
        adversarial_loss.cuda()
    Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor
    # Configure data loader
    os.makedirs('./data/mnist', exist_ok=True)
    dataloader = torch.utils.data.DataLoader(datasets.MNIST(
        './data/mnist',
        train=True,
        download=True,
        transform=transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ])),
                                             batch_size=opt.batch_size,
                                             shuffle=True)

    # Optimizers
    optimizer_G = torch.optim.Adam(generator.parameters(),
                                   lr=opt.lr,
                                   betas=(0.9, 0.99))
    optimizer_D = torch.optim.Adam(discriminator.parameters(),
                                   lr=opt.lr,
                                   betas=(0.9, 0.99))

    # ----------
    #  Training
    # ----------
    for epoch in range(opt.n_epochs):
        for i, (imgs, imgs_ns, labels) in enumerate(dataloader):

            # Configure input
            real_imgs = Variable(labels.type(Tensor))

            # Adversarial ground truths
            valid = Variable(Tensor(imgs.size(0), 1).fill_(1.0),
                             requires_grad=False)
            fake = Variable(Tensor(imgs.size(0), 1).fill_(0.0),
                            requires_grad=False)

            # -----------------
            #  Train Generator
            # -----------------
            optimizer_G.zero_grad()
            # Sample noise as generator input
            #z = Variable(Tensor(np.random.normal(0, 1, (imgs.shape[0], opt.latent_dim)))) # [64, 256]
            # Generate a batch of images
            gen_imgs = generator(imgs.cuda()) + imgs_ns.cuda()  #[64, 64, 64]
            # Loss measures generator's ability to fool the discriminator
            g_loss = adversarial_loss(discriminator(gen_imgs), valid)
            g_loss.backward()
            optimizer_G.step()

            # ---------------------
            #  Train Discriminator
            # ---------------------

            optimizer_D.zero_grad()
            # Measure discriminator's ability to classify real from generated samples
            real_loss = adversarial_loss(discriminator(real_imgs), valid)
            fake_loss = adversarial_loss(discriminator(gen_imgs.detach()),
                                         fake)
            d_loss = (real_loss + fake_loss) / 2
            d_loss.backward()
            optimizer_D.step()

            print("[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f]" %
                  (epoch, opt.n_epochs, i, len(dataloader), d_loss.item(),
                   g_loss.item()))

            batches_done = epoch * len(dataloader) + i
            if batches_done % opt.sample_interval == 0:
                save_image(gen_imgs.data[:25],
                           'images/%d.png' % batches_done,
                           nrow=5,
                           normalize=True)
                weights_path = checkpoint_path.format(epoch=batches_done,
                                                      loss=g_loss)
                print(
                    'saving generator weights file to {}'.format(weights_path))
                torch.save(generator.state_dict(), weights_path)
Esempio n. 4
0
G21 is generator that learns to change image from 2 to 1
D1 is discriminator that differentiates fake generated by G21 from images of class 1
D2 is discriminator that differentiates fake generated by G12 from images of class 2

"""

G12 = Generator(3, 3)
G21 = Generator(3, 3)
D1 = Discriminator(3)
D2 = Discriminator(3)

# shift models to cuda if possible
if torch.cuda.is_available():
    G12.cuda()
    G21.cuda()
    D1.cuda()
    D2.cuda()

# %%
# optimizer and loss
LGAN = MSELoss()
LCYC = L1Loss()
LIdentity = L1Loss()

optimizer_G = Adam(itertools.chain(G12.parameters(), G21.parameters()),
                   lr=0.001)
optimizer_D1 = Adam(D1.parameters(), lr=0.001)
optimizer_D2 = Adam(D2.parameters(), lr=0.001)

# %%
# train models
Esempio n. 5
0
def main(params):
    if params['load_dataset']:
        dataset = load_pkl(params['load_dataset'])
    elif params['dataset_class']:
        dataset = globals()[params['dataset_class']](**params[params['dataset_class']])
        if params['save_dataset']:
            save_pkl(params['save_dataset'], dataset)
    else:
        raise Exception('One of either load_dataset (path to pkl) or dataset_class needs to be specified.')
    result_dir = create_result_subdir(params['result_dir'], params['exp_name'])

    losses = ['G_loss', 'D_loss', 'D_real', 'D_fake']
    stats_to_log = [
        'tick_stat',
        'kimg_stat',
    ]
    if params['progressive_growing']:
        stats_to_log.extend([
            'depth',
            'alpha',
            'lod',
            'minibatch_size'
        ])
    stats_to_log.extend([
        'time',
        'sec.tick',
        'sec.kimg'
    ] + losses)
    logger = TeeLogger(os.path.join(result_dir, 'log.txt'), stats_to_log, [(1, 'epoch')])
    logger.log(params_to_str(params))
    if params['resume_network']:
        G, D = load_models(params['resume_network'], params['result_dir'], logger)
    else:
        G = Generator(dataset.shape, **params['Generator'])
        D = Discriminator(dataset.shape, **params['Discriminator'])
    if params['progressive_growing']:
        assert G.max_depth == D.max_depth
    G.cuda()
    D.cuda()
    latent_size = params['Generator']['latent_size']

    logger.log(str(G))
    logger.log('Total nuber of parameters in Generator: {}'.format(
        sum(map(lambda x: reduce(lambda a, b: a*b, x.size()), G.parameters()))
    ))
    logger.log(str(D))
    logger.log('Total nuber of parameters in Discriminator: {}'.format(
        sum(map(lambda x: reduce(lambda a, b: a*b, x.size()), D.parameters()))
    ))

    def get_dataloader(minibatch_size):
        return DataLoader(dataset, minibatch_size, sampler=InfiniteRandomSampler(dataset),
                          num_workers=params['num_data_workers'], pin_memory=False, drop_last=True)

    def rl(bs):
        return lambda: random_latents(bs, latent_size)

    # Setting up learning rate and optimizers
    opt_g = Adam(G.parameters(), params['G_lr_max'], **params['Adam'])
    opt_d = Adam(D.parameters(), params['D_lr_max'], **params['Adam'])

    def rampup(cur_nimg):
        if cur_nimg < params['lr_rampup_kimg'] * 1000:
            p = max(0.0, 1 - cur_nimg / (params['lr_rampup_kimg'] * 1000))
            return np.exp(-p * p * 5.0)
        else:
            return 1.0
    lr_scheduler_d = LambdaLR(opt_d, rampup)
    lr_scheduler_g = LambdaLR(opt_g, rampup)

    mb_def = params['minibatch_size']
    D_loss_fun = partial(wgan_gp_D_loss, return_all=True, iwass_lambda=params['iwass_lambda'],
                         iwass_epsilon=params['iwass_epsilon'], iwass_target=params['iwass_target'])
    G_loss_fun = wgan_gp_G_loss
    trainer = Trainer(D, G, D_loss_fun, G_loss_fun,
                      opt_d, opt_g, dataset, iter(get_dataloader(mb_def)), rl(mb_def), **params['Trainer'])
    # plugins
    if params['progressive_growing']:
        max_depth = min(G.max_depth, D.max_depth)
        trainer.register_plugin(DepthManager(get_dataloader, rl, max_depth, **params['DepthManager']))
    for i, loss_name in enumerate(losses):
        trainer.register_plugin(EfficientLossMonitor(i, loss_name))

    checkpoints_dir = params['checkpoints_dir'] if params['checkpoints_dir'] else result_dir
    trainer.register_plugin(SaverPlugin(checkpoints_dir, **params['SaverPlugin']))

    def subsitute_samples_path(d):
        return {k:(os.path.join(result_dir, v) if k == 'samples_path' else v) for k,v in d.items()}
    postprocessors = [ globals()[x](**subsitute_samples_path(params[x])) for x in params['postprocessors'] ]
    trainer.register_plugin(OutputGenerator(lambda x: random_latents(x, latent_size),
                                            postprocessors, **params['OutputGenerator']))
    trainer.register_plugin(AbsoluteTimeMonitor(params['resume_time']))
    trainer.register_plugin(LRScheduler(lr_scheduler_d, lr_scheduler_g))
    trainer.register_plugin(logger)
    init_comet(params, trainer)
    trainer.run(params['total_kimg'])
    dataset.close()
Esempio n. 6
0

#Setting the DataLoader
directory = "images/images/"
dataset = AnimeDataset(directory)
training_data_loader = torch.utils.data.DataLoader(dataset, batch_size=128, shuffle=True)
'''
FOR TESTING
temp = training_data_loader.dataset.__getitem__(1)
print(temp.shape[1] == 64)
''' 
#BUILDING NETWORK STRUCTURE
generator = Generator(128)
discriminator = Discriminator(128)
generator.cuda()
discriminator.cuda()

def weights_init_normal(m):
    classname = m.__class__.__name__
    if classname.find("Conv") != -1:
        torch.nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find("BatchNorm2d") != -1:
        torch.nn.init.normal_(m.weight.data, 1.0, 0.02)
        torch.nn.init.constant_(m.bias.data, 0.0)

generator.apply(weights_init_normal)
discriminator.apply(weights_init_normal)

criterion = nn.BCELoss()
criterion.cuda()
optimizer_disc = optim.Adam(discriminator.parameters(),lr=learning_rate, betas=(0.5, 0.999))
Esempio n. 7
0
def main():
    os.environ["CUDA_VISIBLE_DEVICES"] = opt.gpus
    start_epoch = 0
    train_image_dataset = image_preprocessing(opt.dataset, 'train')
    data_loader = DataLoader(train_image_dataset, batch_size=opt.batch_size,
                            shuffle=True, num_workers=opt.num_workers)
    criterion = least_squares
    euclidean_l1 = nn.L1Loss()

    G = Generator(ResidualBlock, layer_count=9)
    F = Generator(ResidualBlock, layer_count=9)
    Dx = Discriminator()
    Dy = Discriminator()

    G_optimizer = optim.Adam(G.parameters(), lr=opt.lr, betas=(opt.beta1, opt.beta2))
    F_optimizer = optim.Adam(F.parameters(), lr=opt.lr, betas=(opt.beta1, opt.beta2))
    Dx_optimizer = optim.Adam(Dx.parameters(), lr=opt.lr, betas=(opt.beta1, opt.beta2))
    Dy_optimizer = optim.Adam(Dy.parameters(), lr=opt.lr, betas=(opt.beta1, opt.beta2))

    if torch.cuda.is_available():
        G = nn.DataParallel(G)
        F = nn.DataParallel(F)
        Dx = nn.DataParallel(Dx)
        Dy = nn.DataParallel(Dy)

        G = G.cuda()
        F = F.cuda()
        Dx = Dx.cuda()
        Dy = Dy.cuda()
    
    if opt.checkpoint is not None:
        G, F, Dx, Dy, G_optimizer, F_optimizer, Dx_optimizer, Dy_optimizer, start_epoch = load_ckp(opt.checkpoint, G, F, Dx, Dy, G_optimizer, F_optimizer, Dx_optimizer, Dy_optimizer)

    print('[Start] : Cycle GAN Training')

    logger = Logger(opt.epochs, len(data_loader), image_step=10)

    for epoch in range(opt.epochs):
        epoch = epoch + start_epoch + 1
        print("Epoch[{epoch}] : Start".format(epoch=epoch))
        
        for step, data in enumerate(data_loader):
            real_A = to_variable(data['A'])
            real_B = to_variable(data['B'])

            fake_B = G(real_A)
            fake_A = F(real_B)

            # Train Dx
            Dx_optimizer.zero_grad()

            Dx_real = Dx(real_A)
            Dx_fake = Dx(fake_A)

            Dx_loss = patch_loss(criterion, Dx_real, True) + patch_loss(criterion, Dx_fake, 0)

            Dx_loss.backward(retain_graph=True)
            Dx_optimizer.step()

            # Train Dy
            Dy_optimizer.zero_grad()

            Dy_real = Dy(real_B)
            Dy_fake = Dy(fake_B)

            Dy_loss = patch_loss(criterion, Dy_real, True) + patch_loss(criterion, Dy_fake, 0)

            Dy_loss.backward(retain_graph=True)
            Dy_optimizer.step()

            # Train G
            G_optimizer.zero_grad()

            Dy_fake = Dy(fake_B)

            G_loss = patch_loss(criterion, Dy_fake, True)

            # Train F
            F_optimizer.zero_grad()

            Dx_fake = Dx(fake_A)

            F_loss = patch_loss(criterion, Dx_fake, True)

            # identity loss
            loss_identity = euclidean_l1(real_A, fake_A) + euclidean_l1(real_B, fake_B)

            # cycle consistency
            loss_cycle = euclidean_l1(F(fake_B), real_A) + euclidean_l1(G(fake_A), real_B)

            # Optimize G & F
            loss = G_loss + F_loss + opt.lamda * loss_cycle + opt.lamda * loss_identity * (0.5)

            loss.backward()
            G_optimizer.step()
            F_optimizer.step()

            if (step + 1 ) % opt.save_step == 0:
                print("Epoch[{epoch}]| Step [{now}/{total}]| Dx Loss: {Dx_loss}, Dy_Loss: {Dy_loss}, G_Loss: {G_loss}, F_Loss: {F_loss}".format(
                    epoch=epoch, now=step + 1, total=len(data_loader), Dx_loss=Dx_loss.item(), Dy_loss=Dy_loss,
                    G_loss=G_loss.item(), F_loss=F_loss.item()))
                batch_image = torch.cat((torch.cat((real_A, real_B), 3), torch.cat((fake_A, fake_B), 3)), 2)

                torchvision.utils.save_image(denorm(batch_image[0]), opt.training_result + 'result_{result_name}_ep{epoch}_{step}.jpg'.format(result_name=opt.result_name,epoch=epoch, step=(step + 1) * opt.batch_size))
            
            # http://localhost:8097
            logger.log(
                losses={
                    'loss_G': G_loss,
                    'loss_F': F_loss,
                    'loss_identity': loss_identity,
                    'loss_cycle': loss_cycle,
                    'total_G_loss': loss,
                    'loss_Dx': Dx_loss,
                    'loss_Dy': Dy_loss,
                    'total_D_loss': (Dx_loss + Dy_loss),
                },
                images={
                    'real_A': real_A,
                    'real_B': real_B,
                    'fake_A': fake_A,
                    'fake_ B': fake_B,
                },
            )


        torch.save({
            'epoch': epoch,
            'G_model': G.state_dict(),
            'G_optimizer': G_optimizer.state_dict(),
            'F_model': F.state_dict(),
            'F_optimizer': F_optimizer.state_dict(),
            'Dx_model': Dx.state_dict(),
            'Dx_optimizer': Dx_optimizer.state_dict(),
            'Dy_model': Dy.state_dict(),
            'Dy_optimizer': Dy_optimizer.state_dict(),
        }, opt.save_model + 'model_{result_name}_CycleGAN_ep{epoch}.ckp'.format(result_name=opt.result_name, epoch=epoch))
Esempio n. 8
0
def train(args):
    # set the logger
    logger = Logger('./logs')

    # GPU enabling
    if (args.gpu != None):
        use_cuda = True
        dtype = torch.cuda.FloatTensor
        torch.cuda.set_device(args.gpu)
        print("Current device: %s" % torch.cuda.get_device_name(args.gpu))

    # define networks
    g_AtoB = Generator().type(dtype)
    g_BtoA = Generator().type(dtype)
    d_A = Discriminator().type(dtype)
    d_B = Discriminator().type(dtype)

    # optimizers
    optimizer_generators = Adam(
        list(g_AtoB.parameters()) + list(g_BtoA.parameters()), INITIAL_LR)
    optimizer_d_A = Adam(d_A.parameters(), INITIAL_LR)
    optimizer_d_B = Adam(d_B.parameters(), INITIAL_LR)

    # loss criterion
    criterion_mse = torch.nn.MSELoss()
    criterion_l1 = torch.nn.L1Loss()

    # get training data
    dataset_transform = transforms.Compose([
        transforms.Resize(int(IMAGE_SIZE * 1),
                          Image.BICUBIC),  # scale shortest side to image_size
        transforms.RandomCrop(
            (IMAGE_SIZE, IMAGE_SIZE)),  # random center image_size out
        transforms.ToTensor(),  # turn image from [0-255] to [0-1]
        transforms.Normalize(mean=(0.5, 0.5, 0.5),
                             std=(0.5, 0.5, 0.5))  # normalize
    ])
    dataloader = DataLoader(ImgPairDataset(args.dataroot, dataset_transform,
                                           'train'),
                            batch_size=BATCH_SIZE,
                            shuffle=True)

    # get some test data to display periodically
    test_data_A = torch.tensor([]).type(dtype)
    test_data_B = torch.tensor([]).type(dtype)
    for i in range(NUM_TEST_SAMPLES):
        imgA = ImgPairDataset(args.dataroot, dataset_transform,
                              'test')[i]['A'].type(dtype).unsqueeze(0)
        imgB = ImgPairDataset(args.dataroot, dataset_transform,
                              'test')[i]['B'].type(dtype).unsqueeze(0)
        test_data_A = torch.cat((test_data_A, imgA), dim=0)
        test_data_B = torch.cat((test_data_B, imgB), dim=0)

        fileStrA = 'visualization/test_%d/%s/' % (i, 'B_inStyleofA')
        fileStrB = 'visualization/test_%d/%s/' % (i, 'A_inStyleofB')
        if not os.path.exists(fileStrA):
            os.makedirs(fileStrA)
        if not os.path.exists(fileStrB):
            os.makedirs(fileStrB)

        fileStrA = 'visualization/test_original_%s_%04d.png' % ('A', i)
        fileStrB = 'visualization/test_original_%s_%04d.png' % ('B', i)
        utils.save_image(
            fileStrA,
            ImgPairDataset(args.dataroot, dataset_transform,
                           'test')[i]['A'].data)
        utils.save_image(
            fileStrB,
            ImgPairDataset(args.dataroot, dataset_transform,
                           'test')[i]['B'].data)

    # replay buffers
    replayBufferA = utils.ReplayBuffer(50)
    replayBufferB = utils.ReplayBuffer(50)

    # training loop
    step = 0
    for e in range(EPOCHS):
        startTime = time.time()
        for idx, batch in enumerate(dataloader):
            real_A = batch['A'].type(dtype)
            real_B = batch['B'].type(dtype)

            # some examples seem to have only 1 color channel instead of 3
            if (real_A.shape[1] != 3):
                continue
            if (real_B.shape[1] != 3):
                continue

            # -----------------
            #  train generators
            # -----------------
            optimizer_generators.zero_grad()
            utils.learning_rate_decay(INITIAL_LR, e, EPOCHS,
                                      optimizer_generators)

            # GAN loss
            fake_A = g_BtoA(real_B)
            disc_fake_A = d_A(fake_A)
            fake_B = g_AtoB(real_A)
            disc_fake_B = d_B(fake_B)

            replayBufferA.push(torch.tensor(fake_A.data))
            replayBufferB.push(torch.tensor(fake_B.data))

            target_real = Variable(torch.ones_like(disc_fake_A)).type(dtype)
            target_fake = Variable(torch.zeros_like(disc_fake_A)).type(dtype)

            loss_gan_AtoB = criterion_mse(disc_fake_B, target_real)
            loss_gan_BtoA = criterion_mse(disc_fake_A, target_real)
            loss_gan = loss_gan_AtoB + loss_gan_BtoA

            # cyclic reconstruction loss
            cyclic_A = g_BtoA(fake_B)
            cyclic_B = g_AtoB(fake_A)
            loss_cyclic_AtoBtoA = criterion_l1(cyclic_A,
                                               real_A) * CYCLIC_WEIGHT
            loss_cyclic_BtoAtoB = criterion_l1(cyclic_B,
                                               real_B) * CYCLIC_WEIGHT
            loss_cyclic = loss_cyclic_AtoBtoA + loss_cyclic_BtoAtoB

            # identity loss
            loss_identity = 0
            loss_identity_A = 0
            loss_identity_B = 0
            if (args.use_identity == True):
                identity_A = g_BtoA(real_A)
                identity_B = g_AtoB(real_B)
                loss_identity_A = criterion_l1(identity_A,
                                               real_A) * 0.5 * CYCLIC_WEIGHT
                loss_identity_B = criterion_l1(identity_B,
                                               real_B) * 0.5 * CYCLIC_WEIGHT
                loss_identity = loss_identity_A + loss_identity_B

            loss_generators = loss_gan + loss_cyclic + loss_identity
            loss_generators.backward()
            optimizer_generators.step()

            # -----------------
            #  train discriminators
            # -----------------
            optimizer_d_A.zero_grad()
            utils.learning_rate_decay(INITIAL_LR, e, EPOCHS, optimizer_d_A)

            fake_A = replayBufferA.sample(1).detach()
            disc_fake_A = d_A(fake_A)
            disc_real_A = d_A(real_A)
            loss_d_A = 0.5 * (criterion_mse(disc_real_A, target_real) +
                              criterion_mse(disc_fake_A, target_fake))

            loss_d_A.backward()
            optimizer_d_A.step()

            optimizer_d_B.zero_grad()
            utils.learning_rate_decay(INITIAL_LR, e, EPOCHS, optimizer_d_B)

            fake_B = replayBufferB.sample(1).detach()
            disc_fake_B = d_B(fake_B)
            disc_real_B = d_B(real_B)
            loss_d_B = 0.5 * (criterion_mse(disc_real_B, target_real) +
                              criterion_mse(disc_fake_B, target_fake))

            loss_d_B.backward()
            optimizer_d_B.step()

            #log info and save sample images
            if ((idx % 250) == 0):
                # eval on some sample images
                g_AtoB.eval()
                g_BtoA.eval()

                test_B_hat = g_AtoB(test_data_A).cpu()
                test_A_hat = g_BtoA(test_data_B).cpu()

                fileBaseStr = 'test_%d_%d' % (e, idx)
                for i in range(NUM_TEST_SAMPLES):
                    fileStrA = 'visualization/test_%d/%s/%03d_%04d.png' % (
                        i, 'B_inStyleofA', e, idx)
                    fileStrB = 'visualization/test_%d/%s/%03d_%04d.png' % (
                        i, 'A_inStyleofB', e, idx)
                    utils.save_image(fileStrA, test_A_hat[i].data)
                    utils.save_image(fileStrB, test_B_hat[i].data)

                g_AtoB.train()
                g_BtoA.train()

                endTime = time.time()
                timeForIntervalIterations = endTime - startTime
                startTime = endTime

                print(
                    'Epoch [{:3d}/{:3d}], Training [{:4d}/{:4d}], Time Spent (s): [{:4.4f}], Losses: [G_GAN: {:4.4f}][G_CYC: {:4.4f}][G_IDT: {:4.4f}][D_A: {:4.4f}][D_B: {:4.4f}]'
                    .format(e, EPOCHS, idx, len(dataloader),
                            timeForIntervalIterations, loss_gan, loss_cyclic,
                            loss_identity, loss_d_A, loss_d_B))

                # tensorboard logging
                info = {
                    'loss_generators':
                    loss_generators.item(),
                    'loss_gan_AtoB':
                    loss_gan_AtoB.item(),
                    'loss_gan_BtoA':
                    loss_gan_BtoA.item(),
                    'loss_cyclic_AtoBtoA':
                    loss_cyclic_AtoBtoA.item(),
                    'loss_cyclic_BtoAtoB':
                    loss_cyclic_BtoAtoB.item(),
                    'loss_cyclic':
                    loss_cyclic.item(),
                    'loss_d_A':
                    loss_d_A.item(),
                    'loss_d_B':
                    loss_d_B.item(),
                    'lr_optimizer_generators':
                    optimizer_generators.param_groups[0]['lr'],
                    'lr_optimizer_d_A':
                    optimizer_d_A.param_groups[0]['lr'],
                    'lr_optimizer_d_B':
                    optimizer_d_B.param_groups[0]['lr'],
                }
                if (args.use_identity):
                    info['loss_identity_A'] = loss_identity_A.item()
                    info['loss_identity_B'] = loss_identity_B.item()
                for tag, value in info.items():
                    logger.scalar_summary(tag, value, step)

                info = {
                    'test_A_hat':
                    test_A_hat.data.numpy().transpose(0, 2, 3, 1),
                    'test_B_hat':
                    test_B_hat.data.numpy().transpose(0, 2, 3, 1),
                }
                for tag, images in info.items():
                    logger.image_summary(tag, images, step)

            step += 1

        # save after every epoch
        g_AtoB.eval()
        g_BtoA.eval()
        d_A.eval()
        d_B.eval()

        if use_cuda:
            g_AtoB.cpu()
            g_BtoA.cpu()
            d_A.cpu()
            d_B.cpu()

        if not os.path.exists("models"):
            os.makedirs("models")
        filename_gAtoB = "models/" + str('g_AtoB') + "_epoch_" + str(
            e) + ".model"
        filename_gBtoA = "models/" + str('g_BtoA') + "_epoch_" + str(
            e) + ".model"
        filename_dA = "models/" + str('d_A') + "_epoch_" + str(e) + ".model"
        filename_dB = "models/" + str('d_B') + "_epoch_" + str(e) + ".model"
        torch.save(g_AtoB.state_dict(), filename_gAtoB)
        torch.save(g_BtoA.state_dict(), filename_gBtoA)
        torch.save(d_A.state_dict(), filename_dA)
        torch.save(d_B.state_dict(), filename_dB)

        if use_cuda:
            g_AtoB.cuda()
            g_BtoA.cuda()
            d_A.cuda()
            d_B.cuda()
Esempio n. 9
0
def main(params):
    if params['load_dataset']:
        dataset = load_pkl(params['load_dataset'])
    elif params['dataset_class']:
        dataset = globals()[params['dataset_class']](
            **params[params['dataset_class']])
        if params['save_dataset']:
            save_pkl(params['save_dataset'], dataset)
    else:
        raise Exception(
            'One of either load_dataset (path to pkl) or dataset_class needs to be specified.'
        )
    result_dir = create_result_subdir(params['result_dir'], params['exp_name'])

    losses = ['G_loss', 'D_loss', 'D_real', 'D_fake']
    stats_to_log = [
        'tick_stat',
        'kimg_stat',
    ]
    if params['progressive_growing']:
        stats_to_log.extend(['depth', 'alpha', 'lod', 'minibatch_size'])
    stats_to_log.extend(['time', 'sec.tick', 'sec.kimg'] + losses)
    logger = TeeLogger(os.path.join(result_dir, 'log.txt'), stats_to_log,
                       [(1, 'epoch')])
    logger.log(params_to_str(params))
    if params['resume_network']:
        G, D = load_models(params['resume_network'], params['result_dir'],
                           logger)
    else:
        G = Generator(dataset.shape, **params['Generator'])
        D = Discriminator(dataset.shape, **params['Discriminator'])
    if params['progressive_growing']:
        assert G.max_depth == D.max_depth
    G.cuda()
    D.cuda()
    latent_size = params['Generator']['latent_size']

    logger.log(str(G))
    logger.log('Total nuber of parameters in Generator: {}'.format(
        sum(map(lambda x: reduce(lambda a, b: a * b, x.size()),
                G.parameters()))))
    logger.log(str(D))
    logger.log('Total nuber of parameters in Discriminator: {}'.format(
        sum(map(lambda x: reduce(lambda a, b: a * b, x.size()),
                D.parameters()))))

    def get_dataloader(minibatch_size):
        return DataLoader(dataset,
                          minibatch_size,
                          sampler=InfiniteRandomSampler(dataset),
                          num_workers=params['num_data_workers'],
                          pin_memory=False,
                          drop_last=True)

    def rl(bs):
        return lambda: random_latents(bs, latent_size)

    # Setting up learning rate and optimizers
    opt_g = Adam(G.parameters(), params['G_lr_max'], **params['Adam'])
    opt_d = Adam(D.parameters(), params['D_lr_max'], **params['Adam'])

    def rampup(cur_nimg):
        if cur_nimg < params['lr_rampup_kimg'] * 1000:
            p = max(0.0, 1 - cur_nimg / (params['lr_rampup_kimg'] * 1000))
            return np.exp(-p * p * 5.0)
        else:
            return 1.0

    lr_scheduler_d = LambdaLR(opt_d, rampup)
    lr_scheduler_g = LambdaLR(opt_g, rampup)

    mb_def = params['minibatch_size']
    D_loss_fun = partial(wgan_gp_D_loss,
                         return_all=True,
                         iwass_lambda=params['iwass_lambda'],
                         iwass_epsilon=params['iwass_epsilon'],
                         iwass_target=params['iwass_target'])
    G_loss_fun = wgan_gp_G_loss
    trainer = Trainer(D, G, D_loss_fun, G_loss_fun, opt_d, opt_g, dataset,
                      iter(get_dataloader(mb_def)), rl(mb_def),
                      **params['Trainer'])
    # plugins
    if params['progressive_growing']:
        max_depth = min(G.max_depth, D.max_depth)
        trainer.register_plugin(
            DepthManager(get_dataloader, rl, max_depth,
                         **params['DepthManager']))
    for i, loss_name in enumerate(losses):
        trainer.register_plugin(EfficientLossMonitor(i, loss_name))

    checkpoints_dir = params['checkpoints_dir'] if params[
        'checkpoints_dir'] else result_dir
    trainer.register_plugin(
        SaverPlugin(checkpoints_dir, **params['SaverPlugin']))

    def subsitute_samples_path(d):
        return {
            k: (os.path.join(result_dir, v) if k == 'samples_path' else v)
            for k, v in d.items()
        }

    postprocessors = [
        globals()[x](**subsitute_samples_path(params[x]))
        for x in params['postprocessors']
    ]
    trainer.register_plugin(
        OutputGenerator(lambda x: random_latents(x, latent_size),
                        postprocessors, **params['OutputGenerator']))
    trainer.register_plugin(AbsoluteTimeMonitor(params['resume_time']))
    trainer.register_plugin(LRScheduler(lr_scheduler_d, lr_scheduler_g))
    trainer.register_plugin(logger)
    init_comet(params, trainer)
    trainer.run(params['total_kimg'])
    dataset.close()
Esempio n. 10
0
import torchvision.transforms as transforms
from utils import imshow_grid, mse_loss, reparameterize, l1_loss
from network import Generator, Discriminator
from torchvision.utils import save_image

batch_size = 8
num_epochs = 500
image_size = 128

generator = Generator(nc_dim=80)
generator.apply(weights_init)
generator = generator.cuda()

discriminator = Discriminator()
discriminator.apply(weights_init)
discriminator = discriminator.cuda()

ones_label = torch.ones(batch_size)
ones_label = Variable(ones_label.cuda())
zeros_label = torch.zeros(batch_size)
zeros_label = Variable(zeros_label.cuda())

loss = nn.BCEWithLogitsLoss()


def to_img(x):
    x = x.clamp(0, 1)
    x = x.view(x.size(0), 3, 128, 128)
    return x

Esempio n. 11
0
def main():
    start_epoch = 0
    train_image_dataset = image_preprocessing(args.dataset)
    data_loader = DataLoader(train_image_dataset, batch_size=args.batch_size,
                            shuffle=True, num_workers=args.num_workers)
    
    criterion = nn.BCELoss()
    euclidean_l1 = nn.L1Loss()
    G = Generator()
    D = Discriminator()

    D_optimizer = optim.Adam(D.parameters(), lr=args.lr, betas=(args.beta1, args.beta2))
    G_optimizer = optim.Adam(G.parameters(), lr=args.lr, betas=(args.beta1, args.beta2))

    if torch.cuda.is_available():
        G = nn.DataParallel(G)
        D = nn.DataParallel(D)

        G = G.cuda()
        D = D.cuda()

    if args.checkpoint is not None:
        G, D, G_optimizer, D_optimizer, start_epoch = load_ckp(args.checkpoint, G, D, G_optimizer, D_optimizer)

    print('[Start] : pix2pix Training')
    for epoch in range(args.epochs):
        epoch = epoch + start_epoch + 1
        print("Epoch[{epoch}] : Start".format(epoch=epoch))
        for step, data in enumerate(data_loader):
            real_A = to_variable(data['A'])
            real_B = to_variable(data['B'])
            fake_B = G(real_A)

            # Train Discriminator
            D_fake = D(torch.cat((real_A, fake_B), 1))
            D_real = D(torch.cat((real_A, real_B), 1))

            D_loss = 0.5 * patch_loss(criterion, D_fake, False) + 0.5 * patch_loss(criterion, D_real, True)
            
            D_optimizer.zero_grad()
            D_loss.backward(retain_graph=True)
            D_optimizer.step()

            # Train Generator

            D_fake = D(torch.cat((real_A, fake_B), 1))

            G_loss = patch_loss(criterion, D_fake, True) + euclidean_l1(fake_B, real_B) * args.lamda

            G_optimizer.zero_grad()
            G_loss.backward(retain_graph=True)
            G_optimizer.step()

            if (step + 1) % args.save_step == 0:
                print("Epoch[{epoch}] |  Step [{now}/{total}] : D Loss : {D_loss}, Patch_loss : {patch_loss}, G_losss : {G_loss}".format(epoch=epoch, now=step + 1, total=len(data_loader), D_loss=D_loss.item(), patch_loss=patch_loss(criterion, D_fake, True), G_loss=G_loss.item()))
                #check 
                batch_image = (torch.cat((torch.cat((real_A, fake_B), 3), real_B), 3))
                torchvision.utils.save_image(denorm(batch_image[0]), args.training_result + 'result_{result_name}_ep{epoch}_{step}.jpg'.format(result_name=args.result_name,epoch=epoch, step=(step + 1) * 4))
        
        torch.save({
            'epoch': epoch,
            'G_model': G.state_dict(),
            'G_optimizer': G_optimizer.state_dict(),
            'D_model': D.state_dict(),
            'D_optimizer': D_optimizer.state_dict(),
        }, args.save_model + 'model_{result_name}_pix2pix_ep{epoch}.ckp'.format(result_name=args.result_name, epoch=epoch))
Esempio n. 12
0
	network = Network()
	network.apply(weights_init)

	d = Discriminator()
	d.apply(weights_init)
	bce_loss = torch.nn.BCEWithLogitsLoss()
	true_crit, fake_crit = torch.ones(args.batch_size, 1, device='cuda'), torch.zeros(args.batch_size, 1, device='cuda')
	
	print(network)
	train_loss = []
	val_loss = []
	if args.train:
		optimizer = torch.optim.Adam(network.parameters(), lr=args.lr)
		d_optimizer = torch.optim.Adam(d.parameters(), lr=args.lr)
		network.cuda()
		d.cuda()
		for epoch in range(args.num_iters):
			train_loader,val_loader,test_loader = read_data(args.batch_size)
			network.train()
			for idx, x in enumerate(train_loader):
				img64 = x['img64'].cuda()
				img128 = x['img128'].cuda()
				imgname = x['img_name']
				optimizer.zero_grad()
				g_img128 = network(img64)
				l2_loss = ((255*(img128-g_img128))**2).mean()
				l1_loss = (abs(255*(img128-g_img128))).mean()
				rmse_loss = rmse(img128,g_img128)
				ssim_loss = ssim(img128,g_img128)
				# tv_losss = tv_loss(255*img128,255*g_img128)
				# dloss = bce_loss(d(g_img128,img64),true_crit)
                       num_pts=opt.num_pts,
                       transform=transform,
                       augmentation=opt.augmentation)
train_loader = torch.utils.data.DataLoader(trainset,
                                           batch_size=opt.batch_size,
                                           shuffle=True,
                                           num_workers=opt.num_workers)
""" Networks : Generator & Discriminator """
G = Generator(opt.num_pts)
D = Discriminator(opt.num_pts)

G.weight_init(mean=0.0, std=0.02)
D.weight_init(mean=0.0, std=0.02)
""" set CUDA """
G.cuda()
D.cuda()
""" Optimizer """
G_optimizer = optim.Adam(G.parameters(), lr=cur_lrG, betas=(0.9, 0.999))
D_optimizer = optim.Adam(D.parameters(), lr=cur_lrD, betas=(0.9, 0.999))
""" Restore """
if opt.restore:
    print('==> Restoring from checkpoint..', opt.restore)
    state = torch.load(opt.restore)

    G.load_state_dict(state['G'])
    D.load_state_dict(state['D'])
    G_optimizer.load_state_dict(state["G_optimizer"])
    D_optimizer.load_state_dict(state["D_optimizer"])
    epoch = state["epoch"]
    global_iter += state["iter"]
    cur_lrG = state["lrG"]