Ejemplo n.º 1
0
    def build_model(self):
        self.g = dcgan.Generator(self.nz, self.ngf, self.nch)
        self.g.apply(weights_init)

        # if trained weights exists
        if self.config.g != '':
            self.g.load_state_dict(torch.load(self.config.g))

        self.g.to(device)

        self.d = dcgan.Discriminator(self.ndf, self.nch)
        self.d.apply(weights_init)

        # if trained weights exists
        if self.config.d != '':
            self.d.load_state_dict(torch.load(self.config.d))

        self.d.to(device)
Ejemplo n.º 2
0
                                     shuffle=True, num_workers=2)

# Setting hyper parameter according to the paper
Z_dim = 128
adam_alpha = 0.0002
adam_beta1 = 0.0
adam_beta2 = 0.9
n_dis = 5 # Number of updates to discriminator for every update to generator
# step = 100000

# Model
if args.model == 'resnet':
    discriminator = resnet.Discriminator().to(device)
    generator = resnet.Generator(Z_dim).to(device)
else:
    discriminator = dcgan.Discriminator().to(device)
    generator = dcgan.Generator(Z_dim).to(device)

# Optimizer
# optim_disc = optim.Adam(discriminator.parameters(), lr=adam_alpha, betas=(adam_beta1,adam_beta2))
optim_disc = optim.SGD(discriminator.parameters(), lr = 0.01, momentum=0.9)
optim_gen  = optim.Adam(generator.parameters(), lr=adam_alpha, betas=(adam_beta1,adam_beta2))

# Loss function
def discriminator_loss(d_real, d_fake):
    if args.loss == 'hinge':
        real_loss = nn.ReLU()(1.0 - d_real).mean()
        fake_loss = nn.ReLU()(1.0 + d_fake).mean()

    elif args.loss == 'wasserstein':
        real_loss = -d_real.mean()
