Пример #1
0
def main(args, data_root):
    os.listdir(data_root)
    img_shape = (args.channels, args.img_size, args.img_size)

    cuda = True if torch.cuda.is_available() else False

    # Loss function
    adversarial_loss = torch.nn.BCELoss()

    # Initialize generator and discriminator
    generator = Generator(args.img_size, args.latent_dim, args.channels)
    discriminator = Discriminator(args.img_size, args.channels)

    if cuda:
        generator.cuda()
        discriminator.cuda()
        adversarial_loss.cuda()

    generator.apply(weights_init_normal)
    discriminator.apply(weights_init_normal)

    dataset = datasets.ImageFolder(root=data_root,
                           transform=transforms.Compose([
                               transforms.Resize(args.img_size),
                               transforms.CenterCrop(args.img_size),
                               transforms.ToTensor(),
                               transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
                           ]))
    # Create the dataloader
    dataloader = torch.utils.data.DataLoader(dataset, batch_size=args.batch_size,
                                             shuffle=True, num_workers=2)


    train(generator, discriminator, dataloader, args, cuda, adversarial_loss)
Пример #2
0
def main(args):

    # Create sample and checkpoint directories
    os.makedirs("images/{}".format(args.data_root), exist_ok=True)
    os.makedirs("saved_models/{}".format(args.data_root), exist_ok=True)

    # Loss weight of L1 pixel-wise loss between translated image and real image
    lambda_pixel = 100

    # Initialize generator and discriminator
    generator = GeneratorUNet()
    discriminator = Discriminator()

    if cuda:
        generator = generator.cuda()
        discriminator = discriminator.cuda()

    if args.epoch != 0:
    # Load pretrained models
        generator.load_state_dict(torch.load("saved_models/{}/generator_{}.pth" % (args.data_root, args.epoch)))
        discriminator.load_state_dict(torch.load("saved_models/{}/discriminator_{}.pth" % (args.data_root, args.epoch)))
    else:
        # Initialize weights
        generator.apply(weights_init_normal)
        discriminator.apply(weights_init_normal)


    # Optimizers
    optimizer_G = torch.optim.Adam(generator.parameters(), lr=args.lr, betas=(args.b1, args.b2))
    optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=args.lr, betas=(args.b1, args.b2))

    transforms_ = [
    transforms.Resize((args.img_height, args.img_width), Image.BICUBIC),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
    ]

    # Training data loader
    dataloader = DataLoader(
        FacadeDataset("../../datasets/{}".format(args.data_root), transforms_=transforms_),
        batch_size=args.batch_size,
        shuffle=True,
        num_workers=args.n_cpu,
    )
    # Test data loader
    # global val_dataloader
    val_dataloader = DataLoader(
        FacadeDataset("../../datasets/{}".format(args.data_root), transforms_=transforms_,  mode="val"),
        batch_size=10,
        shuffle=True,
        num_workers=1,
    )

    optimizer_list = [optimizer_G, optimizer_D]
    network_list = [generator, discriminator]

    dataloaders = [dataloader, val_dataloader]
    train(args, network_list, optimizer_list, dataloaders)
Пример #3
0
def models(channels):
    """
    Creates and initializes the models
    :return: Encoder, Generator, Discriminator
    """

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    encoder = Encoder(channels).to(device)
    encoder.apply(init_weights)

    generator = Generator(channels).to(device)
    generator.apply(init_weights)

    discriminator = Discriminator(channels).to(device)
    discriminator.apply(init_weights)

    return encoder, generator, discriminator
Пример #4
0
def CycleGANmapper(inC, outC, options, init_weights=True):
	GA2B = Generator(inC, outC)
	DB = Discriminator(outC)
	GB2A = Generator(outC, inC)
	DA = Discriminator(inC)
	if options["cuda"]:
		GA2B.cuda()
		DB.cuda()
		GB2A.cuda()
		DA.cuda()
	if init_weights:
		GA2B.apply(weights_init_normal)
		DB.apply(weights_init_normal)
		GB2A.apply(weights_init_normal)
		DA.apply(weights_init_normal)
		return (GA2B, DB), (GB2A, DA)
	if options["continued"]:
		GA2B.load_state_dict(torch.load("output/GEN_AtoB.pth"))
		DB.load_state_dict(torch.load("output/DIS_B.pth"))
		GB2A.load_state_dict(torch.load("output/GEN_BtoA.pth"))
		DA.load_state_dict(torch.load("output/DIS_A.pth"))
		return (GA2B, DB), (GB2A, DA)
Пример #5
0
def main():

    config = get_config()

    # general
    general = config['general']
    dataroot = general['dataroot']
    workers = general['workers']
    gpu = general['gpu']
    batch_size = general['batch_size']
    lr = general['lr']
    beta = general['beta']
    epoch = general['epoch']
    image_size = general['image_size']

    # model config
    model_config = config['model']
    nz = model_config['nz']
    ndf = model_config['ndf']
    ngf = model_config['ngf']

    dataset = datasets.ImageFolder(root=dataroot,
                                   transform=transforms.Compose([
                                       transforms.Resize(image_size),
                                       transforms.CenterCrop(image_size),
                                       transforms.ToTensor(),
                                       transforms.Normalize((0.5, 0.5, 0.5),
                                                            (0.5, 0.5, 0.5)),
                                   ]))
    dataloader = data.DataLoader(
        dataset=dataset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=workers,
    )

    net_D = Discriminator(ndf)
    net_G = Generator(nz, ngf)

    if gpu:
        device = torch.device('cuda')
        net_D = nn.DataParallel(net_D.to(device))
        net_G = nn.DataParallel(net_G.to(device))
    else:
        device = 'cpu'

    def weights_init(m):
        classname = m.__class__.__name__
        if classname.find('Conv') != -1:
            nn.init.normal_(m.weight.data, 0.0, 0.02)
        elif classname.find('BatchNorm') != -1:
            nn.init.normal_(m.weight.data, 1.0, 0.02)
            nn.init.constant_(m.bias.data, 0)

    net_D.apply(weights_init)
    net_G.apply(weights_init)

    optimizers = {
        'Discriminator': optim.Adam(net_D.parameters(),
                                    lr,
                                    betas=(beta, 0.999)),
        'Generator': optim.Adam(net_G.parameters(), lr, betas=(beta, 0.999))
    }
    models = {
        'Discriminator': net_D,
        'Generator': net_G,
    }
    updater = DCGANUpdater(optimizers, models, dataloader, device)
    trainer = Trainer(updater, {'epoch': epoch}, 'test')
    trainer.extend(
        LogReport([
            'iteration',
            'training/D_real',
            'training/D_fake',
            'training/D_loss',
            'training/G_loss',
            'elapsed_time',
        ], {'iteration': 100}))

    trainer.extend(ProgressBar(10))

    save_trigger = MinValueTrigger('training/G_loss',
                                   trigger={'iteration': 100})
    trainer.extend(SnapshotModel(trigger=save_trigger))

    trainer.run()
    print(config)
Пример #6
0
###### Definition of variables ######
# Networks
netG_A2B = Generator(opt.input_nc, opt.output_nc)
netG_B2A = Generator(opt.output_nc, opt.input_nc)
netD_A = Discriminator(opt.input_nc)
netD_B = Discriminator(opt.output_nc)

if opt.cuda:
    netG_A2B.cuda()
    netG_B2A.cuda()
    netD_A.cuda()
    netD_B.cuda()

netG_A2B.apply(weights_init_normal)
netG_B2A.apply(weights_init_normal)
netD_A.apply(weights_init_normal)
netD_B.apply(weights_init_normal)

# Lossess
criterion_GAN = torch.nn.MSELoss()
criterion_cycle = torch.nn.L1Loss()
criterion_identity = torch.nn.L1Loss()

# Optimizers & LR schedulers
optimizer_G = torch.optim.Adam(itertools.chain(netG_A2B.parameters(),
                                               netG_B2A.parameters()),
                               lr=opt.lr,
                               betas=(0.5, 0.999))
optimizer_D_A = torch.optim.Adam(netD_A.parameters(),
                                 lr=opt.lr,
                                 betas=(0.5, 0.999))
Пример #7
0
device = torch.device("cuda:0" if opt.cuda else "cpu")
ngpu = int(opt.ngpu)
nz = int(opt.nz)
ngf = int(opt.ngf)
ndf = int(opt.ndf)
nc = 3

netG = Generator(ngpu, nz, ngf, nc).to(device)
netG.apply(weights_init)
if opt.netG != '':
    netG.load_state_dict(torch.load(opt.netG))
print(netG)

netD = Discriminator(ngpu, nc, ndf).to(device)
netD.apply(weights_init)
if opt.netD != '':
    netD.load_state_dict(torch.load(opt.netD))
print(netD)

criterion = nn.BCELoss()

fixed_noise = torch.randn(opt.batchSize, nz, 1, 1, device=device)
real_label = 1
fake_label = 0

# setup optimizer
optimizerD = optim.Adam(netD.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999))
optimizerG = optim.Adam(netG.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999))

for epoch in range(opt.niter):
Пример #8
0
                                   device_ids=range(torch.cuda.device_count()))

    netD_A_self.cuda()
    netD_A_self = torch.nn.DataParallel(netD_A,
                                        device_ids=range(
                                            torch.cuda.device_count()))

    # netD_A_content.cuda()
    # netD_A_content = torch.nn.DataParallel(netD_A_content, device_ids=range(torch.cuda.device_count()))
    #
    # netD_B_content.cuda()
    # netD_B_content = torch.nn.DataParallel(netD_B_content, device_ids=range(torch.cuda.device_count()))

netG_A2B.apply(weights_init_normal)

netD_A.apply(weights_init_normal)

netD_A_self.apply(weights_init_normal)

# Lossess
criterion_GAN = torch.nn.MSELoss()
criterion_cycle = torch.nn.L1Loss()
criterion_identity = torch.nn.L1Loss()

# Optimizers & LR schedulers
optimizer_G = torch.optim.Adam(itertools.chain(netG_A2B.parameters()),
                               lr=opt.lr,
                               betas=(0.5, 0.999))
optimizer_D_A = torch.optim.Adam(netD_A.parameters(),
                                 lr=opt.lr,
                                 betas=(0.5, 0.999))
Пример #9
0
def train(args):

    data_root = args.path
    total_iterations = args.iter
    checkpoint = args.ckpt
    batch_size = args.batch_size
    im_size = args.im_size
    ndf = 64
    ngf = 64
    nz = 256
    nlr = 0.0002
    nbeta1 = 0.5
    use_cuda = True
    multi_gpu = False
    dataloader_workers = 8
    current_iteration = 0
    save_interval = 100
    saved_model_folder, saved_image_folder = get_dir(args)

    device = torch.device("cpu")
    if use_cuda:
        device = torch.device("cuda:0")

    transform_list = [
        transforms.Resize((int(im_size), int(im_size))),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
    ]
    trans = transforms.Compose(transform_list)

    dataset = ImageFolder(root=data_root, transform=trans)
    dataloader = iter(
        DataLoader(dataset,
                   batch_size=batch_size,
                   shuffle=False,
                   sampler=InfiniteSamplerWrapper(dataset),
                   num_workers=dataloader_workers,
                   pin_memory=True))

    netD = Discriminator(ndf=ndf, im_size=im_size)
    netD.apply(weights_init)

    net_decoder = SimpleDecoder(ndf * 16)
    net_decoder.apply(weights_init)

    net_decoder.to(device)

    ckpt = torch.load(checkpoint)
    netD.load_state_dict(ckpt['d'])
    netD.to(device)
    netD.eval()

    optimizerG = optim.Adam(net_decoder.parameters(),
                            lr=nlr,
                            betas=(nbeta1, 0.999))

    log_rec_loss = 0

    for iteration in tqdm(range(current_iteration, total_iterations + 1)):
        real_image = next(dataloader)
        real_image = real_image.to(device)
        current_batch_size = real_image.size(0)

        net_decoder.zero_grad()

        feat = netD.get_feat(real_image)
        g_imag = net_decoder(feat)

        target_image = F.interpolate(real_image, g_imag.shape[2])

        rec_loss = percept(g_imag, target_image).sum()

        rec_loss.backward()

        optimizerG.step()

        log_rec_loss += rec_loss.item()

        if iteration % 100 == 0:
            print("lpips loss d: %.5f " % (log_rec_loss / 100))
            log_rec_loss = 0

        if iteration % (save_interval * 10) == 0:

            with torch.no_grad():
                vutils.save_image(
                    torch.cat([target_image, g_imag]).add(1).mul(0.5),
                    saved_image_folder + '/rec_%d.jpg' % iteration)

        if iteration % (save_interval *
                        50) == 0 or iteration == total_iterations:
            torch.save(
                {
                    'd': netD.state_dict(),
                    'dec': net_decoder.state_dict()
                }, saved_model_folder + '/%d.pth' % iteration)
