Exemple #1
0
def main(args):
    ### config
    global noise_multiplier
    dataset = args.dataset
    num_discriminators = args.num_discriminators
    noise_multiplier = args.noise_multiplier
    z_dim = args.z_dim
    model_dim = args.model_dim
    batchsize = args.batchsize
    L_gp = args.L_gp
    L_epsilon = args.L_epsilon
    critic_iters = args.critic_iters
    latent_type = args.latent_type
    load_dir = args.load_dir
    save_dir = args.save_dir
    if_dp = (args.dp > 0.)
    gen_arch = args.gen_arch
    num_gpus = args.num_gpus

    ### CUDA
    use_cuda = torch.cuda.is_available()
    devices = [
        torch.device("cuda:%d" % i if use_cuda else "cpu")
        for i in range(num_gpus)
    ]
    device0 = devices[0]
    if use_cuda:
        torch.set_default_tensor_type('torch.cuda.FloatTensor')

    ### Random seed
    random.seed(args.random_seed)
    np.random.seed(args.random_seed)
    torch.manual_seed(args.random_seed)

    ### Fix noise for visualization
    if latent_type == 'normal':
        fix_noise = torch.randn(10, z_dim)
    elif latent_type == 'bernoulli':
        p = 0.5
        bernoulli = torch.distributions.Bernoulli(torch.tensor([p]))
        fix_noise = bernoulli.sample((10, z_dim)).view(10, z_dim)
    else:
        raise NotImplementedError

    ### Set up models
    print('gen_arch:' + gen_arch)
    netG = GeneratorDCGAN(z_dim=z_dim, model_dim=model_dim, num_classes=10)

    netGS = copy.deepcopy(netG)
    netD_list = []
    for i in range(num_discriminators):
        netD = DiscriminatorDCGAN()
        netD_list.append(netD)

    ### Load pre-trained discriminators
    print("load pre-training...")
    if load_dir is not None:
        for netD_id in range(num_discriminators):
            print('Load NetD ', str(netD_id))
            network_path = os.path.join(load_dir, 'netD_%d' % netD_id,
                                        'netD.pth')
            netD = netD_list[netD_id]
            netD.load_state_dict(torch.load(network_path))

    netG = netG.to(device0)
    for netD_id, netD in enumerate(netD_list):
        device = devices[get_device_id(netD_id, num_discriminators, num_gpus)]
        netD.to(device)

    ### Set up optimizers
    optimizerD_list = []
    for i in range(num_discriminators):
        netD = netD_list[i]
        optimizerD = optim.Adam(netD.parameters(), lr=1e-4, betas=(0.5, 0.99))
        optimizerD_list.append(optimizerD)
    optimizerG = optim.Adam(netG.parameters(), lr=1e-4, betas=(0.5, 0.99))

    ### Data loaders
    if dataset == 'mnist' or dataset == 'fashionmnist':
        transform_train = transforms.Compose([
            transforms.CenterCrop((28, 28)),
            transforms.ToTensor(),
            #transforms.Grayscale(),
        ])
    elif dataset == 'cifar_100' or dataset == 'cifar_10':
        transform_train = transforms.Compose([
            transforms.CenterCrop((28, 28)),
            transforms.Grayscale(),
            transforms.ToTensor(),
        ])

    if dataset == 'mnist':
        dataloader = datasets.MNIST
        trainset = dataloader(root=os.path.join(DATA_ROOT, 'MNIST'),
                              train=True,
                              download=True,
                              transform=transform_train)
        IMG_DIM = 784
        NUM_CLASSES = 10
    elif dataset == 'fashionmnist':
        dataloader = datasets.FashionMNIST
        trainset = dataloader(root=os.path.join(DATA_ROOT, 'FashionMNIST'),
                              train=True,
                              download=True,
                              transform=transform_train)
    elif dataset == 'cifar_100':
        dataloader = datasets.CIFAR100
        trainset = dataloader(root=os.path.join(DATA_ROOT, 'CIFAR100'),
                              train=True,
                              download=True,
                              transform=transform_train)
        IMG_DIM = 3072
        NUM_CLASSES = 100
    elif dataset == 'cifar_10':
        IMG_DIM = 784
        NUM_CLASSES = 10
        dataloader = datasets.CIFAR10
        trainset = dataloader(root=os.path.join(DATA_ROOT, 'CIFAR10'),
                              train=True,
                              download=True,
                              transform=transform_train)
    else:
        raise NotImplementedError

    print('creat indices file')
    indices_full = np.arange(len(trainset))
    np.random.shuffle(indices_full)
    #indices_full.dump(os.path.join(save_dir, 'indices.npy'))
    trainset_size = int(len(trainset) / num_discriminators)
    print('Size of the dataset: ', trainset_size)

    input_pipelines = []
    for i in range(num_discriminators):
        start = i * trainset_size
        end = (i + 1) * trainset_size
        indices = indices_full[start:end]
        trainloader = DataLoader(trainset,
                                 batch_size=args.batchsize,
                                 drop_last=False,
                                 num_workers=args.num_workers,
                                 sampler=SubsetRandomSampler(indices))
        #input_data = inf_train_gen(trainloader)
        input_pipelines.append(trainloader)

    if if_dp:
        ### Register hook
        global dynamic_hook_function
        for netD in netD_list:
            netD.conv1.register_backward_hook(master_hook_adder)

    prg_bar = tqdm(range(args.iterations + 1))
    for iters in prg_bar:
        #########################
        ### Update D network
        #########################
        netD_id = np.random.randint(num_discriminators, size=1)[0]
        device = devices[get_device_id(netD_id, num_discriminators, num_gpus)]
        netD = netD_list[netD_id]
        optimizerD = optimizerD_list[netD_id]
        input_data = input_pipelines[netD_id]

        for p in netD.parameters():
            p.requires_grad = True

        for iter_d in range(critic_iters):
            real_data, real_y = next(iter(input_data))
            real_data = real_data.view(-1, IMG_DIM)
            real_data = real_data.to(device)
            real_y = real_y.to(device)
            real_data_v = autograd.Variable(real_data)

            ### train with real
            dynamic_hook_function = dummy_hook
            netD.zero_grad()
            D_real_score = netD(real_data_v, real_y)
            D_real = -D_real_score.mean()

            ### train with fake
            batchsize = real_data.shape[0]
            if latent_type == 'normal':
                noise = torch.randn(batchsize, z_dim).to(device0)
            elif latent_type == 'bernoulli':
                noise = bernoulli.sample(
                    (batchsize, z_dim)).view(batchsize, z_dim).to(device0)
            else:
                raise NotImplementedError
            noisev = autograd.Variable(noise)
            fake = autograd.Variable(netG(noisev, real_y.to(device0)).data)
            inputv = fake.to(device)
            D_fake = netD(inputv, real_y.to(device))
            D_fake = D_fake.mean()
            '''
            ### train with gradient penalty
            gradient_penalty = netD.calc_gradient_penalty(real_data_v.data, fake.data, real_y, L_gp, device)
            D_cost = D_fake + D_real + gradient_penalty

            ### train with epsilon penalty
            logit_cost = L_epsilon * torch.pow(D_real_score, 2).mean()
            D_cost += logit_cost
            '''
            D_cost = D_fake + D_real
            ### update
            D_cost.backward()
            Wasserstein_D = -D_real - D_fake
            optimizerD.step()

        del real_data, real_y, fake, noise, inputv, D_real, D_fake  #, logit_cost, gradient_penalty
        torch.cuda.empty_cache()

        ############################
        # Update G network
        ###########################
        if if_dp:
            ### Sanitize the gradients passed to the Generator
            dynamic_hook_function = dp_conv_hook
        else:
            ### Only modify the gradient norm, without adding noise
            dynamic_hook_function = modify_gradnorm_conv_hook

        for p in netD.parameters():
            p.requires_grad = False
        netG.zero_grad()

        ### train with sanitized discriminator output
        if latent_type == 'normal':
            noise = torch.randn(batchsize, z_dim).to(device0)
        elif latent_type == 'bernoulli':
            noise = bernoulli.sample(
                (batchsize, z_dim)).view(batchsize, z_dim).to(device0)
        else:
            raise NotImplementedError
        label = torch.randint(0, NUM_CLASSES, [batchsize]).to(device0)
        noisev = autograd.Variable(noise)
        fake = netG(noisev, label)
        #summary(netG, input_data=[noisev,label])
        fake = fake.to(device)
        label = label.to(device)
        G = netD(fake, label)
        G = -G.mean()

        ### update
        G.backward()
        G_cost = G
        optimizerG.step()

        ### update the exponential moving average
        exp_mov_avg(netGS, netG, alpha=0.999, global_step=iters)

        ############################
        ### Results visualization
        ############################
        prg_bar.set_description(
            'iter:{}, G_cost:{:.2f}, D_cost:{:.2f}, Wasserstein:{:.2f}'.format(
                iters,
                G_cost.cpu().data,
                D_cost.cpu().data,
                Wasserstein_D.cpu().data))
        if iters % args.vis_step == 0:
            if dataset == 'mnist':
                generate_image_mnist(iters, netGS, fix_noise, save_dir,
                                     device0)
            elif dataset == 'cifar_100':
                generate_image_cifar100(iters, netGS, fix_noise, save_dir,
                                        device0)
            elif dataset == 'cifar_10':
                generate_image_mnist(iters, netGS, fix_noise, save_dir,
                                     device0)

        if iters % args.save_step == 0:
            ### save model
            torch.save(netGS.state_dict(),
                       os.path.join(save_dir, 'netGS_%d.pth' % iters))

        del label, fake, noisev, noise, G, G_cost, D_cost
        torch.cuda.empty_cache()

        if ((iters + 1) % 500 == 0):
            classify_training(netGS, dataset, iters + 1)

    ### save model
    torch.save(netG, os.path.join(save_dir, 'netG.pth'))
    torch.save(netGS, os.path.join(save_dir, 'netGS.pth'))
