Exemplo n.º 1
0
    def build_models(self):
        # ###################encoders######################################## #
        if cfg.TRAIN.NET_E == '':
            print('Error: no pretrained text-image encoders')
            return

        image_encoder = CNN_ENCODER(cfg.TEXT.EMBEDDING_DIM)
        img_encoder_path = cfg.TRAIN.NET_E.replace('text_encoder', 'image_encoder')
        state_dict = \
            torch.load(img_encoder_path, map_location=lambda storage, loc: storage)
        image_encoder.load_state_dict(state_dict)
        for p in image_encoder.parameters():
            p.requires_grad = False
        print('Load image encoder from:', img_encoder_path)
        image_encoder.eval()

        text_encoder = \
            RNN_ENCODER(self.n_words, nhidden=cfg.TEXT.EMBEDDING_DIM)
        state_dict = \
            torch.load(cfg.TRAIN.NET_E,
                       map_location=lambda storage, loc: storage)
        # customed restore text encoder parameters
        own_state = text_encoder.state_dict()
        for name, param in state_dict.items():
            if name not in own_state:
                continue
            own_state[name] = param
        # text_encoder.load_state_dict(state_dict)
        # customed restore text encoder parameteres end
        for p in text_encoder.parameters():
            p.requires_grad = False
        print('Load text encoder from:', cfg.TRAIN.NET_E)
        text_encoder.train()

        # #######################generator and discriminators############## #
        netsD = []
        if cfg.GAN.B_DCGAN:
            if cfg.TREE.BRANCH_NUM ==1:
                from model import D_NET64 as D_NET
            elif cfg.TREE.BRANCH_NUM == 2:
                from model import D_NET128 as D_NET
            else:  # cfg.TREE.BRANCH_NUM == 3:
                from model import D_NET256 as D_NET
            # TODO: elif cfg.TREE.BRANCH_NUM > 3:
            netG = G_DCGAN()
            netsD = [D_NET(b_jcu=False)]
        else:
            from model import D_NET64, D_NET128, D_NET256
            netG = G_NET(text_encoder)
            if cfg.TREE.BRANCH_NUM > 0:
                netsD.append(D_NET64())
            if cfg.TREE.BRANCH_NUM > 1:
                netsD.append(D_NET128())
            if cfg.TREE.BRANCH_NUM > 2:
                netsD.append(D_NET256())
            # TODO: if cfg.TREE.BRANCH_NUM > 3:
        netG.apply(weights_init)
        # print(netG)
        for i in range(len(netsD)):
            netsD[i].apply(weights_init)
            # print(netsD[i])
        print('# of netsD', len(netsD))
        #
        epoch = 0
        if cfg.TRAIN.NET_G != '':
            state_dict = \
                torch.load(cfg.TRAIN.NET_G, map_location=lambda storage, loc: storage)
            netG.load_state_dict(state_dict)
            print('Load G from: ', cfg.TRAIN.NET_G)
            istart = cfg.TRAIN.NET_G.rfind('_') + 1
            iend = cfg.TRAIN.NET_G.rfind('.')
            epoch = cfg.TRAIN.NET_G[istart:iend]
            epoch = int(epoch) + 1
            if cfg.TRAIN.B_NET_D:
                Gname = cfg.TRAIN.NET_G
                for i in range(len(netsD)):
                    s_tmp = Gname[:Gname.rfind('/')]
                    Dname = '%s/netD%d.pth' % (s_tmp, i)
                    print('Load D from: ', Dname)
                    state_dict = \
                        torch.load(Dname, map_location=lambda storage, loc: storage)
                    netsD[i].load_state_dict(state_dict)
        # ########################################################### #
        if cfg.CUDA:
            text_encoder = text_encoder.cuda()
            image_encoder = image_encoder.cuda()
            netG.cuda()
            for i in range(len(netsD)):
                netsD[i].cuda()
        return [text_encoder, image_encoder, netG, netsD, epoch]
