z = Variable(utils.sample(DISTRIBUTION, (len(x_true), N_LATENT))) if CUDA: x_true = x_true.cuda(0) z = z.cuda(0) x_gen = gen(z) p_true, p_gen = dis(x_true), dis(x_gen) if UPDATE_FREQUENCY == 1 or (n_iteration_t + 1) % UPDATE_FREQUENCY != 0: for p in gen.parameters(): p.requires_grad = False dis_optimizer.zero_grad() dis_loss = -utils.compute_gan_loss(p_true, p_gen, mode=MODE) if GRADIENT_PENALTY: penalty = dis.get_penalty(x_true.data, x_gen.data) loss = dis_loss + GRADIENT_PENALTY * penalty if UPDATE_FREQUENCY == 1: loss.backward(retain_graph=True) else: loss.backward() dis_optimizer.step() if MODE == 'wgan' and not GRADIENT_PENALTY: for p in dis.parameters(): p.data.clamp_(-CLIP, CLIP)
if CUDA: penalty = penalty.cuda(0) for i, data in enumerate(trainloader): _t = time.time() x_true, _ = data x_true = Variable(x_true) z = Variable(utils.sample(DISTRIBUTION, (len(x_true), N_LATENT))) if CUDA: x_true = x_true.cuda(0) z = z.cuda(0) x_gen = gen(z) p_true, p_gen = dis(x_true), dis(x_gen) gen_loss = utils.compute_gan_loss(p_true, p_gen, mode=MODE) dis_loss = -gen_loss.clone() if GRADIENT_PENALTY: penalty = dis.get_penalty(x_true.data, x_gen.data) dis_loss += GRADIENT_PENALTY * penalty for p in gen.parameters(): p.requires_grad = False dis_optimizer.zero_grad() # https://github.com/pytorch/examples/issues/116 dis_loss.backward(retain_graph=True) if ALGORITHM == 'ExtraAdam' or ALGORITHM == 'ExtraSGD': if (n_iteration_t + 1) % 2 != 0: dis_optimizer.extrapolation()
def train_agda(dataset, manual_seed, options): random.seed(manual_seed) torch.manual_seed(manual_seed) model = options['model'] loss = options['loss'] data = options['data'] lr = options['learning_rate'] nz = options['nz'] batch_size = options['batch_size'] num_epochs = options['num_epochs'] device = options['device'] # Define gan networks if model == 'vgan': from vgan import VanillaDiscriminator, VanillaGenerator if data == 'mnist': generator = VanillaGenerator(nz).to(device) discriminator = VanillaDiscriminator().to(device) elif data == 'cifar10': generator = VanillaGenerator(nz, n_c=3).to(device) discriminator = VanillaDiscriminator(n_c=3).to(device) elif model == 'dcgan': from dcgan import DCGANDiscriminator, DCGANGenerator if data == 'mnist': generator = DCGANGenerator(nz, n_out=1).to(device) discriminator = DCGANDiscriminator(n_in=1).to(device) elif data == 'cifar10': generator = DCGANGenerator(nz).to(device) discriminator = DCGANDiscriminator().to(device) generator.apply(weights_init) discriminator.apply(weights_init) # init_gen_param = 0.0 # init_dis_param = 0.0 # for param in generator.parameters(): # init_gen_param += torch.norm(param.data.clone()) # for param in discriminator.parameters(): # init_dis_param += torch.norm(param.data.clone()) # # print('generator initial norm: %f' % init_gen_param) # print('discriminator initial norm: %f' % init_dis_param) # print('##### GENERATOR #####') # print(generator) # print('######################') # print('\n##### DISCRIMINATOR #####') # print(discriminator) # print('######################') # optimizers optim_g = optim.SGD(generator.parameters(), lr=lr) optim_d = optim.SGD(discriminator.parameters(), lr=lr) # only takes in D's parameter # Initialize parameter saving gen_param = [] dis_param = [] print('Training......') for epoch in range(num_epochs): # Random selection dataloader sampler = RandomSampler(dataset, replacement=True, num_samples=len(dataset)) train_loader = DataLoader(dataset, batch_sampler=BatchSampler( sampler, batch_size=batch_size, drop_last=False)) # Initialize parameter saving for this epoch epoch_gen_param = [] epoch_dis_param = [] losses_g = 0.0 losses_d = 0.0 # batch training # for i, (images, _, noises) in tqdm(enumerate(train_loader), total=int(len(dataset)/batch_size)): # we don't need the label for imgs for i, (images, _) in tqdm(enumerate(train_loader, 0), total=int(len(dataset) / batch_size)): ############################ # (1) Update D network: maximize log(D(x)) + log(1 - D(G(z))) ########################### # train with real discriminator.zero_grad() images = images.to(device) b_size = images.size()[0] label = torch.full((b_size, ), 1, dtype=images.dtype, device=device) output = discriminator(images) # loss_real = - torch.mean(1 * torch.log(output + 1e-8)) loss_real = compute_gan_loss(output, label, loss=loss) # loss_real = nn.BCELoss()(output, label) loss_real.backward() D_x = output.mean().item() # train with fake noises = torch.randn(b_size, nz, device=device) images_fake = generator(noises) label.fill_(0) output = discriminator(images_fake.detach( )) # Detach fake from the graph to save computation # loss_fake = - torch.mean(1 * torch.log(1 - output + 1e-8)) loss_fake = compute_gan_loss(output, label, loss=loss) # loss_fake = nn.BCELoss()(output, label) loss_fake.backward() D_G_z1 = output.mean().item() loss_d = loss_real + loss_fake optim_d.step() ############################ # (2) Update G network: maximize log(D(G(z))) ########################### generator.zero_grad() label.fill_(1) output = discriminator(images_fake) # loss_g = - torch.mean(1 * torch.log(output + 1e-8)) loss_g = compute_gan_loss(output, label, loss=loss) # loss_g = nn.BCELoss()(output, label) loss_g.backward() D_G_z2 = output.mean().item() optim_g.step() # print('[%d/%d][%d/%d] Loss_D: %.4f Loss_G: %.4f D(x): %.4f D(G(z)): %.4f / %.4f' # % (epoch, num_epochs, i, len(train_loader), loss_d.item(), loss_g.item(), D_x, D_G_z1, D_G_z2)) losses_d += loss_d.item() losses_g += loss_g.item() print( '[%d/%d] Loss_D: %.4f Loss_G: %.4f D(x): %.4f D(G(z)): %.4f / %.4f' % (epoch, num_epochs, losses_d / i, losses_g / i, D_x, D_G_z1, D_G_z2)) # Save parameters for param in generator.parameters(): epoch_gen_param.append( param.data.clone() ) # When you use .data, you get a new Tensor with requires_grad=False, so cloning it won’t involve autograd for param in discriminator.parameters(): epoch_dis_param.append(param.data.clone()) # epoch_loss_g = loss_g / i # epoch_loss_d = loss_d / i # losses_g.append(epoch_loss_g) # losses_d.append(epoch_loss_d) gen_param.append(epoch_gen_param) dis_param.append(epoch_dis_param) return gen_param, dis_param