def forward(self, words=None, z=None):
        """Run forward pass. This will be called by both functions <optimize_parameters> and <test>."""
        if hasattr(self, 'fake'): del self.fake, self.text_encode_fake, self.len_text_fake, self.one_hot_fake

        self.label_fake.sample_()
        if words is None:
            words = [self.lex[int(i)] for i in self.label_fake]
            if self.opt.capitalize:
                for i, word in enumerate(words):
                    if random.random()<0.5:
                        word = list(word)
                        word[0] = unicodedata.normalize('NFKD',word[0].upper()).encode('ascii', 'ignore').decode("utf-8")
                        word = ''.join(word)
                    words[i] = word
            words = [word.encode('utf-8') for word in words]
        if z is None:
            if not self.opt.single_writer:
                self.z.sample_()
        else:
            if z.shape[0]==1:
                self.z = z.repeat((len(words), 1))
                self.z = z.repeat((len(words), 1))
            else:
                self.z = z
        self.words = words
        self.text_encode_fake, self.len_text_fake = self.netconverter.encode(self.words)
        self.text_encode_fake = self.text_encode_fake.to(self.device)
        if self.opt.one_hot:
            self.one_hot_fake = make_one_hot(self.text_encode_fake, self.len_text_fake, self.opt.n_classes).to(self.device)
            try:
                self.fake = self.netG(self.z, self.one_hot_fake)
            except:
                print(words)
        else:
            self.fake = self.netG(self.z, self.text_encode_fake)  # generate output image given the input data_A
    def set_input(self, input):
        """Unpack input data from the dataloader and perform necessary pre-processing steps.

        Parameters:
            input: a dictionary that contains the data itself and its metadata information.
        """
        # if hasattr(self, 'real'): del self.real, self.one_hot_real, self.text_encode, self.len_text
        self.real = input['img'].to(self.device)
        if 'label' in input.keys():
            self.label = input['label']
            self.text_encode, self.len_text = self.netconverter.encode(self.label)
            if self.opt.one_hot:
                self.one_hot_real = make_one_hot(self.text_encode, self.len_text, self.opt.n_classes).to(self.device).detach()
            self.text_encode = self.text_encode.to(self.device).detach()
            self.len_text = self.len_text.detach()
        self.img_path = input['img_path']  # get image paths
        self.idx_real = input['idx']  # get image paths
    def __init__(self, opt):
        BaseModel.__init__(self, opt)  # call the initialization method of BaseModel
        opt.G_activation = activation_dict[opt.G_nl]
        opt.D_activation = activation_dict[opt.D_nl]
        # load saved model to finetune:
        if self.isTrain and opt.saved_model!='':
            opt.G_init = os.path.join(opt.checkpoints_dir, opt.saved_model)
            opt.D_init = os.path.join(opt.checkpoints_dir, opt.saved_model)
            opt.OCR_init = os.path.join(opt.checkpoints_dir, opt.saved_model)
        # specify the training losses you want to print out. The program will call base_model.get_current_losses to plot the losses to the console and save them to the disk.
        self.loss_names = ['G', 'D', 'Dreal', 'Dfake', 'OCR_real', 'OCR_fake', 'grad_fake_OCR', 'grad_fake_adv']
        self.loss_G = torch.zeros(1)
        self.loss_D =torch.zeros(1)
        self.loss_Dreal =torch.zeros(1)
        self.loss_Dfake =torch.zeros(1)
        self.loss_OCR_real =torch.zeros(1)
        self.loss_OCR_fake =torch.zeros(1)
        self.loss_grad_fake_OCR =torch.zeros(1)
        self.loss_grad_fake_adv =torch.zeros(1)
        # specify the models you want to save to the disk. The program will call base_model.save_networks and base_model.load_networks to save and load networks.
        # you can use opt.isTrain to specify different behaviors for training and test. For example, some networks will not be used during test, and you don't need to load them.
        self.model_names = ['G', 'D', 'OCR']
        # define networks; you can use opt.isTrain to specify different behaviors for training and test.
        # Next, build the model
        opt.n_classes = len(opt.alphabet)
        self.netG = Generator(**vars(opt))
        self.Gradloss = torch.nn.L1Loss()

        self.netconverter = strLabelConverter(opt.alphabet)
        self.netOCR = CRNN(opt).to(self.device)
        if len(opt.gpu_ids) > 0:
            assert (torch.cuda.is_available())
            self.netOCR.to(opt.gpu_ids[0])
            self.netG.to(opt.gpu_ids[0])
            # net = torch.nn.DataParallel(net, gpu_ids)  # multi-GPUs
            if len(opt.gpu_ids) > 1:
                self.netOCR = torch.nn.DataParallel(self.netOCR, device_ids=opt.gpu_ids, dim=1, output_device=opt.gpu_ids[0]).cuda()
                self.netG = torch.nn.DataParallel(self.netG, device_ids=opt.gpu_ids, output_device=opt.gpu_ids[0]).cuda()

        self.OCR_criterion = CTCLoss(zero_infinity=True, reduction='none')
        print(self.netG)


        if self.isTrain:  # only defined during training time
            # define your loss functions. You can use losses provided by torch.nn such as torch.nn.L1Loss.
            # We also provide a GANLoss class "networks.GANLoss". self.criterionGAN = networks.GANLoss().to(self.device)
            # define and initialize optimizers. You can define one optimizer for each network.
            # If two networks are updated at the same time, you can use itertools.chain to group them. See cycle_gan_model.py for an example.
            self.optimizer_G = torch.optim.Adam(self.netG.parameters(),
                                                lr=opt.G_lr, betas=(opt.G_B1, opt.G_B2), weight_decay=0, eps=opt.adam_eps)
            self.optimizer_OCR = torch.optim.Adam(self.netOCR.parameters(),
                                                lr=opt.OCR_lr, betas=(opt.OCR_B1, opt.OCR_B2), weight_decay=0,
                                                eps=opt.adam_eps)
            self.optimizers = [self.optimizer_G, self.optimizer_OCR]

            self.optimizer_G.zero_grad()
            self.optimizer_OCR.zero_grad()

        exception_chars = ['ï', 'ü', '.', '_', 'ö', ',', 'ã', 'ñ']
        if opt.lex.endswith('.tsv'):
            self.lex = pd.read_csv(opt.lex, sep='\t')['lemme']
            self.lex = [word.split()[-1] for word in self.lex if
                        (pd.notnull(word) and all(char not in word for char in exception_chars))]
        elif opt.lex.endswith('.txt'):
            with open(opt.lex, 'rb') as f:
                self.lex = f.read().splitlines()
            lex=[]
            for word in self.lex:
                try:
                    word=word.decode("utf-8")
                except:
                    continue
                if len(word)<20:
                    lex.append(word)
            self.lex = lex
        self.fixed_noise_size = 2
        self.fixed_noise, self.fixed_fake_labels = prepare_z_y(self.fixed_noise_size, opt.dim_z,
                                       len(self.lex), device=self.device,
                                       fp16=opt.G_fp16, seed=opt.seed)
        self.fixed_noise.sample_()
        self.fixed_fake_labels.sample_()
        self.rep_dict = {"'":"", '"':'', ' ':'_', ';':'', '.':''}
        fixed_words_fake = [self.lex[int(i)].encode('utf-8') for i in self.fixed_fake_labels]
        self.fixed_text_encode_fake, self.fixed_text_len = self.netconverter.encode(fixed_words_fake)
        if self.opt.one_hot:
            self.one_hot_fixed = make_one_hot(self.fixed_text_encode_fake, self.fixed_text_len, self.opt.n_classes)
        # Todo change to display names of classes instead of numbers
        self.label_fix = [multiple_replace(word.decode("utf-8"), self.rep_dict) for word in fixed_words_fake]
        visual_names_fixed_noise = ['fake_fixed_' + 'label_' + label for label in self.label_fix]
        visual_names_grad_OCR = ['grad_OCR_fixed_' + 'label_' + label for label in self.label_fix]
        visual_names_grad_G = ['grad_G_fixed_' + 'label_' + label for label in self.label_fix]
        # specify the images you want to save and display. The program will call base_model.get_current_visuals to save and display these images.
        self.visual_names = ['real', 'fake']
        self.visual_names.extend(visual_names_fixed_noise)
        self.visual_names.extend(visual_names_grad_G)
        self.visual_names.extend(visual_names_grad_OCR)
        self.z, self.label_fake = prepare_z_y(opt.batch_size, opt.dim_z, len(self.lex),
                                   device=self.device, fp16=opt.G_fp16, z_dist=opt.z_dist, seed=opt.seed)
        if opt.single_writer:
            self.fixed_noise = self.z[0].repeat((self.fixed_noise_size, 1))
            self.z = self.z[0].repeat((opt.batch_size, 1)).to(self.device)
            self.z.requires_grad=True
            self.optimizer_z = torch.optim.SGD([self.z], lr=opt.G_lr)
            self.optimizer_z.zero_grad()
        self.l1_loss = L1Loss()
        self.mse_loss = MSELoss()
        self.OCRconverter = OCRLabelConverter(opt.alphabet)
        self.epsilon = 10e-50
        self.real_z = None
        self.real_z_mean = None