Exemple #1
0
    def build_model(self):
        if self.config.no_bn:
            self.netG = dcgan.DCGAN_G_nobn(self.image_size, self.nz, self.nc,
                                           self.ngf, self.ngpu,
                                           self.config.n_extra_layers)
        elif self.config.mlp_G:
            self.netG = mlp.MLP_G(self.image_size, self.nz, self.nc, self.ngf,
                                  self.ngpu)
        else:
            self.netG = dcgan.DCGAN_G(self.image_size, self.nz, self.nc,
                                      self.ngf, self.ngpu,
                                      self.config.n_extra_layers)
            self.netG.apply(weights_init)

        if self.config.netG != '':  # load checkpoint if needed
            self.netG.load_state_dict(torch.load(self.config.netG))

        if self.config.mlp_D:
            self.netD = mlp.MLP_D(self.image_size, self.nz, self.nc, self.ndf,
                                  self.ngpu)
        else:
            self.netD = dcgan.DCGAN_D(self.image_size, self.nz, self.nc,
                                      self.ndf, self.ngpu,
                                      self.config.n_extra_layers)
            self.netD.apply(weights_init)

        if self.config.netD != '':
            self.netD.load_state_dict(torch.load(self.config.netD))
Exemple #2
0
    def __init__(self, cfg_path, weights_path):

        with open(cfg_path, 'r') as gencfg:
            generator_config = json.loads(gencfg.read())

        imageSize = generator_config["imageSize"]
        nz = generator_config["nz"]
        nc = generator_config["nc"]
        ngf = generator_config["ngf"]
        noBN = generator_config["noBN"]
        ngpu = generator_config["ngpu"]
        mlp_G = generator_config["mlp_G"]
        n_extra_layers = generator_config["n_extra_layers"]

        if noBN:
            netG = dcgan.DCGAN_G_nobn(imageSize, nz, nc, ngf, ngpu,
                                      n_extra_layers)
        elif mlp_G:
            netG = mlp.MLP_G(imageSize, nz, nc, ngf, ngpu)
        else:
            netG = dcgan.DCGAN_G(imageSize, nz, nc, ngf, ngpu, n_extra_layers)

        # load weights
        netG.load_state_dict(torch.load(weights_path))

        if torch.cuda.is_available():
            netG = netG.cuda()

        self.model = netG
        self.nz = nz
Exemple #3
0
def __getGenerator(opt, ngpu, nz, ngf, ndf, nc, n_extra_layers):
    if opt.noBN:
        if isDebug: print("Using No Batch Norm (DCGAN_G_nobn) for Generator")
        netG = dcgan.DCGAN_G_nobn(opt.imageSize, nz, nc, ngf, ngpu,
                                  n_extra_layers)
    elif opt.mlp_G:
        if isDebug: print("Using MLP_G for Generator")
        netG = mlp.MLP_G(opt.imageSize, nz, nc, ngf, ngpu)
    else:
        if isDebug: print("Using DCGAN_G for Generator")
        netG = dcgan.DCGAN_G(opt.imageSize, nz, nc, ngf, ngpu, n_extra_layers)

    netG.apply(weights_init)
    if opt.netG != '':  # load checkpoint if needed
        netG.load_state_dict(torch.load(opt.netG))
    print("netG:\n {0}".format(netG))

    return netG
Exemple #4
0
nc = 3
n_extra_layers = int(opt.n_extra_layers)


