def train_GD(self):
        self.netG.train()
        self.netD.train()
        self.optimizer_G.zero_grad()
        self.optimizer_D.zero_grad()
        # How many chunks to split x and y into?
        x = torch.split(self.real, self.opt.batch_size)
        y = torch.split(self.label, self.opt.batch_size)
        counter = 0

        # Optionally toggle D and G's "require_grad"
        if self.opt.toggle_grads:
            toggle_grad(self.netD, True)
            toggle_grad(self.netG, False)

        for step_index in range(self.opt.num_critic_train):
            self.optimizer_D.zero_grad()
            with torch.set_grad_enabled(False):
                self.forward()
            D_input = torch.cat([self.fake, x[counter]],
                                0) if x is not None else self.fake
            D_class = torch.cat([self.label_fake, y[counter]],
                                0) if y[counter] is not None else y[counter]
            # Get Discriminator output
            D_out = self.netD(D_input, D_class)
            if x is not None:
                pred_fake, pred_real = torch.split(
                    D_out, [self.fake.shape[0], x[counter].shape[0]
                            ])  # D_fake, D_real
            else:
                pred_fake = D_out
            # Combined loss
            self.loss_Dreal, self.loss_Dfake = loss_hinge_dis(
                pred_fake, pred_real, self.len_text_fake.detach(),
                self.len_text.detach(), self.opt.mask_loss)
            self.loss_D = self.loss_Dreal + self.loss_Dfake
            self.loss_D.backward()
            counter += 1
            self.optimizer_D.step()

        # Optionally toggle D and G's "require_grad"
        if self.opt.toggle_grads:
            toggle_grad(self.netD, False)
            toggle_grad(self.netG, True)
        # Zero G's gradients by default before training G, for safety
        self.optimizer_G.zero_grad()
        self.forward()
        self.loss_G = loss_hinge_gen(self.netD(self.fake, self.label_fake),
                                     self.len_text_fake.detach(),
                                     self.opt.mask_loss)
        self.loss_G.backward()
        self.optimizer_G.step()
    def visualize_fixed_noise(self):
        if self.opt.single_writer:
            self.fixed_noise = self.z[0].repeat((self.fixed_noise_size, 1))
        if self.opt.one_hot:
            images = self.netG(self.fixed_noise,
                               self.one_hot_fixed.to(self.device))
        else:
            images = self.netG(self.fixed_noise,
                               self.fixed_text_encode_fake.to(self.device))

        loss_G = loss_hinge_gen(
            self.netD(**{
                'x': images,
                'z': self.fixed_noise
            }), self.fixed_text_len.detach(), self.opt.mask_loss)
        # self.loss_G = loss_hinge_gen(self.netD(self.fake, self.rep_label_fake))
        # OCR loss on real data
        pred_fake_OCR = self.netOCR(images)
        preds_size = torch.IntTensor([pred_fake_OCR.size(0)] *
                                     len(self.fixed_text_len)).detach()
        # loss_OCR_fake = self.OCR_criterion(pred_fake_OCR.log_softmax(2), self.fixed_text_encode_fake.detach().to(self.device),
        #                                    preds_size, self.fixed_text_len.detach())
        loss_OCR_fake = self.OCR_criterion(
            pred_fake_OCR,
            self.fixed_text_encode_fake.detach().to(self.device), preds_size,
            self.fixed_text_len.detach())
        loss_OCR_fake = torch.mean(loss_OCR_fake[~torch.isnan(loss_OCR_fake)])

        grad_fixed_OCR = torch.autograd.grad(loss_OCR_fake, images)
        grad_fixed_adv = torch.autograd.grad(loss_G, images)
        _, preds = pred_fake_OCR.max(2)
        preds = preds.transpose(1, 0).contiguous().view(-1)
        sim_preds = self.OCRconverter.decode(preds.data,
                                             preds_size.data,
                                             raw=False)
        raw_preds = self.OCRconverter.decode(preds.data,
                                             preds_size.data,
                                             raw=True)
        print('######## fake images OCR prediction ########')
        for i in range(self.fixed_noise_size):
            print('%-20s => %-20s, gt: %-20s' %
                  (raw_preds[i], sim_preds[i], self.lex[int(
                      self.fixed_fake_labels[i])]))
            image = images[i].unsqueeze(0).detach()
            grad_OCR = torch.abs(grad_fixed_OCR[0][i]).unsqueeze(0).detach()
            grad_OCR = (grad_OCR / torch.max(grad_OCR)) * 2 - 1
            grad_adv = torch.abs(grad_fixed_adv[0][i]).unsqueeze(0).detach()
            grad_adv = (grad_adv / torch.max(grad_adv)) * 2 - 1
            label = self.label_fix[i]
            setattr(self, 'grad_OCR_fixed_' + 'label_' + label, grad_OCR)
            setattr(self, 'grad_G_fixed_' + 'label_' + label, grad_adv)
            setattr(self, 'fake_fixed_' + 'label_' + label, image)
    def backward_G(self):
        self.loss_G = loss_hinge_gen(self.netD(**{'x': self.fake, 'z': self.z}), self.len_text_fake.detach(), self.opt.mask_loss)
        # OCR loss on real data

        pred_fake_OCR = self.netOCR(self.fake)
        preds_size = torch.IntTensor([pred_fake_OCR.size(0)] * self.opt.batch_size).detach()
        loss_OCR_fake = self.OCR_criterion(pred_fake_OCR, self.text_encode_fake.detach(), preds_size, self.len_text_fake.detach())
        self.loss_OCR_fake = torch.mean(loss_OCR_fake[~torch.isnan(loss_OCR_fake)])
        # total loss
        self.loss_T = self.loss_G + self.opt.gb_alpha*self.loss_OCR_fake
        grad_fake_OCR = torch.autograd.grad(self.loss_OCR_fake, self.fake, retain_graph=True)[0]
        self.loss_grad_fake_OCR = 10**6*torch.mean(grad_fake_OCR**2)
        grad_fake_adv = torch.autograd.grad(self.loss_G, self.fake, retain_graph=True)[0]
        self.loss_grad_fake_adv = 10**6*torch.mean(grad_fake_adv**2)
        if not self.opt.no_grad_balance:
            self.loss_T.backward(retain_graph=True)
            grad_fake_OCR = torch.autograd.grad(self.loss_OCR_fake, self.fake, create_graph=True, retain_graph=True)[0]
            grad_fake_adv = torch.autograd.grad(self.loss_G, self.fake, create_graph=True, retain_graph=True)[0]
            a = self.opt.gb_alpha * torch.div(torch.std(grad_fake_adv), self.epsilon+torch.std(grad_fake_OCR))
            if a is None:
                print(self.loss_OCR_fake, self.loss_G, torch.std(grad_fake_adv), torch.std(grad_fake_OCR))
            if a>1000 or a<0.0001:
                print(a)
            b = self.opt.gb_alpha * (torch.mean(grad_fake_adv) -
                                            torch.div(torch.std(grad_fake_adv), self.epsilon+torch.std(grad_fake_OCR))*
                                            torch.mean(grad_fake_OCR))
            # self.loss_OCR_fake = a.detach() * self.loss_OCR_fake + b.detach() * torch.sum(self.fake)
            self.loss_OCR_fake = a.detach() * self.loss_OCR_fake
            self.loss_T = (1-1*self.opt.onlyOCR)*self.loss_G + self.loss_OCR_fake
            self.loss_T.backward(retain_graph=True)
            grad_fake_OCR = torch.autograd.grad(self.loss_OCR_fake, self.fake, create_graph=False, retain_graph=True)[0]
            grad_fake_adv = torch.autograd.grad(self.loss_G, self.fake, create_graph=False, retain_graph=True)[0]
            self.loss_grad_fake_OCR = 10 ** 6 * torch.mean(grad_fake_OCR ** 2)
            self.loss_grad_fake_adv = 10 ** 6 * torch.mean(grad_fake_adv ** 2)
            with torch.no_grad():
                self.loss_T.backward()
        else:
            self.loss_T.backward()

        if self.opt.clip_grad > 0:
             clip_grad_norm_(self.netG.parameters(), self.opt.clip_grad)
        if any(torch.isnan(loss_OCR_fake)) or torch.isnan(self.loss_G):
            print('loss OCR fake: ', loss_OCR_fake, ' loss_G: ', self.loss_G, ' words: ', self.words)
            sys.exit()