def disc_gp(self, x): with torch.no_grad(): # Spatial features cnn_feats = self.conv(x) bsz, c, h, w = cnn_feats.size() # Discriminator disc_inp = cnn_feats.view(c, -1).t() # Gradient Penalty gp = utils.calc_gradient_penalty(self.disc, disc_inp) return cnn_feats, gp
def train(train_data, g_net, d_net, criterion, g_optimizer, d_optimizer, epoch, logger): g_net.train() d_net.train() g_losses = AverageMeter() d_losses = AverageMeter() for i, (img, _) in enumerate(train_data): mini_batch = img.size()[0] # train discriminator x_ = Variable(img.cuda()) z_ = torch.randn(mini_batch, config.G_input_dim).view(-1, config.G_input_dim, 1, 1) z_ = Variable(z_.cuda()) gen_img = g_net(z_).detach() gradient_penalty = calc_gradient_penalty(d_net, x_.data, gen_img.data) # print(gradient_penalty, -d_net(x_).mean(), d_net(gen_img).mean()) d_loss = -d_net(x_).mean() + d_net(gen_img).mean() + config.lambda_gp * gradient_penalty # bp d_optimizer.zero_grad() d_loss.backward() d_optimizer.step() # for p in d_net.parameters(): # p.data.clamp_(-config.clip, config.clip) if i % config.n_critic == 0: z_ = torch.randn(mini_batch, config.G_input_dim).view(-1, config.G_input_dim, 1, 1) z_ = Variable(z_.cuda()) gen_img = g_net(z_) g_loss = -d_net(gen_img).mean() # bp g_optimizer.zero_grad() g_loss.backward() g_optimizer.step() g_losses.update(g_loss.item()) d_losses.update(d_loss.item()) if i % config.print_freq == 0: logger.info('Epoch: [{0}][{1}][{2}]\t' 'g_loss {g_loss.val:.4f} ({g_loss.avg:.4f})\t' 'd_loss {d_loss.val:.3f} ({d_loss.avg:.3f})' .format(epoch, i, len(train_data), g_loss=g_losses, d_loss=d_losses)) return g_losses.avg, d_losses.avg
def train(vocabs, char_vocab, tag_vocab, train_sets, dev_sets, test_sets, unlabeled_sets): """ train_sets, dev_sets, test_sets: dict[lang] -> AmazonDataset For unlabeled langs, no train_sets are available """ # dataset loaders train_loaders, unlabeled_loaders = {}, {} train_iters, unlabeled_iters, d_unlabeled_iters = {}, {}, {} dev_loaders, test_loaders = {}, {} my_collate = utils.sorted_collate if opt.model=='lstm' else utils.unsorted_collate for lang in opt.langs: train_loaders[lang] = DataLoader(train_sets[lang], opt.batch_size, shuffle=True, collate_fn = my_collate) train_iters[lang] = iter(train_loaders[lang]) for lang in opt.dev_langs: dev_loaders[lang] = DataLoader(dev_sets[lang], opt.batch_size, shuffle=False, collate_fn = my_collate) test_loaders[lang] = DataLoader(test_sets[lang], opt.batch_size, shuffle=False, collate_fn = my_collate) for lang in opt.all_langs: if lang in opt.unlabeled_langs: uset = unlabeled_sets[lang] else: # for labeled langs, consider which data to use as unlabeled set if opt.unlabeled_data == 'both': uset = ConcatDataset([train_sets[lang], unlabeled_sets[lang]]) elif opt.unlabeled_data == 'unlabeled': uset = unlabeled_sets[lang] elif opt.unlabeled_data == 'train': uset = train_sets[lang] else: raise Exception(f'Unknown options for the unlabeled data usage: {opt.unlabeled_data}') unlabeled_loaders[lang] = DataLoader(uset, opt.batch_size, shuffle=True, collate_fn = my_collate) unlabeled_iters[lang] = iter(unlabeled_loaders[lang]) d_unlabeled_iters[lang] = iter(unlabeled_loaders[lang]) # embeddings emb = MultiLangWordEmb(vocabs, char_vocab, opt.use_wordemb, opt.use_charemb).to(opt.device) # models F_s = None F_p = None C, D = None, None num_experts = len(opt.langs)+1 if opt.expert_sp else len(opt.langs) if opt.model.lower() == 'lstm': if opt.shared_hidden_size > 0: F_s = LSTMFeatureExtractor(opt.total_emb_size, opt.F_layers, opt.shared_hidden_size, opt.word_dropout, opt.dropout, opt.bdrnn) if opt.private_hidden_size > 0: if not opt.concat_sp: assert opt.shared_hidden_size == opt.private_hidden_size, "shared dim != private dim when using add_sp!" F_p = nn.Sequential( LSTMFeatureExtractor(opt.total_emb_size, opt.F_layers, opt.private_hidden_size, opt.word_dropout, opt.dropout, opt.bdrnn), MixtureOfExperts(opt.MoE_layers, opt.private_hidden_size, len(opt.langs), opt.private_hidden_size, opt.private_hidden_size, opt.dropout, opt.MoE_bn, False) ) else: raise Exception(f'Unknown model architecture {opt.model}') if opt.C_MoE: C = SpMixtureOfExperts(opt.C_layers, opt.shared_hidden_size, opt.private_hidden_size, opt.concat_sp, num_experts, opt.shared_hidden_size + opt.private_hidden_size, len(tag_vocab), opt.mlp_dropout, opt.C_bn) else: C = SpMlpTagger(opt.C_layers, opt.shared_hidden_size, opt.private_hidden_size, opt.concat_sp, opt.shared_hidden_size + opt.private_hidden_size, len(tag_vocab), opt.mlp_dropout, opt.C_bn) if opt.shared_hidden_size > 0 and opt.n_critic > 0: if opt.D_model.lower() == 'lstm': d_args = { 'num_layers': opt.D_lstm_layers, 'input_size': opt.shared_hidden_size, 'hidden_size': opt.shared_hidden_size, 'word_dropout': opt.D_word_dropout, 'dropout': opt.D_dropout, 'bdrnn': opt.D_bdrnn, 'attn_type': opt.D_attn } elif opt.D_model.lower() == 'cnn': d_args = { 'num_layers': 1, 'input_size': opt.shared_hidden_size, 'hidden_size': opt.shared_hidden_size, 'kernel_num': opt.D_kernel_num, 'kernel_sizes': opt.D_kernel_sizes, 'word_dropout': opt.D_word_dropout, 'dropout': opt.D_dropout } else: d_args = None if opt.D_model.lower() == 'mlp': D = MLPLanguageDiscriminator(opt.D_layers, opt.shared_hidden_size, opt.shared_hidden_size, len(opt.all_langs), opt.loss, opt.D_dropout, opt.D_bn) else: D = LanguageDiscriminator(opt.D_model, opt.D_layers, opt.shared_hidden_size, opt.shared_hidden_size, len(opt.all_langs), opt.D_dropout, opt.D_bn, d_args) if opt.use_data_parallel: F_s, C, D = nn.DataParallel(F_s).to(opt.device) if F_s else None, nn.DataParallel(C).to(opt.device), nn.DataParallel(D).to(opt.device) if D else None else: F_s, C, D = F_s.to(opt.device) if F_s else None, C.to(opt.device), D.to(opt.device) if D else None if F_p: if opt.use_data_parallel: F_p = nn.DataParallel(F_p).to(opt.device) else: F_p = F_p.to(opt.device) # optimizers optimizer = optim.Adam(filter(lambda p: p.requires_grad, itertools.chain(*map(list, [emb.parameters(), F_s.parameters() if F_s else [], \ C.parameters(), F_p.parameters() if F_p else []]))), lr=opt.learning_rate, weight_decay=opt.weight_decay) if D: optimizerD = optim.Adam(D.parameters(), lr=opt.D_learning_rate, weight_decay=opt.D_weight_decay) # testing if opt.test_only: log.info(f'Loading model from {opt.model_save_file}...') if F_s: F_s.load_state_dict(torch.load(os.path.join(opt.model_save_file, f'netF_s.pth'))) for lang in opt.all_langs: F_p.load_state_dict(torch.load(os.path.join(opt.model_save_file, f'net_F_p.pth'))) C.load_state_dict(torch.load(os.path.join(opt.model_save_file, f'netC.pth'))) if D: D.load_state_dict(torch.load(os.path.join(opt.model_save_file, f'netD.pth'))) log.info('Evaluating validation sets:') acc = {} log.info(dev_loaders) log.info(vocabs) for lang in opt.all_langs: acc[lang] = evaluate(f'{lang}_dev', dev_loaders[lang], vocabs[lang], tag_vocab, emb, lang, F_s, F_p, C) avg_acc = sum([acc[d] for d in opt.dev_langs]) / len(opt.dev_langs) log.info(f'Average validation accuracy: {avg_acc}') log.info('Evaluating test sets:') test_acc = {} for lang in opt.all_langs: test_acc[lang] = evaluate(f'{lang}_test', test_loaders[lang], vocabs[lang], tag_vocab, emb, lang, F_s, F_p, C) avg_test_acc = sum([test_acc[d] for d in opt.dev_langs]) / len(opt.dev_langs) log.info(f'Average test accuracy: {avg_test_acc}') return {'valid': acc, 'test': test_acc} # training best_acc, best_avg_acc = defaultdict(float), 0.0 epochs_since_decay = 0 # lambda scheduling if opt.lambd > 0 and opt.lambd_schedule: opt.lambd_orig = opt.lambd num_iter = int(utils.gmean([len(train_loaders[l]) for l in opt.langs])) # adapt max_epoch if opt.max_epoch > 0 and num_iter * opt.max_epoch < 15000: opt.max_epoch = 15000 // num_iter log.info(f"Setting max_epoch to {opt.max_epoch}") for epoch in range(opt.max_epoch): emb.train() if F_s: F_s.train() C.train() if D: D.train() if F_p: F_p.train() # lambda scheduling if hasattr(opt, 'lambd_orig') and opt.lambd_schedule: if epoch == 0: opt.lambd = opt.lambd_orig elif epoch == 5: opt.lambd = 10 * opt.lambd_orig elif epoch == 15: opt.lambd = 100 * opt.lambd_orig log.info(f'Scheduling lambda = {opt.lambd}') # training accuracy correct, total = defaultdict(int), defaultdict(int) gate_correct = defaultdict(int) c_gate_correct = defaultdict(int) # D accuracy d_correct, d_total = 0, 0 for i in tqdm(range(num_iter), ascii=True): # D iterations if opt.shared_hidden_size > 0: utils.freeze_net(emb) utils.freeze_net(F_s) utils.freeze_net(F_p) utils.freeze_net(C) utils.unfreeze_net(D) # WGAN n_critic trick since D trains slower n_critic = opt.n_critic if opt.wgan_trick: if opt.n_critic>0 and ((epoch==0 and i<25) or i%500==0): n_critic = 100 for _ in range(n_critic): D.zero_grad() loss_d = {} lang_features = {} # train on both labeled and unlabeled langs for lang in opt.all_langs: # targets not used d_inputs, _ = utils.endless_get_next_batch( unlabeled_loaders, d_unlabeled_iters, lang) d_inputs, d_lengths, mask, d_chars, d_char_lengths = d_inputs d_embeds = emb(lang, d_inputs, d_chars, d_char_lengths) shared_feat = F_s((d_embeds, d_lengths)) if opt.grad_penalty != 'none': lang_features[lang] = shared_feat.detach() if opt.D_model.lower() == 'mlp': d_outputs = D(shared_feat) # if token-level D, we can reuse the gate label generator d_targets = utils.get_gate_label(d_outputs, lang, mask, False, all_langs=True) d_total += torch.sum(d_lengths).item() else: d_outputs = D((shared_feat, d_lengths)) d_targets = utils.get_lang_label(opt.loss, lang, len(d_lengths)) d_total += len(d_lengths) # D accuracy _, pred = torch.max(d_outputs, -1) # d_total += len(d_lengths) d_correct += (pred==d_targets).sum().item() if opt.use_data_parallel: l_d = functional.nll_loss(d_outputs.view(-1, D.module.num_langs), d_targets.view(-1), ignore_index=-1) else: l_d = functional.nll_loss(d_outputs.view(-1, D.num_langs), d_targets.view(-1), ignore_index=-1) l_d.backward() loss_d[lang] = l_d.item() # gradient penalty if opt.grad_penalty != 'none': gp = utils.calc_gradient_penalty(D, lang_features, onesided=opt.onesided_gp, interpolate=(opt.grad_penalty=='wgan')) gp.backward() optimizerD.step() # F&C iteration utils.unfreeze_net(emb) if opt.use_wordemb and opt.fix_emb: for lang in emb.langs: emb.wordembs[lang].weight.requires_grad = False if opt.use_charemb and opt.fix_charemb: emb.charemb.weight.requires_grad = False utils.unfreeze_net(F_s) utils.unfreeze_net(F_p) utils.unfreeze_net(C) utils.freeze_net(D) emb.zero_grad() if F_s: F_s.zero_grad() if F_p: F_p.zero_grad() C.zero_grad() # optimizer.zero_grad() for lang in opt.langs: inputs, targets = utils.endless_get_next_batch( train_loaders, train_iters, lang) inputs, lengths, mask, chars, char_lengths = inputs bs, seq_len = inputs.size() embeds = emb(lang, inputs, chars, char_lengths) shared_feat, private_feat = None, None if opt.shared_hidden_size > 0: shared_feat = F_s((embeds, lengths)) if opt.private_hidden_size > 0: private_feat, gate_outputs = F_p((embeds, lengths)) if opt.C_MoE: c_outputs, c_gate_outputs = C((shared_feat, private_feat)) else: c_outputs = C((shared_feat, private_feat)) # targets are padded with -1 l_c = functional.nll_loss(c_outputs.view(bs*seq_len, -1), targets.view(-1), ignore_index=-1) # gate loss if F_p: gate_targets = utils.get_gate_label(gate_outputs, lang, mask, False) l_gate = functional.cross_entropy(gate_outputs.view(bs*seq_len, -1), gate_targets.view(-1), ignore_index=-1) l_c += opt.gate_loss_weight * l_gate _, gate_pred = torch.max(gate_outputs.view(bs*seq_len, -1), -1) gate_correct[lang] += (gate_pred == gate_targets.view(-1)).sum().item() if opt.C_MoE and opt.C_gate_loss_weight > 0: c_gate_targets = utils.get_gate_label(c_gate_outputs, lang, mask, opt.expert_sp) _, c_gate_pred = torch.max(c_gate_outputs.view(bs*seq_len, -1), -1) if opt.expert_sp: l_c_gate = functional.binary_cross_entropy_with_logits( mask.unsqueeze(-1) * c_gate_outputs, c_gate_targets) c_gate_correct[lang] += torch.index_select(c_gate_targets.view(bs*seq_len, -1), -1, c_gate_pred.view(bs*seq_len)).sum().item() else: l_c_gate = functional.cross_entropy(c_gate_outputs.view(bs*seq_len, -1), c_gate_targets.view(-1), ignore_index=-1) c_gate_correct[lang] += (c_gate_pred == c_gate_targets.view(-1)).sum().item() l_c += opt.C_gate_loss_weight * l_c_gate l_c.backward() _, pred = torch.max(c_outputs, -1) total[lang] += torch.sum(lengths).item() correct[lang] += (pred == targets).sum().item() # update F with D gradients on all langs if D: for lang in opt.all_langs: inputs, _ = utils.endless_get_next_batch( unlabeled_loaders, unlabeled_iters, lang) inputs, lengths, mask, chars, char_lengths = inputs embeds = emb(lang, inputs, chars, char_lengths) shared_feat = F_s((embeds, lengths)) # d_outputs = D((shared_feat, lengths)) if opt.D_model.lower() == 'mlp': d_outputs = D(shared_feat) # if token-level D, we can reuse the gate label generator d_targets = utils.get_gate_label(d_outputs, lang, mask, False, all_langs=True) else: d_outputs = D((shared_feat, lengths)) d_targets = utils.get_lang_label(opt.loss, lang, len(lengths)) if opt.use_data_parallel: l_d = functional.nll_loss(d_outputs.view(-1, D.module.num_langs), d_targets.view(-1), ignore_index=-1) else: l_d = functional.nll_loss(d_outputs.view(-1, D.num_langs), d_targets.view(-1), ignore_index=-1) if opt.lambd > 0: l_d *= -opt.lambd l_d.backward() optimizer.step() # end of epoch log.info('Ending epoch {}'.format(epoch+1)) if d_total > 0: log.info('D Training Accuracy: {}%'.format(100.0*d_correct/d_total)) log.info('Training accuracy:') log.info('\t'.join(opt.langs)) log.info('\t'.join([str(100.0*correct[d]/total[d]) for d in opt.langs])) log.info('Gate accuracy:') log.info('\t'.join([str(100.0*gate_correct[d]/total[d]) for d in opt.langs])) log.info('Tagger Gate accuracy:') log.info('\t'.join([str(100.0*c_gate_correct[d]/total[d]) for d in opt.langs])) log.info('Evaluating validation sets:') acc = {} for lang in opt.dev_langs: acc[lang] = evaluate(f'{lang}_dev', dev_loaders[lang], vocabs[lang], tag_vocab, emb, lang, F_s, F_p, C) avg_acc = sum([acc[d] for d in opt.dev_langs]) / len(opt.dev_langs) log.info(f'Average validation accuracy: {avg_acc}') log.info('Evaluating test sets:') test_acc = {} for lang in opt.dev_langs: test_acc[lang] = evaluate(f'{lang}_test', test_loaders[lang], vocabs[lang], tag_vocab, emb, lang, F_s, F_p, C) avg_test_acc = sum([test_acc[d] for d in opt.dev_langs]) / len(opt.dev_langs) log.info(f'Average test accuracy: {avg_test_acc}') if avg_acc > best_avg_acc: epochs_since_decay = 0 log.info(f'New best average validation accuracy: {avg_acc}') best_acc['valid'] = acc best_acc['test'] = test_acc best_avg_acc = avg_acc with open(os.path.join(opt.model_save_file, 'options.pkl'), 'wb') as ouf: pickle.dump(opt, ouf) if F_s: torch.save(F_s.state_dict(), '{}/netF_s.pth'.format(opt.model_save_file)) torch.save(emb.state_dict(), '{}/net_emb.pth'.format(opt.model_save_file)) if F_p: torch.save(F_p.state_dict(), '{}/net_F_p.pth'.format(opt.model_save_file)) torch.save(C.state_dict(), '{}/netC.pth'.format(opt.model_save_file)) if D: torch.save(D.state_dict(), '{}/netD.pth'.format(opt.model_save_file)) else: epochs_since_decay += 1 if opt.lr_decay < 1 and epochs_since_decay >= opt.lr_decay_epochs: epochs_since_decay = 0 old_lr = optimizer.param_groups[0]['lr'] optimizer.param_groups[0]['lr'] = old_lr * opt.lr_decay log.info(f'Decreasing LR to {old_lr * opt.lr_decay}') # end of training log.info(f'Best average validation accuracy: {best_avg_acc}') return best_acc
def train(NetG, NetD, optimizerG, optimizerD, dataloader, epoch): total_dice = 0 total_g_loss = 0 total_g_loss_dice = 0 total_g_loss_bce = 0 total_d_loss = 0 total_d_loss_penalty = 0 NetG.train() NetD.train() for i, data in enumerate(dataloader, 1): # train D optimizerD.zero_grad() NetD.zero_grad() for p in NetG.parameters(): p.requires_grad = False for p in NetD.parameters(): p.requires_grad = True input, target = Variable(data[0]), Variable(data[1]) input = input.float() target = target.float() if use_cuda: input = input.cuda() target = target.cuda() output = NetG(input) output = F.sigmoid(output) output = output.detach() input_img = input.clone() output_masked = input_img * output if use_cuda: output_masked = output_masked.cuda() result = NetD(output_masked) target_masked = input_img * target if use_cuda: target_masked = target_masked.cuda() target_D = NetD(target_masked) loss_mac = -torch.mean(torch.abs(result - target_D)) loss_mac.backward() # D net gradient_penalty batch_size = target_masked.size(0) gradient_penalty = utils.calc_gradient_penalty(NetD, target_masked, output_masked, batch_size, use_cuda, input.shape) gradient_penalty.backward() optimizerD.step() # train G optimizerG.zero_grad() NetG.zero_grad() for p in NetG.parameters(): p.requires_grad = True for p in NetD.parameters(): p.requires_grad = False output = NetG(input) output = F.sigmoid(output) target_dice = target.view(-1).long() output_dice = output.view(-1) loss_dice = utils.dice_loss(output_dice, target_dice) output_masked = input_img * output if use_cuda: output_masked = output_masked.cuda() result = NetD(output_masked) target_G = NetD(target_masked) loss_G = torch.mean(torch.abs(result - target_G)) loss_G_joint = loss_G + loss_dice loss_G_joint.backward() optimizerG.step() total_dice += 1 - loss_dice.data[0] total_g_loss += loss_G_joint.data[0] total_g_loss_dice += loss_dice.data[0] total_g_loss_bce += loss_G.data[0] total_d_loss += loss_mac.data[0] total_d_loss_penalty += gradient_penalty.data[0] for p in NetG.parameters(): p.requires_grad = True for p in NetD.parameters(): p.requires_grad = True size = len(dataloader) epoch_dice = total_dice / size epoch_g_loss = total_g_loss / size epoch_g_loss_dice = total_g_loss_dice / size epoch_g_loss_bce = total_g_loss_bce / size epoch_d_loss = total_d_loss / size epoch_d_loss_penalty = total_d_loss_penalty / size print_format = [ epoch, conf.epochs, epoch_dice * 100, epoch_g_loss, epoch_g_loss_dice, epoch_g_loss_bce, epoch_d_loss, epoch_d_loss_penalty ] print('===> Training step {}/{} \tepoch_dice: {:.5f}' '\tepoch_g_loss: {:.5f} \tepoch_g_loss_dice: {:.5f}' '\tepoch_g_loss_bce: {:.5f} \tepoch_d_loss: {:.5f}' '\tepoch_d_loss_penalty: {:.5f}'.format(*print_format))
output = D_a(real_a).to(opt.device) errD_real = -1 * (2 + opt.lambda_self) * output.mean() # -a errD_real.backward(retain_graph=True) output_a = D_a(mix_g_a.detach()) output_a2 = D_a(fake_a.detach()) if opt.lambda_self > 0.0: output_a3 = D_a(self_a.detach()) output_a3 = output_a3.mean() else: output_a3 = 0 errD_fake_a = output_a.mean() + output_a2.mean( ) + opt.lambda_self * output_a3 errD_fake_a.backward(retain_graph=True) gradient_penalty_a = calc_gradient_penalty( D_a, real_a, mix_g_a, opt.lambda_grad, opt.device) gradient_penalty_a += calc_gradient_penalty( D_a, real_a, fake_a, opt.lambda_grad, opt.device) if opt.lambda_self > 0.0: gradient_penalty_a += opt.lambda_self * calc_gradient_penalty( D_a, real_a, self_a, opt.lambda_grad, opt.device) gradient_penalty_a.backward(retain_graph=True) ############################# #### Train D_b #### ############################# D_b.zero_grad() output = D_b(real_b).to(opt.device) errD_real = -1 * (2 + opt.lambda_self) * output.mean() # -a
def train_single_scale(netD, netG, reals, Gs, Zs, in_s, NoiseAmp, opt, centers=None): real = reals[len(Gs)] opt.nzx = real.shape[2] #+(opt.ker_size-1)*(opt.num_layer) opt.nzy = real.shape[3] #+(opt.ker_size-1)*(opt.num_layer) opt.receptive_field = opt.ker_size + ((opt.ker_size - 1) * (opt.num_layer - 1)) * opt.stride pad_noise = int(((opt.ker_size - 1) * opt.num_layer) / 2) pad_image = int(((opt.ker_size - 1) * opt.num_layer) / 2) if opt.mode == 'animation_train': opt.nzx = real.shape[2] + (opt.ker_size - 1) * (opt.num_layer) opt.nzy = real.shape[3] + (opt.ker_size - 1) * (opt.num_layer) pad_noise = 0 m_noise = nn.ZeroPad2d(int(pad_noise)) m_image = nn.ZeroPad2d(int(pad_image)) alpha = opt.alpha fixed_noise = functions.generate_noise([opt.nc_z, opt.nzx, opt.nzy], device=opt.device) z_opt = torch.full(fixed_noise.shape, 0, device=opt.device) z_opt = m_noise(z_opt) # setup optimizer optimizerD = optim.Adam(netD.parameters(), lr=opt.lr_d, betas=(opt.beta1, 0.999)) optimizerG = optim.Adam(netG.parameters(), lr=opt.lr_g, betas=(opt.beta1, 0.999)) schedulerD = torch.optim.lr_scheduler.MultiStepLR(optimizer=optimizerD, milestones=[1600], gamma=opt.gamma) schedulerG = torch.optim.lr_scheduler.MultiStepLR(optimizer=optimizerG, milestones=[1600], gamma=opt.gamma) errD2plot = [] errG2plot = [] D_real2plot = [] D_fake2plot = [] z_opt2plot = [] for epoch in range(opt.niter): if (Gs == []) & (opt.mode != 'SR_train'): z_opt = functions.generate_noise([1, opt.nzx, opt.nzy], device=opt.device) z_opt = m_noise(z_opt.expand(1, 3, opt.nzx, opt.nzy)) noise_ = functions.generate_noise([1, opt.nzx, opt.nzy], device=opt.device) noise_ = m_noise(noise_.expand(1, 3, opt.nzx, opt.nzy)) else: noise_ = functions.generate_noise([opt.nc_z, opt.nzx, opt.nzy], device=opt.device) noise_ = m_noise(noise_) ############################ # (1) Update D network: maximize D(x) + D(G(z)) ########################### for j in range(opt.Dsteps): # train with real netD.zero_grad() output = netD(real).to(opt.device) #D_real_map = output.detach() errD_real = -output.mean() #-a errD_real.backward(retain_graph=True) D_x = -errD_real.item() # train with fake if (j == 0) & (epoch == 0): if (Gs == []) & (opt.mode != 'SR_train'): prev = torch.full([1, opt.nc_z, opt.nzx, opt.nzy], 0, device=opt.device) in_s = prev prev = m_image(prev) z_prev = torch.full([1, opt.nc_z, opt.nzx, opt.nzy], 0, device=opt.device) z_prev = m_noise(z_prev) opt.noise_amp = 1 elif opt.mode == 'SR_train': z_prev = in_s criterion = nn.MSELoss() RMSE = torch.sqrt(criterion(real, z_prev)) opt.noise_amp = opt.noise_amp_init * RMSE z_prev = m_image(z_prev) prev = z_prev else: prev = draw_concat(Gs, Zs, reals, NoiseAmp, in_s, 'rand', m_noise, m_image, opt) prev = m_image(prev) z_prev = draw_concat(Gs, Zs, reals, NoiseAmp, in_s, 'rec', m_noise, m_image, opt) criterion = nn.MSELoss() RMSE = torch.sqrt(criterion(real, z_prev)) opt.noise_amp = opt.noise_amp_init * RMSE z_prev = m_image(z_prev) else: prev = draw_concat(Gs, Zs, reals, NoiseAmp, in_s, 'rand', m_noise, m_image, opt) prev = m_image(prev) if opt.mode == 'paint_train': prev = functions.quant2centers(prev, centers) plt.imsave('%s/prev.png' % (opt.outf), functions.convert_image_np(prev), vmin=0, vmax=1) if (Gs == []) & (opt.mode != 'SR_train'): noise = noise_ else: noise = opt.noise_amp * noise_ + prev fake = netG(noise.detach(), prev) output = netD(fake.detach()) errD_fake = output.mean() errD_fake.backward(retain_graph=True) D_G_z = output.mean().item() gradient_penalty = functions.calc_gradient_penalty( netD, real, fake, opt.lambda_grad, opt.device) gradient_penalty.backward() errD = errD_real + errD_fake + gradient_penalty optimizerD.step() errD2plot.append(errD.detach()) ############################ # (2) Update G network: maximize D(G(z)) ########################### for j in range(opt.Gsteps): netG.zero_grad() output = netD(fake) #D_fake_map = output.detach() errG = -output.mean() errG.backward(retain_graph=True) if alpha != 0: loss = nn.MSELoss() if opt.mode == 'paint_train': z_prev = functions.quant2centers(z_prev, centers) plt.imsave('%s/z_prev.png' % (opt.outf), functions.convert_image_np(z_prev), vmin=0, vmax=1) Z_opt = opt.noise_amp * z_opt + z_prev rec_loss = alpha * loss(netG(Z_opt.detach(), z_prev), real) rec_loss.backward(retain_graph=True) rec_loss = rec_loss.detach() else: Z_opt = z_opt rec_loss = 0 optimizerG.step() errG2plot.append(errG.detach() + rec_loss) D_real2plot.append(D_x) D_fake2plot.append(D_G_z) z_opt2plot.append(rec_loss) if epoch % 25 == 0 or epoch == (opt.niter - 1): print('scale %d:[%d/%d]' % (len(Gs), epoch, opt.niter)) if epoch % 500 == 0 or epoch == (opt.niter - 1): plt.imsave('%s/fake_sample.png' % (opt.outf), functions.convert_image_np(fake.detach()), vmin=0, vmax=1) plt.imsave('%s/G(z_opt).png' % (opt.outf), functions.convert_image_np( netG(Z_opt.detach(), z_prev).detach()), vmin=0, vmax=1) #plt.imsave('%s/D_fake.png' % (opt.outf), functions.convert_image_np(D_fake_map)) #plt.imsave('%s/D_real.png' % (opt.outf), functions.convert_image_np(D_real_map)) #plt.imsave('%s/z_opt.png' % (opt.outf), functions.convert_image_np(z_opt.detach()), vmin=0, vmax=1) #plt.imsave('%s/prev.png' % (opt.outf), functions.convert_image_np(prev), vmin=0, vmax=1) #plt.imsave('%s/noise.png' % (opt.outf), functions.convert_image_np(noise), vmin=0, vmax=1) #plt.imsave('%s/z_prev.png' % (opt.outf), functions.convert_image_np(z_prev), vmin=0, vmax=1) torch.save(z_opt, '%s/z_opt.pth' % (opt.outf)) schedulerD.step() schedulerG.step() functions.save_networks(netG, netD, z_opt, opt) return z_opt, in_s, netG
# compute real data loss for discriminator d_loss_real = D(images) d_loss_real = d_loss_real.mean() d_loss_real.backward(fake_labels) # compute fake data loss for discriminator noise = make_variable(torch.randn( params.batch_size, params.z_dim, 1, 1).normal_(0, 1), volatile=True) fake_images = make_variable(G(noise).data) d_loss_fake = D(fake_images.detach()) d_loss_fake = d_loss_fake.mean() d_loss_fake.backward(real_labels) # compute gradient penalty gradient_penalty = calc_gradient_penalty( D, images.data, fake_images.data) gradient_penalty.backward() # optimize weights of discriminator d_loss = - d_loss_real + d_loss_fake + gradient_penalty d_optimizer.step() ########################## # (2) training generator # ########################## # avoid to compute gradients for D for p in D.parameters(): p.requires_grad = False # zero grad for optimizer of generator g_optimizer.zero_grad()
D_real_spk.backward(mone) # train with real from other speakers D_real_nspk = BETA * D_net(real_data_nspk).mean() D_real_nspk.backward(one) # train with fake data noise = autograd.Variable(torch.randn(BATCH_SIZE, 128), volatile=True).cuda() fake_data = autograd.Variable(G_net(noise).data) D_fake = D_net(fake_data).mean() D_fake.backward(one) # train with gradient penalty gradient_penalty = calc_gradient_penalty(D_net, real_data_spk.data, fake_data.data, BATCH_SIZE, LAMBDA) gradient_penalty.backward() D_cost = D_fake + D_real_nspk - D_real_spk + gradient_penalty Wasserstein_D = D_real_spk - D_fake D_optimizer.step() ############################ # (2) Update G network ########################### for p in D_net.parameters(): p.requires_grad = False # to avoid computation G_net.zero_grad() noise = autograd.Variable(torch.randn(BATCH_SIZE, 128)).cuda()
def train(self): """Training Discriminator with Generator.""" start_time = time.time() n_classes = 10 # before epoch training loop starts loss1 = [] loss2 = [] loss3 = [] loss4 = [] loss5 = [] acc1 = [] np.random.seed(352) label = np.asarray(list(range(10)) * 10) noise = np.random.normal(0, 1, (100, self.n_z)) label_onehot = np.zeros((100, n_classes)) label_onehot[np.arange(100), label] = 1 noise[np.arange(100), :n_classes] = label_onehot[np.arange(100)] noise = noise.astype(np.float32) save_noise = torch.from_numpy(noise) if self.cuda: save_noise = save_noise.cuda() save_noise = Variable(save_noise) # Train the model for epoch in range(self.start_epoch, self.start_epoch + self.epochs): # turn models to `train` mode self.aG.train() self.aD.train() for batch_idx, (X_train_batch, Y_train_batch) in enumerate(self.trainloader): if Y_train_batch.shape[0] < self.batch_size: continue # train G if batch_idx % self.gen_train == 0: for p in self.aD.parameters(): p.requires_grad_(False) self.aG.zero_grad() label = np.random.randint(0, n_classes, self.batch_size) noise = np.random.normal(0, 1, (self.batch_size, self.n_z)) label_onehot = np.zeros((self.batch_size, n_classes)) label_onehot[np.arange(self.batch_size), label] = 1 noise[np.arange(self.batch_size), :n_classes] = label_onehot[np.arange( self.batch_size)] noise = noise.astype(np.float32) noise = torch.from_numpy(noise) if self.cuda: noise = noise.cuda() noise = Variable(noise) # noise = Variable(noise).cuda() if self.cuda: fake_label = Variable(torch.from_numpy(label)).cuda() else: fake_label = Variable(torch.from_numpy(label)) # fake_label = Variable(torch.from_numpy(label)).cuda() fake_data = self.aG(noise) gen_source, gen_class = self.aD(fake_data) gen_source = gen_source.mean() gen_class = self.criterion(gen_class, fake_label) gen_cost = -gen_source + gen_class gen_cost.backward() for group in self.optimizer_g.param_groups: for p in group['params']: state = self.optimizer_g.state[p] if('step' in state and state['step'] >= 1024): state['step'] = 1000 self.optimizer_g.step() # train D for p in self.aD.parameters(): p.requires_grad_(True) self.aD.zero_grad() # train discriminator with input from generator label = np.random.randint(0, n_classes, self.batch_size) noise = np.random.normal(0, 1, (self.batch_size, self.n_z)) label_onehot = np.zeros((self.batch_size, n_classes)) label_onehot[np.arange(self.batch_size), label] = 1 noise[np.arange(self.batch_size), :n_classes] = label_onehot[np.arange( self.batch_size)] noise = noise.astype(np.float32) noise = torch.from_numpy(noise) if self.cuda: noise = noise.cuda() noise = Variable(noise) if self.cuda: fake_label = Variable(torch.from_numpy(label)).cuda() else: fake_label = Variable(torch.from_numpy(label)) with torch.no_grad(): fake_data = self.aG(noise) disc_fake_source, disc_fake_class = self.aD(fake_data) disc_fake_source = disc_fake_source.mean() disc_fake_class = self.criterion(disc_fake_class, fake_label) # train discriminator with input from the discriminator if self.cuda: real_data, real_label = X_train_batch.cuda(), Y_train_batch.cuda() else: real_data, real_label = X_train_batch, Y_train_batch real_data, real_label = Variable( real_data), Variable(real_label) disc_real_source, disc_real_class = self.aD(real_data) prediction = disc_real_class.data.max(1)[1] accuracy = (float(prediction.eq(real_label.data).sum() ) / float(self.batch_size)) * 100.0 disc_real_source = disc_real_source.mean() disc_real_class = self.criterion(disc_real_class, real_label) gradient_penalty = calc_gradient_penalty( self.aD, real_data, fake_data, self.batch_size, self.cuda) disc_cost = disc_fake_source - disc_real_source + \ disc_real_class + disc_fake_class + gradient_penalty disc_cost.backward() for group in self.optimizer_d.param_groups: for p in group['params']: state = self.optimizer_d.state[p] if('step' in state and state['step'] >= 1024): state['step'] = 1000 self.optimizer_d.step() # within the training loop loss1.append(gradient_penalty.item()) loss2.append(disc_fake_source.item()) loss3.append(disc_real_source.item()) loss4.append(disc_real_class.item()) loss5.append(disc_fake_class.item()) acc1.append(accuracy) if batch_idx % 50 == 0: print("Trainig epoch: {} | Accuracy: {} | Batch: {} | Gradient penalty: {} | Discriminator fake source: {} | Discriminator real source: {} | Discriminator real class: {} | Discriminator fake class: {}".format( epoch, np.mean(acc1), batch_idx, np.mean(loss1), np.mean(loss2), np.mean(loss3), np.mean(loss4), np.mean(loss5))) # Test the model self.aD.eval() with torch.no_grad(): test_accu = [] for batch_idx, (X_test_batch, Y_test_batch) in enumerate(self.testloader): if self.cuda: X_test_batch, Y_test_batch = X_test_batch.cuda(), Y_test_batch.cuda() X_test_batch, Y_test_batch = Variable( X_test_batch), Variable(Y_test_batch) with torch.no_grad(): _, output = self.aD(X_test_batch) # first column has actual prob. prediction = output.data.max(1)[1] accuracy = ( float(prediction.eq(Y_test_batch.data).sum()) / float(self.batch_size)) * 100.0 test_accu.append(accuracy) accuracy_test = np.mean(test_accu) # print('Testing', accuracy_test, time.time() - start_time) print("Testing accuracy: {} | Eplased time: {}".format( accuracy_test, time.time() - start_time)) # save output with torch.no_grad(): self.aG.eval() samples = self.aG(save_noise) samples = samples.data.cpu().numpy() samples += 1.0 samples /= 2.0 samples = samples.transpose(0, 2, 3, 1) self.aG.train() fig = plot(samples) if not os.path.isdir('../output'): os.mkdir('../output') plt.savefig('../output/%s.png' % str(epoch).zfill(3), bbox_inches='tight') plt.close(fig) if (epoch + 1) % 1 == 0: torch.save(self.aG, '../model/tempG.model') torch.save(self.aD, '../model/tempD.model')
# Discriminateur D optimizerD.zero_grad() outputTrue = netD(x_cuda, alpha=(1-alpha_value) if alpha else -1) # lossDT = F.binary_cross_entropy_with_logits(outputTrue, real_label) lossDT = -torch.mean(outputTrue) # with false label outputG = netG(Variable(noise)) outputFalse = netD(outputG.detach(), alpha=(1-alpha_value) if alpha else -1) # lossDF = F.binary_cross_entropy_with_logits(outputFalse, fake_label) lossDF = torch.mean(outputFalse) dTrue.append(F.sigmoid(outputTrue).data.mean()) dFalse.append(F.sigmoid(outputFalse).data.mean()) gradient_penalty = utils.calc_gradient_penalty(netD, x_cuda, outputG, batch_size=batchsize, lda=10, view=x_cuda.size()) (lossDT+lossDF+gradient_penalty).backward() optimizerD.step() ldf += lossDF ldt += lossDT # Generateur optimizerG.zero_grad() outputG = netG(noise, alpha=(1-alpha_value) if alpha else -1) outputD = netD(outputG, alpha=(1-alpha_value) if alpha else -1) # lossG = F.binary_cross_entropy_with_logits(outputD, real_label) lossG = -torch.mean(outputD) lossG.backward() optimizerG.step()
def train(): TRAINING_ITERATIONS = 100000 #@param {type:"number"} MAX_CONTEXT_POINTS = 50 #@param {type:"number"} PLOT_AFTER = 100 #10000 #@param {type:"number"} HIDDEN_SIZE = 300 #@param {type:"number"} MODEL_TYPE = 'ANP' #@param ['NP','ANP'] ATTENTION_TYPE = 'multihead' #@param ['uniform','laplace','dot_product','multihead'] batch_size = 64 X_SIZE = 1 Y_SIZE = 1 vocab = pickle.load(open("vocab.pkl", "rb")) test_sentences = [ "Two men seated at an open air restaurant", "flowers in a pot sitting on a cement wall", "a vase and lids are sitting on a table", "a teddy bear that is sitting next to some item on a table", "a plant in a vase by the window", "a young girl is similing and she has food around her on a table" ] dataset_train = get_coco_loader("./resized_small_train2014/", "./annotations/captions_train2014.json", vocab=vocab, transform=None, batch_size=batch_size, shuffle=True, num_workers=4) dataset_test = dataset_train # we will need to build out the test dataset soon # Sizes of the layers of the MLPs for the encoders and decoder # The final output layer of the decoder outputs two values, one for the mean and # one for the variance of the prediction at the target location latent_encoder_output_sizes = [HIDDEN_SIZE] * 4 num_latents = HIDDEN_SIZE deterministic_encoder_output_sizes = [HIDDEN_SIZE] * 4 decoder_output_sizes = [32] * 2 + [2] use_deterministic_path = True xy_size = X_SIZE + Y_SIZE # # ANP with multihead attention # if MODEL_TYPE == 'ANP': # attention = Attention(rep='mlp', x_size=X_SIZE, r_size=deterministic_encoder_output_sizes[-1], output_sizes=[HIDDEN_SIZE]*2, # att_type=ATTENTION_TYPE).to(device) # CHANGE: rep was originally 'mlp' # # NP - equivalent to uniform attention # elif MODEL_TYPE == 'NP': # attention = Attention(rep='identity', x_size=None, output_sizes=None, att_type='uniform').to(device) # else: # raise NameError("MODEL_TYPE not among ['ANP,'NP']") # # Define the model # print("num_latents: {}, latent_encoder_output_sizes: {}, deterministic_encoder_output_sizes: {}, decoder_output_sizes: {}".format( # num_latents, latent_encoder_output_sizes, deterministic_encoder_output_sizes, decoder_output_sizes)) # decoder_input_size = 2 * HIDDEN_SIZE + X_SIZE # model_wass = LatentModel(X_SIZE, Y_SIZE, latent_encoder_output_sizes, num_latents, # decoder_output_sizes, use_deterministic_path, # deterministic_encoder_output_sizes, attention, loss_type="wass").to(device) encoder = Encoder().to(device) decoder = Decoder().to(device) critic = Critic().to(device) optimizer_critic = torch.optim.Adam(critic.parameters()) optimizer = torch.optim.Adam( list(encoder.parameters()) + list(decoder.parameters())) for epoch in range(10): progress = tqdm(enumerate(dataset_train)) total_loss = 0 count = 0 for i, (images, targets, lengths) in progress: try: optimizer.zero_grad() gen_loss = 0 for instance in range(batch_size): image, target, length = images[instance].to( device), targets[instance], lengths[instance] sentence = get_sentence(target, vocab) vectors = torch.Tensor([ nlp(word).vector for word in sentence if "<" not in word ]).to(device) vectors = vectors.unsqueeze(0) image = image.unsqueeze(0) r = encoder(vectors, image) decoder_input = torch.cat( (r.repeat(1, vectors.shape[1], 1), vectors.float()), -1) out = decoder(decoder_input) fake_image = out.view(32, 32, 3) disc_fake = critic(fake_image) disc_fake.backward() gen_loss = -disc_fake optimizer.step() for t in range(5): optimizer_critic.zero_grad() loss = 0 for instance in range(batch_size): image, target, length = images[instance].to( device), targets[instance], lengths[instance] sentence = get_sentence(target, vocab) vectors = torch.Tensor([ nlp(word).vector for word in sentence if "<" not in word ]).to(device) vectors = vectors.unsqueeze(0) image = image.unsqueeze(0) r = encoder(vectors, image) decoder_input = torch.cat((r.repeat( 1, vectors.shape[1], 1), vectors.float()), -1) out = decoder(decoder_input) fake_image = out.view(32, 32, 3) # fake_image = fake_image.transpose(1,0).transpose(2,1) disc_real = critic(image) disc_fake = critic(fake_image) gradient_penalty = utils.calc_gradient_penalty( critic, image, fake_image) loss = disc_fake - disc_real + gradient_penalty loss.backward() w_dist = disc_real - disc_fake optimizer_critic.step() progress.set_description("E{} - L{:.4f}".format( epoch, w_dist.item())) with open("encoder.pkl", "wb") as of: pickle.dump(encoder, of) with open("decoder.pkl", "wb") as of: pickle.dump(decoder, of) with open("critic.pkl", "wb") as of: pickle.dump(critic, of) except Exception as e: print(e) continue try: if i % 100 == 0: with torch.no_grad(): decoder_input = torch.cat((r.repeat( 1, vectors.shape[1], 1), vectors.float()), -1) out = decoder(decoder_input) fake_image = out.view(32, 32, 3) plt.imshow(fake_image.detach().cpu()) plt.xlabel(" ".join([ x for x in sentence if x not in {"<end>", "<pad>", "<start>", "<unk>"} ]), wrap=True) plt.tight_layout() plt.savefig("{}generated{}.png".format( i + 1, sentence[1])) plt.close() except: continue print("done")
D.zero_grad() D_optimizer.zero_grad() real_pair = torch.cat((imgs, g_truth), dim=1) d_real = D(real_pair) d_real = d_real.mean() d_real.backward(mone) fake_pair = torch.cat((imgs, G(imgs).detach()), dim=1) d_fake = D(fake_pair) d_fake = d_fake.mean() d_fake.backward(one) gradient_penalty = calc_gradient_penalty(D, real_pair.data, fake_pair.data) gradient_penalty.backward() D_optimizer.step() Wasserstein_D = d_real- d_fake D_losses.append(Wasserstein_D.item()) # train the generator for idx, (imgs, g_truth) in tqdm.tqdm(enumerate(train_loader)): mini_batch = imgs.size()[0] y_real_ = torch.ones(mini_batch) y_fake_ = torch.zeros(mini_batch) imgs, g_truth, y_real_, y_fake_ = Variable(imgs.cuda()), Variable(g_truth.cuda()), Variable( y_real_.cuda()), Variable(y_fake_.cuda()) #imgs, g_truth, y_real_, y_fake_ = Variable(imgs), Variable(g_truth), Variable(
disc_fake_class = criterion(disc_fake_class, fake_label) # calculate discriminator loss with real data real_data = Variable(x).cuda() real_label = Variable(y).cuda() disc_real_source, disc_real_class = aD(real_data) prediction = disc_real_class.data.max(1)[1] accuracy = (float(prediction.eq(real_label.data).sum()) / float(batch_size)) * 100.0 disc_real_source = disc_real_source.mean() disc_real_class = criterion(disc_real_class, real_label) gradient_penalty = calc_gradient_penalty(aD, real_data, fake_data, batch_size) disc_cost = disc_fake_source - disc_real_source + disc_real_class + disc_fake_class + gradient_penalty disc_cost.backward() optimizer_d.step() """ Append losses and print """ loss1.append(gradient_penalty.item()) loss2.append(disc_fake_source.item()) loss3.append(disc_real_source.item()) loss4.append(disc_real_class.item()) loss5.append(disc_fake_class.item()) acc1.append(accuracy) if batch_idx % 50 == 0:
def deepinversion_improved(self, use_generator = False, \ discrete_label = True, \ noisify_network = 0.0, \ knowledge_distill = 0.0, \ mutual_info = 0.0, \ batchnorm_transfer = 0.0, \ use_discriminator = 0.0, \ n_iters = 100): tb = SummaryWriter() if use_generator == True: z = torch.randn((self.n_samples, self.latent_dim), requires_grad=False, device=self.device, dtype=torch.float) if discrete_label == True: y_gt = torch.randint(0, 2, (self.n_samples, self.label_dim), dtype=torch.float, device=self.device) else: y_gt = torch.cuda.FloatTensor(self.n_samples, self.label_dim).uniform_(0, 1) x = self.net_gen(z, y_gt) if mutual_info > 0.0: ''' declare the optimizer for the encoder network ''' optimizer = torch.optim.Adam(list(self.net_gen.parameters()) + list(self.net_enc.parameters()), lr=self.lr) else: optimizer = torch.optim.Adam(self.net_gen.parameters(), lr=self.lr) else: x = torch.randn((self.n_samples, 2), requires_grad=True, device=self.device, dtype=torch.float) if discrete_label == True: y_gt = torch.randint(0, 2, (self.n_samples, self.label_dim), dtype=torch.float, device=self.device) else: y_gt = torch.cuda.FloatTensor(self.n_samples, self.label_dim).uniform_(0, 1) optimizer = torch.optim.Adam([x], lr=self.lr) #update name of output self.imgname = self.imgname + "_gen%d" % (use_generator) ''' declare the optimizer for the student network ''' optimizer_std = torch.optim.Adam(self.net_std.parameters(), lr=self.classifier_lr) if self.device == 'cuda': x_np = x.cpu().detach().clone().numpy() else: x_np = x.detach().clone().numpy() fig, ax = self.setup_plot_progress(x_np) total_loss = [] # set for testing with batchnorm self.net.eval() ## Create hooks for feature statistics loss_bn_feature_layers = [] if use_generator == True and use_discriminator > 0.0: nets_dis = [] nets_dis_params = [] for module in self.net.modules(): if isinstance(module, nn.BatchNorm1d): loss_bn_feature_layers.append(bn1dfeathook(module)) if use_generator == True and use_discriminator > 0.0: net_dis = netdis(module.running_mean.shape[0], self.n_hidden, 1).cuda() net_dis.apply(weights_init) nets_dis.append(net_dis) nets_dis_params += list(net_dis.parameters()) if use_generator == True and use_discriminator > 0.0: self.optimizer_dis = torch.optim.Adam(nets_dis_params, lr=self.lr, betas=(0.5, 0.9)) ## Create hooks for feature statistics for generator if use_generator == True and batchnorm_transfer > 0.0: loss_bn_feature_layers_gen = [] self.compute_loss_bn_gen(loss_bn_feature_layers_gen) for it in range(n_iters): self.net.zero_grad() self.net_gen.zero_grad() self.net_std.zero_grad() self.net_enc.zero_grad() optimizer.zero_grad() optimizer_std.zero_grad() if use_generator == True: ''' randomly sampling latent and labels ''' z = torch.randn((self.n_samples, self.latent_dim), requires_grad=False, device=self.device, dtype=torch.float) y_gt = torch.randint(0, 2, (self.n_samples, self.label_dim), dtype=torch.float, device=self.device) if use_generator == True: ''' generating samples with generator ''' x = self.net_gen(z, y_gt) ''' ********************************************************************** To optimize the generated samples or training the generator ********************************************************************** ''' if noisify_network > 0.0: ''' adding noise into the pre-trained classifier ''' weight = noisify_network * (n_iters - it) / n_iters self.net, orig_params = add_noise_to_net(self.net, weight=weight, noise_type='uniform') if it == 0: self.imgname = self.imgname + "_nosify%0.3f" % ( noisify_network) y_pd = self.net(x) ''' main loss (cross-entropy loss) ''' loss_main = self.loss_func(y_pd, y_gt) ''' l2 regularization ''' loss_l2 = torch.norm(x.view(-1, self.n_input_dim), dim=1).mean() ''' batch-norm regularization ''' rescale = [1. for _ in range(len(loss_bn_feature_layers))] loss_bn = sum([ mod.r_feature * rescale[idx] for (idx, mod) in enumerate(loss_bn_feature_layers) ]) ''' total loss ''' if use_generator == True and use_discriminator > 0.0: bn_w = 0.05 else: bn_w = 1.0 loss = loss_main + 0.005 * loss_l2 + bn_w * loss_bn if knowledge_distill > 0.0: ''' knowledge distillation (teacher-student) based regularization ''' y_st = self.net_std(x) #loss_kd = 1 - self.loss_func(y_st, y_pd.detach()) loss_kd = knowledge_distill_loss(y_pd.detach(), y_st) loss = loss + knowledge_distill * loss_kd if it == 0: self.imgname = self.imgname + "_kdistill%0.3f" % ( knowledge_distill) if mutual_info > 0.0: ''' mutual information constraint ''' ze = self.net_enc(x) loss_mi = ((z - ze)**2).mean() zdiv = torch.randn((self.n_samples, self.latent_dim), requires_grad=False, device=self.device, dtype=torch.float) xdiv = self.net_gen(zdiv, y_gt) loss_div = diveristy_loss(z, x, zdiv, xdiv) loss = loss + mutual_info * loss_mi + 0.1 * mutual_info * loss_div if it == 0: self.imgname = self.imgname + "_minfo%0.3f" % (mutual_info) if use_generator == True and batchnorm_transfer > 0.0: ''' batch-norm transfer loss ''' rescale_gen = [ 1. for _ in range(len(loss_bn_feature_layers_gen)) ] loss_bn_gen = sum([ mod.r_feature * rescale_gen[idx] for (idx, mod) in enumerate(loss_bn_feature_layers_gen) ]) loss = loss + batchnorm_transfer * loss_bn_gen if it == 0: self.imgname = self.imgname + "_btransfer%0.3f" % ( batchnorm_transfer) if use_generator == True and use_discriminator > 0.0: # train the generator on features loss_g = 0 # traing the generator on features for (idx, mod) in enumerate(loss_bn_feature_layers): nets_dis[idx].zero_grad() # frozen the gradient for the discriminator for p in nets_dis[idx].parameters(): p.requires_grad = False # to avoid computation feat_fake = mod.feat_fake.cuda() d_fake = nets_dis[idx](feat_fake) loss_g = loss_g - d_fake.mean() loss = loss + use_discriminator * loss_g if use_generator == True and it == 0: self.imgname = self.imgname + "_discriminator%0.3f" % ( use_discriminator) loss.backward() optimizer.step() if it % 100 == 0: tb.add_scalar("Total loss: ", loss, it) tb.add_scalar("Loss batchnorm", loss_bn, it) tb.add_histogram("Input", x, it) # tb.add_histogram("Input/gradients", x.grad, it) for name, param in self.net_gen.named_parameters(): tb.add_histogram(name, param.data, it) tb.add_histogram(name + "/gradients", param.grad, it) if noisify_network > 0.0: ''' reset the network's parameters ''' reset_params(self.net, orig_params) if knowledge_distill > 0.0: ''' ********************************************************************** To update the student network ********************************************************************** ''' if use_generator == True: ''' generating samples with generator ''' x = self.net_gen(z, y_gt) y_pd = self.net(x) y_st = self.net_std(x) #loss_kd = self.loss_func(y_st, y_pd.detach()) loss_kd = 1. - knowledge_distill_loss(y_pd.detach(), y_st) loss_kd.backward() optimizer_std.step() ''' store the main loss to plot on the figure ''' total_loss.append(loss.item()) if use_generator == True and use_discriminator > 0.0: # traing the discriminator on features for _ in range(5): loss_d = 0 x = self.net_gen(z, y_gt) self.net(x) for (idx, mod) in enumerate(loss_bn_feature_layers): nets_dis[idx].zero_grad() for p in nets_dis[idx].parameters( ): # reset requires_grad p.requires_grad = True feat_real = mod.feat_real.cuda() feat_fake = mod.feat_fake.cuda() d_real = nets_dis[idx](feat_real) d_fake = nets_dis[idx](feat_fake) penalty = calc_gradient_penalty(nets_dis[idx], feat_real, feat_fake, LAMBDA=1.0) loss_d = loss_d + use_discriminator * ( d_fake.mean() - d_real.mean() + penalty) loss_d.backward() self.optimizer_dis.step() if it % 10 == 0: print('-- iter %d --' % (it)) print('target loss: %f' % (loss_main.item())) print('l2-norm loss: %f' % (loss_l2.item())) print('batchnorm loss: %f' % (loss_bn.item())) if knowledge_distill > 0.0: print('distillation loss: %f' % (loss_bn.item())) if mutual_info > 0.0: print('mutual information / diversity losses: %f / %f' % (loss_mi.item(), loss_div.item())) if batchnorm_transfer > 0.0: print('batch-norm transfer loss: %f ' % (loss_bn_gen.item())) if use_generator == True and use_discriminator > 0.0: print('loss d / loss g: %f / %f' % (loss_d.item(), loss_g.item())) print('total loss: %f' % (loss.item())) ''' realtime plot ''' ax[0].plot(total_loss, c='b') fig.canvas.draw() if self.device == 'cuda': x_np = x.cpu().detach().numpy() else: x_np = x.detach().numpy() tb.close() ax[1].scatter(x_np[:, 0], x_np[:, 1], c='b', cmap=plt.cm.Accent) plt.savefig(self.basedir + "%s.png" % (self.imgname)) plt.show()
# train with fake theta = minR + 2 * (np.pi - minR) * torch.rand( batch_size, 17, device=device) with torch.no_grad(): z_pred = netG(real_data) fake_data = utils.rotate_and_project(real_data, z_pred, theta) D_fake = netD(fake_data) D_fake = D_fake.mean() D_fake.backward(one) # gradient penalty GP = utils.calc_gradient_penalty(netD, real_data, fake_data, LAMBDA=LAMBDA, device=device) GP.backward() GP = GP.item() loss_D = (D_fake - D_real + GP).item() WD = (D_real - D_fake).item() optimizerD.step() else: ############################ # (1) Update G network: maximize E[D(G(x))] ########################### netG.zero_grad() for p in netD.parameters():