def main(args): ### config global noise_multiplier dataset = args.dataset num_discriminators = args.num_discriminators noise_multiplier = args.noise_multiplier z_dim = args.z_dim model_dim = args.model_dim batchsize = args.batchsize L_gp = args.L_gp L_epsilon = args.L_epsilon critic_iters = args.critic_iters latent_type = args.latent_type load_dir = args.load_dir save_dir = args.save_dir if_dp = (args.dp > 0.) gen_arch = args.gen_arch num_gpus = args.num_gpus ### CUDA use_cuda = torch.cuda.is_available() devices = [ torch.device("cuda:%d" % i if use_cuda else "cpu") for i in range(num_gpus) ] device0 = devices[0] if use_cuda: torch.set_default_tensor_type('torch.cuda.FloatTensor') ### Random seed random.seed(args.random_seed) np.random.seed(args.random_seed) torch.manual_seed(args.random_seed) ### Fix noise for visualization if latent_type == 'normal': fix_noise = torch.randn(10, z_dim) elif latent_type == 'bernoulli': p = 0.5 bernoulli = torch.distributions.Bernoulli(torch.tensor([p])) fix_noise = bernoulli.sample((10, z_dim)).view(10, z_dim) else: raise NotImplementedError ### Set up models print('gen_arch:' + gen_arch) netG = GeneratorDCGAN(z_dim=z_dim, model_dim=model_dim, num_classes=10) netGS = copy.deepcopy(netG) netD_list = [] for i in range(num_discriminators): netD = DiscriminatorDCGAN() netD_list.append(netD) ### Load pre-trained discriminators print("load pre-training...") if load_dir is not None: for netD_id in range(num_discriminators): print('Load NetD ', str(netD_id)) network_path = os.path.join(load_dir, 'netD_%d' % netD_id, 'netD.pth') netD = netD_list[netD_id] netD.load_state_dict(torch.load(network_path)) netG = netG.to(device0) for netD_id, netD in enumerate(netD_list): device = devices[get_device_id(netD_id, num_discriminators, num_gpus)] netD.to(device) ### Set up optimizers optimizerD_list = [] for i in range(num_discriminators): netD = netD_list[i] optimizerD = optim.Adam(netD.parameters(), lr=1e-4, betas=(0.5, 0.99)) optimizerD_list.append(optimizerD) optimizerG = optim.Adam(netG.parameters(), lr=1e-4, betas=(0.5, 0.99)) ### Data loaders if dataset == 'mnist' or dataset == 'fashionmnist': transform_train = transforms.Compose([ transforms.CenterCrop((28, 28)), transforms.ToTensor(), #transforms.Grayscale(), ]) elif dataset == 'cifar_100' or dataset == 'cifar_10': transform_train = transforms.Compose([ transforms.CenterCrop((28, 28)), transforms.Grayscale(), transforms.ToTensor(), ]) if dataset == 'mnist': dataloader = datasets.MNIST trainset = dataloader(root=os.path.join(DATA_ROOT, 'MNIST'), train=True, download=True, transform=transform_train) IMG_DIM = 784 NUM_CLASSES = 10 elif dataset == 'fashionmnist': dataloader = datasets.FashionMNIST trainset = dataloader(root=os.path.join(DATA_ROOT, 'FashionMNIST'), train=True, download=True, transform=transform_train) elif dataset == 'cifar_100': dataloader = datasets.CIFAR100 trainset = dataloader(root=os.path.join(DATA_ROOT, 'CIFAR100'), train=True, download=True, transform=transform_train) IMG_DIM = 3072 NUM_CLASSES = 100 elif dataset == 'cifar_10': IMG_DIM = 784 NUM_CLASSES = 10 dataloader = datasets.CIFAR10 trainset = dataloader(root=os.path.join(DATA_ROOT, 'CIFAR10'), train=True, download=True, transform=transform_train) else: raise NotImplementedError print('creat indices file') indices_full = np.arange(len(trainset)) np.random.shuffle(indices_full) #indices_full.dump(os.path.join(save_dir, 'indices.npy')) trainset_size = int(len(trainset) / num_discriminators) print('Size of the dataset: ', trainset_size) input_pipelines = [] for i in range(num_discriminators): start = i * trainset_size end = (i + 1) * trainset_size indices = indices_full[start:end] trainloader = DataLoader(trainset, batch_size=args.batchsize, drop_last=False, num_workers=args.num_workers, sampler=SubsetRandomSampler(indices)) #input_data = inf_train_gen(trainloader) input_pipelines.append(trainloader) if if_dp: ### Register hook global dynamic_hook_function for netD in netD_list: netD.conv1.register_backward_hook(master_hook_adder) prg_bar = tqdm(range(args.iterations + 1)) for iters in prg_bar: ######################### ### Update D network ######################### netD_id = np.random.randint(num_discriminators, size=1)[0] device = devices[get_device_id(netD_id, num_discriminators, num_gpus)] netD = netD_list[netD_id] optimizerD = optimizerD_list[netD_id] input_data = input_pipelines[netD_id] for p in netD.parameters(): p.requires_grad = True for iter_d in range(critic_iters): real_data, real_y = next(iter(input_data)) real_data = real_data.view(-1, IMG_DIM) real_data = real_data.to(device) real_y = real_y.to(device) real_data_v = autograd.Variable(real_data) ### train with real dynamic_hook_function = dummy_hook netD.zero_grad() D_real_score = netD(real_data_v, real_y) D_real = -D_real_score.mean() ### train with fake batchsize = real_data.shape[0] if latent_type == 'normal': noise = torch.randn(batchsize, z_dim).to(device0) elif latent_type == 'bernoulli': noise = bernoulli.sample( (batchsize, z_dim)).view(batchsize, z_dim).to(device0) else: raise NotImplementedError noisev = autograd.Variable(noise) fake = autograd.Variable(netG(noisev, real_y.to(device0)).data) inputv = fake.to(device) D_fake = netD(inputv, real_y.to(device)) D_fake = D_fake.mean() ''' ### train with gradient penalty gradient_penalty = netD.calc_gradient_penalty(real_data_v.data, fake.data, real_y, L_gp, device) D_cost = D_fake + D_real + gradient_penalty ### train with epsilon penalty logit_cost = L_epsilon * torch.pow(D_real_score, 2).mean() D_cost += logit_cost ''' D_cost = D_fake + D_real ### update D_cost.backward() Wasserstein_D = -D_real - D_fake optimizerD.step() del real_data, real_y, fake, noise, inputv, D_real, D_fake #, logit_cost, gradient_penalty torch.cuda.empty_cache() ############################ # Update G network ########################### if if_dp: ### Sanitize the gradients passed to the Generator dynamic_hook_function = dp_conv_hook else: ### Only modify the gradient norm, without adding noise dynamic_hook_function = modify_gradnorm_conv_hook for p in netD.parameters(): p.requires_grad = False netG.zero_grad() ### train with sanitized discriminator output if latent_type == 'normal': noise = torch.randn(batchsize, z_dim).to(device0) elif latent_type == 'bernoulli': noise = bernoulli.sample( (batchsize, z_dim)).view(batchsize, z_dim).to(device0) else: raise NotImplementedError label = torch.randint(0, NUM_CLASSES, [batchsize]).to(device0) noisev = autograd.Variable(noise) fake = netG(noisev, label) #summary(netG, input_data=[noisev,label]) fake = fake.to(device) label = label.to(device) G = netD(fake, label) G = -G.mean() ### update G.backward() G_cost = G optimizerG.step() ### update the exponential moving average exp_mov_avg(netGS, netG, alpha=0.999, global_step=iters) ############################ ### Results visualization ############################ prg_bar.set_description( 'iter:{}, G_cost:{:.2f}, D_cost:{:.2f}, Wasserstein:{:.2f}'.format( iters, G_cost.cpu().data, D_cost.cpu().data, Wasserstein_D.cpu().data)) if iters % args.vis_step == 0: if dataset == 'mnist': generate_image_mnist(iters, netGS, fix_noise, save_dir, device0) elif dataset == 'cifar_100': generate_image_cifar100(iters, netGS, fix_noise, save_dir, device0) elif dataset == 'cifar_10': generate_image_mnist(iters, netGS, fix_noise, save_dir, device0) if iters % args.save_step == 0: ### save model torch.save(netGS.state_dict(), os.path.join(save_dir, 'netGS_%d.pth' % iters)) del label, fake, noisev, noise, G, G_cost, D_cost torch.cuda.empty_cache() if ((iters + 1) % 500 == 0): classify_training(netGS, dataset, iters + 1) ### save model torch.save(netG, os.path.join(save_dir, 'netG.pth')) torch.save(netGS, os.path.join(save_dir, 'netGS.pth'))
def main(args): ### config global noise_multiplier dataset = args.dataset num_discriminators = args.num_discriminators noise_multiplier = args.noise_multiplier z_dim = args.z_dim if dataset == 'celeba': z_dim = 100 model_dim = args.model_dim batchsize = args.batchsize L_gp = args.L_gp L_epsilon = args.L_epsilon critic_iters = args.critic_iters latent_type = args.latent_type load_dir = args.load_dir save_dir = args.save_dir if_dp = (args.noise_multiplier > 0.) gen_arch = args.gen_arch num_gpus = args.num_gpus ### CUDA use_cuda = torch.cuda.is_available() devices = [ torch.device("cuda:%d" % i if use_cuda else "cpu") for i in range(num_gpus) ] device0 = devices[0] if use_cuda: torch.set_default_tensor_type('torch.cuda.FloatTensor') ### Random seed if args.random_seed == 1: args.random_seed = np.random.randint(10000, size=1)[0] print('random_seed: {}'.format(args.random_seed)) os.system('rm ' + os.path.join(save_dir, 'seed*')) os.system('touch ' + os.path.join(save_dir, 'seed=%s' % str(args.random_seed))) random.seed(args.random_seed) np.random.seed(args.random_seed) torch.manual_seed(args.random_seed) ### Set up models print('gen_arch:' + gen_arch) if dataset == 'celeba': ngpu = 1 netG = Generator_celeba(ngpu).to(device0) #netG.load_state_dict(torch.load('../results/celeba/main/d_1_2e-4_g_1_2e-4_SN_full/netG_15000.pth')) # Handle multi-gpu if desired if (device0.type == 'cuda') and (ngpu > 1): netG = nn.DataParallel(netG, list(range(ngpu))) # Apply the weights_init function to randomly initialize all weights # to mean=0, stdev=0.02. netG.apply(weights_init) netGS = copy.deepcopy(netG).to(device0) if dataset == 'celeba': ngpu = 1 netD = Discriminator_celeba(ngpu).to(device0) #netD.load_state_dict(torch.load('../results/celeba/main/d_1_2e-4_g_1_2e-4_SN_full/netD_15000.pth')) # Handle multi-gpu if desired if (device0.type == 'cuda') and (ngpu > 1): netD = nn.DataParallel(netD, list(range(ngpu))) # Apply the weights_init function to randomly initialize all weights # to mean=0, stdev=0.2. #netD.apply(weights_init) ### Set up optimizers optimizerD = optim.Adam(netD.parameters(), lr=2e-4, betas=(0.5, 0.99)) optimizerG = optim.Adam(netG.parameters(), lr=2e-4, betas=(0.5, 0.99)) ### Data loaders if dataset == 'celeba': transform_train = transforms.Compose([ transforms.Resize(64), transforms.CenterCrop(64), transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), ]) if dataset == 'celeba': IMG_DIM = 64 * 64 * 3 NUM_CLASSES = 2 trainset = CelebA( root=os.path.join('/work/u5366584/exp/datasets/celeba'), split='train', transform=transform_train, download=False) #, custom_subset=True) #trainset = CelebA(root=os.path.join('../data'), split='train', # transform=transform_train, download=False, custom_subset=True) else: raise NotImplementedError ###fix sub-training set (fix to 10000 training samples) if args.update_train_dataset: if dataset == 'mnist': indices_full = np.arange(60000) elif dataset == 'cifar_10': indices_full = np.arange(50000) elif dataset == 'celeba': indices_full = np.arange(len(trainset)) np.random.shuffle(indices_full) ''' #####ref indices = np.loadtxt('index_20k.txt', dtype=np.int_) remove_idx = [np.argwhere(indices_full==x) for x in indices] indices_ref = np.delete(indices_full, remove_idx) indices_slice = indices_ref[:20000] np.savetxt('index_20k_ref.txt', indices_slice, fmt='%i') ##ref index is disjoint to original index ''' ### growing dataset indices = np.loadtxt('index_20k.txt', dtype=np.int_) remove_idx = [np.argwhere(indices_full == x) for x in indices] indices_rest = np.delete(indices_full, remove_idx) indices_rest = indices_rest[:20000] indices_slice = np.concatenate((indices, indices_rest), axis=0) np.savetxt('index_40k.txt', indices_slice, fmt='%i') indices = np.loadtxt('index_100k.txt', dtype=np.int_) trainset = torch.utils.data.Subset(trainset, indices) print(len(trainset)) workers = 4 dataloader = torch.utils.data.DataLoader(trainset, batch_size=batchsize, shuffle=True, num_workers=workers) if if_dp: ### Register hook global dynamic_hook_function for netD in netD_list: netD.conv1.register_backward_hook(master_hook_adder) criterion = nn.BCELoss() real_label = 1. fake_label = 0. nz = 100 fixed_noise = torch.randn(100, nz, 1, 1, device=device0) iters = 0 num_epochs = 256 * 5 + 1 print("Starting Training Loop...") # For each epoch for epoch in range(num_epochs): # For each batch in the dataloader for i, (data, y) in enumerate(dataloader, 0): ############################ # (1) Update D network: maximize log(D(x)) + log(1 - D(G(z))) ########################### ## Train with all-real batch netD.zero_grad() # Format batch real_cpu = data.to(device0) b_size = real_cpu.size(0) label = torch.full((b_size, ), real_label, dtype=torch.float, device=device0) # Forward pass real batch through D output = netD(real_cpu).view(-1) # Calculate loss on all-real batch errD_real = criterion(output, label) # Calculate gradients for D in backward pass errD_real.backward() D_x = output.mean().item() ## Train with all-fake batch # Generate batch of latent vectors noise = torch.randn(b_size, nz, 1, 1, device=device0) # Generate fake image batch with G fake = netG(noise) label.fill_(fake_label) # Classify all fake batch with D output = netD(fake.detach()).view(-1) # Calculate D's loss on the all-fake batch errD_fake = criterion(output, label) # Calculate the gradients for this batch, accumulated (summed) with previous gradients errD_fake.backward() D_G_z1 = output.mean().item() # Compute error of D as sum over the fake and the real batches errD = errD_real + errD_fake # Update D optimizerD.step() iters += 1 for iter_g in range(1): ############################ # Update G network ########################### if if_dp: ### Sanitize the gradients passed to the Generator dynamic_hook_function = dp_conv_hook else: ### Only modify the gradient norm, without adding noise dynamic_hook_function = modify_gradnorm_conv_hook ############################ # (2) Update G network: maximize log(D(G(z))) ########################### noise = torch.randn(b_size, nz, 1, 1, device=device0) fake = netG(noise) label = torch.full((b_size, ), real_label, dtype=torch.float, device=device0) netG.zero_grad() label.fill_( real_label) # fake labels are real for generator cost # Since we just updated D, perform another forward pass of all-fake batch through D output = netD(fake).view(-1) # Calculate G's loss based on this output errG = criterion(output, label) # Calculate gradients for G errG.backward() D_G_z2 = output.mean().item() # Update G optimizerG.step() ### update the exponential moving average exp_mov_avg(netGS, netG, alpha=0.999, global_step=iters) ############################ ### Results visualization ############################ if iters % 10 == 0: print('iter:{}, G_cost:{:.2f}, D_cost:{:.2f}'.format( iters, errG.item(), errD.item(), )) if iters % args.vis_step == 0: if dataset == 'celeba': generate_image_celeba(str(iters + 0), netGS, fixed_noise, save_dir, device0) if iters % args.save_step == 0: ### save model torch.save( netGS.state_dict(), os.path.join(save_dir, 'netGS_%s.pth' % str(iters + 0))) torch.save( netD.state_dict(), os.path.join(save_dir, 'netD_%s.pth' % str(iters + 0))) torch.cuda.empty_cache()