Exemple #2
0
def main(args):
    ### config
    global noise_multiplier
    dataset = args.dataset
    num_discriminators = args.num_discriminators
    noise_multiplier = args.noise_multiplier
    z_dim = args.z_dim
    if dataset == 'celeba':
        z_dim = 100
    model_dim = args.model_dim
    batchsize = args.batchsize
    L_gp = args.L_gp
    L_epsilon = args.L_epsilon
    critic_iters = args.critic_iters
    latent_type = args.latent_type
    load_dir = args.load_dir
    save_dir = args.save_dir
    if_dp = (args.noise_multiplier > 0.)
    gen_arch = args.gen_arch
    num_gpus = args.num_gpus

    ### CUDA
    use_cuda = torch.cuda.is_available()
    devices = [
        torch.device("cuda:%d" % i if use_cuda else "cpu")
        for i in range(num_gpus)
    ]
    device0 = devices[0]
    if use_cuda:
        torch.set_default_tensor_type('torch.cuda.FloatTensor')

    ### Random seed
    if args.random_seed == 1:
        args.random_seed = np.random.randint(10000, size=1)[0]
    print('random_seed: {}'.format(args.random_seed))
    os.system('rm ' + os.path.join(save_dir, 'seed*'))
    os.system('touch ' +
              os.path.join(save_dir, 'seed=%s' % str(args.random_seed)))
    random.seed(args.random_seed)
    np.random.seed(args.random_seed)
    torch.manual_seed(args.random_seed)

    ### Set up models
    print('gen_arch:' + gen_arch)
    if dataset == 'celeba':
        ngpu = 1
        netG = Generator_celeba(ngpu).to(device0)
        #netG.load_state_dict(torch.load('../results/celeba/main/d_1_2e-4_g_1_2e-4_SN_full/netG_15000.pth'))

        # Handle multi-gpu if desired
        if (device0.type == 'cuda') and (ngpu > 1):
            netG = nn.DataParallel(netG, list(range(ngpu)))

        # Apply the weights_init function to randomly initialize all weights
        #  to mean=0, stdev=0.02.
        netG.apply(weights_init)

    netGS = copy.deepcopy(netG).to(device0)
    if dataset == 'celeba':
        ngpu = 1
        netD = Discriminator_celeba(ngpu).to(device0)
        #netD.load_state_dict(torch.load('../results/celeba/main/d_1_2e-4_g_1_2e-4_SN_full/netD_15000.pth'))
        # Handle multi-gpu if desired
        if (device0.type == 'cuda') and (ngpu > 1):
            netD = nn.DataParallel(netD, list(range(ngpu)))

        # Apply the weights_init function to randomly initialize all weights
        #  to mean=0, stdev=0.2.
        #netD.apply(weights_init)

    ### Set up optimizers
    optimizerD = optim.Adam(netD.parameters(), lr=2e-4, betas=(0.5, 0.99))
    optimizerG = optim.Adam(netG.parameters(), lr=2e-4, betas=(0.5, 0.99))

    ### Data loaders
    if dataset == 'celeba':
        transform_train = transforms.Compose([
            transforms.Resize(64),
            transforms.CenterCrop(64),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
        ])

    if dataset == 'celeba':
        IMG_DIM = 64 * 64 * 3
        NUM_CLASSES = 2
        trainset = CelebA(
            root=os.path.join('/work/u5366584/exp/datasets/celeba'),
            split='train',
            transform=transform_train,
            download=False)  #, custom_subset=True)
        #trainset = CelebA(root=os.path.join('../data'), split='train',
        #    transform=transform_train, download=False, custom_subset=True)
    else:
        raise NotImplementedError

    ###fix sub-training set (fix to 10000 training samples)
    if args.update_train_dataset:
        if dataset == 'mnist':
            indices_full = np.arange(60000)
        elif dataset == 'cifar_10':
            indices_full = np.arange(50000)
        elif dataset == 'celeba':
            indices_full = np.arange(len(trainset))
        np.random.shuffle(indices_full)
        '''
        #####ref
        indices = np.loadtxt('index_20k.txt', dtype=np.int_)
        remove_idx = [np.argwhere(indices_full==x) for x in indices]
        indices_ref = np.delete(indices_full, remove_idx)
        
        indices_slice = indices_ref[:20000]
        np.savetxt('index_20k_ref.txt', indices_slice, fmt='%i')   ##ref index is disjoint to original index
        '''

        ### growing dataset
        indices = np.loadtxt('index_20k.txt', dtype=np.int_)
        remove_idx = [np.argwhere(indices_full == x) for x in indices]
        indices_rest = np.delete(indices_full, remove_idx)

        indices_rest = indices_rest[:20000]
        indices_slice = np.concatenate((indices, indices_rest), axis=0)
        np.savetxt('index_40k.txt', indices_slice, fmt='%i')
    indices = np.loadtxt('index_100k.txt', dtype=np.int_)
    trainset = torch.utils.data.Subset(trainset, indices)
    print(len(trainset))

    workers = 4
    dataloader = torch.utils.data.DataLoader(trainset,
                                             batch_size=batchsize,
                                             shuffle=True,
                                             num_workers=workers)

    if if_dp:
        ### Register hook
        global dynamic_hook_function
        for netD in netD_list:
            netD.conv1.register_backward_hook(master_hook_adder)

    criterion = nn.BCELoss()
    real_label = 1.
    fake_label = 0.
    nz = 100
    fixed_noise = torch.randn(100, nz, 1, 1, device=device0)
    iters = 0
    num_epochs = 256 * 5 + 1

    print("Starting Training Loop...")
    # For each epoch
    for epoch in range(num_epochs):
        # For each batch in the dataloader
        for i, (data, y) in enumerate(dataloader, 0):

            ############################
            # (1) Update D network: maximize log(D(x)) + log(1 - D(G(z)))
            ###########################
            ## Train with all-real batch
            netD.zero_grad()
            # Format batch
            real_cpu = data.to(device0)
            b_size = real_cpu.size(0)
            label = torch.full((b_size, ),
                               real_label,
                               dtype=torch.float,
                               device=device0)
            # Forward pass real batch through D
            output = netD(real_cpu).view(-1)
            # Calculate loss on all-real batch
            errD_real = criterion(output, label)
            # Calculate gradients for D in backward pass
            errD_real.backward()
            D_x = output.mean().item()

            ## Train with all-fake batch
            # Generate batch of latent vectors
            noise = torch.randn(b_size, nz, 1, 1, device=device0)
            # Generate fake image batch with G
            fake = netG(noise)
            label.fill_(fake_label)
            # Classify all fake batch with D
            output = netD(fake.detach()).view(-1)
            # Calculate D's loss on the all-fake batch
            errD_fake = criterion(output, label)
            # Calculate the gradients for this batch, accumulated (summed) with previous gradients
            errD_fake.backward()
            D_G_z1 = output.mean().item()
            # Compute error of D as sum over the fake and the real batches
            errD = errD_real + errD_fake
            # Update D
            optimizerD.step()

            iters += 1

            for iter_g in range(1):
                ############################
                # Update G network
                ###########################
                if if_dp:
                    ### Sanitize the gradients passed to the Generator
                    dynamic_hook_function = dp_conv_hook
                else:
                    ### Only modify the gradient norm, without adding noise
                    dynamic_hook_function = modify_gradnorm_conv_hook

                ############################
                # (2) Update G network: maximize log(D(G(z)))
                ###########################

                noise = torch.randn(b_size, nz, 1, 1, device=device0)
                fake = netG(noise)
                label = torch.full((b_size, ),
                                   real_label,
                                   dtype=torch.float,
                                   device=device0)

                netG.zero_grad()
                label.fill_(
                    real_label)  # fake labels are real for generator cost
                # Since we just updated D, perform another forward pass of all-fake batch through D
                output = netD(fake).view(-1)
                # Calculate G's loss based on this output
                errG = criterion(output, label)
                # Calculate gradients for G
                errG.backward()
                D_G_z2 = output.mean().item()
                # Update G
                optimizerG.step()

            ### update the exponential moving average
            exp_mov_avg(netGS, netG, alpha=0.999, global_step=iters)

            ############################
            ### Results visualization
            ############################
            if iters % 10 == 0:
                print('iter:{}, G_cost:{:.2f}, D_cost:{:.2f}'.format(
                    iters,
                    errG.item(),
                    errD.item(),
                ))
            if iters % args.vis_step == 0:
                if dataset == 'celeba':
                    generate_image_celeba(str(iters + 0), netGS, fixed_noise,
                                          save_dir, device0)

            if iters % args.save_step == 0:
                ### save model
                torch.save(
                    netGS.state_dict(),
                    os.path.join(save_dir, 'netGS_%s.pth' % str(iters + 0)))
                torch.save(
                    netD.state_dict(),
                    os.path.join(save_dir, 'netD_%s.pth' % str(iters + 0)))

        torch.cuda.empty_cache()