def forward(self, words=None, z=None): """Run forward pass. This will be called by both functions <optimize_parameters> and <test>.""" if hasattr(self, 'fake'): del self.fake, self.text_encode_fake, self.len_text_fake, self.one_hot_fake self.label_fake.sample_() if words is None: words = [self.lex[int(i)] for i in self.label_fake] if self.opt.capitalize: for i, word in enumerate(words): if random.random()<0.5: word = list(word) word[0] = unicodedata.normalize('NFKD',word[0].upper()).encode('ascii', 'ignore').decode("utf-8") word = ''.join(word) words[i] = word words = [word.encode('utf-8') for word in words] if z is None: if not self.opt.single_writer: self.z.sample_() else: if z.shape[0]==1: self.z = z.repeat((len(words), 1)) self.z = z.repeat((len(words), 1)) else: self.z = z self.words = words self.text_encode_fake, self.len_text_fake = self.netconverter.encode(self.words) self.text_encode_fake = self.text_encode_fake.to(self.device) if self.opt.one_hot: self.one_hot_fake = make_one_hot(self.text_encode_fake, self.len_text_fake, self.opt.n_classes).to(self.device) try: self.fake = self.netG(self.z, self.one_hot_fake) except: print(words) else: self.fake = self.netG(self.z, self.text_encode_fake) # generate output image given the input data_A
def set_input(self, input): """Unpack input data from the dataloader and perform necessary pre-processing steps. Parameters: input: a dictionary that contains the data itself and its metadata information. """ # if hasattr(self, 'real'): del self.real, self.one_hot_real, self.text_encode, self.len_text self.real = input['img'].to(self.device) if 'label' in input.keys(): self.label = input['label'] self.text_encode, self.len_text = self.netconverter.encode(self.label) if self.opt.one_hot: self.one_hot_real = make_one_hot(self.text_encode, self.len_text, self.opt.n_classes).to(self.device).detach() self.text_encode = self.text_encode.to(self.device).detach() self.len_text = self.len_text.detach() self.img_path = input['img_path'] # get image paths self.idx_real = input['idx'] # get image paths
def __init__(self, opt): BaseModel.__init__(self, opt) # call the initialization method of BaseModel opt.G_activation = activation_dict[opt.G_nl] opt.D_activation = activation_dict[opt.D_nl] # load saved model to finetune: if self.isTrain and opt.saved_model!='': opt.G_init = os.path.join(opt.checkpoints_dir, opt.saved_model) opt.D_init = os.path.join(opt.checkpoints_dir, opt.saved_model) opt.OCR_init = os.path.join(opt.checkpoints_dir, opt.saved_model) # specify the training losses you want to print out. The program will call base_model.get_current_losses to plot the losses to the console and save them to the disk. self.loss_names = ['G', 'D', 'Dreal', 'Dfake', 'OCR_real', 'OCR_fake', 'grad_fake_OCR', 'grad_fake_adv'] self.loss_G = torch.zeros(1) self.loss_D =torch.zeros(1) self.loss_Dreal =torch.zeros(1) self.loss_Dfake =torch.zeros(1) self.loss_OCR_real =torch.zeros(1) self.loss_OCR_fake =torch.zeros(1) self.loss_grad_fake_OCR =torch.zeros(1) self.loss_grad_fake_adv =torch.zeros(1) # specify the models you want to save to the disk. The program will call base_model.save_networks and base_model.load_networks to save and load networks. # you can use opt.isTrain to specify different behaviors for training and test. For example, some networks will not be used during test, and you don't need to load them. self.model_names = ['G', 'D', 'OCR'] # define networks; you can use opt.isTrain to specify different behaviors for training and test. # Next, build the model opt.n_classes = len(opt.alphabet) self.netG = Generator(**vars(opt)) self.Gradloss = torch.nn.L1Loss() self.netconverter = strLabelConverter(opt.alphabet) self.netOCR = CRNN(opt).to(self.device) if len(opt.gpu_ids) > 0: assert (torch.cuda.is_available()) self.netOCR.to(opt.gpu_ids[0]) self.netG.to(opt.gpu_ids[0]) # net = torch.nn.DataParallel(net, gpu_ids) # multi-GPUs if len(opt.gpu_ids) > 1: self.netOCR = torch.nn.DataParallel(self.netOCR, device_ids=opt.gpu_ids, dim=1, output_device=opt.gpu_ids[0]).cuda() self.netG = torch.nn.DataParallel(self.netG, device_ids=opt.gpu_ids, output_device=opt.gpu_ids[0]).cuda() self.OCR_criterion = CTCLoss(zero_infinity=True, reduction='none') print(self.netG) if self.isTrain: # only defined during training time # define your loss functions. You can use losses provided by torch.nn such as torch.nn.L1Loss. # We also provide a GANLoss class "networks.GANLoss". self.criterionGAN = networks.GANLoss().to(self.device) # define and initialize optimizers. You can define one optimizer for each network. # If two networks are updated at the same time, you can use itertools.chain to group them. See cycle_gan_model.py for an example. self.optimizer_G = torch.optim.Adam(self.netG.parameters(), lr=opt.G_lr, betas=(opt.G_B1, opt.G_B2), weight_decay=0, eps=opt.adam_eps) self.optimizer_OCR = torch.optim.Adam(self.netOCR.parameters(), lr=opt.OCR_lr, betas=(opt.OCR_B1, opt.OCR_B2), weight_decay=0, eps=opt.adam_eps) self.optimizers = [self.optimizer_G, self.optimizer_OCR] self.optimizer_G.zero_grad() self.optimizer_OCR.zero_grad() exception_chars = ['ï', 'ü', '.', '_', 'ö', ',', 'ã', 'ñ'] if opt.lex.endswith('.tsv'): self.lex = pd.read_csv(opt.lex, sep='\t')['lemme'] self.lex = [word.split()[-1] for word in self.lex if (pd.notnull(word) and all(char not in word for char in exception_chars))] elif opt.lex.endswith('.txt'): with open(opt.lex, 'rb') as f: self.lex = f.read().splitlines() lex=[] for word in self.lex: try: word=word.decode("utf-8") except: continue if len(word)<20: lex.append(word) self.lex = lex self.fixed_noise_size = 2 self.fixed_noise, self.fixed_fake_labels = prepare_z_y(self.fixed_noise_size, opt.dim_z, len(self.lex), device=self.device, fp16=opt.G_fp16, seed=opt.seed) self.fixed_noise.sample_() self.fixed_fake_labels.sample_() self.rep_dict = {"'":"", '"':'', ' ':'_', ';':'', '.':''} fixed_words_fake = [self.lex[int(i)].encode('utf-8') for i in self.fixed_fake_labels] self.fixed_text_encode_fake, self.fixed_text_len = self.netconverter.encode(fixed_words_fake) if self.opt.one_hot: self.one_hot_fixed = make_one_hot(self.fixed_text_encode_fake, self.fixed_text_len, self.opt.n_classes) # Todo change to display names of classes instead of numbers self.label_fix = [multiple_replace(word.decode("utf-8"), self.rep_dict) for word in fixed_words_fake] visual_names_fixed_noise = ['fake_fixed_' + 'label_' + label for label in self.label_fix] visual_names_grad_OCR = ['grad_OCR_fixed_' + 'label_' + label for label in self.label_fix] visual_names_grad_G = ['grad_G_fixed_' + 'label_' + label for label in self.label_fix] # specify the images you want to save and display. The program will call base_model.get_current_visuals to save and display these images. self.visual_names = ['real', 'fake'] self.visual_names.extend(visual_names_fixed_noise) self.visual_names.extend(visual_names_grad_G) self.visual_names.extend(visual_names_grad_OCR) self.z, self.label_fake = prepare_z_y(opt.batch_size, opt.dim_z, len(self.lex), device=self.device, fp16=opt.G_fp16, z_dist=opt.z_dist, seed=opt.seed) if opt.single_writer: self.fixed_noise = self.z[0].repeat((self.fixed_noise_size, 1)) self.z = self.z[0].repeat((opt.batch_size, 1)).to(self.device) self.z.requires_grad=True self.optimizer_z = torch.optim.SGD([self.z], lr=opt.G_lr) self.optimizer_z.zero_grad() self.l1_loss = L1Loss() self.mse_loss = MSELoss() self.OCRconverter = OCRLabelConverter(opt.alphabet) self.epsilon = 10e-50 self.real_z = None self.real_z_mean = None