Ejemplo n.º 1
0
def generate_image(word):
    seed = np.random.randint(0, 10e4)
    words = get_word(word)
    z, _ = prepare_z_y(1, 128, 80, device='cpu', seed=seed)
    res = model.forward(z=z, y=words)
    res = res.detach().numpy()[0, 0] * 255
    im = Image.fromarray(res).convert('RGB')
    return im
Ejemplo n.º 2
0
def generate_image(word='meet', seed=None, device=device):
    if seed is None:
        seed = np.random.randint(0, 10e4)
    words = get_word(word).to(device)
    z, _ = prepare_z_y(1, 128, 80, device=device, seed=seed)
    with torch.no_grad():
        res = model.forward(z=z, y=words)
    res = res.detach().cpu().numpy()[0, 0] * 255
    im = np.array(Image.fromarray(res).convert('RGB'))
    return im
Ejemplo n.º 3
0
 def GenImgs(words=None, z=None, nsamples=5, device=0):
     model.netG.to(device)
     model.z, model.label_fake = prepare_z_y(opt.batch_size,
                                             opt.dim_z,
                                             len(model.lex),
                                             device=device,
                                             fp16=opt.G_fp16)
     model.device = device
     if words is None:
         words = nsamples * [words]
         z = nsamples * [z]
     words_encoded = []
     wordBins = []
     for i in tqdm(range(len(words))):
         wordBin, word = GenImg(words[i], z[i])
         words_encoded.append(word)
         wordBins.append(wordBin)
     return wordBins, words_encoded
    def __init__(self, opt):
        BaseModel.__init__(self, opt)  # call the initialization method of BaseModel
        opt.G_activation = activation_dict[opt.G_nl]
        opt.D_activation = activation_dict[opt.D_nl]
        # load saved model to finetune:
        if self.isTrain and opt.saved_model!='':
            opt.G_init = os.path.join(opt.checkpoints_dir, opt.saved_model)
            opt.D_init = os.path.join(opt.checkpoints_dir, opt.saved_model)
            opt.OCR_init = os.path.join(opt.checkpoints_dir, opt.saved_model)
        # specify the training losses you want to print out. The program will call base_model.get_current_losses to plot the losses to the console and save them to the disk.
        self.loss_names = ['G', 'D', 'Dreal', 'Dfake', 'OCR_real', 'OCR_fake', 'grad_fake_OCR', 'grad_fake_adv']
        self.loss_G = torch.zeros(1)
        self.loss_D =torch.zeros(1)
        self.loss_Dreal =torch.zeros(1)
        self.loss_Dfake =torch.zeros(1)
        self.loss_OCR_real =torch.zeros(1)
        self.loss_OCR_fake =torch.zeros(1)
        self.loss_grad_fake_OCR =torch.zeros(1)
        self.loss_grad_fake_adv =torch.zeros(1)
        # specify the models you want to save to the disk. The program will call base_model.save_networks and base_model.load_networks to save and load networks.
        # you can use opt.isTrain to specify different behaviors for training and test. For example, some networks will not be used during test, and you don't need to load them.
        self.model_names = ['G', 'D', 'OCR']
        # define networks; you can use opt.isTrain to specify different behaviors for training and test.
        # Next, build the model
        opt.n_classes = len(opt.alphabet)
        self.netG = Generator(**vars(opt))
        self.Gradloss = torch.nn.L1Loss()

        self.netconverter = strLabelConverter(opt.alphabet)
        self.netOCR = CRNN(opt).to(self.device)
        if len(opt.gpu_ids) > 0:
            assert (torch.cuda.is_available())
            self.netOCR.to(opt.gpu_ids[0])
            self.netG.to(opt.gpu_ids[0])
            # net = torch.nn.DataParallel(net, gpu_ids)  # multi-GPUs
            if len(opt.gpu_ids) > 1:
                self.netOCR = torch.nn.DataParallel(self.netOCR, device_ids=opt.gpu_ids, dim=1, output_device=opt.gpu_ids[0]).cuda()
                self.netG = torch.nn.DataParallel(self.netG, device_ids=opt.gpu_ids, output_device=opt.gpu_ids[0]).cuda()

        self.OCR_criterion = CTCLoss(zero_infinity=True, reduction='none')
        print(self.netG)


        if self.isTrain:  # only defined during training time
            # define your loss functions. You can use losses provided by torch.nn such as torch.nn.L1Loss.
            # We also provide a GANLoss class "networks.GANLoss". self.criterionGAN = networks.GANLoss().to(self.device)
            # define and initialize optimizers. You can define one optimizer for each network.
            # If two networks are updated at the same time, you can use itertools.chain to group them. See cycle_gan_model.py for an example.
            self.optimizer_G = torch.optim.Adam(self.netG.parameters(),
                                                lr=opt.G_lr, betas=(opt.G_B1, opt.G_B2), weight_decay=0, eps=opt.adam_eps)
            self.optimizer_OCR = torch.optim.Adam(self.netOCR.parameters(),
                                                lr=opt.OCR_lr, betas=(opt.OCR_B1, opt.OCR_B2), weight_decay=0,
                                                eps=opt.adam_eps)
            self.optimizers = [self.optimizer_G, self.optimizer_OCR]

            self.optimizer_G.zero_grad()
            self.optimizer_OCR.zero_grad()

        exception_chars = ['ï', 'ü', '.', '_', 'ö', ',', 'ã', 'ñ']
        if opt.lex.endswith('.tsv'):
            self.lex = pd.read_csv(opt.lex, sep='\t')['lemme']
            self.lex = [word.split()[-1] for word in self.lex if
                        (pd.notnull(word) and all(char not in word for char in exception_chars))]
        elif opt.lex.endswith('.txt'):
            with open(opt.lex, 'rb') as f:
                self.lex = f.read().splitlines()
            lex=[]
            for word in self.lex:
                try:
                    word=word.decode("utf-8")
                except:
                    continue
                if len(word)<20:
                    lex.append(word)
            self.lex = lex
        self.fixed_noise_size = 2
        self.fixed_noise, self.fixed_fake_labels = prepare_z_y(self.fixed_noise_size, opt.dim_z,
                                       len(self.lex), device=self.device,
                                       fp16=opt.G_fp16, seed=opt.seed)
        self.fixed_noise.sample_()
        self.fixed_fake_labels.sample_()
        self.rep_dict = {"'":"", '"':'', ' ':'_', ';':'', '.':''}
        fixed_words_fake = [self.lex[int(i)].encode('utf-8') for i in self.fixed_fake_labels]
        self.fixed_text_encode_fake, self.fixed_text_len = self.netconverter.encode(fixed_words_fake)
        if self.opt.one_hot:
            self.one_hot_fixed = make_one_hot(self.fixed_text_encode_fake, self.fixed_text_len, self.opt.n_classes)
        # Todo change to display names of classes instead of numbers
        self.label_fix = [multiple_replace(word.decode("utf-8"), self.rep_dict) for word in fixed_words_fake]
        visual_names_fixed_noise = ['fake_fixed_' + 'label_' + label for label in self.label_fix]
        visual_names_grad_OCR = ['grad_OCR_fixed_' + 'label_' + label for label in self.label_fix]
        visual_names_grad_G = ['grad_G_fixed_' + 'label_' + label for label in self.label_fix]
        # specify the images you want to save and display. The program will call base_model.get_current_visuals to save and display these images.
        self.visual_names = ['real', 'fake']
        self.visual_names.extend(visual_names_fixed_noise)
        self.visual_names.extend(visual_names_grad_G)
        self.visual_names.extend(visual_names_grad_OCR)
        self.z, self.label_fake = prepare_z_y(opt.batch_size, opt.dim_z, len(self.lex),
                                   device=self.device, fp16=opt.G_fp16, z_dist=opt.z_dist, seed=opt.seed)
        if opt.single_writer:
            self.fixed_noise = self.z[0].repeat((self.fixed_noise_size, 1))
            self.z = self.z[0].repeat((opt.batch_size, 1)).to(self.device)
            self.z.requires_grad=True
            self.optimizer_z = torch.optim.SGD([self.z], lr=opt.G_lr)
            self.optimizer_z.zero_grad()
        self.l1_loss = L1Loss()
        self.mse_loss = MSELoss()
        self.OCRconverter = OCRLabelConverter(opt.alphabet)
        self.epsilon = 10e-50
        self.real_z = None
        self.real_z_mean = None