Пример #10
0
    disc_net.to(device='cuda: 0')
    gen1_net.to(device='cuda: 0')
    gen2_net.to(device='cuda: 0')

if args.parallel and cuda:
    disc_net = torch.nn.DataParallel(Discriminator())
    gen1_net = torch.nn.DataParallel(Generator1_CAN8(is_anm=args.is_anm))
    gen2_net = torch.nn.DataParallel(Generator2_UCAN64(is_anm=args.is_anm))

    disc_net.to(device='cuda: 0')
    gen1_net.to(device='cuda: 0')
    gen2_net.to(device='cuda: 0')


# weight init
disc_net.apply(init_linear_weights)
gen1_net.apply(init_conv2_weights)
gen2_net.apply(init_conv2_weights)

# optimizers
disc_optimizer = torch.optim.Adam(disc_net.parameters(), lr=args.d_lr, betas=(0.5, 0.999))
gen1_optimizer = torch.optim.Adam(gen1_net.parameters(), lr=args.g1_lr, betas=(0.9, 0.999))
gen2_optimizer = torch.optim.Adam(gen2_net.parameters(), lr=args.g2_lr, betas=(0.9, 0.999))

# dataset
logging.info("Preparing dataset...")

composed = transforms.Compose(
    [
        transforms.Grayscale(1),
        transforms.ToTensor(),
Пример #11
0
Файл: train.py Проект: nnuq/tpu
def main(index, args):
    device = xm.xla_device()

    gen_net = Generator(args).to(device)
    dis_net = Discriminator(args).to(device)
    enc_net = Encoder(args).to(device)

    def weights_init(m):
        classname = m.__class__.__name__
        if classname.find('Conv2d') != -1:
            if args.init_type == 'normal':
                nn.init.normal_(m.weight.data, 0.0, 0.02)
            elif args.init_type == 'orth':
                nn.init.orthogonal_(m.weight.data)
            elif args.init_type == 'xavier_uniform':
                nn.init.xavier_uniform(m.weight.data, 1.)
            else:
                raise NotImplementedError('{} unknown inital type'.format(
                    args.init_type))
        elif classname.find('BatchNorm2d') != -1:
            nn.init.normal_(m.weight.data, 1.0, 0.02)
            nn.init.constant_(m.bias.data, 0.0)

    gen_net.apply(weights_init)
    dis_net.apply(weights_init)
    enc_net.apply(weights_init)

    ae_recon_optimizer = torch.optim.Adam(
        itertools.chain(enc_net.parameters(), gen_net.parameters()),
        args.ae_recon_lr, (args.beta1, args.beta2))
    ae_reg_optimizer = torch.optim.Adam(
        itertools.chain(enc_net.parameters(), gen_net.parameters()),
        args.ae_reg_lr, (args.beta1, args.beta2))
    dis_optimizer = torch.optim.Adam(dis_net.parameters(), args.d_lr,
                                     (args.beta1, args.beta2))
    gen_optimizer = torch.optim.Adam(gen_net.parameters(), args.g_lr,
                                     (args.beta1, args.beta2))

    dataset = datasets.ImageDataset(args)
    train_loader = dataset.train
    valid_loader = dataset.valid
    para_loader = pl.ParallelLoader(train_loader, [device])

    fid_stat = str(pathlib.Path(
        __file__).parent.absolute()) + '/fid_stat/fid_stat_cifar10_test.npz'
    if not os.path.exists(fid_stat):
        download_stat_cifar10_test()

    is_best = True
    args.num_epochs = np.ceil(args.num_iter / len(train_loader))

    gen_scheduler = LinearLrDecay(gen_optimizer, args.g_lr, 0,
                                  args.num_iter / 2, args.num_iter)
    dis_scheduler = LinearLrDecay(dis_optimizer, args.d_lr, 0,
                                  args.num_iter / 2, args.num_iter)
    ae_recon_scheduler = LinearLrDecay(ae_recon_optimizer, args.ae_recon_lr, 0,
                                       args.num_iter / 2, args.num_iter)
    ae_reg_scheduler = LinearLrDecay(ae_reg_optimizer, args.ae_reg_lr, 0,
                                     args.num_iter / 2, args.num_iter)

    # initial
    start_epoch = 0
    best_fid = 1e4

    # set writer
    if args.load_path:
        print(f'=> resuming from {args.load_path}')
        assert os.path.exists(args.load_path)
        checkpoint_file = os.path.join(args.load_path, 'Model',
                                       'checkpoint.pth')
        assert os.path.exists(checkpoint_file)
        checkpoint = torch.load(checkpoint_file)
        start_epoch = checkpoint['epoch']
        best_fid = checkpoint['best_fid']
        gen_net.load_state_dict(checkpoint['gen_state_dict'])
        enc_net.load_state_dict(checkpoint['enc_state_dict'])
        dis_net.load_state_dict(checkpoint['dis_state_dict'])
        gen_optimizer.load_state_dict(checkpoint['gen_optimizer'])
        dis_optimizer.load_state_dict(checkpoint['dis_optimizer'])
        ae_recon_optimizer.load_state_dict(checkpoint['ae_recon_optimizer'])
        ae_reg_optimizer.load_state_dict(checkpoint['ae_reg_optimizer'])
        args.path_helper = checkpoint['path_helper']
        logger = create_logger(args.path_helper['log_path'])
        logger.info(
            f'=> loaded checkpoint {checkpoint_file} (epoch {start_epoch})')
    else:
        # create new log dir
        assert args.exp_name
        logs_dir = str(pathlib.Path(__file__).parent.parent) + '/logs'
        args.path_helper = set_log_dir(logs_dir, args.exp_name)
        logger = create_logger(args.path_helper['log_path'])

    logger.info(args)
    writer_dict = {
        'writer': SummaryWriter(args.path_helper['log_path']),
        'train_global_steps': start_epoch * len(train_loader),
        'valid_global_steps': start_epoch // args.val_freq,
    }

    # train loop
    for epoch in tqdm(range(int(start_epoch), int(args.num_epochs)),
                      desc='total progress'):
        lr_schedulers = (gen_scheduler, dis_scheduler, ae_recon_scheduler,
                         ae_reg_scheduler)
        train(device, args, gen_net, dis_net, enc_net, gen_optimizer,
              dis_optimizer, ae_recon_optimizer, ae_reg_optimizer, para_loader,
              epoch, writer_dict, lr_schedulers)
        if epoch and epoch % args.val_freq == 0 or epoch == args.num_epochs - 1:
            fid_score = validate(args, fid_stat, gen_net, writer_dict,
                                 valid_loader)
            logger.info(f'FID score: {fid_score} || @ epoch {epoch}.')
            if fid_score < best_fid:
                best_fid = fid_score
                is_best = True
            else:
                is_best = False
        else:
            is_best = False

        save_checkpoint(
            {
                'epoch': epoch + 1,
                'gen_state_dict': gen_net.state_dict(),
                'dis_state_dict': dis_net.state_dict(),
                'enc_state_dict': enc_net.state_dict(),
                'gen_optimizer': gen_optimizer.state_dict(),
                'dis_optimizer': dis_optimizer.state_dict(),
                'ae_recon_optimizer': ae_recon_optimizer.state_dict(),
                'ae_reg_optimizer': ae_reg_optimizer.state_dict(),
                'best_fid': best_fid,
                'path_helper': args.path_helper
            }, is_best, args.path_helper['ckpt_path'])
Пример #12
0
#init
Net_G = Generator(opt.z_dim, (3, opt.img_height, opt.img_width))
Net_E = Encoder(opt.z_dim)
D_VAE = Discriminator()
D_LR = Discriminator()
l1_loss = torch.nn.L1Loss()
#cuda
Net_G.cuda()
Net_E.cuda()
D_VAE.cuda()
D_LR.cuda()
l1_loss.cuda()

#weight_init
Net_G.apply(weights_init)
D_VAE.apply(weights_init)
D_LR.apply(weights_init)
#optimizer
optimizer_E = torch.optim.Adam(Net_E.parameters(),
                               lr=opt.lr,
                               betas=(opt.b1, opt.b2))
optimizer_G = torch.optim.Adam(Net_G.parameters(),
                               lr=opt.lr,
                               betas=(opt.b1, opt.b2))
optimizer_D_VAE = torch.optim.Adam(D_VAE.parameters(),
                                   lr=opt.lr,
                                   betas=(opt.b1, opt.b2))
optimizer_D_LR = torch.optim.Adam(D_LR.parameters(),
                                  lr=opt.lr,
                                  betas=(opt.b1, opt.b2))
except OSError:
    pass

useCuda = torch.cuda.is_available() and not opt.disableCuda
device = torch.device("cuda:0" if useCuda else "cpu")
Tensor = torch.cuda.FloatTensor if useCuda else torch.Tensor

# Networks.
G = Generator(nZ=opt.nZ, nDis=opt.nDis, dDis=opt.dDis, nCon=opt.nCon, nC=opt.nC, nfG=opt.nfG)\
    .to(device)
D = Discriminator(nC=opt.nC, nfD=opt.nfD, nZ=opt.nZ, nDis=opt.nDis, dDis=opt.dDis, nCon=opt.nCon)\
    .to(device)

# Initialize the weights.
G.apply(weightsInit)
D.apply(weightsInit)

# Load weights if provided.
if opt.netG != '':
    G.load_state_dict(torch.load(opt.netG))
if opt.netD != '':
    D.load_state_dict(torch.load(opt.netD))

# Define the losses.
binaryCrossEntropy = nn.BCELoss()  # Discriminate between real and fake images.
crossEntropy = nn.CrossEntropyLoss()  # Loss for discrete latent variables.
normalNLL = NormalNLLLoss()  # Loss for continuous latent varibales.

# Optimizers.
optimizerG = torch.optim.Adam(G.parameters(), lr=opt.lrG, betas=(0.5, 0.999))
optimizerD = torch.optim.Adam(list(D.commonDQ.parameters()) +
Пример #14
0
def train(args):
    data_root = args.path
    total_iterations = args.iter
    checkpoint = args.ckpt
    batch_size = args.batch_size
    im_size = args.im_size
    ndf = 64
    ngf = 64
    nz = 256
    nlr = 0.0002
    nbeta1 = 0.5
    multi_gpu = False
    dataloader_workers = 8
    current_iteration = 0
    save_interval = 100

    device = torch.device("cpu")
    if torch.cuda.is_available():
        device = torch.device("cuda:0")

    transform_list = [
        transforms.Resize((int(im_size), int(im_size))),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
    ]
    trans = transforms.Compose(transform_list)

    if 'lmdb' in data_root:
        from operation import MultiResolutionDataset
        dataset = MultiResolutionDataset(data_root, trans, 1024)
    else:
        dataset = ImageFolder(root=data_root, transform=trans)

    dataloader = iter(
        DataLoader(dataset,
                   batch_size=batch_size,
                   shuffle=False,
                   sampler=InfiniteSamplerWrapper(dataset),
                   num_workers=0,
                   pin_memory=True))
    '''
    loader = MultiEpochsDataLoader(dataset, batch_size=batch_size, 
                               shuffle=True, num_workers=dataloader_workers, 
                               pin_memory=True)
    dataloader = CudaDataLoader(loader, 'cuda')
    '''

    #from model_s import Generator, Discriminator
    netG = Generator(ngf=ngf, nz=nz, im_size=im_size)
    netG.apply(weights_init)

    netD = Discriminator(ndf=ndf, im_size=im_size)
    netD.apply(weights_init)

    netG.to(device)
    netD.to(device)

    avg_param_G = copy_G_params(netG)

    fixed_noise = torch.FloatTensor(8, nz).normal_(0, 1).to(device)

    if torch.cuda.is_available():
        netG = nn.DataParallel(netG)
        netD = nn.DataParallel(netD)

    optimizerG = optim.Adam(netG.parameters(), lr=nlr, betas=(nbeta1, 0.999))
    optimizerD = optim.Adam(netD.parameters(), lr=nlr, betas=(nbeta1, 0.999))

    if checkpoint != 'None':
        ckpt = torch.load(checkpoint)
        netG.load_state_dict(ckpt['g'])
        netD.load_state_dict(ckpt['d'])
        avg_param_G = ckpt['g_ema']
        optimizerG.load_state_dict(ckpt['opt_g'])
        optimizerD.load_state_dict(ckpt['opt_d'])
        current_iteration = int(checkpoint.split('_')[-1].split('.')[0])
        del ckpt

    for iteration in range(current_iteration, total_iterations + 1):
        real_image = next(dataloader)
        if torch.cuda.is_available():
            real_image = real_image.cuda(non_blocking=True)
        current_batch_size = real_image.size(0)
        noise = torch.Tensor(current_batch_size, nz).normal_(0, 1).to(device)

        fake_images = netG(noise)

        real_image = DiffAugment(real_image, policy=policy)
        fake_images = [
            DiffAugment(fake, policy=policy) for fake in fake_images
        ]

        ## 2. train Discriminator
        netD.zero_grad()

        err_dr, rec_img_all, rec_img_small, rec_img_part = train_d(
            netD, real_image, label="real")
        train_d(netD, [fi.detach() for fi in fake_images], label="fake")
        optimizerD.step()

        ## 3. train Generator
        netG.zero_grad()
        pred_g = netD(fake_images, "fake")
        err_g = -pred_g.mean()

        err_g.backward()
        optimizerG.step()

        for p, avg_p in zip(netG.parameters(), avg_param_G):
            avg_p.mul_(0.999).add_(0.001 * p.data)

        if iteration % save_interval == 0:
            print("GAN: loss d: %.5f    loss g: %.5f" %
                  (err_dr.item(), -err_g.item()))

        if iteration % (save_interval) == 0:
            backup_para = copy_G_params(netG)
            load_params(netG, avg_param_G)
            saved_model_folder, saved_image_folder = get_dir(args)
            with torch.no_grad():
                vutils.save_image(netG(fixed_noise)[0].add(1).mul(0.5),
                                  saved_image_folder + '/%d.jpg' % iteration,
                                  nrow=4)
                vutils.save_image(
                    torch.cat([
                        F.interpolate(real_image, 128), rec_img_all,
                        rec_img_small, rec_img_part
                    ]).add(1).mul(0.5),
                    saved_image_folder + '/rec_%d.jpg' % iteration)
            load_params(netG, backup_para)

        if iteration % (save_interval *
                        50) == 0 or iteration == total_iterations:
            backup_para = copy_G_params(netG)
            load_params(netG, avg_param_G)
            torch.save({
                'g': netG.state_dict(),
                'd': netD.state_dict()
            }, saved_model_folder + '/%d.pth' % iteration)
            load_params(netG, backup_para)
            torch.save(
                {
                    'g': netG.state_dict(),
                    'd': netD.state_dict(),
                    'g_ema': avg_param_G,
                    'opt_g': optimizerG.state_dict(),
                    'opt_d': optimizerD.state_dict()
                }, saved_model_folder + '/all_%d.pth' % iteration)
Пример #15
0
    )

netG_A2B = Generator(opt.input_nc, opt.output_nc)
netG_B2A = Generator(opt.input_nc, opt.output_nc)
netD_A = Discriminator(opt.input_nc)
netD_B = Discriminator(opt.input_nc)

if opt.cuda:
    netG_A2B.cuda()
    netG_B2A.cuda()
    netD_A.cuda()
    netD_B.cuda()

netG_A2B.apply(init_weight)
netG_B2A.apply(init_weight)
netD_A.apply(init_weight)
netD_B.apply(init_weight)

criterion_GAN = torch.nn.MSELoss()
criterion_cycle = torch.nn.L1Loss()
criterion_identity = torch.nn.L1Loss()

optimizer_G = torch.optim.Adam(itertools.chain(netG_A2B.parameters(),
                                               netG_B2A.parameters()),
                               lr=opt.lr,
                               betas=(0.5, 0.999))
optimizer_D_A = torch.optim.Adam(netD_A.parameters(),
                                 lr=opt.lr,
                                 betas=(0.5, 0.999))
optimizer_D_B = torch.optim.Adam(netD_B.parameters(),
                                 lr=opt.lr,
Пример #16
0
                                    shuffle=True,
                                    num_workers=1)

# prepare network
D = Discriminator(ndf=args.ndf).cuda()
G = Generator(100, ngf=args.ngf).cuda()


## initialization the network parameters
def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv' or 'Linear') != -1:
        init.normal(m.weight, mean=args.mu, std=args.sigma)


D.apply(weights_init)
G.apply(weights_init)

# criterion
criterion = nn.BCELoss()

# prepare optimizer
d_optimizer = optim.Adam(D.parameters(), lr=args.lr)
g_optimizer = optim.Adam(G.parameters(), lr=args.lr)

# train
training_history = np.zeros((4, args.epochs))
for i in tqdm(range(args.epochs)):
    running_d_loss = 0
    running_g_loss = 0
    running_d_true = 0
Пример #17
0

def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        nn.init.normal_(m.weight.data, 1.0, 0.02)
        nn.init.constant_(m.bias.data, 0)


discriminator = Discriminator(num_channels, img_size).to(device)
generator = Generator(z_dim, num_channels, img_size).to(device)

generator.apply(weights_init)
discriminator.apply(weights_init)

fixed_noise = torch.randn((batch_size, z_dim, 1, 1)).to(device)

transform = transforms.Compose([
    transforms.Resize(img_size),
    transforms.CenterCrop(img_size),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

opt_disc = optim.Adam(discriminator.parameters(), lr=lr, betas=(0.5, 0.999))
opt_gen = optim.Adam(generator.parameters(), lr=lr, betas=(0.5, 0.999))
criterion = nn.BCELoss()

dataset = datasets.ImageFolder(root=args.data_path, transform=transform)
Пример #18
0
def main(args):
    # Set random seed for reproducibility
    manualSeed = 999
    #manualSeed = random.randint(1, 10000) # use if you want new results
    print("Random Seed: ", manualSeed)
    random.seed(manualSeed)
    torch.manual_seed(manualSeed)

    dataroot = args.dataroot
    workers = args.workers
    batch_size = args.batch_size
    nc = args.nc
    ngf = args.ngf
    ndf = args.ndf
    nhd = args.nhd
    num_epochs = args.num_epochs
    lr = args.lr
    beta1 = args.beta1
    ngpu = args.ngpu
    resume = args.resume
    record_pnt = args.record_pnt
    log_pnt = args.log_pnt
    mse = args.mse
    '''
    # We can use an image folder dataset the way we have it setup.
    # Create the dataset
    dataset = dset.ImageFolder(root=dataroot,
                               transform=transforms.Compose([
                                   transforms.Resize(image_size),
                                   transforms.CenterCrop(image_size),
                                   transforms.ToTensor(),
                                   transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
                               ]))
    # Create the dataloader
    dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size,
                                             shuffle=True, num_workers=workers)
    '''
    dataset = dset.MNIST(
        root=dataroot,
        transform=transforms.Compose([
            transforms.ToTensor(),
            # transforms.Normalize(0.5, 0.5),
        ]))

    dataloader = torch.utils.data.DataLoader(dataset,
                                             batch_size=batch_size,
                                             shuffle=True,
                                             num_workers=workers)

    # Decide which device we want to run on
    device = torch.device("cuda:0" if (
        torch.cuda.is_available() and ngpu > 0) else "cpu")

    # Create the generator
    netG = AutoEncoder(nc, ngf, nhd=nhd).to(device)

    # Handle multi-gpu if desired
    if (device.type == 'cuda') and (ngpu > 1):
        netG = nn.DataParallel(netG, list(range(ngpu)))

    # Apply the weights_init function to randomly initialize all weights
    #  to mean=0, stdev=0.2.
    netG.apply(weights_init)

    # Create the Discriminator
    netD = Discriminator(nc, ndf, ngpu).to(device)

    # Handle multi-gpu if desired
    if (device.type == 'cuda') and (ngpu > 1):
        netD = nn.DataParallel(netD, list(range(ngpu)))

    # Apply the weights_init function to randomly initialize all weights
    #  to mean=0, stdev=0.2.
    netD.apply(weights_init)

    #resume training if args.resume is True
    if resume:
        ckpt = torch.load('ckpts/recent.pth')
        netG.load_state_dict(ckpt["netG"])
        netD.load_state_dict(ckpt["netD"])

    # Initialize BCELoss function
    criterion = nn.BCELoss()
    MSE = nn.MSELoss()
    mse_coeff = 1.
    center_coeff = 0.001

    # Establish convention for real and fake flags during training
    real_flag = 1
    fake_flag = 0

    # Setup Adam optimizers for both G and D
    optimizerD = optim.Adam(netD.parameters(), lr=lr, betas=(beta1, 0.999))
    optimizerG = optim.Adam(netG.dec.parameters(), lr=lr, betas=(beta1, 0.999))
    optimizerAE = optim.Adam(netG.parameters(), lr=lr, betas=(beta1, 0.999))

    # Training Loop

    # Lists to keep track of progress
    iters = 0

    R_errG = 0
    R_errD = 0
    R_errAE = 0
    R_std = 0
    R_mean = 0

    print("Starting Training Loop...")
    # For each epoch
    for epoch in range(num_epochs):
        # For each batch in the dataloader
        for i, data in enumerate(dataloader, 0):

            ############################
            # (1) Update D network: maximize log(D(x)) + log(1 - D(G(z)))
            ###########################
            ## Train with all-real batch
            netD.zero_grad()
            # Format batch
            real_img, label = data
            real_img, label = real_img.to(device), to_one_hot_vector(
                10, label).to(device)

            b_size = real_img.size(0)
            flag = torch.full((b_size, ), real_flag, device=device)
            # Forward pass real batch through D
            output = netD(real_img, label).view(-1)
            # Calculate loss on all-real batch
            errD_real = criterion(output, flag)
            # Calculate gradients for D in backward pass
            errD_real.backward()
            D_x = output.mean().item()

            ## Train with all-fake batch
            # Generate fake image batch with G
            noise = torch.randn(b_size, nhd, 1, 1).to(device)
            fake = netG.dec(noise, label)
            flag.fill_(fake_flag)
            # Classify all fake batch with D
            output = netD(fake.detach(), label).view(-1)
            # Calculate D's loss on the all-fake batch
            errD_fake = criterion(output, flag)
            # Calculate the gradients for this batch
            errD_fake.backward()
            D_G_z1 = output.mean().item()
            # Add the gradients from the all-real and all-fake batches
            errD = errD_real + errD_fake
            # Update D
            optimizerD.step()

            ############################
            # (2) Update G network: maximize log(D(G(z)))
            ###########################
            netG.dec.zero_grad()
            flag.fill_(real_flag)  # fake flags are real for generator cost
            # Since we just updated D, perform another forward pass of all-fake batch through D
            output = netD(fake, label).view(-1)
            # Calculate G's loss based on this output
            errG = criterion(output, flag)
            # Calculate gradients for G
            errG.backward()
            # Update G
            optimizerG.step()

            ############################
            # (3) Update AE network: minimize reconstruction loss
            ###########################
            netG.zero_grad()
            new_img = netG(real_img, label, label)
            hidden = netG.enc(real_img, label)
            central_loss = MSE(hidden, torch.zeros(hidden.shape).to(device))
            errAE = mse_coeff* MSE(real_img, new_img) \
                    + center_coeff* central_loss
            errAE.backward()
            optimizerAE.step()

            R_errG += errG.item()
            R_errD += errD.item()
            R_errAE += errAE.item()
            R_std += (hidden**2).mean().item()
            R_mean += hidden.mean().item()
            # Output training stats
            if i % log_pnt == 0:
                print(
                    '[%d/%d][%d/%d]\tLoss_D: %.4f\tLoss_G: %.4f\tLoss_AE: %.4f\t'
                    % (epoch, num_epochs, i, len(dataloader), R_errD / log_pnt,
                       R_errG / log_pnt, R_errAE / log_pnt))
                print('mean: %.4f\tstd: %.4f\tcentral/msecoeff: %4f' %
                      (R_mean / log_pnt, R_std / log_pnt,
                       center_coeff / mse_coeff))
                R_errG = 0.
                R_errD = 0.
                R_errAE = 0.
                R_std = 0.
                R_mean = 0.

            # Check how the generator is doing by saving G's output on fixed_noise
            if (iters % record_pnt == 0) or ((epoch == num_epochs - 1) and
                                             (i == len(dataloader) - 1)):
                vutils.save_image(
                    fake.to("cpu"),
                    './samples/image_{}.png'.format(iters // record_pnt))
                torch.save(
                    {
                        "netG": netG.state_dict(),
                        "netD": netD.state_dict(),
                        "nc": nc,
                        "ngf": ngf,
                        "ndf": ndf
                    }, 'ckpts/recent.pth')

            iters += 1
Пример #19
0
gen = Generator(z_dim).to(device)
opt_gen = optim.Adam(gen.parameters(), lr=lr, betas=(beta_1, beta_2))
disc = Discriminator().to(device)
opt_disc = optim.Adam(disc.parameters(), lr=lr, betas=(beta_1, beta_2))

# Here, we want to initialize the weights to the normal distribution
# with mean 0 and standard deviation 0.02
def initialize_weights(m):
    if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
        nn.init.normal_(m.weight, 0.0, 0.02)
    if isinstance(m, nn.BatchNorm2d):
        nn.init.normal_(m.weight, 0.0, 0.02)
        nn.init.normal_(m.bias, 0)

gen = gen.apply(initialize_weights)
disc = disc.apply(initialize_weights)

print(gen)
print(disc)


######################### Train DCGAN ###############################

""" 
    Finally, we can train the GAN model! For each epoch, we will process the entire dataset in batches. 
    For every batch, we will update the discriminator and generator. Then, we can see DCGAN's results!
"""

step = 0
mean_generator_loss = 0
mean_discriminator_loss = 0
Пример #20
0
def main(config):
    model = load_model(config)
    train_loader, val_loader = get_loaders(model, config)

    # Make dirs
    if not os.path.exists(config.checkpoints):
        os.makedirs(config.checkpoints, exist_ok=True)
    if not os.path.exists(config.save_path):
        os.makedirs(config.save_path, exist_ok=True)

    # Loss Functions
    criterion_GAN = mse_loss

    # Calculate output of image discriminator (PatchGAN)
    patch = (1, config.image_size // 2**4, config.image_size // 2**4)

    # Initialize
    vgg = Vgg16().to(config.device)
    resnet = ResNet18(requires_grad=True, pretrained=True).to(config.device)
    generator = GeneratorUNet().to(config.device)
    discriminator = Discriminator().to(config.device)

    if config.epoch != 0:
        # Load pretrained models
        resnet.load_state_dict(
            torch.load(
                os.path.join(config.checkpoints, 'epoch_%d_%s.pth' %
                             (config.epoch - 1, 'resnet'))))
        generator.load_state_dict(
            torch.load(
                os.path.join(
                    config.checkpoints,
                    'epoch_%d_%s.pth' % (config.epoch - 1, 'generator'))))
        discriminator.load_state_dict(
            torch.load(
                os.path.join(
                    config.checkpoints,
                    'epoch_%d_%s.pth' % (config.epoch - 1, 'discriminator'))))
    else:
        # Initialize weights
        # resnet.apply(weights_init_normal)
        generator.apply(weights_init_normal)
        discriminator.apply(weights_init_normal)

    # Optimizers
    optimizer_resnet = torch.optim.Adam(resnet.parameters(),
                                        lr=config.lr,
                                        betas=(config.b1, config.b2))
    optimizer_G = torch.optim.Adam(generator.parameters(),
                                   lr=config.lr,
                                   betas=(config.b1, config.b2))
    optimizer_D = torch.optim.Adam(discriminator.parameters(),
                                   lr=config.lr,
                                   betas=(config.b1, config.b2))

    # ----------
    #  Training
    # ----------

    resnet.train()
    generator.train()
    discriminator.train()
    for epoch in range(config.epoch, config.n_epochs):
        for i, (im1, m1, im2, m2) in enumerate(train_loader):
            assert im1.size(0) == im2.size(0)
            valid = Variable(torch.Tensor(np.ones(
                (im1.size(0), *patch))).to(config.device),
                             requires_grad=False)
            fake = Variable(torch.Tensor(np.ones(
                (im1.size(0), *patch))).to(config.device),
                            requires_grad=False)

            # ------------------
            #  Train Generators
            # ------------------

            optimizer_resnet.zero_grad()
            optimizer_G.zero_grad()

            # GAN loss
            z = resnet(im2 * m2)
            if epoch < config.gan_epochs:
                fake_im = generator(im1 * (1 - m1), im2 * m2, z)
            else:
                fake_im = generator(im1, im2, z)
            if epoch < config.gan_epochs:
                pred_fake = discriminator(fake_im, im2)
                gan_loss = config.lambda_gan * criterion_GAN(pred_fake, valid)
            else:
                gan_loss = torch.Tensor([0]).to(config.device)

            # Hair, Face loss
            fake_m2 = torch.argmax(model(fake_im),
                                   1).unsqueeze(1).type(torch.uint8).repeat(
                                       1, 3, 1, 1).to(config.device)
            if 0.5 * torch.sum(m1) <= torch.sum(
                    fake_m2) <= 1.5 * torch.sum(m1):
                hair_loss = config.lambda_style * calc_style_loss(
                    fake_im * fake_m2, im2 * m2, vgg) + calc_content_loss(
                        fake_im * fake_m2, im2 * m2, vgg)
                face_loss = calc_content_loss(fake_im, im1, vgg)
            else:
                hair_loss = config.lambda_style * calc_style_loss(
                    fake_im * m1, im2 * m2, vgg) + calc_content_loss(
                        fake_im * m1, im2 * m2, vgg)
                face_loss = calc_content_loss(fake_im, im1, vgg)
            hair_loss *= config.lambda_hair
            face_loss *= config.lambda_face

            # Total loss
            loss = gan_loss + hair_loss + face_loss

            loss.backward()
            optimizer_resnet.step()
            optimizer_G.step()

            # ---------------------
            #  Train Discriminator
            # ---------------------

            if epoch < config.gan_epochs:
                optimizer_D.zero_grad()

                # Real loss
                pred_real = discriminator(im1 * (1 - m1) + im2 * m2, im2)
                loss_real = criterion_GAN(pred_real, valid)
                # Fake loss
                pred_fake = discriminator(fake_im.detach(), im2)
                loss_fake = criterion_GAN(pred_fake, fake)
                # Total loss
                loss_D = 0.5 * (loss_real + loss_fake)

                loss_D.backward()
                optimizer_D.step()

            if i % config.sample_interval == 0:
                msg = "Train || Gan loss: %.6f, hair loss: %.6f, face loss: %.6f, loss: %.6f\n" % \
                    (gan_loss.item(), hair_loss.item(), face_loss.item(), loss.item())
                sys.stdout.write("Epoch: %d || Batch: %d\n" % (epoch, i))
                sys.stdout.write(msg)
                fname = os.path.join(
                    config.save_path,
                    "Train_Epoch:%d_Batch:%d.png" % (epoch, i))
                sample_images([im1[0], im2[0], fake_im[0]],
                              ["img1", "img2", "img1+img2"], fname)
                for j, (im1, m1, im2, m2) in enumerate(val_loader):
                    with torch.no_grad():
                        valid = Variable(torch.Tensor(
                            np.ones((im1.size(0), *patch))).to(config.device),
                                         requires_grad=False)
                        fake = Variable(torch.Tensor(
                            np.ones((im1.size(0), *patch))).to(config.device),
                                        requires_grad=False)

                        # GAN loss
                        z = resnet(im2 * m2)
                        if epoch < config.gan_epochs:
                            fake_im = generator(im1 * (1 - m1), im2 * m2, z)
                        else:
                            fake_im = generator(im1, im2, z)

                        if epoch < config.gan_epochs:
                            pred_fake = discriminator(fake_im, im2)
                            gan_loss = config.lambda_gan * criterion_GAN(
                                pred_fake, valid)
                        else:
                            gan_loss = torch.Tensor([0]).to(config.device)

                        # Hair, Face loss
                        fake_m2 = torch.argmax(
                            model(fake_im),
                            1).unsqueeze(1).type(torch.uint8).repeat(
                                1, 3, 1, 1).to(config.device)
                        if 0.5 * torch.sum(m1) <= torch.sum(
                                fake_m2) <= 1.5 * torch.sum(m1):
                            hair_loss = config.lambda_style * calc_style_loss(
                                fake_im * fake_m2, im2 * m2,
                                vgg) + calc_content_loss(
                                    fake_im * fake_m2, im2 * m2, vgg)
                            face_loss = calc_content_loss(fake_im, im1, vgg)
                        else:
                            hair_loss = config.lambda_style * calc_style_loss(
                                fake_im * m1, im2 * m2,
                                vgg) + calc_content_loss(
                                    fake_im * m1, im2 * m2, vgg)
                            face_loss = calc_content_loss(fake_im, im1, vgg)
                        hair_loss *= config.lambda_hair
                        face_loss *= config.lambda_face

                        # Total loss
                        loss = gan_loss + hair_loss + face_loss

                        msg = "Validation || Gan loss: %.6f, hair loss: %.6f, face loss: %.6f, loss: %.6f\n" % \
                                (gan_loss.item(), hair_loss.item(), face_loss.item(), loss.item())
                        sys.stdout.write(msg)
                        fname = os.path.join(
                            config.save_path,
                            "Validation_Epoch:%d_Batch:%d.png" % (epoch, i))
                        sample_images([im1[0], im2[0], fake_im[0]],
                                      ["img1", "img2", "img1+img2"], fname)
                        break

        if epoch % config.checkpoint_interval == 0:
            if epoch < config.gan_epochs:
                models = [resnet, generator, discriminator]
                fnames = ['resnet', 'generator', 'discriminator']
            else:
                models = [resnet, generator]
                fnames = ['resnet', 'generator']
            fnames = [
                os.path.join(config.checkpoints,
                             'epoch_%d_%s.pth' % (epoch, s)) for s in fnames
            ]
            save_weights(models, fnames)
Пример #21
0
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    imSize = 256 if torch.cuda.is_available() else 128

    # Define the two generators and discriminators

    # GXY : Takes images from X and translates them to Y domain
    GAB = Generator((3, imSize, imSize)).double().to(device)
    GBA = Generator((3, imSize, imSize)).double().to(device)

    DA = Discriminator().double().to(device) # Discriminate A images
    DB = Discriminator().double().to(device) # Discriminate B images

    # Initialize the weights of the networks as described in the paper
    GAB.apply(init_weights)
    GBA.apply(init_weights)
    DA.apply(init_weights)
    DB.apply(init_weights)

    # Select the different losses 
    LGan = torch.nn.MSELoss().to(device)
    LCyc = torch.nn.L1Loss().to(device)
    LId = torch.nn.L1Loss().to(device)

    # Create the optimizers
    # We have to chain because the losses make use of both networks
    optimGen = torch.optim.Adam(itertools.chain(GAB.parameters(), GBA.parameters()), lr = 0.0002, betas=(0.5,0.999))
    optimDis = torch.optim.Adam(itertools.chain(DA.parameters(), DB.parameters()), lr=0.0001, betas=(0.5,0.999))

    # Create custom lr schedulers since they have a particular behaviour
    schedGen = torch.optim.lr_scheduler.LambdaLR(optimGen, lr_lambda = CustomLR(args.epochs, args.offset).step)
    schedDis = torch.optim.lr_scheduler.LambdaLR(optimDis, lr_lambda = CustomLR(args.epochs, args.offset).step)
if torch.cuda.device_count() > 1 and cuda:
    print('Use %d gpus.' % torch.cuda.device_count())
    G = nn.DataParallel(G)
    F = nn.DataParallel(F)
    Dx = nn.DataParallel(Dx)
    Dy = nn.DataParallel(Dy)

G.to(device)
F.to(device)
Dx.to(device)
Dy.to(device)

if transfer is False:
    G.apply(utils.init_weights_normal)
    F.apply(utils.init_weights_normal)
    Dx.apply(utils.init_weights_normal)
    Dy.apply(utils.init_weights_normal)

# loss functions
criterion_GAN = torch.nn.MSELoss()
criterion_cycle = torch.nn.L1Loss()
criterion_identity = torch.nn.L1Loss()

# optimizier
optimizer_G = torch.optim.Adam(itertools.chain(G.parameters(), F.parameters()),
                               lr=lr,
                               betas=(0.5, 0.999))
optimizer_Dx = torch.optim.Adam(Dx.parameters(), lr=lr, betas=(0.5, 0.999))
optimizer_Dy = torch.optim.Adam(Dy.parameters(), lr=lr, betas=(0.5, 0.999))

# tensor wrapper
Пример #23
0
torch.manual_seed(opt.random_seed)

dataset = helper.DisVectorData('./GAN-data-10.xlsx')
dataloader = torch.utils.data.DataLoader(dataset,
                                         batch_size=opt.batch_size,
                                         shuffle=True)

# Building generator
generator = Generator(opt.vector_size)
gen_optimizer = torch.optim.Adam(generator.parameters(),
                                 lr=opt.lr_rate,
                                 betas=(opt.beta, opt.beta1))

# Building discriminator
discriminator = Discriminator(opt.vector_size)
discriminator.apply(init_weights)
dis_optimizer = torch.optim.Adam(discriminator.parameters(),
                                 lr=opt.lr_rate,
                                 betas=(opt.beta, opt.beta1))

# Loss functions
bce_loss = torch.nn.BCELoss()

LT = torch.LongTensor
FT = torch.FloatTensor

if is_cuda:
    generator.cuda()
    discriminator.cuda()
    bce_loss.cuda()
    LT = torch.cuda.LongTensor
Пример #24
0
def train(opt):
    opt = dotDict(opt)

    if not os.path.exists(opt.checkpoints_dir):
        os.makedirs(opt.checkpoints_dir)

    if not os.path.exists(os.path.join(opt.out_dir, opt.run_name)):
        os.makedirs(os.path.join(opt.out_dir, opt.run_name))

    if torch.cuda.is_available() and not opt.cuda:
        print(
            "WARNING: You have a CUDA device, so you should probably run with --cuda"
        )

    ###### Definition of variables ######
    # Networks
    G0 = GeometrySynthesizer()
    G1 = Generator(opt.input_nc, opt.output_nc)
    G2 = Generator(opt.input_nc, opt.output_nc)
    D1 = Discriminator(opt.input_nc)
    D2 = Discriminator(opt.output_nc)

    if opt.cuda:
        G0.cuda()
        G1.cuda()
        G2.cuda()
        D1.cuda()
        D2.cuda()

    G1.apply(weights_init_normal)
    G2.apply(weights_init_normal)
    D2.apply(weights_init_normal)
    D1.apply(weights_init_normal)

    # Optimizers & LR schedulers
    optimizer_G0 = torch.optim.Adam(G0.parameters(),
                                    lr=opt.lr_GS,
                                    betas=(0.5, 0.999))
    optimizer_G = torch.optim.Adam(itertools.chain(G1.parameters(),
                                                   G2.parameters()),
                                   lr=opt.lr_AS,
                                   betas=(0.5, 0.999))
    optimizer_D1 = torch.optim.Adam(D1.parameters(),
                                    lr=opt.lr_AS,
                                    betas=(0.5, 0.999))
    optimizer_D2 = torch.optim.Adam(D2.parameters(),
                                    lr=opt.lr_AS,
                                    betas=(0.5, 0.999))

    if opt.G0_checkpoint is not None:
        G0 = load_G0_ckp(opt.G0_checkpoint, G0)

    if opt.AS_checkpoint is not None:
        _, G1, D1, G2, D2, optimizer_G, optimizer_D1, optimizer_D2 = load_AS_ckp(
            opt.AS_checkpoint, G1, D1, G2, D2, optimizer_G, optimizer_D1,
            optimizer_D2)

    if opt.resume_checkpoint is not None:
        opt.epoch, G0, G1, D1, G2, D2, optimizer_G0, optimizer_G, optimizer_D1, optimizer_D2 = load_ckp(
            opt.resume_checkpoint, G0, G1, D1, G2, D2, optimizer_G0,
            optimizer_G, optimizer_D1, optimizer_D2)

    lr_scheduler_G0 = torch.optim.lr_scheduler.LambdaLR(
        optimizer_G0,
        lr_lambda=LambdaLR(opt.n_epochs, opt.epoch, opt.decay_epoch).step)
    lr_scheduler_G = torch.optim.lr_scheduler.LambdaLR(
        optimizer_G,
        lr_lambda=LambdaLR(opt.n_epochs, opt.epoch, opt.decay_epoch).step)
    lr_scheduler_D1 = torch.optim.lr_scheduler.LambdaLR(
        optimizer_D1,
        lr_lambda=LambdaLR(opt.n_epochs, opt.epoch, opt.decay_epoch).step)
    lr_scheduler_D2 = torch.optim.lr_scheduler.LambdaLR(
        optimizer_D2,
        lr_lambda=LambdaLR(opt.n_epochs, opt.epoch, opt.decay_epoch).step)

    # Inputs & targets memory allocation
    Tensor = torch.cuda.FloatTensor if opt.cuda else torch.Tensor
    background_t = Tensor(opt.batchSize, opt.input_nc, opt.size, opt.size)
    foregound_t = Tensor(opt.batchSize, opt.input_nc, opt.size, opt.size)
    real_t = Tensor(opt.batchSize, opt.output_nc, opt.size, opt.size)

    target_real = Variable(Tensor(opt.batchSize).fill_(1.0),
                           requires_grad=False)
    target_fake = Variable(Tensor(opt.batchSize).fill_(0.0),
                           requires_grad=False)

    composed_buffer = ReplayBuffer()
    fake_real_buffer = ReplayBuffer()
    fake_composed_buffer = ReplayBuffer()

    # Dataset loader
    transforms_dataset = [
        transforms.Resize(int(opt.size * 1.12), Image.BICUBIC),
        transforms.RandomCrop(opt.size),
        transforms.ToTensor(),
        # transforms.Normalize((0.5,0.5,0.5), (0.5,0.5,0.5))
    ]

    transforms_masks = [transforms.ToTensor()]

    text = TextUtils(opt.root, transforms_=transforms_masks)

    dataset = MyDataset(opt.root, transforms_=transforms_dataset)
    print("No. of Examples = ", len(dataset))
    dataloader = DataLoader(dataset,
                            batch_size=opt.batchSize,
                            shuffle=True,
                            num_workers=opt.n_cpu)

    # Loss plot
    logger = Logger(opt.n_epochs, len(dataloader),
                    os.path.join(opt.out_dir, opt.run_name), opt.epoch + 1)
    ###################################

    ###### Training ######
    for epoch in range(opt.epoch, opt.n_epochs):
        for i, batch in enumerate(dataloader):
            # Set model input
            background = Variable(background_t.copy_(batch['X']),
                                  requires_grad=True)
            # foreground = Variable(foregound_t.copy_(batch['Y']), requires_grad=True)
            real = Variable(real_t.copy_(batch['Z']), requires_grad=True)
            foreground = Variable(foregound_t.copy_(
                text.get_text_masks(opt.batchSize)),
                                  requires_grad=True)

            ###### Geometric Synthesizer ######
            composed_GS = G0(
                background,
                foreground)  # concatenate background and foreground object

            ## optimize G0 loss
            optimizer_G0.zero_grad()
            loss_G0 = criterion_discriminator(D2(composed_GS), target_fake)
            loss_G0.backward()
            optimizer_G0.step()

            ###### Appearance Synthesizer ######
            composed = composed_buffer.push_and_pop(composed_GS)
            ###### Generators G1 and G2 ######
            optimizer_G.zero_grad()

            ## Identity loss
            # G1(X) should equal X if X = real
            same_real = G1(real)
            loss_identity_1 = criterion_identity(real, same_real) * 5.0
            # G2(X) should equal X if X = composed
            same_composed = G2(composed)
            loss_identity_2 = criterion_identity(composed, same_composed) * 5.0

            loss_identity = loss_identity_1 + loss_identity_2

            ## GAN loss
            fake_real = G1(composed)
            loss_G1 = criterion_generator(D1(fake_real), target_real)

            fake_composed = G2(real)
            loss_G2 = criterion_generator(D2(fake_composed), target_real)

            loss_GAN = loss_G1 + loss_G2

            ## Cycle loss
            recovered_real = G1(fake_composed)
            loss_cycle_real = criterion_cycle(recovered_real, real) * 10.0

            recovered_composed = G2(fake_real)
            loss_cycle_composed = criterion_cycle(recovered_composed,
                                                  composed) * 10.0

            loss_cycle = loss_cycle_composed + loss_cycle_real

            # Total loss
            loss_G = loss_identity + loss_GAN + loss_cycle

            loss_G.backward()
            optimizer_G.step()
            #####################################

            ###### Discriminator D1 ######
            # real loss
            loss_D1_real = criterion_discriminator(D1(real), target_real)

            # fake loss
            new_fake_real = fake_real_buffer.push_and_pop(fake_real)
            loss_D1_fake = criterion_discriminator(D1(new_fake_real.detach()),
                                                   target_fake)

            loss_D1 = (loss_D1_real + loss_D1_fake) * 0.5
            loss_D1.backward()
            optimizer_D1.step()

            ###### Discriminator D2 ######
            # real loss
            new_composed = composed_buffer.push_and_pop(composed)
            loss_D2_real = criterion_discriminator(D2(new_composed.detach()),
                                                   target_real)

            # fake loss
            new_fake_composed = fake_composed_buffer.push_and_pop(
                fake_composed)
            loss_D2_fake = criterion_discriminator(
                D2(new_fake_composed.detach()), target_fake)

            loss_D2 = (loss_D2_real + loss_D2_fake) * 0.5
            loss_D2.backward()
            optimizer_D2.step()

            #####################################

            # Progress report (http://localhost:8097)
            losses = {
                'loss_G0': loss_G0,
                'loss_G': loss_G,
                'loss_D1': loss_D1,
                'loss_D2': loss_D2
            }
            images = {
                'background': background,
                'foreground': foreground,
                'real': real,
                'composed_GS': composed_GS,
                'composed': composed,
                'synthesized': fake_real,
                'adapted_real': fake_composed
            }

            logger.log(losses, images)

        # Update learning rates
        lr_scheduler_G.step()
        lr_scheduler_D1.step()
        lr_scheduler_D2.step()

        # Save models checkpoints
        checkpoint = {
            'epoch': epoch + 1,
            'state_dict': {
                "G0": G0.state_dict(),
                "G1": G1.state_dict(),
                "D1": D1.state_dict(),
                "G2": G2.state_dict(),
                "D2": D2.state_dict()
            },
            'optimizer': {
                "G0": optimizer_G0.state_dict(),
                "G": optimizer_G.state_dict(),
                "D1": optimizer_D1.state_dict(),
                "D2": optimizer_D2.state_dict()
            }
        }
        save_ckp(checkpoint,
                 os.path.join(opt.checkpoints_dir, opt.run_name + '.pth'))
Пример #25
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--epoch', type=int, default=0, help='starting epoch')
    parser.add_argument('--n_epochs',
                        type=int,
                        default=200,
                        help='number of epochs of training')
    parser.add_argument('--batchSize',
                        type=int,
                        default=1,
                        help='size of the batches')
    parser.add_argument('--dataroot',
                        type=str,
                        default='datasets/data/',
                        help='root directory of the dataset')
    parser.add_argument('--lr',
                        type=float,
                        default=0.0002,
                        help='initial learning rate')
    parser.add_argument(
        '--decay_epoch',
        type=int,
        default=100,
        help='epoch to start linearly decaying the learning rate to 0')
    parser.add_argument('--size',
                        type=int,
                        default=256,
                        help='size of the data crop (squared assumed)')
    parser.add_argument('--input_nc',
                        type=int,
                        default=3,
                        help='number of channels of input data')
    parser.add_argument('--output_nc',
                        type=int,
                        default=3,
                        help='number of channels of output data')
    parser.add_argument('--cuda',
                        action='store_true',
                        help='use GPU computation')
    parser.add_argument(
        '--n_cpu',
        type=int,
        default=8,
        help='number of cpu threads to use during batch generation')
    opt = parser.parse_args()
    print(opt)

    if torch.cuda.is_available() and not opt.cuda:
        print(
            "WARNING: You have a CUDA device, so you should probably run with --cuda"
        )

    ###### Definition of variables ######
    # Networks
    netG_A2B = Generator(opt.input_nc, opt.output_nc)
    netG_B2A = Generator(opt.output_nc, opt.input_nc)
    netD_A = Discriminator(opt.input_nc)
    netD_B = Discriminator(opt.output_nc)

    if opt.cuda:
        netG_A2B.cuda()
        netG_B2A.cuda()
        netD_A.cuda()
        netD_B.cuda()

    netG_A2B.apply(weights_init_normal)
    netG_B2A.apply(weights_init_normal)
    netD_A.apply(weights_init_normal)
    netD_B.apply(weights_init_normal)

    # Lossess
    criterion_GAN = torch.nn.MSELoss()
    criterion_cycle = torch.nn.L1Loss()
    criterion_identity = torch.nn.L1Loss()

    # Optimizers & LR schedulers
    optimizer_G = torch.optim.Adam(itertools.chain(netG_A2B.parameters(),
                                                   netG_B2A.parameters()),
                                   lr=opt.lr,
                                   betas=(0.5, 0.999))
    optimizer_D_A = torch.optim.Adam(netD_A.parameters(),
                                     lr=opt.lr,
                                     betas=(0.5, 0.999))
    optimizer_D_B = torch.optim.Adam(netD_B.parameters(),
                                     lr=opt.lr,
                                     betas=(0.5, 0.999))

    lr_scheduler_G = torch.optim.lr_scheduler.LambdaLR(
        optimizer_G,
        lr_lambda=LambdaLR(opt.n_epochs, opt.epoch, opt.decay_epoch).step)
    lr_scheduler_D_A = torch.optim.lr_scheduler.LambdaLR(
        optimizer_D_A,
        lr_lambda=LambdaLR(opt.n_epochs, opt.epoch, opt.decay_epoch).step)
    lr_scheduler_D_B = torch.optim.lr_scheduler.LambdaLR(
        optimizer_D_B,
        lr_lambda=LambdaLR(opt.n_epochs, opt.epoch, opt.decay_epoch).step)

    # Inputs & targets memory allocation
    Tensor = torch.cuda.FloatTensor if opt.cuda else torch.Tensor
    input_A = Tensor(opt.batchSize, opt.input_nc, opt.size, opt.size)
    input_B = Tensor(opt.batchSize, opt.output_nc, opt.size, opt.size)
    target_real = Variable(Tensor(opt.batchSize).fill_(1.0),
                           requires_grad=False)
    target_fake = Variable(Tensor(opt.batchSize).fill_(0.0),
                           requires_grad=False)

    fake_A_buffer = ReplayBuffer()
    fake_B_buffer = ReplayBuffer()

    # Dataset loader
    transforms_ = [
        transforms.Resize(int(opt.size * 1.12), Image.BICUBIC),
        transforms.RandomCrop(opt.size),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ]
    dataloader = DataLoader(ImageDataset(opt.dataroot,
                                         transforms_=transforms_,
                                         unaligned=True),
                            batch_size=opt.batchSize,
                            shuffle=True,
                            num_workers=opt.n_cpu)

    # Loss plot
    logger = Logger(opt.n_epochs, len(dataloader))
    ###################################

    ###### Training ######
    for epoch in range(opt.epoch, opt.n_epochs):
        for i, batch in enumerate(dataloader):
            # Set model input
            real_A = Variable(input_A.copy_(batch['A']))
            real_B = Variable(input_B.copy_(batch['B']))

            ###### Generators A2B and B2A ######
            optimizer_G.zero_grad()

            # Identity loss
            # G_A2B(B) should equal B if real B is fed
            same_B = netG_A2B(real_B)
            loss_identity_B = criterion_identity(same_B, real_B) * 5.0
            # G_B2A(A) should equal A if real A is fed
            same_A = netG_B2A(real_A)
            loss_identity_A = criterion_identity(same_A, real_A) * 5.0

            # GAN loss
            fake_B = netG_A2B(real_A)
            pred_fake = netD_B(fake_B)
            loss_GAN_A2B = criterion_GAN(pred_fake, target_real)

            fake_A = netG_B2A(real_B)
            pred_fake = netD_A(fake_A)
            loss_GAN_B2A = criterion_GAN(pred_fake, target_real)

            # Cycle loss
            recovered_A = netG_B2A(fake_B)
            loss_cycle_ABA = criterion_cycle(recovered_A, real_A) * 10.0

            recovered_B = netG_A2B(fake_A)
            loss_cycle_BAB = criterion_cycle(recovered_B, real_B) * 10.0

            # Total loss
            loss_G = loss_identity_A + loss_identity_B + loss_GAN_A2B + loss_GAN_B2A + loss_cycle_ABA + loss_cycle_BAB
            loss_G.backward()

            optimizer_G.step()
            ###################################

            ###### Discriminator A ######
            optimizer_D_A.zero_grad()

            # Real loss
            pred_real = netD_A(real_A)
            loss_D_real = criterion_GAN(pred_real, target_real)

            # Fake loss
            fake_A = fake_A_buffer.push_and_pop(fake_A)
            pred_fake = netD_A(fake_A.detach())
            loss_D_fake = criterion_GAN(pred_fake, target_fake)

            # Total loss
            loss_D_A = (loss_D_real + loss_D_fake) * 0.5
            loss_D_A.backward()

            optimizer_D_A.step()
            ###################################

            ###### Discriminator B ######
            optimizer_D_B.zero_grad()

            # Real loss
            pred_real = netD_B(real_B)
            loss_D_real = criterion_GAN(pred_real, target_real)

            # Fake loss
            fake_B = fake_B_buffer.push_and_pop(fake_B)
            pred_fake = netD_B(fake_B.detach())
            loss_D_fake = criterion_GAN(pred_fake, target_fake)

            # Total loss
            loss_D_B = (loss_D_real + loss_D_fake) * 0.5
            loss_D_B.backward()

            optimizer_D_B.step()
            ###################################

            # Progress report (http://localhost:8097)
            logger.log(
                {
                    'loss_G': loss_G,
                    'loss_G_identity': (loss_identity_A + loss_identity_B),
                    'loss_G_GAN': (loss_GAN_A2B + loss_GAN_B2A),
                    'loss_G_cycle': (loss_cycle_ABA + loss_cycle_BAB),
                    'loss_D': (loss_D_A + loss_D_B)
                },
                images={
                    'real_A': real_A,
                    'real_B': real_B,
                    'fake_A': fake_A,
                    'fake_B': fake_B
                })

        # Update learning rates
        lr_scheduler_G.step()
        lr_scheduler_D_A.step()
        lr_scheduler_D_B.step()

        # Save models checkpoints
        torch.save(netG_A2B.state_dict(), 'output/netG_A2B.pth')
        torch.save(netG_B2A.state_dict(), 'output/netG_B2A.pth')
        torch.save(netD_A.state_dict(), 'output/netD_A.pth')
        torch.save(netD_B.state_dict(), 'output/netD_B.pth')
Пример #26
0
def main(args):
    torch.manual_seed(0)
    if args.mb_D:
        raise NotImplementedError('mb_D not implemented')
        assert args.batch_size > 1, 'batch size needs to be larger than 1 if mb_D'

    if args.img_norm != 'znorm':
        raise NotImplementedError('{} not implemented'.format(args.img_norm))

    assert args.act in ['relu', 'mish'], 'args.act = {}'.format(args.act)

    modelarch = 'C_{0}_{1}_{2}_{3}_{4}{5}{6}{7}{8}{9}{10}{11}{12}{13}{14}{15}{16}{17}{18}{19}{20}{21}{22}'.format(
        args.size, args.batch_size, args.lr,  args.n_epochs, args.decay_epoch, # 0, 1, 2, 3, 4
        '_G' if args.G_extra else '',  # 5
        '_D' if args.D_extra else '',  # 6
        '_U' if args.upsample else '',  # 7
        '_S' if args.slow_D else '',  # 8
        '_RL{}-{}'.format(args.start_recon_loss_val, args.start_recon_loss_val),  # 9
        '_GL{}-{}'.format(args.start_gan_loss_val, args.start_gan_loss_val),  # 10
        '_prop' if args.keep_prop else '',  # 11
        '_' + args.img_norm,  # 12
        '_WL' if args.wasserstein else '',  # 13
        '_MBD' if args.mb_D else '',  # 14
        '_FM' if args.fm_loss else '',  # 15
        '_BF{}'.format(args.buffer_size) if args.buffer_size != 50 else '',  # 16
        '_N' if args.add_noise else '',  # 17
        '_L{}'.format(args.load_iter) if args.load_iter > 0 else '',  # 18
        '_res{}'.format(args.n_resnet_blocks),  # 19
        '_n{}'.format(args.data_subset) if args.data_subset is not None else '',  # 20
        '_{}'.format(args.optim),  # 21
        '_{}'.format(args.act))  # 22

    samples_path = os.path.join(args.output_dir, modelarch, 'samples')
    safe_mkdirs(samples_path)
    model_path = os.path.join(args.output_dir, modelarch, 'models')
    safe_mkdirs(model_path)
    test_path = os.path.join(args.output_dir, modelarch, 'test')
    safe_mkdirs(test_path)

    # Definition of variables ######
    # Networks
    netG_A2B = Generator(args.input_nc, args.output_nc, img_size=args.size,
                         extra_layer=args.G_extra, upsample=args.upsample,
                         keep_weights_proportional=args.keep_prop,
                         n_residual_blocks=args.n_resnet_blocks,
                         act=args.act)
    netG_B2A = Generator(args.output_nc, args.input_nc, img_size=args.size,
                         extra_layer=args.G_extra, upsample=args.upsample,
                         keep_weights_proportional=args.keep_prop,
                         n_residual_blocks=args.n_resnet_blocks,
                         act=args.act)
    netD_A = Discriminator(args.input_nc, extra_layer=args.D_extra, mb_D=args.mb_D, x_size=args.size)
    netD_B = Discriminator(args.output_nc, extra_layer=args.D_extra, mb_D=args.mb_D, x_size=args.size)

    if args.cuda:
        netG_A2B.cuda()
        netG_B2A.cuda()
        netD_A.cuda()
        netD_B.cuda()

    if args.load_iter == 0:
        netG_A2B.apply(weights_init_normal)
        netG_B2A.apply(weights_init_normal)
        netD_A.apply(weights_init_normal)
        netD_B.apply(weights_init_normal)
    else:
        netG_A2B.load_state_dict(torch.load(os.path.join(args.load_dir, 'models', 'G_A2B_{}.pth'.format(args.load_iter))))
        netG_B2A.load_state_dict(torch.load(os.path.join(args.load_dir, 'models', 'G_B2A_{}.pth'.format(args.load_iter))))
        netD_A.load_state_dict(torch.load(os.path.join(args.load_dir, 'models', 'D_A_{}.pth'.format(args.load_iter))))
        netD_B.load_state_dict(torch.load(os.path.join(args.load_dir, 'models', 'D_B_{}.pth'.format(args.load_iter))))

        netG_A2B.train()
        netG_B2A.train()
        netD_A.train()
        netD_B.train()

    # Lossess
    criterion_GAN = wasserstein_loss if args.wasserstein else torch.nn.MSELoss()
    criterion_cycle = torch.nn.L1Loss()
    criterion_identity = torch.nn.L1Loss()
    feat_criterion = torch.nn.HingeEmbeddingLoss()

    # I could also update D only if iters % 2 == 0
    lr_G = args.lr
    lr_D = args.lr / 2 if args.slow_D else args.lr

    # Optimizers & LR schedulers
    if args.optim == 'adam':
        optim = torch.optim.Adam
    elif args.optim == 'radam':
        optim = RAdam
    elif args.optim == 'ranger':
        optim = Ranger
    elif args.optim == 'rangerlars':
        optim = RangerLars
    else:
        raise NotImplementedError('args.optim = {} not implemented'.format(args.optim))

    optimizer_G = optim(itertools.chain(netG_A2B.parameters(), netG_B2A.parameters()),
                        lr=args.lr, betas=(0.5, 0.999))
    optimizer_D_A = optim(netD_A.parameters(), lr=lr_G, betas=(0.5, 0.999))
    optimizer_D_B = optim(netD_B.parameters(), lr=lr_D, betas=(0.5, 0.999))

    lr_scheduler_G = torch.optim.lr_scheduler.LambdaLR(optimizer_G, lr_lambda=LambdaLR(args.n_epochs, args.load_iter, args.decay_epoch).step)
    lr_scheduler_D_A = torch.optim.lr_scheduler.LambdaLR(optimizer_D_A, lr_lambda=LambdaLR(args.n_epochs, args.load_iter, args.decay_epoch).step)
    lr_scheduler_D_B = torch.optim.lr_scheduler.LambdaLR(optimizer_D_B, lr_lambda=LambdaLR(args.n_epochs, args.load_iter, args.decay_epoch).step)

    # Inputs & targets memory allocation
    Tensor = torch.cuda.FloatTensor if args.cuda else torch.Tensor
    input_A = Tensor(args.batch_size, args.input_nc, args.size, args.size)
    input_B = Tensor(args.batch_size, args.output_nc, args.size, args.size)
    target_real = Variable(Tensor(args.batch_size).fill_(1.0), requires_grad=False)
    target_fake = Variable(Tensor(args.batch_size).fill_(0.0), requires_grad=False)

    fake_A_buffer = ReplayBuffer(args.buffer_size)
    fake_B_buffer = ReplayBuffer(args.buffer_size)

    # Transforms and dataloader for training set
    transforms_ = []
    if args.resize_crop:
        transforms_ += [transforms.Resize(int(args.size*1.12), Image.BICUBIC),
                        transforms.RandomCrop(args.size)]
    else:
        transforms_ += [transforms.Resize(args.size, Image.BICUBIC)]

    if args.horizontal_flip:
        transforms_ += [transforms.RandomHorizontalFlip()]

    transforms_ += [transforms.ToTensor()]

    if args.add_noise:
        transforms_ += [transforms.Lambda(lambda x: x + torch.randn_like(x))]

    transforms_norm = []
    if args.img_norm == 'znorm':
        transforms_norm += [transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
    elif 'scale01' in args.img_norm:
        transforms_norm += [transforms.Lambda(lambda x: x.mul(1/255))]  # TODO this might not preserve the dimensions. is .mul per element?
        if 'flip' in args.img_norm:
            transforms_norm += [transforms.Lambda(lambda x: (x - 1).abs())]  # TODO this might not preserve the dimensions. is .mul per element?
    else:
        raise ValueError('wrong --img_norm. only znorm|scale01|scale01flip')

    transforms_ += transforms_norm

    dataloader = DataLoader(ImageDataset(args.dataroot, transforms_=transforms_, unaligned=True, n=args.data_subset),
                            batch_size=args.batch_size, shuffle=True, num_workers=args.n_cpu)

    # Transforms and dataloader for test set
    transforms_test_ = [transforms.Resize(args.size, Image.BICUBIC),
                        transforms.ToTensor()]
    transforms_test_ += transforms_norm

    dataloader_test = DataLoader(ImageDataset(args.dataroot, transforms_=transforms_test_, mode='test'),
                                 batch_size=args.batch_size, shuffle=False, num_workers=args.n_cpu)
    # Training ######
    if args.load_iter == 0 and args.load_epoch != 0:
        print('****** NOTE: args.load_iter == 0 and args.load_epoch != 0 ******')

    iter = args.load_iter
    prev_time = time.time()
    n_test = 10e10 if args.n_test is None else args.n_test
    n_sample = 10e10 if args.n_sample is None else args.n_sample

    rl_delta_x = args.n_epochs - args.recon_loss_epoch
    rl_delta_y = args.end_recon_loss_val - args.start_recon_loss_val

    gan_delta_x = args.n_epochs - args.gan_loss_epoch
    gan_delta_y = args.end_gan_loss_val - args.start_gan_loss_val

    for epoch in range(args.load_epoch, args.n_epochs):

        rl_effective_epoch = max(epoch - args.recon_loss_epoch, 0)
        recon_loss_rate = args.start_recon_loss_val + rl_effective_epoch * (rl_delta_y / rl_delta_x)

        gan_effective_epoch = max(epoch - args.gan_loss_epoch, 0)
        gan_loss_rate = args.start_gan_loss_val + gan_effective_epoch * (gan_delta_y / gan_delta_x)

        id_loss_rate = 5.0

        for i, batch in enumerate(dataloader):
            # Set model input
            real_A = Variable(input_A.copy_(batch['A']))
            real_B = Variable(input_B.copy_(batch['B']))

            # Generators A2B and B2A ######
            optimizer_G.zero_grad()

            # Identity loss
            # G_A2B(B) should equal B if real B is fed
            same_B = netG_A2B(real_B)
            loss_identity_B = criterion_identity(same_B, real_B)
            # G_B2A(A) should equal A if real A is fed
            same_A = netG_B2A(real_A)
            loss_identity_A = criterion_identity(same_A, real_A)

            # GAN loss
            fake_B = netG_A2B(real_A)
            pred_fake, _ = netD_B(fake_B)
            loss_GAN_A2B = criterion_GAN(pred_fake, target_real)

            fake_A = netG_B2A(real_B)
            pred_fake, _ = netD_A(fake_A)
            loss_GAN_B2A = criterion_GAN(pred_fake, target_real)

            # Cycle loss
            recovered_A = netG_B2A(fake_B)
            loss_cycle_ABA = criterion_cycle(recovered_A, real_A)

            recovered_B = netG_A2B(fake_A)
            loss_cycle_BAB = criterion_cycle(recovered_B, real_B)

            # Total loss
            loss_G = (loss_identity_A + loss_identity_B) * id_loss_rate
            loss_G += (loss_GAN_A2B + loss_GAN_B2A) * gan_loss_rate
            loss_G += (loss_cycle_ABA + loss_cycle_BAB) * recon_loss_rate

            loss_G.backward()

            optimizer_G.step()

            # Discriminator A ######
            optimizer_D_A.zero_grad()

            # Real loss
            pred_real, _ = netD_A(real_A)
            loss_D_real = criterion_GAN(pred_real, target_real)

            # Fake loss
            fake_A = fake_A_buffer.push_and_pop(fake_A)
            pred_fake, _ = netD_A(fake_A.detach())
            loss_D_fake = criterion_GAN(pred_fake, target_fake)

            loss_D_A = (loss_D_real + loss_D_fake) * 0.5

            if args.fm_loss:
                pred_real, feats_real = netD_A(real_A)
                pred_fake, feats_fake = netD_A(fake_A.detach())

                fm_loss_A = get_fm_loss(feats_real, feats_fake, feat_criterion, args.cuda)

                loss_D_A = loss_D_A * 0.1 + fm_loss_A * 0.9

            loss_D_A.backward()

            optimizer_D_A.step()

            # Discriminator B ######
            optimizer_D_B.zero_grad()

            # Real loss
            pred_real, _ = netD_B(real_B)
            loss_D_real = criterion_GAN(pred_real, target_real)

            # Fake loss
            fake_B = fake_B_buffer.push_and_pop(fake_B)
            pred_fake, _ = netD_B(fake_B.detach())
            loss_D_fake = criterion_GAN(pred_fake, target_fake)

            loss_D_B = (loss_D_real + loss_D_fake)*0.5

            if args.fm_loss:
                pred_real, feats_real = netD_B(real_B)
                pred_fake, feats_fake = netD_B(fake_B.detach())

                fm_loss_B = get_fm_loss(feats_real, feats_fake, feat_criterion, args.cuda)

                loss_D_B = loss_D_B * 0.1 + fm_loss_B * 0.9

            loss_D_B.backward()

            optimizer_D_B.step()

            if iter % args.log_interval == 0:

                print('---------------------')
                print('GAN loss:', as_np(loss_GAN_A2B), as_np(loss_GAN_B2A))
                print('Identity loss:', as_np(loss_identity_A), as_np(loss_identity_B))
                print('Cycle loss:', as_np(loss_cycle_ABA), as_np(loss_cycle_BAB))
                print('D loss:', as_np(loss_D_A), as_np(loss_D_B))
                if args.fm_loss:
                    print('fm loss:', as_np(fm_loss_A), as_np(fm_loss_B))
                print('recon loss rate:', recon_loss_rate)
                print('time:', time.time() - prev_time)
                prev_time = time.time()

            if iter % args.plot_interval == 0:
                pass

            if iter % args.image_save_interval == 0:
                samples_path_ = os.path.join(samples_path, str(iter / args.image_save_interval))
                safe_mkdirs(samples_path_)

                # New savedir
                test_pth_AB = os.path.join(test_path, str(iter / args.image_save_interval), 'AB')
                test_pth_BA = os.path.join(test_path, str(iter / args.image_save_interval), 'BA')

                safe_mkdirs(test_pth_AB)
                safe_mkdirs(test_pth_BA)

                for j, batch_ in enumerate(dataloader_test):

                    real_A_test = Variable(input_A.copy_(batch_['A']))
                    real_B_test = Variable(input_B.copy_(batch_['B']))

                    fake_AB_test = netG_A2B(real_A_test)
                    fake_BA_test = netG_B2A(real_B_test)

                    if j < n_sample:
                        recovered_ABA_test = netG_B2A(fake_AB_test)
                        recovered_BAB_test = netG_A2B(fake_BA_test)

                        fn = os.path.join(samples_path_, str(j))
                        imageio.imwrite(fn + '.A.jpg', tensor2image(real_A_test[0], args.img_norm))
                        imageio.imwrite(fn + '.B.jpg', tensor2image(real_B_test[0], args.img_norm))
                        imageio.imwrite(fn + '.BA.jpg', tensor2image(fake_BA_test[0], args.img_norm))
                        imageio.imwrite(fn + '.AB.jpg', tensor2image(fake_AB_test[0], args.img_norm))
                        imageio.imwrite(fn + '.ABA.jpg', tensor2image(recovered_ABA_test[0], args.img_norm))
                        imageio.imwrite(fn + '.BAB.jpg', tensor2image(recovered_BAB_test[0], args.img_norm))

                    if j < n_test:
                        fn_A = os.path.basename(batch_['img_A'][0])
                        imageio.imwrite(os.path.join(test_pth_AB, fn_A), tensor2image(fake_AB_test[0], args.img_norm))

                        fn_B = os.path.basename(batch_['img_B'][0])
                        imageio.imwrite(os.path.join(test_pth_BA, fn_B), tensor2image(fake_BA_test[0], args.img_norm))

            if iter % args.model_save_interval == 0:
                # Save models checkpoints
                torch.save(netG_A2B.state_dict(), os.path.join(model_path, 'G_A2B_{}.pth'.format(iter)))
                torch.save(netG_B2A.state_dict(), os.path.join(model_path, 'G_B2A_{}.pth'.format(iter)))
                torch.save(netD_A.state_dict(), os.path.join(model_path, 'D_A_{}.pth'.format(iter)))
                torch.save(netD_B.state_dict(), os.path.join(model_path, 'D_B_{}.pth'.format(iter)))

            iter += 1

        # Update learning rates
        lr_scheduler_G.step()
        lr_scheduler_D_A.step()
        lr_scheduler_D_B.step()
Пример #27
0
# Initialize generator and discriminator
generator = Generator()
discriminator = Discriminator()

# Load pretrained generator
"""
weights_ = torch.load("./models/generator_40.pth")
weights = OrderedDict()
for k, v in weights_.items():
    weights[k.split('module.')[-1]] = v
generator.load_state_dict(weights)
"""

# Initialize fc layer of discriminator
discriminator.apply(weights_init_normal)

# find gpu devices
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device_nums = torch.cuda.device_count()

# Use all GPUs by default
if device_nums > 1:
    generator = torch.nn.DataParallel(generator, device_ids=range(device_nums))
    discriminator = torch.nn.DataParallel(discriminator,
                                          device_ids=range(device_nums))

# Set models to gpu
generator = generator.to(device)
discriminator = discriminator.to(device)
Пример #28
0
'''
remove BBG part, replace with npy files
netBBG = Generator(ngpu, nz, ngf, nc).to(device)
netBBG.load_state_dict(torch.load(opt.netBBG))

netBBD = Discriminator(ngpu, nc, ndf).to(device)
netBBD.load_state_dict(torch.load(opt.netBBD))
'''
netWBG = Generator(ngpu, nz, ngf, nc).to(device)
netWBG.apply(weights_init)
if opt.netWBG != '':
    netWBG.load_state_dict(torch.load(opt.netWBG))

netWBD = Discriminator(ngpu, nc, ndf).to(device)
netWBD.apply(weights_init)
if opt.netWBD != '':
    netWBD.load_state_dict(torch.load(opt.netWBD))
'''
netBBG.eval()
netBBD.eval()

##### White-box attack ####
# Assumes we have direct access to BBD
wb_predictions = []

# loop over training data
for i, data in enumerate(trainloader, 0):
    real_cpu = data[0].to(device)
    output = netBBD(real_cpu)
    output = [x for x in output.detach().cpu().numpy()]
Пример #29
0
def main():
    ############################
    # argument setup
    ############################
    args, cfg = setup_args_and_config()

    if args.show:
        print("### Run Argv:\n> {}".format(' '.join(sys.argv)))
        print("### Run Arguments:")
        s = dump_args(args)
        print(s + '\n')
        print("### Configs:")
        print(cfg.dumps())
        sys.exit()

    timestamp = utils.timestamp()
    unique_name = "{}_{}".format(timestamp, args.name)
    cfg['unique_name'] = unique_name  # for save directory
    cfg['name'] = args.name

    utils.makedirs('logs')
    utils.makedirs(Path('checkpoints', unique_name))

    # logger
    logger_path = Path('logs', f"{unique_name}.log")
    logger = Logger.get(file_path=logger_path,
                        level=args.log_lv,
                        colorize=True)

    # writer
    image_scale = 0.6
    writer_path = Path('runs', unique_name)
    if args.tb_image:
        writer = utils.TBWriter(writer_path, scale=image_scale)
    else:
        image_path = Path('images', unique_name)
        writer = utils.TBDiskWriter(writer_path, image_path, scale=image_scale)

    # log default informations
    args_str = dump_args(args)
    logger.info("Run Argv:\n> {}".format(' '.join(sys.argv)))
    logger.info("Args:\n{}".format(args_str))
    logger.info("Configs:\n{}".format(cfg.dumps()))
    logger.info("Unique name: {}".format(unique_name))

    # seed
    np.random.seed(cfg['seed'])
    torch.manual_seed(cfg['seed'])
    random.seed(cfg['seed'])

    if args.deterministic:
        #  https://discuss.pytorch.org/t/how-to-get-deterministic-behavior/18177/16
        #  https://pytorch.org/docs/stable/notes/randomness.html
        torch.backends.cudnn.benchmark = False
        torch.backends.cudnn.deterministic = True
        cfg['n_workers'] = 0
        logger.info("#" * 80)
        logger.info("# Deterministic option is activated !")
        logger.info("#" * 80)
    else:
        torch.backends.cudnn.benchmark = True

    ############################
    # setup dataset & loader
    ############################
    logger.info("Get dataset ...")

    # setup language dependent values
    content_font, n_comp_types, n_comps = setup_language_dependent(cfg)

    # setup transform
    transform = transforms.Compose(
        [transforms.ToTensor(),
         transforms.Normalize([0.5], [0.5])])

    # setup data
    hdf5_data, meta = setup_data(cfg, transform)

    # setup dataset
    trn_dset, loader = get_dset_loader(hdf5_data,
                                       meta['train']['fonts'],
                                       meta['train']['chars'],
                                       transform,
                                       True,
                                       cfg,
                                       content_font=content_font)

    logger.info("### Training dataset ###")
    logger.info("# of avail fonts = {}".format(trn_dset.n_fonts))
    logger.info(f"Total {len(loader)} iterations per epochs")
    logger.info("# of avail items = {}".format(trn_dset.n_avails))
    logger.info(f"#fonts = {trn_dset.n_fonts}, #chars = {trn_dset.n_chars}")

    val_loaders = setup_cv_dset_loader(hdf5_data, meta, transform,
                                       n_comp_types, content_font, cfg)
    sfuc_loader = val_loaders['SeenFonts-UnseenChars']
    sfuc_dset = sfuc_loader.dataset
    ufsc_loader = val_loaders['UnseenFonts-SeenChars']
    ufsc_dset = ufsc_loader.dataset
    ufuc_loader = val_loaders['UnseenFonts-UnseenChars']
    ufuc_dset = ufuc_loader.dataset

    logger.info("### Cross-validation datasets ###")
    logger.info("Seen fonts, Unseen chars | "
                "#items = {}, #fonts = {}, #chars = {}, #steps = {}".format(
                    len(sfuc_dset), len(sfuc_dset.fonts), len(sfuc_dset.chars),
                    len(sfuc_loader)))
    logger.info("Unseen fonts, Seen chars | "
                "#items = {}, #fonts = {}, #chars = {}, #steps = {}".format(
                    len(ufsc_dset), len(ufsc_dset.fonts), len(ufsc_dset.chars),
                    len(ufsc_loader)))
    logger.info("Unseen fonts, Unseen chars | "
                "#items = {}, #fonts = {}, #chars = {}, #steps = {}".format(
                    len(ufuc_dset), len(ufuc_dset.fonts), len(ufuc_dset.chars),
                    len(ufuc_loader)))

    ############################
    # build model
    ############################
    logger.info("Build model ...")
    # generator
    g_kwargs = cfg.get('g_args', {})
    gen = MACore(1,
                 cfg['C'],
                 1,
                 **g_kwargs,
                 n_comps=n_comps,
                 n_comp_types=n_comp_types,
                 language=cfg['language'])
    gen.cuda()
    gen.apply(weights_init(cfg['init']))

    d_kwargs = cfg.get('d_args', {})
    disc = Discriminator(cfg['C'], trn_dset.n_fonts, trn_dset.n_chars,
                         **d_kwargs)
    disc.cuda()
    disc.apply(weights_init(cfg['init']))

    if cfg['ac_w'] > 0.:
        C = gen.mem_shape[0]
        aux_clf = AuxClassifier(C, n_comps, **cfg['ac_args'])
        aux_clf.cuda()
        aux_clf.apply(weights_init(cfg['init']))
    else:
        aux_clf = None
        assert cfg[
            'ac_gen_w'] == 0., "ac_gen loss is only available with ac loss"

    # setup optimizer
    g_optim = optim.Adam(gen.parameters(),
                         lr=cfg['g_lr'],
                         betas=cfg['adam_betas'])
    d_optim = optim.Adam(disc.parameters(),
                         lr=cfg['d_lr'],
                         betas=cfg['adam_betas'])
    ac_optim = optim.Adam(aux_clf.parameters(), lr=cfg['g_lr'], betas=cfg['adam_betas']) \
               if aux_clf is not None else None

    # resume checkpoint
    st_step = 1
    if args.resume:
        st_step, loss = load_checkpoint(args.resume, gen, disc, aux_clf,
                                        g_optim, d_optim, ac_optim)
        logger.info(
            "Resumed checkpoint from {} (Step {}, Loss {:7.3f})".format(
                args.resume, st_step - 1, loss))
    if args.finetune:
        load_gen_checkpoint(args.finetune, gen)

    ############################
    # setup validation
    ############################
    evaluator = Evaluator(hdf5_data,
                          trn_dset.avails,
                          logger,
                          writer,
                          cfg['batch_size'],
                          content_font=content_font,
                          transform=transform,
                          language=cfg['language'],
                          val_loaders=val_loaders,
                          meta=meta)
    if args.debug:
        evaluator.n_cv_batches = 10
        logger.info("Change CV batches to 10 for debugging")

    ############################
    # start training
    ############################
    trainer = Trainer(gen, disc, g_optim, d_optim, aux_clf, ac_optim, writer,
                      logger, evaluator, cfg)
    trainer.train(loader, st_step)
        torch.nn.init.normal_(m.weight, 0.0, 0.02)
    if isinstance(m, nn.BatchNorm2d):
        torch.nn.init.normal_(m.weight, 0.0, 0.02)
        torch.nn.init.constant_(m.bias, 0)

# Feel free to change pretrained to False if you're training the model from scratch
pretrained = False
if pretrained:
    loaded_state = torch.load("pix2pix_15000.pth")
    gen.load_state_dict(loaded_state["gen"])
    gen_opt.load_state_dict(loaded_state["gen_opt"])
    disc.load_state_dict(loaded_state["disc"])
    disc_opt.load_state_dict(loaded_state["disc_opt"])
else:
    gen = gen.apply(weights_init)
    disc = disc.apply(weights_init)


# UNQ_C2 (UNIQUE CELL IDENTIFIER, DO NOT EDIT)
# GRADED CLASS: get_gen_loss
def get_gen_loss(gen, disc, real, condition, adv_criterion, recon_criterion, lambda_recon):
    '''
    Return the loss of the generator given inputs.
    Parameters:
        gen: the generator; takes the condition and returns potential images
        disc: the discriminator; takes images and the condition and
          returns real/fake prediction matrices
        real: the real images (e.g. maps) to be used to evaluate the reconstruction
        condition: the source images (e.g. satellite imagery) which are used to produce the real images
        adv_criterion: the adversarial loss function; takes the discriminator 
                  predictions and the true labels and returns a adversarial