Ejemplo n.º 3
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--dataroot', required=True, help='path to dataset')
    parser.add_argument('--log_dir', type=str, default='./runs/', help='path to logs')
    parser.add_argument('--workers', type=int, help='number of data loading workers', default=2)
    parser.add_argument('--batchSize', type=int, default=64, help='input batch size')
    parser.add_argument('--imageSize', type=int, default=64, help='the height / width of the input image to network')
    parser.add_argument('--nz', type=int, default=100, help='size of the latent z vector')
    parser.add_argument('--ngf', type=int, default=64)
    parser.add_argument('--ndf', type=int, default=64)
    parser.add_argument('--niter', type=int, default=25, help='number of epochs to train for')
    parser.add_argument('--lr', type=float, default=0.0002, help='learning rate, default=0.0002')
    parser.add_argument('--beta1', type=float, default=0.5, help='beta1 for adam. default=0.5')
    parser.add_argument('--cuda', action='store_true', help='enables cuda')
    parser.add_argument('--ngpu', type=int, default=1, help='number of GPUs to use')
    parser.add_argument('--netG', default='', help="path to netG (to continue training)")
    parser.add_argument('--netD', default='', help="path to netD (to continue training)")
    parser.add_argument('--outf', default='./output', help='folder to output images and model checkpoints')
    parser.add_argument('--manualSeed', type=int, help='manual seed')

    opt = parser.parse_args()
    print(opt)

    try:
        os.makedirs(opt.outf)
    except OSError:
        pass

    writer = SummaryWriter(opt.log_dir + datetime.now().strftime("%Y%m%d-%H%M%S"))

    if opt.manualSeed is None:
        opt.manualSeed = random.randint(1, 10000)
    print("Random Seed: ", opt.manualSeed)
    random.seed(opt.manualSeed)
    torch.manual_seed(opt.manualSeed)

    cudnn.benchmark = True

    if torch.cuda.is_available() and not opt.cuda:
        print("WARNING: You have a CUDA device, so you should probably run with --cuda")

    nc = 3
    dataset = dset.ImageFolder(
        root=opt.dataroot,
        transform=transforms.Compose([
            transforms.Resize(opt.imageSize),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ])
    )

    assert dataset
    dataloader = torch.utils.data.DataLoader(dataset, batch_size=opt.batchSize, shuffle=True, num_workers=opt.workers)

    device = torch.device("cuda:0" if opt.cuda else "cpu")
    ngpu = int(opt.ngpu)
    nz = int(opt.nz)
    ngf = int(opt.ngf)
    ndf = int(opt.ndf)

    netG = dcgan.Generator(ngpu, nz, nc, ngf).to(device)
    netG.apply(weights_init)
    if opt.netG != '':
        netG.load_state_dict(torch.load(opt.netG))
    print(netG)

    netD = dcgan.Discriminator(ngpu, nz, nc, ndf).to(device)
    netD.apply(weights_init)
    if opt.netD != '':
        netD.load_state_dict(torch.load(opt.netD))
    print(netD)

    criterion = nn.BCELoss()

    fixed_noise = torch.randn(opt.batchSize, nz, 1, 1, device=device)
    real_label = 0.9  # label smoothing
    fake_label = 0

    # setup optimizer
    optimizerD = optim.Adam(netD.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999))
    optimizerG = optim.Adam(netG.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999))

    for epoch in range(opt.niter):
        for i, data in enumerate(dataloader, 0):
            ############################
            # (1) Update D network: maximize log(D(x)) + log(1 - D(G(z)))
            ###########################
            # train with real
            netD.zero_grad()
            real_cpu = data[0].to(device)
            batch_size = real_cpu.size(0)
            label = torch.full((batch_size,), real_label, device=device)
            sample_noise = torch.rand_like(real_cpu, device=device)

            output = netD(real_cpu + 0.01 * sample_noise)
            errD_real = criterion(output, label)
            errD_real.backward()
            D_x = output.mean().item()

            # train with fake
            noise = torch.randn(batch_size, nz, 1, 1, device=device)
            fake = netG(noise)
            label.fill_(fake_label)
            output = netD(fake.detach())
            errD_fake = criterion(output, label)
            errD_fake.backward()
            D_G_z1 = output.mean().item()
            errD = errD_real + errD_fake
            optimizerD.step()

            ############################
            # (2) Update G network: maximize log(D(G(z)))
            ###########################
            netG.zero_grad()
            label.fill_(real_label)  # fake labels are real for generator cost
            output = netD(fake)
            errG = criterion(output, label)
            errG.backward()
            D_G_z2 = output.mean().item()
            optimizerG.step()

            # print('[%d/%d][%d/%d] Loss_D: %.4f Loss_G: %.4f D(x): %.4f D(G(z)): %.4f / %.4f'
            #       % (epoch, opt.niter, i, len(dataloader),
            #          errD.item(), errG.item(), D_x, D_G_z1, D_G_z2))
            writer.add_scalar('Loss_D', errD.item(), epoch * len(dataloader) + i)
            writer.add_scalar('Loss_G', errG.item(), epoch * len(dataloader) + i)
            # writer.add_scalar('D(x)', D_x, epoch * len(dataloader) + i)
            # writer.add_scalar('D(G(z))', D_G_z2, epoch * len(dataloader) + i)
            if i % 100 == 99:
                # the first 64 samples from the mini-batch are saved
                # vutils.save_image(real_cpu[0:64, :, :, :],
                #                   '%s/real_samples.png' % opt.outf,
                #                   normalize=True, nrow=8)
                fake = netG(fixed_noise)
                vutils.save_image(fake.detach()[0:64, :, :, :], '%s/fake_samples_epoch_%03d.png' % (opt.outf, epoch), normalize=True, nrow=8, padding=0)

        # do checkpointing
        torch.save(netG.state_dict(), '%s/netG_epoch_%d.pth' % (opt.outf, epoch))
        torch.save(netD.state_dict(), '%s/netD_epoch_%d.pth' % (opt.outf, epoch))

    writer.close()
