Esempio n. 1
0
    def __init__(self, opt):
        self.opt = opt
        self.netG = _netG(opt.anchor_num, opt.latent_dim, opt.nz, opt.ngf, opt.nc)
        self.netD = _netD(opt.nc, opt.ndf)
        self.encoder = _encoder(opt.nc, opt.ndf, opt.latent_dim)
        self.decoder = _decoder(opt.nc, opt.ngf, opt.latent_dim)
        self.learnBasis = nn.Linear(self.opt.anchor_num, self.opt.latent_dim, bias=False)
        self.learnCoeff = nn.Linear(self.opt.anchor_num, self.opt.batchSize_s2, bias=False)
        self.dataloader = torch.utils.data.DataLoader(utils.createDataSet(self.opt, self.opt.imageSize),
                                                      batch_size=self.opt.batchSize_s1,
                                                      shuffle=True, num_workers=int(self.opt.workers))

        self.criterion_bce = nn.BCELoss()
        self.criterion_l1 = nn.L1Loss(reduction='elementwise_mean')
        self.criterion_l2 = nn.MSELoss(reduction='elementwise_mean')

        # initialize the optimizers
        self.optimizerD = optim.Adam(self.netD.parameters(), lr=opt.s3_lr, betas=(opt.beta1, opt.beta2))
        self.optimizerG = optim.Adam(self.netG.parameters(), lr=opt.s3_lr, betas=(opt.beta1, opt.beta2))
        self.optimizerEncoder = optim.Adam(self.encoder.parameters(), lr=opt.s1_lr, betas=(opt.beta1, opt.beta2))
        self.optimizerDecoder = optim.Adam(self.decoder.parameters(), lr=opt.s1_lr, betas=(opt.beta1, opt.beta2))
        self.optimizerBasis = optim.Adam(self.learnBasis.parameters(), lr=opt.s2_lr, betas=(opt.beta1, opt.beta2))
        self.optimizerCoeff = optim.Adam(self.learnCoeff.parameters(), lr=opt.s2_lr, betas=(opt.beta1, opt.beta2))

        # some variables
        input = torch.FloatTensor(opt.batchSize_s1, opt.nc, opt.imageSize, opt.imageSize)
        label = torch.FloatTensor(opt.batchSize_s3)
        self.one = torch.FloatTensor([1])
        self.mone = self.one * -1
        self.one = self.one.cuda()
        self.mone = self.mone.cuda()
        if opt.cuda:
            input, label = input.cuda(), label.cuda()
            self.netD = utils.dataparallel(self.netD, opt.ngpu, opt.gpu)
            self.netG = utils.dataparallel(self.netG, opt.ngpu, opt.gpu)
            self.encoder = utils.dataparallel(self.encoder, opt.ngpu, opt.gpu)
            self.decoder = utils.dataparallel(self.decoder, opt.ngpu, opt.gpu)
            self.learnBasis = utils.dataparallel(self.learnBasis, opt.ngpu, opt.gpu)
            self.learnCoeff = utils.dataparallel(self.learnCoeff, opt.ngpu, opt.gpu)
            self.criterion_bce.cuda()
            self.criterion_l1.cuda()
            self.criterion_l2.cuda()

        self.input = Variable(input)
        self.label = Variable(label)
        self.batchSize = self.opt.batchSize_s1
Esempio n. 2
0
    def _build_model(self):
        '''
        Builds generator and discriminator
        :return: tuple of (generator, discriminator)
        '''
        netG = _netG(self.opt.nz, self.opt.ngf, self.opt.alpha, self.nc,
                     self.use_gpu)
        netG.apply(self._weights_init)
        print(netG)

        netD = _netD(self.opt.ndf, self.opt.alpha, self.nc, self.opt.drop_rate,
                     self.num_classes, self.use_gpu)
        netD.apply(self._weights_init)
        print(netD)

        if self.use_gpu:
            netG = netG.cuda()
            netD = netD.cuda()

        return netG, netD
Esempio n. 3
0
    def __init__(self, glo_params, image_params, rn):
        self.netZ = model._netZ(glo_params.nz, image_params.n)
        self.netZ.apply(model.weights_init)
        self.netZ.cuda()
        self.rn = rn

        self.netG = model._netG(glo_params.nz, image_params.sz[0],
                                image_params.nc, glo_params.do_bn)
        self.netG.apply(model.weights_init)
        self.netG.cuda()
        # self.netG = nn.DataParallel(self.netG)

        self.vis_n = 64

        fixed_noise = torch.FloatTensor(self.vis_n,
                                        glo_params.nz).normal_(0, 1)
        self.fixed_noise = fixed_noise.cuda()

        self.glo_params = glo_params
        self.image_params = image_params

        # lap_criterion = pyr.MS_Lap(4, 5).cuda()
        self.dist = utils.distance_metric(image_params.sz[0], image_params.nc,
                                          glo_params.force_l2)