Exemplo n.º 2
0
    def gen_example(self, data_dic):        
        if cfg.TRAIN.NET_G == '':
            print('Error: the path for morels is not found!')
        else:
            # Build and load the generator
            batch_size = 16
            text_encoder = \
                RNN_ENCODER(self.n_words, nhidden=cfg.TEXT.EMBEDDING_DIM)
            print("=======self.n_words: %d", self.n_words)
            state_dict = \
                torch.load(cfg.TRAIN.NET_E, map_location=lambda storage, loc: storage)
            # customed restore text encoder parameters
            # ext_encoder.load_state_dict(state_dict)
            own_state = text_encoder.state_dict()
            for name, param in state_dict.items():
                if name not in own_state:
                    continue
                own_state[name] = param
            print('Load text encoder from:', cfg.TRAIN.NET_E)
            text_encoder = text_encoder.cuda()
            text_encoder.eval()

            # the path to save generated images
            if cfg.GAN.B_DCGAN:
                netG = G_DCGAN()
            else:
                netG = G_NET(text_encoder)
            s_tmp = cfg.TRAIN.NET_G[:cfg.TRAIN.NET_G.rfind('.pth')]
            model_dir = cfg.TRAIN.NET_G
            state_dict = \
                torch.load(model_dir, map_location=lambda storage, loc: storage)
            netG.load_state_dict(state_dict)
            print('Load G from: ', model_dir)
            netG.cuda()
            netG.eval()
            for key in data_dic:
                save_dir = '%s/%s' % (s_tmp, key)
                mkdir_p(save_dir)
                captions, cap_lens, sorted_indices = data_dic[key]

                # batch_size = captions.shape[0]
                total_time = len(captions)//batch_size
                nz = cfg.GAN.Z_DIM
                # captions = Variable(torch.from_numpy(captions), volatile=True)
                # cap_lens = Variable(torch.from_numpy(cap_lens), volatile=True)

                # captions = captions.cuda()
                # cap_lens = cap_lens.cuda()
                with torch.no_grad():
                    for i in range(total_time):  # 16
                        noise = Variable(torch.FloatTensor(batch_size, nz))
                        noise = noise.cuda()
                        caption_tmp = Variable(torch.from_numpy(captions[i*batch_size:(i+1)*batch_size]))
                        if i < 3:
                            print(caption_tmp.data)
                        cap_len_tmp = Variable(torch.from_numpy(cap_lens[i*batch_size:(i+1)*batch_size]))
                        caption_tmp = caption_tmp.cuda()
                        cap_len_tmp = cap_len_tmp.cuda()
                        #######################################################
                        # (1) Extract text embeddings
                        ######################################################
                        hidden = text_encoder.init_hidden(batch_size)
                        # words_embs: batch_size x nef x seq_len
                        # sent_emb: batch_size x nef
                        words_embs, sent_emb, _ = text_encoder(caption_tmp, cap_len_tmp, None)
                        words_embs, sent_emb = words_embs.detach(), sent_emb.detach()
                        mask = (caption_tmp == 0)
                        #######################################################
                        # (2) Generate fake images
                        ######################################################
                        random.seed(datetime.now())
                        rnd= random.randint(0,1000)
                        torch.cuda.manual_seed(rnd)
                        noise.data.normal_(0, 1)                                          
                        fake_imgs, attention_maps, _, _, _ = netG(noise, sent_emb, words_embs, mask, caption_tmp, cap_len_tmp)                    
                        # G attention
                        # cap_lens_np = cap_lens.cpu().data.numpy()
                        cap_lens_np = cap_len_tmp.cpu().data.numpy()
                        for j in range(batch_size):
                            save_name = '%s/s_%d' % (save_dir, sorted_indices[i*batch_size+j])
                            for k in range(len(fake_imgs)):
                                im = fake_imgs[k][j].data.cpu().numpy()
                                im = ((im + 1.0) / 2)* 255.0
                                im = im.astype(np.uint8)
                                # print('im', im.shape)
                                im = np.transpose(im, (1, 2, 0))
                                # print('im', im.shape)
                                im = Image.fromarray(im)
                                fullpath = '%s_g%d.png' % (save_name, k)
                                im.save(fullpath)
                                # save to seperate directory                            
                                save_dir2 = '%s/stage_%d' % (save_dir, k)
                                mkdir_p(save_dir2)
                                fullpath = '%s/%d_g%d.png' % (save_dir2, sorted_indices[i*batch_size+j], k)
                                im.save(fullpath)

                            for k in range(len(attention_maps)):
                                if len(fake_imgs) > 1:
                                    im = fake_imgs[k + 1].detach().cpu()
                                else:
                                    im = fake_imgs[0].detach().cpu()
                                attn_maps = attention_maps[k]
                                att_sze = attn_maps.size(2)
                                img_set, sentences = \
                                    build_super_images2(im[j].unsqueeze(0),
                                                        caption_tmp[j].unsqueeze(0),
                                                        [cap_len_tmp[j]], self.ixtoword,
                                                        [attn_maps[j]], att_sze)
                                if img_set is not None:
                                    im = Image.fromarray(img_set)
                                    fullpath = '%s_a%d.png' % (save_name, k)
                                    im.save(fullpath)