Ejemplo n.º 4
0
def main():
    quan_type = str(opt.quan_type)
    if opt.Gw_bit is not None:
        gw_bit = int(opt.Gw_bit)
    if opt.Ga_bit is not None:
        ga_bit = int(opt.Ga_bit)
    if not opt.depth: 
        if opt.G_bnn :
            if opt.Gw_bit is not None or opt.Ga_bit is not None:
                opt.outf = 'outputs_G_bnn_' + str(opt.Gw_bit) + '_' + str(opt.Ga_bit)#extra
                print('We quantize weights and activation in %d and %d bits respectively.' % (gw_bit,ga_bit))
            else:
                opt.outf = 'outputs_G_bnn' #only binarize G network
                gw_bit = 1
                print('Binarize both weights and activation in G.')
            if opt.pretrained_D:
                opt.outf = 'outputs_G_bnn_pretrained_D'#binarize G but D pretrained_D and fixed
                print('binarize G with D pretrained and fixed.')
            if opt.D_q:
                opt.outf = opt.outf + '_D_q_' + str(opt.bit)
                print("quantize G with D quantized.")
    elif opt.depth:
        opt.outf = 'outputs_G_fwn_depth'
        if opt.G_bnn:
            if opt.Gw_bit is not None or opt.Ga_bit is not None:
                opt.outf = 'outputs_G_bnn_depth_' + str(opt.Gw_bit) + '_' + str(opt.Ga_bit)
            if opt.pretrained_D:
                opt.outf = 'outputs_G_bnn_pretrained_D_depth'
            if opt.D_q:
                opt.outf = 'outputs_G_bnn_D_q_depth'
        else:
            if opt.D_q:
                opt.outf = 'outputs_G_fwn_D_q_depth'
    opt.outf = str(opt.prefix) + opt.dataset + '_' + opt.outf
    if not os.path.exists(opt.outf):
        os.system('mkdir {0}'.format(opt.outf))
    if opt.pretrained_D:
        model_path = opt.pretrained_D_path
        checkpoint = torch.load(model_path)
    # Set random seem for reproducibility
    #manualSeed = 999
    manualSeed = random.randint(1, 10000) # use if you want new results
    print("Random Seed: ", manualSeed)
    random.seed(manualSeed)
    torch.manual_seed(manualSeed)

    grad_data = []    
    layer_name = 'conv_dw2'
    nc = 3     
    if opt.dataset == 'celeA':
        dataset = dset.ImageFolder(root=opt.dataroot,
                                   transform=transforms.Compose([
                                       transforms.Resize(opt.image_size),
                                       transforms.CenterCrop(opt.image_size),
                                       transforms.ToTensor(),
                                       transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
                                   ]))
    elif opt.dataset == 'cifar10':
        dataset = dset.CIFAR10(root=opt.dataroot,download=True,
                                   transform=transforms.Compose([
                                       transforms.Resize(opt.image_size),
                                       transforms.ToTensor(),
                                       transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
                                   ]))
    elif opt.dataset == 'mnist':
        dataset = dset.MNIST(root=opt.dataroot,download=True,
                                   transform=transforms.Compose([
                                       transforms.Resize(opt.image_size),
                                       transforms.ToTensor(),
                                       transforms.Normalize((0.5,), (0.5,)),
                                   ]))
        nc = 1 
    elif opt.dataset == 'lsun':
        dataset = dset.LSUN(root=opt.dataroot,classes=['bedroom_train'],
                                   transform=transforms.Compose([
                                       transforms.CenterCrop(opt.image_size),
                                       transforms.Resize(opt.image_size),
                                       transforms.ToTensor(),
                                       transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
                                   ]))
    assert dataset
    dataloader = torch.utils.data.DataLoader(dataset, batch_size=opt.batch_size,
                                             shuffle=True, num_workers=opt.workers)
    #id = 'cuda:' + str(opt.gpuid)
    #print("Gpu id: ",id)
    #device = torch.device(id if (torch.cuda.is_available() and opt.ngpu > 0) else "cpu")
    device = torch.device('cuda')
    # write out generator config to generate images together wth training checkpoints (.pth)
    generator_config = {"image_Size": opt.image_size, "nz": opt.nz, "nc": nc, "ngf": opt.ngf, "ngpu": opt.ngpu}
    with open(os.path.join(opt.outf, "generator_config.json"), 'w') as gcfg:
        gcfg.write(json.dumps(generator_config)+"\n")
        
    real_batch = next(iter(dataloader))
    plt.figure(figsize=(8,8))
    plt.axis("off")
    plt.title("Training Images")
    plt.imshow(np.transpose(vutils.make_grid(real_batch[0].to(device)[:64], padding=2, normalize=True).cpu(),(1,2,0)))

    
    if opt.depth:
        netG = dcgan.Generator_depth(opt.ngpu,opt.nz,nc,opt.ngf,'fwn').to(device)
        if opt.G_bnn:
            netG = dcgan.Generator_depth(opt.ngpu,opt.nz,nc,opt.ngf,'bnn').to(device)
    elif not opt.depth:
        netG = dcgan.Generator(opt.ngpu,opt.nz,nc,opt.ngf,'fwn').to(device)#
        if opt.G_bnn:
            netG = dcgan.Generator(opt.ngpu,opt.nz,nc,opt.ngf,'bnn').to(device)#如果 权重1激活32 则改为fwn
    if (device.type == 'cuda') and (opt.ngpu > 1):
        netG = nn.DataParallel(netG, list(range(opt.ngpu)))
    print(netG)
    ##weight init
    if opt.G_bnn:
        for m in netG.modules():
            if isinstance(m,nn.ConvTranspose2d):
                nn.init.normal_(m.weight.data, 0.0, 0.02)
            elif isinstance(m,nn.BatchNorm2d):
                nn.init.normal_(m.weight.data, 1.0, 0.02)
                nn.init.constant_(m.bias.data, 0)
    else:
        netG.apply(dcgan.weights_init) 
    netD = dcgan.Discriminator(opt.ngpu,nc,opt.ndf).to(device)
    if opt.pretrained_D:
        netD.load_state_dict(checkpoint['state_dict'])
    else:
        netD.apply(dcgan.weights_init)
    if (device.type == 'cuda') and (opt.ngpu > 1):
        netD = nn.DataParallel(netD, list(range(opt.ngpu)))  
    
    print(netD)
    if opt.G_bnn:
        bin_op_G = util.Bin_G(netG,'bin_G',quan_type,gw_bit)  #
        if opt.depth:
            bin_op_G = util.Bin_G(netG,'bin_G_depth',quan_type,gw_bit)
    if opt.D_q :
        bit = int(opt.bit)
        print('Quantize D with %d bits',bit)
        bin_op_D = util.Quan_D(netD,bit)
    if opt.validate:
        modelpath = "checkpoint.tar"
        noise = torch.randn(opt.batch_size, opt.nz, 1, 1, device=device)
        with torch.no_grad():
            output = dcgan.validate(netG,modelpath,noise)
            plt.figure(figsize=(8,8))
            plt.axis("off")
            plt.title("Fake Images")
            plt.imshow(np.transpose(vutils.make_grid(output.to(device)[:64], padding=2, normalize=True).cpu(),(1,2,0)))
            plt.show()
        return
    
    criterion = nn.BCELoss()
    fixed_noise = torch.randn(64, opt.nz, 1, 1, device=device)
    real_label = 1
    fake_label = 0
    optimizerD = optim.Adam(netD.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999))
    optimizerG = optim.Adam(netG.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999))
    img_list = []
    G_losses = []
    D_losses = []
    iters = 0

    print("Starting Training Loop...")
    # For each epoch
    for epoch in range(opt.num_epochs):
        for i, data in enumerate(dataloader, 0):
            ############################
            # (1) Update D network: maximize log(D(x)) + log(1 - D(G(z)))
            ###########################
            # train with real
            if not opt.pretrained_D:
                netD.zero_grad()
                if opt.D_q:
                    bin_op_D.quantization()
                real_cpu = data[0].to(device)
                b_size = real_cpu.size(0) 
                label = torch.full((b_size,), real_label, device=device)
                output = netD(real_cpu).view(-1)
                errD_real = criterion(output, label)
                errD_real.backward()
                D_x = output.mean().item()
                if opt.G_bnn :
                    bin_op_G.binarization()
                # train with fake
                noise = torch.randn(b_size, opt.nz, 1, 1, device=device)
                fake = netG(noise)
                label.fill_(fake_label)
                output = netD(fake.detach()).view(-1)
                errD_fake = criterion(output, label)
                errD_fake.backward()
                if opt.D_q :
                    bin_op_D.restore()
                D_G_z1 = output.mean().item()
                errD = errD_real + errD_fake
                optimizerD.step()
            elif opt.pretrained_D:
                if opt.G_bnn :
                    bin_op_G.binarization()
                b_size = data[0].to(device).size(0)
                label = torch.full((b_size,), real_label, device=device)
                noise = torch.randn(b_size, opt.nz, 1, 1, device=device)
                fake = netG(noise)
            ############################
            # (2) Update G network: maximize log(D(G(z)))
            ###########################
            netG.zero_grad()
            label.fill_(real_label)  
            output = netD(fake).view(-1)
            errG = criterion(output, label)
            errG.backward()
            
            if opt.G_bnn:
                bin_op_G.restore()
                #bin_op_G.updateBinaryGradWeight()
            D_G_z2 = output.mean().item()
            optimizerG.step()
            
            if not opt.pretrained_D:
                if i % 50 == 0:
                    print('[%d/%d][%d/%d]\tLoss_D: %.4f\tLoss_G: %.4f\tD(x): %.4f\tD(G(z)): %.4f / %.4f'
                          % (epoch, opt.num_epochs, i, len(dataloader),
                             errD.item(), errG.item(), D_x, D_G_z1, D_G_z2))
            else:
                if i % 50 == 0:
                    print('[%d/%d][%d/%d]\tLoss_G: %.4f'
                          % (epoch, opt.num_epochs, i, len(dataloader),
                             errG.item()))
                 #show mean and variance of weights in netG
                util.showWeightsInfo(netG,layer_name ,grad_data)  
            G_losses.append(errG.item())
            if not opt.pretrained_D:
                D_losses.append(errD.item())
            
            if (iters % 500 == 0) or ((epoch == opt.num_epochs-1) and (i == len(dataloader)-1)):
                with torch.no_grad():
                    fake = netG(fixed_noise).detach().cpu()
                img_list.append(vutils.make_grid(fake, padding=2, normalize=True))      
                vutils.save_image(fake.detach(),
                    '%s/fake_samples_epoch_%03d.png' % (opt.outf, epoch),
                    normalize=True)
            iters += 1
        
        dcgan.save_netG_checkpoint({
                    'epoch':epoch ,
                    'state_dict':netG.state_dict(),
                    },opt.outf,epoch)
        dcgan.save_netD_checkpoint({
                    'epoch':epoch ,
                    'state_dict':netD.state_dict(),
                    },opt.outf,epoch)

    print("Training finished.")
    
    #save grad_data to bin for analysis
    grad_data = np.array(grad_data)
    filename = opt.outf + '/grad_data_' + layer_name + '_' + str(opt.num_epochs) + '.bin'
    grad_data.tofile(filename)
    plt.figure(figsize=(15,5))
    plt.subplot(1,2,1)
    plt.title("Generator and Discriminator Loss During Training")
    plt.plot(G_losses,label="G")
    plt.plot(D_losses,label="D")
    plt.xlabel("iterations")
    plt.ylabel("Loss")
    plt.legend()
    plt.subplot(1,2,2)
    plt.title('Specified layer grad During Training')
    plt.plot(grad_data,label="Grad_data_" + layer_name)
    plt.xlabel('iters')
    plt.ylabel('magnitude of grad_data in' + layer_name) 
    plt.legend()
    if opt.G_bnn :
        if not opt.D_q:
            plt.savefig(opt.outf + '/loss_G_bnn_' + str(opt.num_epochs) + '.jpg')
        if opt.D_q:
            plt.savefig(opt.outf + '/loss_G_bnn_D_q_' + str(bit) + '.jpg' )
    else:
        plt.savefig(opt.outf + '/loss_fwn_' + str(opt.num_epochs) + '.jpg')
    plt.show()

    fig = plt.figure(figsize=(8,8))
    plt.axis("off")
    ims = [[plt.imshow(np.transpose(i,(1,2,0)), animated=True)] for i in img_list]
    ani = animation.ArtistAnimation(fig, ims, interval=1000, repeat_delay=1000, blit=True)

    HTML(ani.to_jshtml())


    real_batch = next(iter(dataloader))

    plt.figure(figsize=(15,15))
    plt.subplot(1,2,1)
    plt.axis("off")
    plt.title("Real Images") 
    plt.imshow(np.transpose(vutils.make_grid(real_batch[0].to(device)[:64], padding=5, normalize=True).cpu(),(1,2,0))) 
    # Plot the fake images from the last epoch
    plt.subplot(1,2,2)
    plt.axis("off")
    plt.title("Fake Images")
    plt.imshow(np.transpose(img_list[-1],(1,2,0)))
    
    if opt.G_bnn :
        plt.savefig(opt.outf + '/Result_G_bnn_' + str(opt.num_epochs) + '.jpg')
        if opt.D_q:
            plt.savefig(opt.outf + '/Result_G_D_q.jpg')
    else:
        plt.savefig(opt.outf + '/Result_fwn_' + str(opt.num_epochs) + '.jpg')
    plt.show()