Esempio n. 4
0
overlapL2Weight = 10
# custom weights initialization called on netG and netD


def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        m.weight.data.normal_(0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        m.weight.data.normal_(1.0, 0.02)
        m.bias.data.fill_(0)


resume_epoch = 0

netG = _netG(opt)
netG.apply(weights_init)
if opt.netG != '':
    netG.load_state_dict(
        torch.load(
            opt.netG,
            map_location=lambda storage, location: storage)['state_dict'])
    resume_epoch = torch.load(opt.netG)['epoch']
print(netG)

netD = _netlocalD(opt)
netD.apply(weights_init)
if opt.netD != '':
    netD.load_state_dict(
        torch.load(
            opt.netD,
Esempio n. 5
0
nc = params['nc']
sz = params['sz']
batch_size = params['fid']['batch_size']
total_n = params['fid']['n_images']

W = torch.load('runs/nets_%s/netZ_nag.pth' % (rn))
W = W['emb.weight'].data.cpu().numpy()

Zs = utils.sample_gaussian(torch.from_numpy(W), total_n)
Zs = Zs.data.cpu().numpy()

state_dict = torch.load('runs/nets_%s/netT_nag.pth' % rn)
netT = icp._netT(d, nz).cuda()
netT.load_state_dict(state_dict)

netG = model._netG(nz, sz, nc, do_bn).cuda()
state_dict = torch.load('runs/nets_%s/netG_nag.pth' % (rn))
netG.load_state_dict(state_dict)

train_ims = np.load(train_path).astype('float')
test_ims = np.load(test_path).astype('float')

rp = np.random.permutation(len(train_ims))[:total_n]
train_ims = train_ims[rp]
rp = np.random.permutation(len(test_ims))[:total_n]
test_ims = test_ims[rp]

batch_n = total_n // batch_size
ims_glann = np.zeros((batch_n * batch_size, nc, sz, sz))
ims_glo = np.zeros((batch_n * batch_size, nc, sz, sz))
ims_reconstruction = np.zeros((batch_n * batch_size, nc, sz, sz))
Esempio n. 6
0
        img_name = os.path.join(self.root_dir, self.landmarks_frame.iloc[idx,
                                                                         0])
        image = io.imread(img_name).T
        #print("image size :",image.shape)
        # landmarks = self.landmarks_frame.iloc[idx, 1:].as_matrix()
        # landmarks = landmarks.astype('float').reshape(-1, 2)
        sample = image

        if self.transform:
            sample = self.transform(sample)
        return sample


MSN_path = "./log/MSNGAN"
model_checkpoint = 'netG_epoch_399.pth'
model = _netG(128, 3, 64)
MSN_root_path, MSN_trained_model = Generate_images(model, model_checkpoint,
                                                   MSN_path, 1000)

MSN_root_path = MSN_path + '/' + 'final_images' + '/'

SN_path = "./log/SNGAN"
model_checkpoint = 'netG_epoch_399.pth'
model = _netG(128, 3, 64)
SN_root_path, SN_trained_model = Generate_images(model, model_checkpoint,
                                                 SN_path, 1000)

SN_root_path = SN_path + '/' + 'final_images' + '/'

trans = transforms.Compose([
    transforms.ToTensor(),
Esempio n. 7
0
assert dataset
dataloader = torch.utils.data.DataLoader(dataset,
                                         batch_size=opt.batchSize,
                                         shuffle=True,
                                         num_workers=int(opt.workers))

ngpu = len(gpu_ids)
nz = int(opt.nz)
ngf = int(opt.ngf)
ndf = int(opt.ndf)
nc = 3

#-------Initial Model ----------
#--------E-----------
if opt.instance:
    netG = _netG(ngpu, norm_layer=nn.InstanceNorm2d)
    netG.apply(weights_init)
else:
    netG = _netG(ngpu)
    netG.apply(weights_init)

if opt.netG != '':
    netG.load_state_dict(torch.load(opt.netG))
print(netG)

#---------D------------
if opt.instance:
    netD = _netD(ngpu,
                 use_sigmoid=(not opt.lsgan),
                 norm_layer=nn.InstanceNorm2d)
    netD.apply(weights_init)
Esempio n. 8
0
def main():
    try:
        os.makedirs("result/test/cropped")
        os.makedirs("result/test/real")
        os.makedirs("result/test/recon")
        os.makedirs("model")
    except OSError:
        pass

    if opt.manualSeed is None:
        opt.manualSeed = random.randint(1, 10000)
    print("Random Seed: ", opt.manualSeed)
    random.seed(opt.manualSeed)
    torch.manual_seed(opt.manualSeed)
    if opt.cuda:
        torch.cuda.manual_seed_all(opt.manualSeed)

    cudnn.benchmark = True

    if torch.cuda.is_available() and not opt.cuda:
        print(
            "WARNING: You have a CUDA device, so you should probably run with --cuda"
        )

    transform = transforms.Compose([
        transforms.Resize(opt.imageSize),
        transforms.CenterCrop(opt.imageSize),
        transforms.ToTensor(),
        transforms.Normalize((0.0, 0.0, 0.0), (1.0, 1.0, 1.0))
    ])
    dataset = dset.ImageFolder(root=opt.dataroot, transform=transform)
    assert dataset
    dataloader = torch.utils.data.DataLoader(dataset,
                                             batch_size=opt.batchSize,
                                             shuffle=False,
                                             num_workers=int(opt.workers))

    ngpu = int(opt.ngpu)
    nz = int(opt.nz)
    ngf = int(opt.ngf)
    ndf = int(opt.ndf)
    nc = 3
    nef = int(opt.nef)
    nBottleneck = int(opt.nBottleneck)
    wtl2 = float(opt.wtl2)
    overlapL2Weight = 10

    netG = _netG(opt)
    netG.load_state_dict(
        torch.load(
            opt.netG,
            map_location=lambda storage, location: storage)['state_dict'])
    print(netG)

    criterion = nn.BCELoss()
    criterionMSE = nn.MSELoss()

    input_real = torch.FloatTensor(opt.batchSize, 3, opt.imageSize,
                                   opt.imageSize)
    input_cropped = torch.FloatTensor(opt.batchSize, 3, opt.imageSize,
                                      opt.imageSize)
    label = torch.FloatTensor(opt.batchSize)
    real_label = 1
    fake_label = 0

    print(opt.batchSize)
    print(opt.imageSize)

    real_center = torch.FloatTensor(int(opt.batchSize), 3,
                                    int(opt.imageSize / 2),
                                    int(opt.imageSize / 2))
    #real_center = torch.FloatTensor(64, 3, 64,64)

    if opt.cuda:
        netG.cuda()
        criterion.cuda()
        criterionMSE.cuda()
        input_real, input_cropped, label = input_real.cuda(
        ), input_cropped.cuda(), label.cuda()
        real_center = real_center.cuda()

    input_real = Variable(input_real)
    input_cropped = Variable(input_cropped)
    label = Variable(label)

    real_center = Variable(real_center)
    #jittering add
    randwf = random.uniform(-1.0, 1.0)
    randhf = random.uniform(-1.0, 1.0)
    if opt.jittering:
        jitterSizeW = int(opt.imageSize / 5 * randwf)
        jitterSizeH = int(opt.imageSize / 5 * randhf)
        print("jittering : W > ", jitterSizeW, " H >", jitterSizeH)
    else:
        jitterSizeW = 0
        jitterSizeH = 0
    for i, data in enumerate(dataloader, 0):
        real_cpu, _ = data
        real_center_cpu = real_cpu[:, :,
                                   int(opt.imageSize / 4 +
                                       jitterSizeW):int(opt.imageSize / 4 +
                                                        opt.imageSize / 2 +
                                                        jitterSizeW),
                                   int(opt.imageSize / 4 +
                                       jitterSizeH):int(opt.imageSize / 4 +
                                                        opt.imageSize / 2 +
                                                        jitterSizeH)]
        batch_size = real_cpu.size(0)
        input_real.data.resize_(real_cpu.size()).copy_(real_cpu)
        input_cropped.data.resize_(real_cpu.size()).copy_(real_cpu)
        real_center.data.resize_(real_center_cpu.size()).copy_(real_center_cpu)
        input_cropped.data[:, 0,
                           int(opt.imageSize / 4 + opt.overlapPred +
                               jitterSizeW):int(opt.imageSize / 4 +
                                                opt.imageSize / 2 -
                                                opt.overlapPred + jitterSizeW),
                           int(opt.imageSize / 4 + opt.overlapPred +
                               jitterSizeH
                               ):int(opt.imageSize / 4 + opt.imageSize / 2 -
                                     opt.overlapPred +
                                     jitterSizeH)] = 2 * 117.0 / 255.0 - 1.0
        input_cropped.data[:, 1,
                           int(opt.imageSize / 4 + opt.overlapPred +
                               jitterSizeW):int(opt.imageSize / 4 +
                                                opt.imageSize / 2 -
                                                opt.overlapPred + jitterSizeW),
                           int(opt.imageSize / 4 + opt.overlapPred +
                               jitterSizeH
                               ):int(opt.imageSize / 4 + opt.imageSize / 2 -
                                     opt.overlapPred +
                                     jitterSizeH)] = 2 * 104.0 / 255.0 - 1.0
        input_cropped.data[:, 2,
                           int(opt.imageSize / 4 + opt.overlapPred +
                               jitterSizeW):int(opt.imageSize / 4 +
                                                opt.imageSize / 2 -
                                                opt.overlapPred + jitterSizeW),
                           int(opt.imageSize / 4 + opt.overlapPred +
                               jitterSizeH
                               ):int(opt.imageSize / 4 + opt.imageSize / 2 -
                                     opt.overlapPred +
                                     jitterSizeH)] = 2 * 104.0 / 255.0 - 1.0

        label.data.resize_(batch_size).fill_(real_label)
        fake = netG(input_cropped)
        label.data.fill_(fake_label)

        errG = criterionMSE(fake, real_center)
        print('errG: %.4f' % errG.data[0])

        vutils.save_image(real_cpu, 'result/test/real/real_samples.png')
        vutils.save_image(input_cropped.data,
                          'result/test/cropped/cropped_samples.png')
        recon_image = input_cropped.clone()
        recon_image.data[:, :,
                         int(opt.imageSize / 4 +
                             jitterSizeW):int(opt.imageSize / 4 +
                                              opt.imageSize / 2 + jitterSizeW),
                         int(opt.imageSize / 4 +
                             jitterSizeH):int(opt.imageSize / 4 +
                                              opt.imageSize / 2 +
                                              jitterSizeH)] = fake.data
        vutils.save_image(recon_image.data,
                          'result/test/recon/recon_center_samples.png')
Esempio n. 9
0
def train_GAN(model = "SNGAN"):
    D_loss = []
    G_loss = []
    Dx_loss = []
    DGx_loss = []
    G = _netG(nz, 3, 64)

    if model == "SNGAN":
        SND = _netD(3, 64)
    elif model == "MSNGAN":
        SND = _MSNnetD(3, 64)
    else:
        raise ValueError("Invalid GAN type given")

    if not os.path.exists('log/' + model + '/'):
        os.mkdir('log/' + model + '/')

    print(G)
    print(SND)
    G.apply(weight_filler)
    SND.apply(weight_filler)

    input = torch.FloatTensor(opt.batchsize, 3, 32, 32)
    noise = torch.FloatTensor(opt.batchsize, nz, 1, 1)
    fixed_noise = torch.FloatTensor(opt.batchsize, nz, 1, 1).normal_(0, 1)
    label = torch.FloatTensor(opt.batchsize)
    real_label = 1
    fake_label = 0

    fixed_noise = Variable(fixed_noise)
    criterion = nn.BCELoss()

    if opt.cuda:
        G.cuda()
        SND.cuda()
        criterion.cuda()
        input, label = input.cuda(), label.cuda()
        noise, fixed_noise = noise.cuda(), fixed_noise.cuda()

    optimizerG = optim.Adam(G.parameters(), lr=0.0002, betas=(0.5, 0.999))
    optimizerSND = optim.Adam(SND.parameters(), lr=0.0002, betas=(0.5, 0.999))
    print(len(list(SND.parameters())))

    num_iter = 400
    for epoch in range(num_iter):
        start_time = time()
        for i, data in enumerate(dataloader, 0):

            step = epoch * len(dataloader) + i
            ############################
            # (1) Update D network: maximize log(D(x)) + log(1 - D(G(z)))
            ###########################
            # train with real
            SND.zero_grad()
            real_cpu, _ = data
            batch_size = real_cpu.size(0)
            if opt.cuda:
               real_cpu = real_cpu.cuda()
            input.resize_(real_cpu.size()).copy_(real_cpu)
            label.resize_(batch_size).fill_(real_label)
            inputv = Variable(input)
            labelv = Variable(label)
            output = SND(inputv)

            errD_real = torch.mean(F.softplus(-output))
            #errD_real = criterion(output, labelv)
            #errD_real.backward()
            D_x = output.data.mean()
            # train with fake
            noise.resize_(batch_size, nz, 1, 1).normal_(0, 1)
            noisev = Variable(noise)
            fake = G(noisev)
            labelv = Variable(label.fill_(fake_label))
            output = SND(fake.detach())
            errD_fake = torch.mean(F.softplus(output))
            #errD_fake = criterion(output, labelv)

            D_G_z1 = output.data.mean()
            #grad_penal = gradient_penalty(inputv.data, SND)
            errD = errD_real + errD_fake #+ grad_penal*10.0
            #print(output)
            errD.backward()
            optimizerSND.step()

            ############################
            # (2) Update G network: maximize log(D(G(z)))
            ###########################
            if step % n_dis == 0:
                G.zero_grad()
                labelv = Variable(label.fill_(real_label))  # fake labels are real for generator cost
                output = SND(fake)
                errG = torch.mean(F.softplus(-output))
                #errG = criterion(output, labelv)
                errG.backward()
                D_G_z2 = output.data.mean()
                optimizerG.step()
            if i % 100 == 0:
                end_time = time() - start_time
                print('[%3d/%3d][%3d/%3d] Loss_D: %.4f Loss_G: %.4f D(x): %+.4f D(G(z)): %+.4f / %+.4f Time : %.3f'
                      % (epoch, num_iter, i, len(dataloader),
                         errD.data.item(), errG.data.item(), D_x, D_G_z1, D_G_z2, end_time))

                G_loss.append(errG.data.item())
                D_loss.append(errD.data.item())
                Dx_loss.append(D_x)
                DGx_loss.append(D_G_z1 / D_G_z2)

                start_time = time()
            if i % 100 == 0:
                vutils.save_image(real_cpu,
                        '%s/%s/real_samples.png' % ('log', model),
                        normalize=True)
                fake = G(fixed_noise)
                vutils.save_image(fake.data,
                        '%s/%s/fake_samples_epoch_%03d.png' % ('log', model, epoch),
                        normalize=True)
        # do checkpointing
    torch.save(G.state_dict(), '%s/%s/netG_epoch_%d.pth' % ('log', model, epoch))
    torch.save(SND.state_dict(), '%s/%s/netD_epoch_%d.pth' % ('log', model, epoch))

    np.savetxt('log/'+model+'/G_loss.csv', G_loss, delimiter=',')
    np.savetxt('log/'+model+'/D_loss.csv', D_loss, delimiter=',')
    np.savetxt('log/'+model+'/Dx_loss.csv', Dx_loss, delimiter=',')
    np.savetxt('log/'+model+'/DGx_loss.csv', DGx_loss, delimiter=',')
nt = int(opt.nt)
nte = int(opt.nte)


# custom weights initialization called on netG and netD
def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        m.weight.data.normal_(0.0, 0.02)
        # m.bias.data.fill_(0)
    elif classname.find('BatchNorm') != -1:
        m.weight.data.normal_(1.0, 0.02)
        m.bias.data.fill_(0)


netG = model._netG(ngpu, nz, ngf, nc, nte, nt)
netG.apply(weights_init)
if opt.netG != '':
    netG.load_state_dict(torch.load(opt.netG))
print(netG)

netD = model._netD(ngpu, nc, ndf, nte, nt)
netD.apply(weights_init)
if opt.netD != '':
    netD.load_state_dict(torch.load(opt.netD))
print(netD)

criterion = nn.BCELoss()

input = torch.FloatTensor(opt.batchSize, 3, opt.imageSize, opt.imageSize)
noise = torch.FloatTensor(opt.batchSize, nz, 1, 1)
Esempio n. 11
0
                                decay_epochs=decay, decay_rate=0.5,  pad_image=[0,0])

    
    nt = glo.GLOTrainer(data, glo_params, rn, is_cuda,[], [])
    G, Z , noise_amp = nt.train_glo(glo_opt_params)
    

    # icp

    dim = params['icp']['dim']
    nepoch = params['icp']['total_epoch']

    W = torch.load('runs/nets_%s/netZ_nag.pth' % (rn))
    W = W['emb.weight'].data.cpu().numpy()

    netG = model._netG(nz, sz, 3)
    if is_cuda:
        netG = netG.cuda()
    state_dict = torch.load('runs/nets_%s/netG_nag.pth' % (rn))
    netG.load_state_dict(state_dict)

    icpt = icp.ICPTrainer(W, dim, is_cuda)
    icpt.train_icp(nepoch)
    torch.save(icpt.icp.netT.state_dict(), 'runs/nets_%s/netT_nag.pth' % rn)

    if is_cuda:
        z = icpt.icp.netT(torch.randn(64, dim).cuda())
    else:
        z = icpt.icp.netT(torch.randn(64, dim))
    
    #net_T.append(icpt.icp.netT)
        input_noise = input_noise.resize_(opt.batchsize, 100, 1, 1)
        input_noise = input_noise.cuda()
        input_noise = Variable(input_noise)
        outputs = model(input_noise)
        fake = outputs.data
        print(fake.shape)
        for j in range(opt.batchsize):
            im = fake[j, :, :, :]
            torchvision.utils.save_image(im.view(1, im.size(0), im.size(1),
                                                 im.size(2)),
                                         os.path.join('./result', opt.name,
                                                      '%d_%d.jpg' % (i, j)),
                                         nrow=1,
                                         padding=0,
                                         normalize=True)


#-----------------Load model
def load_network(network):
    save_path = os.path.join('./model', opt.name,
                             'netG_epoch_%s.pth' % opt.which_epoch)
    network.load_state_dict(torch.load(save_path))
    return network


model_structure = _netG()
model = load_network(model_structure)
model = model.cuda()

generate_img(model)
wtl2 = float(opt.wtl2)
overlapL2Weight = 10

# custom weights initialization called on netG and netD
def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        m.weight.data.normal_(0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        m.weight.data.normal_(1.0, 0.02)
        m.bias.data.fill_(0)


resume_epoch=0

netG = _netG(opt)
netG.apply(weights_init)
if opt.netG != '':
    netG.load_state_dict(torch.load(opt.netG,map_location=lambda storage, location: storage)['state_dict'])
    resume_epoch = torch.load(opt.netG)['epoch']
print(netG)


netD = _netlocalD(opt)
netD.apply(weights_init)
if opt.netD != '':
    netD.load_state_dict(torch.load(opt.netD,map_location=lambda storage, location: storage)['state_dict'])
    resume_epoch = torch.load(opt.netD)['epoch']
print(netD)

criterion = nn.BCELoss()
Esempio n. 14
0

# custom weights initialization called on netG and netD
def weights_init(m):
    class_name = m.__class__.__name__
    if class_name.find('Conv') != -1:
        m.weight.data.normal_(0.0, 0.02)  # mean variance
    elif class_name.find('BatchNorm') != -1:
        m.weight.data.normal_(1.0, 0.02)
        m.bias.data.fill_(0)


# 恢复到指定 epoch
resume_epoch = 0

netG = _netG(opt)  # 编码器+判别器网络
netG.apply(weights_init)  # 初始化权重

if opt.netG != '':  # 可选导入模型
    netG.load_state_dict(
        torch.load(
            opt.netG,
            map_location=lambda storage, location: storage)['state_dict'])
    resume_epoch = torch.load(opt.netG)['epoch']

print(netG)

netD = _netlocalD(opt)  # 判别器网络
netD.apply(weights_init)  # 初始化判别器

if opt.netD != '':  # 可选导入模型
Esempio n. 15
0
def main():
    try:
        os.makedirs("result/train/cropped")
        os.makedirs("result/train/real")
        os.makedirs("result/train/recon")
        os.makedirs("model")
    except OSError:
        pass

    if opt.manualSeed is None:
        opt.manualSeed = random.randint(1, 10000)
    print("Random Seed: ", opt.manualSeed)
    random.seed(opt.manualSeed)
    torch.manual_seed(opt.manualSeed)
    if opt.cuda:
        torch.cuda.manual_seed_all(opt.manualSeed)

    cudnn.benchmark = True

    if torch.cuda.is_available() and not opt.cuda:
        print(
            "WARNING: You have a CUDA device, so you should probably run with --cuda"
        )

    if opt.dataset in ['imagenet', 'folder', 'lfw']:
        # folder dataset
        dataset = dset.ImageFolder(root=opt.dataroot,
                                   transform=transforms.Compose([
                                       transforms.Resize(opt.imageSize),
                                       transforms.CenterCrop(opt.imageSize),
                                       transforms.ToTensor(),
                                       transforms.Normalize((0.5, 0.5, 0.5),
                                                            (0.5, 0.5, 0.5)),
                                   ]))
    elif opt.dataset == 'lsun':
        dataset = dset.LSUN(db_path=opt.dataroot,
                            classes=['bedroom_train'],
                            transform=transforms.Compose([
                                transforms.Resize(opt.imageSize),
                                transforms.CenterCrop(opt.imageSize),
                                transforms.ToTensor(),
                                transforms.Normalize((0.5, 0.5, 0.5),
                                                     (0.5, 0.5, 0.5)),
                            ]))
    elif opt.dataset == 'cifar10':
        dataset = dset.CIFAR10(root=opt.dataroot,
                               download=True,
                               transform=transforms.Compose([
                                   transforms.Resize(opt.imageSize),
                                   transforms.ToTensor(),
                                   transforms.Normalize((0.5, 0.5, 0.5),
                                                        (0.5, 0.5, 0.5)),
                               ]))
    elif opt.dataset == 'streetview':
        transform = transforms.Compose([
            transforms.Resize(opt.imageSize),
            transforms.CenterCrop(opt.imageSize),
            transforms.ToTensor(),
            transforms.Normalize((0.0, 0.0, 0.0), (1.0, 1.0, 1.0))
        ])  #none normalize
        dataset = dset.ImageFolder(root=opt.dataroot, transform=transform)
    assert dataset
    dataloader = torch.utils.data.DataLoader(dataset,
                                             batch_size=opt.batchSize,
                                             shuffle=True,
                                             num_workers=int(opt.workers))

    ngpu = int(opt.ngpu)
    nz = int(opt.nz)
    ngf = int(opt.ngf)
    ndf = int(opt.ndf)
    nc = 3
    nef = int(opt.nef)
    nBottleneck = int(opt.nBottleneck)
    wtl2 = float(opt.wtl2)
    overlapL2Weight = 10

    # custom weights initialization called on netG and netD
    def weights_init(m):
        classname = m.__class__.__name__
        if classname.find('Conv') != -1:
            m.weight.data.normal_(0.0, 0.02)
        elif classname.find('BatchNorm') != -1:
            m.weight.data.normal_(1.0, 0.02)
            m.bias.data.fill_(0)

    resume_epoch = 0

    netG = _netG(opt)
    netG.apply(weights_init)
    if opt.netG != '':
        netG.load_state_dict(
            torch.load(
                opt.netG,
                map_location=lambda storage, location: storage)['state_dict'])
        resume_epoch = torch.load(opt.netG)['epoch']
    print(netG)

    netD = _netlocalD(opt)
    netD.apply(weights_init)
    if opt.netD != '':
        netD.load_state_dict(
            torch.load(
                opt.netD,
                map_location=lambda storage, location: storage)['state_dict'])
        resume_epoch = torch.load(opt.netD)['epoch']
    print(netD)

    criterion = nn.BCELoss()
    criterionMSE = nn.MSELoss()

    input_real = torch.FloatTensor(opt.batchSize, 3, opt.imageSize,
                                   opt.imageSize)
    input_cropped = torch.FloatTensor(opt.batchSize, 3, opt.imageSize,
                                      opt.imageSize)
    label = torch.FloatTensor(opt.batchSize)
    real_label = 1
    fake_label = 0

    print(opt.batchSize)
    print(opt.imageSize)

    real_center = torch.FloatTensor(int(opt.batchSize), 3,
                                    int(opt.imageSize / 2),
                                    int(opt.imageSize / 2))
    #real_center = torch.FloatTensor(64, 3, 64,64)

    if opt.cuda:
        netD.cuda()
        netG.cuda()
        criterion.cuda()
        criterionMSE.cuda()
        input_real, input_cropped, label = input_real.cuda(
        ), input_cropped.cuda(), label.cuda()
        real_center = real_center.cuda()

    input_real = Variable(input_real)
    input_cropped = Variable(input_cropped)
    label = Variable(label)

    real_center = Variable(real_center)

    # setup optimizer
    optimizerD = optim.Adam(netD.parameters(),
                            lr=opt.lr,
                            betas=(opt.beta1, 0.999))
    optimizerG = optim.Adam(netG.parameters(),
                            lr=opt.lr,
                            betas=(opt.beta1, 0.999))

    for epoch in range(resume_epoch, opt.niter):
        #jittering add
        randwf = random.uniform(-1.0, 1.0)
        randhf = random.uniform(-1.0, 1.0)
        if opt.jittering:
            jitterSizeW = int(opt.imageSize / 5 * randwf)
            jitterSizeH = int(opt.imageSize / 5 * randhf)
            print("jittering : W > ", jitterSizeW, " H >", jitterSizeH)
        else:
            jitterSizeW = 0
            jitterSizeH = 0
        for i, data in enumerate(dataloader, 0):
            real_cpu, _ = data
            real_center_cpu = real_cpu[:, :,
                                       int(opt.imageSize / 4 +
                                           jitterSizeW):int(opt.imageSize / 4 +
                                                            opt.imageSize / 2 +
                                                            jitterSizeW),
                                       int(opt.imageSize / 4 +
                                           jitterSizeH):int(opt.imageSize / 4 +
                                                            opt.imageSize / 2 +
                                                            jitterSizeH)]
            batch_size = real_cpu.size(0)
            input_real.data.resize_(real_cpu.size()).copy_(real_cpu)
            input_cropped.data.resize_(real_cpu.size()).copy_(real_cpu)
            real_center.data.resize_(
                real_center_cpu.size()).copy_(real_center_cpu)
            input_cropped.data[:, 0,
                               int(opt.imageSize / 4 + opt.overlapPred +
                                   jitterSizeW):int(opt.imageSize / 4 +
                                                    opt.imageSize / 2 -
                                                    opt.overlapPred +
                                                    jitterSizeW),
                               int(opt.imageSize / 4 + opt.overlapPred +
                                   jitterSizeH):
                               int(opt.imageSize / 4 + opt.imageSize / 2 -
                                   opt.overlapPred +
                                   jitterSizeH)] = 2 * 117.0 / 255.0 - 1.0
            input_cropped.data[:, 1,
                               int(opt.imageSize / 4 + opt.overlapPred +
                                   jitterSizeW):int(opt.imageSize / 4 +
                                                    opt.imageSize / 2 -
                                                    opt.overlapPred +
                                                    jitterSizeW),
                               int(opt.imageSize / 4 + opt.overlapPred +
                                   jitterSizeH):
                               int(opt.imageSize / 4 + opt.imageSize / 2 -
                                   opt.overlapPred +
                                   jitterSizeH)] = 2 * 104.0 / 255.0 - 1.0
            input_cropped.data[:, 2,
                               int(opt.imageSize / 4 + opt.overlapPred +
                                   jitterSizeW):int(opt.imageSize / 4 +
                                                    opt.imageSize / 2 -
                                                    opt.overlapPred +
                                                    jitterSizeW),
                               int(opt.imageSize / 4 + opt.overlapPred +
                                   jitterSizeH):
                               int(opt.imageSize / 4 + opt.imageSize / 2 -
                                   opt.overlapPred +
                                   jitterSizeH)] = 2 * 104.0 / 255.0 - 1.0

            # train with real
            netD.zero_grad()
            label.data.resize_(batch_size).fill_(real_label)

            output = netD(real_center)
            errD_real = criterion(output, label)
            errD_real.backward()
            D_x = output.data.mean()

            # train with fake
            # noise.data.resize_(batch_size, nz, 1, 1)
            # noise.data.normal_(0, 1)
            fake = netG(input_cropped)
            label.data.fill_(fake_label)
            output = netD(fake.detach())
            errD_fake = criterion(output, label)
            errD_fake.backward()
            D_G_z1 = output.data.mean()
            errD = errD_real + errD_fake
            optimizerD.step()

            ############################
            # (2) Update G network: maximize log(D(G(z)))
            ###########################
            netG.zero_grad()
            label.data.fill_(
                real_label)  # fake labels are real for generator cost
            output = netD(fake)
            errG_D = criterion(output, label)
            # errG_D.backward(retain_variables=True)

            # errG_l2 = criterionMSE(fake,real_center)
            wtl2Matrix = real_center.clone()
            wtl2Matrix.data.fill_(wtl2 * overlapL2Weight)
            wtl2Matrix.data[:, :,
                            int(opt.overlapPred):int(opt.imageSize / 2 -
                                                     opt.overlapPred),
                            int(opt.overlapPred):int(opt.imageSize / 2 -
                                                     opt.overlapPred)] = wtl2

            errG_l2 = (fake - real_center).pow(2)
            errG_l2 = errG_l2 * wtl2Matrix
            errG_l2 = errG_l2.mean()

            errG = (1 - wtl2) * errG_D + wtl2 * errG_l2

            errG.backward()

            D_G_z2 = output.data.mean()
            optimizerG.step()

            print(
                '[%d/%d][%d/%d] Loss_D: %.4f Loss_G: %.4f / %.4f l_D(x): %.4f l_D(G(z)): %.4f'
                % (
                    epoch,
                    opt.niter,
                    i,
                    len(dataloader),
                    errD.data[0],
                    errG_D.data[0],
                    errG_l2.data[0],
                    D_x,
                    D_G_z1,
                ))
            if i % 100 == 0:
                vutils.save_image(
                    real_cpu,
                    'result/train/real/real_samples_epoch_%03d.png' % (epoch))
                vutils.save_image(
                    input_cropped.data,
                    'result/train/cropped/cropped_samples_epoch_%03d.png' %
                    (epoch))
                recon_image = input_cropped.clone()
                recon_image.data[:, :,
                                 int(opt.imageSize / 4 +
                                     jitterSizeW):int(opt.imageSize / 4 +
                                                      opt.imageSize / 2 +
                                                      jitterSizeW),
                                 int(opt.imageSize / 4 +
                                     jitterSizeH):int(opt.imageSize / 4 +
                                                      opt.imageSize / 2 +
                                                      jitterSizeH)] = fake.data
                vutils.save_image(
                    recon_image.data,
                    'result/train/recon/recon_center_samples_epoch_%03d.png' %
                    (epoch))

        # do checkpointing
        torch.save({
            'epoch': epoch + 1,
            'state_dict': netG.state_dict()
        }, 'model/netG_streetview.pth')
        torch.save({
            'epoch': epoch + 1,
            'state_dict': netD.state_dict()
        }, 'model/netlocalD.pth')