def train_GD(self): self.netG.train() self.netD.train() self.optimizer_G.zero_grad() self.optimizer_D.zero_grad() # How many chunks to split x and y into? x = torch.split(self.real, self.opt.batch_size) y = torch.split(self.label, self.opt.batch_size) counter = 0 # Optionally toggle D and G's "require_grad" if self.opt.toggle_grads: toggle_grad(self.netD, True) toggle_grad(self.netG, False) for step_index in range(self.opt.num_critic_train): self.optimizer_D.zero_grad() with torch.set_grad_enabled(False): self.forward() D_input = torch.cat([self.fake, x[counter]], 0) if x is not None else self.fake D_class = torch.cat([self.label_fake, y[counter]], 0) if y[counter] is not None else y[counter] # Get Discriminator output D_out = self.netD(D_input, D_class) if x is not None: pred_fake, pred_real = torch.split( D_out, [self.fake.shape[0], x[counter].shape[0] ]) # D_fake, D_real else: pred_fake = D_out # Combined loss self.loss_Dreal, self.loss_Dfake = loss_hinge_dis( pred_fake, pred_real, self.len_text_fake.detach(), self.len_text.detach(), self.opt.mask_loss) self.loss_D = self.loss_Dreal + self.loss_Dfake self.loss_D.backward() counter += 1 self.optimizer_D.step() # Optionally toggle D and G's "require_grad" if self.opt.toggle_grads: toggle_grad(self.netD, False) toggle_grad(self.netG, True) # Zero G's gradients by default before training G, for safety self.optimizer_G.zero_grad() self.forward() self.loss_G = loss_hinge_gen(self.netD(self.fake, self.label_fake), self.len_text_fake.detach(), self.opt.mask_loss) self.loss_G.backward() self.optimizer_G.step()
def visualize_fixed_noise(self): if self.opt.single_writer: self.fixed_noise = self.z[0].repeat((self.fixed_noise_size, 1)) if self.opt.one_hot: images = self.netG(self.fixed_noise, self.one_hot_fixed.to(self.device)) else: images = self.netG(self.fixed_noise, self.fixed_text_encode_fake.to(self.device)) loss_G = loss_hinge_gen( self.netD(**{ 'x': images, 'z': self.fixed_noise }), self.fixed_text_len.detach(), self.opt.mask_loss) # self.loss_G = loss_hinge_gen(self.netD(self.fake, self.rep_label_fake)) # OCR loss on real data pred_fake_OCR = self.netOCR(images) preds_size = torch.IntTensor([pred_fake_OCR.size(0)] * len(self.fixed_text_len)).detach() # loss_OCR_fake = self.OCR_criterion(pred_fake_OCR.log_softmax(2), self.fixed_text_encode_fake.detach().to(self.device), # preds_size, self.fixed_text_len.detach()) loss_OCR_fake = self.OCR_criterion( pred_fake_OCR, self.fixed_text_encode_fake.detach().to(self.device), preds_size, self.fixed_text_len.detach()) loss_OCR_fake = torch.mean(loss_OCR_fake[~torch.isnan(loss_OCR_fake)]) grad_fixed_OCR = torch.autograd.grad(loss_OCR_fake, images) grad_fixed_adv = torch.autograd.grad(loss_G, images) _, preds = pred_fake_OCR.max(2) preds = preds.transpose(1, 0).contiguous().view(-1) sim_preds = self.OCRconverter.decode(preds.data, preds_size.data, raw=False) raw_preds = self.OCRconverter.decode(preds.data, preds_size.data, raw=True) print('######## fake images OCR prediction ########') for i in range(self.fixed_noise_size): print('%-20s => %-20s, gt: %-20s' % (raw_preds[i], sim_preds[i], self.lex[int( self.fixed_fake_labels[i])])) image = images[i].unsqueeze(0).detach() grad_OCR = torch.abs(grad_fixed_OCR[0][i]).unsqueeze(0).detach() grad_OCR = (grad_OCR / torch.max(grad_OCR)) * 2 - 1 grad_adv = torch.abs(grad_fixed_adv[0][i]).unsqueeze(0).detach() grad_adv = (grad_adv / torch.max(grad_adv)) * 2 - 1 label = self.label_fix[i] setattr(self, 'grad_OCR_fixed_' + 'label_' + label, grad_OCR) setattr(self, 'grad_G_fixed_' + 'label_' + label, grad_adv) setattr(self, 'fake_fixed_' + 'label_' + label, image)
def backward_G(self): self.loss_G = loss_hinge_gen(self.netD(**{'x': self.fake, 'z': self.z}), self.len_text_fake.detach(), self.opt.mask_loss) # OCR loss on real data pred_fake_OCR = self.netOCR(self.fake) preds_size = torch.IntTensor([pred_fake_OCR.size(0)] * self.opt.batch_size).detach() loss_OCR_fake = self.OCR_criterion(pred_fake_OCR, self.text_encode_fake.detach(), preds_size, self.len_text_fake.detach()) self.loss_OCR_fake = torch.mean(loss_OCR_fake[~torch.isnan(loss_OCR_fake)]) # total loss self.loss_T = self.loss_G + self.opt.gb_alpha*self.loss_OCR_fake grad_fake_OCR = torch.autograd.grad(self.loss_OCR_fake, self.fake, retain_graph=True)[0] self.loss_grad_fake_OCR = 10**6*torch.mean(grad_fake_OCR**2) grad_fake_adv = torch.autograd.grad(self.loss_G, self.fake, retain_graph=True)[0] self.loss_grad_fake_adv = 10**6*torch.mean(grad_fake_adv**2) if not self.opt.no_grad_balance: self.loss_T.backward(retain_graph=True) grad_fake_OCR = torch.autograd.grad(self.loss_OCR_fake, self.fake, create_graph=True, retain_graph=True)[0] grad_fake_adv = torch.autograd.grad(self.loss_G, self.fake, create_graph=True, retain_graph=True)[0] a = self.opt.gb_alpha * torch.div(torch.std(grad_fake_adv), self.epsilon+torch.std(grad_fake_OCR)) if a is None: print(self.loss_OCR_fake, self.loss_G, torch.std(grad_fake_adv), torch.std(grad_fake_OCR)) if a>1000 or a<0.0001: print(a) b = self.opt.gb_alpha * (torch.mean(grad_fake_adv) - torch.div(torch.std(grad_fake_adv), self.epsilon+torch.std(grad_fake_OCR))* torch.mean(grad_fake_OCR)) # self.loss_OCR_fake = a.detach() * self.loss_OCR_fake + b.detach() * torch.sum(self.fake) self.loss_OCR_fake = a.detach() * self.loss_OCR_fake self.loss_T = (1-1*self.opt.onlyOCR)*self.loss_G + self.loss_OCR_fake self.loss_T.backward(retain_graph=True) grad_fake_OCR = torch.autograd.grad(self.loss_OCR_fake, self.fake, create_graph=False, retain_graph=True)[0] grad_fake_adv = torch.autograd.grad(self.loss_G, self.fake, create_graph=False, retain_graph=True)[0] self.loss_grad_fake_OCR = 10 ** 6 * torch.mean(grad_fake_OCR ** 2) self.loss_grad_fake_adv = 10 ** 6 * torch.mean(grad_fake_adv ** 2) with torch.no_grad(): self.loss_T.backward() else: self.loss_T.backward() if self.opt.clip_grad > 0: clip_grad_norm_(self.netG.parameters(), self.opt.clip_grad) if any(torch.isnan(loss_OCR_fake)) or torch.isnan(self.loss_G): print('loss OCR fake: ', loss_OCR_fake, ' loss_G: ', self.loss_G, ' words: ', self.words) sys.exit()