# custom weights initialization called on netG and netD
def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        m.weight.data.normal_(0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        m.weight.data.normal_(1.0, 0.02)
        m.bias.data.fill_(0)


if opt.noBN:
    netG = dcgan.DCGAN_G_nobn(opt.imageSize, nz, nc, ngf, ngpu, n_extra_layers)
elif opt.mlp_G:
    netG = mlp.MLP_G(opt.imageSize, nz, nc, ngf, ngpu)
else:
    netG = dcgan.DCGAN_G(opt.imageSize, nz, nc, ngf, ngpu, n_extra_layers)

netG.apply(weights_init)
if opt.netG != '':  # load checkpoint if needed
    netG.load_state_dict(torch.load(opt.netG))
print(netG)

if opt.mlp_D:
    netD = mlp.MLP_D(opt.imageSize, nz, nc, ndf, ngpu)
else:
    netD = dcgan.DCGAN_D(opt.imageSize, nz, nc, ndf, ngpu, n_extra_layers)
    netD.apply(weights_init)
Exemple #5
0
def main(opt, reporter=None):
    writer = SummaryWriter()

    with open(writer.file_writer.get_logdir() + '/args.json', 'w') as f:
        json.dump(opt, f)

    if opt['experiment'] is None:
        opt['experiment'] = 'samples'
    os.system('mkdir {0}'.format(opt['experiment']))

    opt['manualSeed'] = random.randint(1, 10000)  # fix seed
    print("Random Seed: ", opt['manualSeed'])
    random.seed(opt['manualSeed'])
    torch.manual_seed(opt['manualSeed'])

    cudnn.benchmark = True

    if torch.cuda.is_available() and not opt['cuda']:
        print("WARNING: You have a CUDA device,"
              "so you should probably run with --cuda")

    if opt['dataset'] in ['imagenet', 'folder', 'lfw']:
        # folder dataset
        dataset = dset.ImageFolder(root=opt['dataroot'],
                                   transform=transforms.Compose([
                                       transforms.Scale(opt['imageSize']),
                                       transforms.CenterCrop(opt['imageSize']),
                                       transforms.ToTensor(),
                                       transforms.Normalize((0.5, 0.5, 0.5),
                                                            (0.5, 0.5, 0.5)),
                                   ]))
    elif opt['dataset'] == 'lsun':
        dataset = dset.LSUN(root=opt['dataroot'],
                            classes=['bedroom_train'],
                            transform=transforms.Compose([
                                transforms.Scale(opt['imageSize']),
                                transforms.CenterCrop(opt['imageSize']),
                                transforms.ToTensor(),
                                transforms.Normalize((0.5, 0.5, 0.5),
                                                     (0.5, 0.5, 0.5)),
                            ]))
    elif opt['dataset'] == 'cifar10':
        dataset = dset.CIFAR10(root=opt['dataroot'],
                               download=True,
                               transform=transforms.Compose([
                                   transforms.Scale(opt['imageSize']),
                                   transforms.ToTensor(),
                                   transforms.Normalize((0.5, 0.5, 0.5),
                                                        (0.5, 0.5, 0.5)),
                               ]))

    assert dataset
    dataloader = torch.utils.data.DataLoader(dataset,
                                             batch_size=opt['batchSize'],
                                             shuffle=True,
                                             num_workers=int(opt['workers']))

    ngpu = int(opt['ngpu'])
    nz = int(opt['nz'])
    ngf = int(opt['ngf'])
    ndf = int(opt['ndf'])
    nc = int(opt['nc'])
    n_extra_layers = int(opt['n_extra_layers'])

    # custom weights initialization called on netG and netD
    def weights_init(m):
        classname = m.__class__.__name__
        if classname.find('Conv') != -1:
            m.weight.data.normal_(0.0, 0.02)
        elif classname.find('BatchNorm') != -1:
            m.weight.data.normal_(1.0, 0.02)
            m.bias.data.fill_(0)

    if opt['noBN']:
        netG = dcgan.DCGAN_G_nobn(opt['imageSize'], nz, nc, ngf, ngpu,
                                  n_extra_layers)
    elif opt['type'] == 'mlp':
        netG = mlp.MLP_G(opt['imageSize'], nz, nc, ngf, ngpu)
    elif opt['type'] == 'resnet':
        netG = resnet.Generator(nz)
    else:
        netG = dcgan.DCGAN_G(opt['imageSize'], nz, nc, ngf, ngpu,
                             n_extra_layers)

    netG.apply(weights_init)
    print(netG)

    if opt['type'] == 'mlp':
        netD = mlp.MLP_D(opt['imageSize'], nz, nc, ndf, ngpu)
    elif opt['type'] == 'resnet':
        netD = resnet.Discriminator(nz)
    else:
        netD = dcgan.DCGAN_D(opt['imageSize'], nz, nc, ndf, ngpu,
                             n_extra_layers)
        netD.apply(weights_init)

    print(netD)

    inc_noise = torch.utils.data.TensorDataset(
        torch.randn(50000, nz, 1, 1).cuda())
    inc_noise_dloader = torch.utils.data.DataLoader(
        inc_noise, batch_size=opt['batchSize'])

    input = torch.FloatTensor(opt['batchSize'], 3, opt['imageSize'],
                              opt['imageSize'])
    noise = torch.FloatTensor(opt['batchSize'], nz, 1, 1)
    fixed_noise = torch.FloatTensor(opt['batchSize'], nz, 1, 1).normal_(0, 1)
    one = torch.FloatTensor([1])
    mone = one * -1

    if opt['cuda']:
        netD.cuda()
        netG.cuda()
        input = input.cuda()
        one, mone = one.cuda(), mone.cuda()
        noise, fixed_noise = noise.cuda(), fixed_noise.cuda()

    # setup optimizer
    if opt['adam']:
        optimizerD = optim.Adam(netD.parameters(),
                                lr=opt['lrD'],
                                betas=(opt['beta1'], opt['beta2']))
        optimizerG = optim.Adam(netG.parameters(),
                                lr=opt['lrG'],
                                betas=(opt['beta1'], opt['beta2']))
    else:
        optimizerD = optim.RMSprop(netD.parameters(), lr=opt['lrD'])
        optimizerG = optim.RMSprop(netG.parameters(), lr=opt['lrG'])

    var_weight = 0.5
    w = torch.tensor(
        [var_weight * (1 - var_weight)**i for i in range(9, -1, -1)]).cuda()

    gen_iterations = 0
    for epoch in range(opt['niter']):
        data_iter = iter(dataloader)
        i = 0
        while i < len(dataloader):
            # l_var = opt.l_var + (gen_iterations + 1)/3000
            l_var = opt['l_var']
            ############################
            # (1) Update D network
            ###########################
            for p in netD.parameters():  # reset requires_grad
                p.requires_grad = True

            # train the discriminator Diters times
            # if gen_iterations < 25 or gen_iterations % 500 == 0:
            if gen_iterations % 500 == 0:
                Diters = 100
            else:
                Diters = opt['Diters']

            j = 0
            while j < Diters and i < len(dataloader):
                j += 1

                # enforce constraint
                if not opt['var_constraint']:
                    for p in netD.parameters():
                        p.data.clamp_(opt['clamp_lower'], opt['clamp_upper'])

                data = data_iter.next()
                i += 1

                # train with real
                real_cpu, _ = data

                netD.zero_grad()
                batch_size = real_cpu.size(0)

                if opt['cuda']:
                    real_cpu = real_cpu.cuda()
                input.resize_as_(real_cpu).copy_(real_cpu)
                inputv = Variable(input)

                out_D_real = netD(inputv)
                errD_real = out_D_real.mean(0).view(1)

                if opt['var_constraint']:
                    vm_real = out_D_real.var(0)

                # train with fake
                noise.resize_(opt['batchSize'], nz, 1, 1).normal_(0, 1)
                with torch.no_grad():
                    noisev = Variable(noise)  # totally freeze netG
                    fake = netG(noisev).data
                inputv = fake
                out_D_fake = netD(inputv)
                errD_fake = out_D_fake.mean(0).view(1)

                if opt['var_constraint']:
                    vm_fake = out_D_fake.var(0)

                errD = errD_real - errD_fake

                loss = -((errD_real - errD_fake) - l_var * torch.exp(
                    torch.sqrt(torch.log(vm_real)**2 + torch.log(vm_fake)**2)))
                loss.backward()

                optimizerD.step()

                if opt['var_constraint']:
                    writer.add_scalars('train/variance', {
                        'real': vm_real.item(),
                        'fake': vm_fake.item()
                    },
                                       epoch * len(dataloader) + i)

            ############################
            # (2) Update G network
            ###########################
            for p in netD.parameters():
                p.requires_grad = False  # to avoid computation
            netG.zero_grad()
            # in case our last batch was the tail batch of the dataloader,
            # make sure we feed a full batch of noise
            noise.resize_(opt['batchSize'], nz, 1, 1).normal_(0, 1)
            noisev = Variable(noise)
            fake = netG(noisev)
            errG = -netD(fake).mean(0).view(1)
            errG.backward()
            optimizerG.step()
            gen_iterations += 1

            if torch.isnan(errG):
                raise ValueError("Loss is nan")

            ############################
            # Log Data
            ###########################
            print('[%d/%d][%d/%d][%d] Loss_D: %f Loss_G: %f Loss_D_real: %f'
                  ' Loss_D_fake %f' %
                  (epoch, opt['niter'], i, len(dataloader), gen_iterations,
                   errD.data[0], errG.data[0], errD_real.data[0],
                   errD_fake.data[0]))
            writer.add_scalar('train/critic', -errD.item(), gen_iterations)
            if gen_iterations % (500 * 64 / opt['batchSize']) == 0:
                real_cpu = real_cpu.mul(0.5).add(0.5)
                vutils.save_image(real_cpu,
                                  f'{opt["experiment"]}/real_samples.png')
                with torch.no_grad():
                    fake = netG(Variable(fixed_noise))
                fake.data = fake.data.mul(0.5).add(0.5)
                vutils.save_image(
                    fake.data, f'{opt["experiment"]}/'
                    f'fake_samples_{gen_iterations:010d}.png')
                writer.add_image(
                    'train/sample',
                    fake.data.mul(255).clamp(0, 255).byte().cpu().numpy(),
                    gen_iterations)

            ############################
            # (3) Compute Scores
            ############################
            if gen_iterations % (500 * 64 / opt['batchSize']) == 0:
                with torch.no_grad():
                    netG.eval()
                    samples = []
                    for (x, ) in inc_noise_dloader:
                        samples.append(netG(x))
                    netG.train()
                    samples = torch.cat(samples, dim=0).cpu()
                    samples = (samples - samples.mean()) / samples.std()

                score, _ = inception_score(samples.numpy(),
                                           cuda=True,
                                           resize=True,
                                           splits=10)
                writer.add_scalar('test/inception_50k', score, gen_iterations)
                # fids = fid_score(
                #     samples.permute(0, 2, 3,
                #                     1).mul(128).add(128).clamp(255).numpy(),
                #     'cifar10'
                # )
                # writer.add_scalar('test/fid_50k', fids, gen_iterations)
                if reporter:
                    reporter(inception=score, fid=0)

        # do checkpointing
        torch.save(netG.state_dict(),
                   f'{opt["experiment"]}/netG_epoch_{epoch}.pth')
        torch.save(netD.state_dict(),
                   f'{opt["experiment"]}/netD_epoch_{epoch}.pth')