예제 #1
0
    def gen_example(self, data_dic):
        if cfg.TRAIN.NET_G == '' or cfg.TRAIN.NET_C == '':
            print('Error: the path for main module or DCM is not found!')
        else:
            # The text encoder
            text_encoder = \
                RNN_ENCODER(self.n_words, nhidden=cfg.TEXT.EMBEDDING_DIM)
            state_dict = \
                torch.load(cfg.TRAIN.NET_E, map_location=lambda storage, loc: storage)
            text_encoder.load_state_dict(state_dict)
            print('Load text encoder from:', cfg.TRAIN.NET_E)
            text_encoder = text_encoder.cuda()
            text_encoder.eval()

            # The image encoder
            """
            image_encoder = CNN_ENCODER(cfg.TEXT.EMBEDDING_DIM)
            img_encoder_path = cfg.TRAIN.NET_E.replace('text_encoder', 'image_encoder')
            state_dict = \
                torch.load(img_encoder_path, map_location=lambda storage, loc: storage)
            image_encoder.load_state_dict(state_dict)
            print('Load image encoder from:', img_encoder_path)
            image_encoder = image_encoder.cuda()
            image_encoder.eval()
            """
            """
            image_encoder = CNN_dummy()
            image_encoder = image_encoder.cuda()
            image_encoder.eval()
            """

            # The VGG network
            VGG = VGG16()
            print("Load the VGG model")
            VGG.cuda()
            VGG.eval()

            # The main module
            if cfg.GAN.B_DCGAN:
                netG = G_DCGAN()
            else:
                netG = EncDecNet()
            s_tmp = cfg.TRAIN.NET_G[:cfg.TRAIN.NET_G.rfind('.pth')]
            s_tmp = os.path.join(cfg.DATA_DIR, 'output', self.args.netG,
                                 'valid/gen_example')

            model_dir = os.path.join(cfg.DATA_DIR, 'output', self.args.netG,
                                     'Model/netG_epoch_8.pth')
            state_dict = \
                torch.load(model_dir, map_location=lambda storage, loc: storage)
            netG.load_state_dict(state_dict)
            print('Load G from: ', model_dir)
            #netG = nn.DataParallel(netG, device_ids= self.gpus)
            netG.cuda()
            netG.eval()

            for key in data_dic:
                save_dir = '%s/%s' % (s_tmp, key)
                mkdir_p(save_dir)
                captions, cap_lens, sorted_indices, imgs = data_dic[key]

                batch_size = captions.shape[0]
                nz = cfg.GAN.Z_DIM
                captions = Variable(torch.from_numpy(captions), volatile=True)
                cap_lens = Variable(torch.from_numpy(cap_lens), volatile=True)

                captions = captions.cuda()
                cap_lens = cap_lens.cuda()
                for i in range(1):
                    noise = Variable(torch.FloatTensor(batch_size, nz),
                                     volatile=True)
                    noise = noise.cuda()

                    #######################################################
                    # (1) Extract text and image embeddings
                    ######################################################
                    hidden = text_encoder.init_hidden(batch_size)

                    # The text embeddings
                    words_embs, sent_emb = text_encoder(
                        captions, cap_lens, hidden)
                    words_embs, sent_emb = words_embs.detach(
                    ), sent_emb.detach()

                    # The image embeddings
                    mask = (captions == 0)
                    #######################################################
                    # (2) Modify real images
                    ######################################################
                    noise.data.normal_(0, 1)

                    imgs_256 = imgs[-1].unsqueeze(0).repeat(
                        batch_size, 1, 1, 1)
                    enc_features = VGG(imgs_256)
                    fake_img, mu, logvar = nn.parallel.data_parallel(
                        netG, (imgs[-1], sent_emb, words_embs, noise, mask,
                               enc_features), self.gpus)

                    cap_lens_np = cap_lens.cpu().data.numpy()

                    one_imgs = []
                    for j in range(captions.shape[0]):
                        font = ImageFont.truetype('./FreeMono.ttf', 20)
                        canv = Image.new('RGB', (256, 256), (255, 255, 255))
                        draw = ImageDraw.Draw(canv)
                        sent = []
                        for k in range(len(captions[j])):
                            if (captions[j][k] == 0):
                                break
                            word = self.ixtoword[captions[j][k].item()].encode(
                                'ascii', 'ignore').decode('ascii')
                            if (k % 2 == 1):
                                word = word + '\n'
                            sent.append(word)
                        fake_sent = ' '.join(sent)
                        draw.text((0, 0), fake_sent, font=font, fill=(0, 0, 0))
                        canv_np = np.asarray(canv)

                        real_im = imgs[-1]
                        real_im = (real_im + 1) * 127.5
                        real_im = real_im.cpu().numpy().astype(np.uint8)
                        real_im = np.transpose(real_im, (1, 2, 0))

                        fake_im = fake_img[j]
                        fake_im = (fake_im + 1.0) * 127.5
                        fake_im = fake_im.detach().cpu().numpy().astype(
                            np.uint8)
                        fake_im = np.transpose(fake_im, (1, 2, 0))

                        one_img = np.concatenate([real_im, canv_np, fake_im],
                                                 axis=1)
                        one_imgs.append(one_img)

                    img_set = np.concatenate(one_imgs, axis=0)
                    super_img = Image.fromarray(img_set)
                    full_path = os.path.join(save_dir, 'super.png')
                    super_img.save(full_path)
                    """
                    for j in range(5): ## batch_size
                        save_name = '%s/%d_s_%d' % (save_dir, i, sorted_indices[j])
                        for k in range(len(fake_imgs)):
                            im = fake_imgs[k][j].data.cpu().numpy()
                            im = (im + 1.0) * 127.5
                            im = im.astype(np.uint8)
                            im = np.transpose(im, (1, 2, 0))
                            im = Image.fromarray(im)
                            fullpath = '%s_g%d.png' % (save_name, k)
                            im.save(fullpath)

                        for k in range(len(attention_maps)):
                            if len(fake_imgs) > 1:
                                im = fake_imgs[k + 1].detach().cpu()
                            else:
                                im = fake_imgs[0].detach().cpu()
                            attn_maps = attention_maps[k]
                            att_sze = attn_maps.size(2)
                    """
                    """
                            img_set, sentences = \
                                build_super_images2(im[j].unsqueeze(0),
                                                    captions[j].unsqueeze(0),
                                                    [cap_lens_np[j]], self.ixtoword,
                                                    [attn_maps[j]], att_sze)
                            if img_set is not None:
                                im = Image.fromarray(img_set)
                                fullpath = '%s_a%d.png' % (save_name, k)
                                im.save(fullpath)
                    """
                    """
예제 #2
0
    def build_models(self):
        ################### Text and Image encoders ########################################
        if cfg.TRAIN.NET_E == '':
            print('Error: no pretrained text-image encoders')
            return
        """
        image_encoder = CNN_dummy()
        image_encoder.cuda()
        image_encoder.eval()
        image_encoder = CNN_ENCODER(cfg.TEXT.EMBEDDING_DIM)
        img_encoder_path = cfg.TRAIN.NET_E.replace('text_encoder', 'image_encoder')
        state_dict = \
            torch.load(img_encoder_path, map_location=lambda storage, loc: storage)
        image_encoder.load_state_dict(state_dict)
        for p in image_encoder.parameters():
            p.requires_grad = False
        print('Load image encoder from:', img_encoder_path)
        image_encoder.eval()
        """

        VGG = VGG16()
        VGG.eval()

        text_encoder = \
            RNN_ENCODER(self.n_words, nhidden=cfg.TEXT.EMBEDDING_DIM)
        state_dict = \
            torch.load(cfg.TRAIN.NET_E,
                       map_location=lambda storage, loc: storage)
        text_encoder.load_state_dict(state_dict)
        for p in text_encoder.parameters():
            p.requires_grad = False
        print('Load text encoder from:', cfg.TRAIN.NET_E)
        text_encoder.eval()

        ####################### Generator and Discriminators ##############
        from model import D_NET256
        netD = D_NET256()
        netG = EncDecNet()

        netD.apply(weights_init)
        netG.apply(weights_init)

        #
        epoch = 0
        """
        if cfg.TRAIN.NET_G != '':
            state_dict = \
                torch.load(cfg.TRAIN.NET_G, map_location=lambda storage, loc: storage)
            netG.load_state_dict(state_dict)
            print('Load G from: ', cfg.TRAIN.NET_G)
            istart = cfg.TRAIN.NET_G.rfind('_') + 1
            iend = cfg.TRAIN.NET_G.rfind('.')
            epoch = cfg.TRAIN.NET_G[istart:iend]
            epoch = int(epoch) + 1
            if cfg.TRAIN.B_NET_D:
                Gname = cfg.TRAIN.NET_G
                for i in range(len(netsD)):
                    s_tmp = Gname[:Gname.rfind('/')]
                    Dname = '%s/netD%d.pth' % (s_tmp, i)
                    print('Load D from: ', Dname)
                    state_dict = \
                        torch.load(Dname, map_location=lambda storage, loc: storage)
                    netsD[i].load_state_dict(state_dict)
        """
        # ########################################################### #
        if cfg.CUDA:
            text_encoder = text_encoder.cuda()
            netG.cuda()
            netD.cuda()
            VGG.cuda()
        return [text_encoder, netG, netD, epoch, VGG]
예제 #3
0
    def sampling(self, split_dir):
        if cfg.TRAIN.NET_G == '':
            print('Error: the path for main module is not found!')
        else:
            if split_dir == 'test':
                split_dir = 'valid'

            if cfg.GAN.B_DCGAN:
                netG = G_DCGAN()
            else:
                netG = EncDecNet()
            netG.apply(weights_init)
            netG.cuda()
            netG.eval()
            # The text encoder
            text_encoder = RNN_ENCODER(self.n_words,
                                       nhidden=cfg.TEXT.EMBEDDING_DIM)
            state_dict = \
                torch.load(cfg.TRAIN.NET_E, map_location=lambda storage, loc: storage)
            text_encoder.load_state_dict(state_dict)
            print('Load text encoder from:', cfg.TRAIN.NET_E)
            text_encoder = text_encoder.cuda()
            text_encoder.eval()
            # The image encoder
            """
            image_encoder = CNN_ENCODER(cfg.TEXT.EMBEDDING_DIM)
            img_encoder_path = cfg.TRAIN.NET_E.replace('text_encoder', 'image_encoder')
            state_dict = \
                torch.load(img_encoder_path, map_location=lambda storage, loc: storage)
            image_encoder.load_state_dict(state_dict)
            print('Load image encoder from:', img_encoder_path)
            image_encoder = image_encoder.cuda()
            image_encoder.eval()
            """

            # The VGG network
            VGG = VGG16()
            print("Load the VGG model")
            VGG.cuda()
            VGG.eval()

            batch_size = self.batch_size
            nz = cfg.GAN.Z_DIM
            noise = Variable(torch.FloatTensor(batch_size, nz), volatile=True)
            noise = noise.cuda()

            model_dir = os.path.join(cfg.DATA_DIR, 'output', self.args.netG,
                                     'Model/netG_epoch_600.pth')
            state_dict = \
                torch.load(model_dir, map_location=lambda storage, loc: storage)
            netG.load_state_dict(state_dict)
            print('Load G from: ', model_dir)

            # the path to save modified images
            save_dir_valid = os.path.join(cfg.DATA_DIR, 'output',
                                          self.args.netG, 'valid')
            #mkdir_p(save_dir)

            cnt = 0
            idx = 0
            for i in range(5):  # (cfg.TEXT.CAPTIONS_PER_IMAGE):
                # the path to save modified images
                save_dir = os.path.join(save_dir_valid, 'valid_%d' % i)
                save_dir_super = os.path.join(save_dir, 'super')
                save_dir_single = os.path.join(save_dir, 'single')
                mkdir_p(save_dir_super)
                mkdir_p(save_dir_single)
                for step, data in enumerate(self.data_loader, 0):
                    cnt += batch_size
                    if step % 100 == 0:
                        print('step: ', step)

                    imgs, w_imgs, captions, cap_lens, class_ids, keys, wrong_caps, \
                                wrong_caps_len, wrong_cls_id = prepare_data(data)

                    #######################################################
                    # (1) Extract text and image embeddings
                    ######################################################

                    hidden = text_encoder.init_hidden(batch_size)

                    words_embs, sent_emb = text_encoder(
                        wrong_caps, wrong_caps_len, hidden)
                    words_embs, sent_emb = words_embs.detach(
                    ), sent_emb.detach()

                    mask = (wrong_caps == 0)
                    num_words = words_embs.size(2)
                    if mask.size(1) > num_words:
                        mask = mask[:, :num_words]

                    #######################################################
                    # (2) Modify real images
                    ######################################################

                    noise.data.normal_(0, 1)

                    fake_img, mu, logvar = netG(imgs[-1], sent_emb, words_embs,
                                                noise, mask, VGG)

                    img_set = build_images(imgs[-1], fake_img, captions,
                                           wrong_caps, self.ixtoword)
                    img = Image.fromarray(img_set)
                    full_path = '%s/super_step%d.png' % (save_dir_super, step)
                    img.save(full_path)

                    for j in range(batch_size):
                        s_tmp = '%s/single' % (save_dir_single)
                        folder = s_tmp[:s_tmp.rfind('/')]
                        if not os.path.isdir(folder):
                            print('Make a new folder: ', folder)
                            mkdir_p(folder)
                        k = -1
                        im = fake_img[j].data.cpu().numpy()
                        #im = (im + 1.0) * 127.5
                        im = im.astype(np.uint8)
                        im = np.transpose(im, (1, 2, 0))
                        im = Image.fromarray(im)
                        fullpath = '%s_s%d.png' % (s_tmp, idx)
                        idx = idx + 1
                        im.save(fullpath)
예제 #4
0
    def calc_mp(self):
        if cfg.TRAIN.NET_G == '':
            print('Error: the path for main module is not found!')
        else:
            #if split_dir == 'test':
            #    split_dir = 'valid'

            if cfg.GAN.B_DCGAN:
                netG = G_DCGAN()
            else:
                netG = EncDecNet()
            netG.apply(weights_init)
            netG.cuda()
            netG.eval()
            # The text encoder

            text_encoder = RNN_ENCODER(self.n_words,
                                       nhidden=cfg.TEXT.EMBEDDING_DIM)
            state_dict = \
                torch.load(cfg.TRAIN.NET_E, map_location=lambda storage, loc: storage)
            text_encoder.load_state_dict(state_dict)
            print('Load text encoder from:', cfg.TRAIN.NET_E)
            text_encoder = text_encoder.cuda()
            text_encoder.eval()
            # The image encoder
            #image_encoder = CNN_dummy()
            #print('define image_encoder')

            image_encoder = CNN_ENCODER(cfg.TEXT.EMBEDDING_DIM)
            img_encoder_path = cfg.TRAIN.NET_E.replace('text_encoder',
                                                       'image_encoder')
            state_dict = \
                torch.load(img_encoder_path, map_location=lambda storage, loc: storage)
            image_encoder.load_state_dict(state_dict)
            print('Load image encoder from:', img_encoder_path)

            image_encoder = image_encoder.cuda()
            image_encoder.eval()

            # The VGG network
            VGG = VGG16()
            print("Load the VGG model")
            #VGG.to(torch.device("cuda:1"))
            VGG.cuda()
            VGG.eval()

            batch_size = self.batch_size
            nz = cfg.GAN.Z_DIM
            noise = Variable(torch.FloatTensor(batch_size, nz), volatile=True)
            noise = noise.cuda()

            model_dir = os.path.join(cfg.DATA_DIR, 'output', self.args.netG,
                                     'Model', self.args.netG_epoch)
            state_dict = \
                torch.load(model_dir, map_location=lambda storage, loc: storage)
            netG.load_state_dict(state_dict)
            print('Load G from: ', model_dir)

            # the path to save modified images

            cnt = 0
            idx = 0
            diffs, sims = [], []
            for _ in range(1):  # (cfg.TEXT.CAPTIONS_PER_IMAGE):
                for step, data in enumerate(self.data_loader, 0):
                    cnt += batch_size
                    if step % 100 == 0:
                        print('step: ', step)

                    imgs, w_imgs, captions, cap_lens, class_ids, keys, wrong_caps, \
                                wrong_caps_len, wrong_cls_id = prepare_data(data)

                    #######################################################
                    # (1) Extract text and image embeddings
                    ######################################################

                    hidden = text_encoder.init_hidden(batch_size)

                    words_embs, sent_emb = text_encoder(
                        wrong_caps, wrong_caps_len, hidden)
                    words_embs, sent_emb = words_embs.detach(
                    ), sent_emb.detach()

                    mask = (wrong_caps == 0)
                    num_words = words_embs.size(2)
                    if mask.size(1) > num_words:
                        mask = mask[:, :num_words]

                    #######################################################
                    # (2) Modify real images
                    ######################################################

                    noise.data.normal_(0, 1)

                    fake_img, mu, logvar = netG(imgs[-1], sent_emb, words_embs,
                                                noise, mask, VGG)

                    diff = F.l1_loss(fake_img, imgs[-1])
                    diffs.append(diff.item())

                    region_code, cnn_code = image_encoder(fake_img)

                    sim = cosine_similarity(sent_emb, cnn_code)
                    sim = torch.mean(sim)
                    sims.append(sim.item())

            diff = np.sum(diffs) / len(diffs)
            sim = np.sum(sims) / len(sims)
            print('diff: %.3f, sim:%.3f' % (diff, sim))
            print('MP: %.3f' % ((1 - diff) * sim))
            netG_epoch = self.args.netG_epoch[self.args.netG_epoch.find('_') +
                                              1:-4]
            print('model_epoch:%s, diff: %.3f, sim:%.3f, MP:%.3f' %
                  (netG_epoch, np.sum(diffs) / len(diffs),
                   np.sum(sims) / len(sims), (1 - diff) * sim))