예제 #1
0
    def build_models(self):

        if cfg.TRAIN.NET_E == '':
            print('Error: no pretrained text-image encoders')
            return

        # vgg16 network
        style_loss = VGGNet()

        for p in style_loss.parameters():
            p.requires_grad = False

        print("Load the style loss model")
        style_loss.eval()

        image_encoder = CNN_ENCODER(cfg.TEXT.EMBEDDING_DIM)
        img_encoder_path = cfg.TRAIN.NET_E.replace('text_encoder',
                                                   'image_encoder')
        state_dict = \
            torch.load(img_encoder_path, map_location=lambda storage, loc: storage)
        image_encoder.load_state_dict(state_dict)
        for p in image_encoder.parameters():
            p.requires_grad = False
        print('Load image encoder from:', img_encoder_path)
        image_encoder.eval()

        text_encoder = \
            RNN_ENCODER(self.n_words, nhidden=cfg.TEXT.EMBEDDING_DIM)
        state_dict = \
            torch.load(cfg.TRAIN.NET_E,
                       map_location=lambda storage, loc: storage)
        text_encoder.load_state_dict(state_dict)
        for p in text_encoder.parameters():
            p.requires_grad = False
        print('Load text encoder from:', cfg.TRAIN.NET_E)
        text_encoder.eval()

        netsD = []
        if cfg.GAN.B_DCGAN:
            if cfg.TREE.BRANCH_NUM == 1:
                from model import D_NET64 as D_NET
            elif cfg.TREE.BRANCH_NUM == 2:
                from model import D_NET128 as D_NET
            else:  # cfg.TREE.BRANCH_NUM == 3:
                from model import D_NET256 as D_NET
            netG = G_DCGAN()
            netsD = [D_NET(b_jcu=False)]
        else:
            from model import D_NET64, D_NET128, D_NET256
            netG = G_NET()

            if cfg.TREE.BRANCH_NUM > 0:
                netsD.append(D_NET64())
            if cfg.TREE.BRANCH_NUM > 1:
                netsD.append(D_NET128())
            if cfg.TREE.BRANCH_NUM > 2:
                netsD.append(D_NET256())
        netG.apply(weights_init)

        for i in range(len(netsD)):
            netsD[i].apply(weights_init)
        print('# of netsD', len(netsD))
        #
        epoch = 0
        if cfg.TRAIN.NET_G != '':
            state_dict = \
                torch.load(cfg.TRAIN.NET_G, map_location=lambda storage, loc: storage)
            netG.load_state_dict(state_dict)

            print('Load G from: ', cfg.TRAIN.NET_G)
            istart = cfg.TRAIN.NET_G.rfind('_') + 1
            iend = cfg.TRAIN.NET_G.rfind('.')
            epoch = cfg.TRAIN.NET_G[istart:iend]
            epoch = int(epoch) + 1
            if cfg.TRAIN.B_NET_D:
                Gname = cfg.TRAIN.NET_G
                for i in range(len(netsD)):
                    s_tmp = Gname[:Gname.rfind('/')]
                    Dname = '%s/netD%d.pth' % (s_tmp, i)
                    print('Load D from: ', Dname)
                    state_dict = \
                        torch.load(Dname, map_location=lambda storage, loc: storage)
                    netsD[i].load_state_dict(state_dict)

        # Create a target network.
        target_netG = deepcopy(netG)

        if cfg.CUDA:
            text_encoder = text_encoder.cuda()
            image_encoder = image_encoder.cuda()
            style_loss = style_loss.cuda()

            # The target network is stored on the scondary GPU.---------------------------------
            target_netG.cuda(secondary_device)
            target_netG.ca_net.device = secondary_device
            #-----------------------------------------------------------------------------------

            netG.cuda()
            for i in range(len(netsD)):
                netsD[i] = netsD[i].cuda()

        # Disable training in the target network:
        for p in target_netG.parameters():
            p.requires_grad = False

        return [
            text_encoder, image_encoder, netG, target_netG, netsD, epoch,
            style_loss
        ]
예제 #2
0
    def gen_example(self, data_dic):
        if cfg.TRAIN.NET_G == '' or cfg.TRAIN.NET_C == '':
            print('Error: the path for main module or DCM is not found!')
        else:
            # The text encoder
            text_encoder = \
              RNN_ENCODER(self.n_words, nhidden=cfg.TEXT.EMBEDDING_DIM)
            state_dict = \
              torch.load(cfg.TRAIN.NET_E, map_location=lambda storage, loc: storage)
            text_encoder.load_state_dict(state_dict)
            print('Load text encoder from:', cfg.TRAIN.NET_E)
            text_encoder = text_encoder.cuda()
            text_encoder.eval()

            # The image encoder
            image_encoder = CNN_ENCODER(cfg.TEXT.EMBEDDING_DIM)
            img_encoder_path = cfg.TRAIN.NET_E.replace('text_encoder',
                                                       'image_encoder')
            state_dict = \
              torch.load(img_encoder_path, map_location=lambda storage, loc: storage)
            image_encoder.load_state_dict(state_dict)
            print('Load image encoder from:', img_encoder_path)
            image_encoder = image_encoder.cuda()
            image_encoder.eval()

            # The VGG network
            VGG = VGGNet()
            print("Load the VGG model")
            VGG.cuda()
            VGG.eval()

            # The main module
            if cfg.GAN.B_DCGAN:
                netG = G_DCGAN()
            else:
                netG = G_NET()
            s_tmp = cfg.TRAIN.NET_G[:cfg.TRAIN.NET_G.rfind('.pth')]
            model_dir = cfg.TRAIN.NET_G
            state_dict = \
              torch.load(model_dir, map_location=lambda storage, loc: storage)
            netG.load_state_dict(state_dict)
            print('Load G from: ', model_dir)
            netG.cuda()
            netG.eval()

            # The DCM
            netDCM = DCM_Net()
            if cfg.TRAIN.NET_C != '':
                state_dict = \
                  torch.load(cfg.TRAIN.NET_C, map_location=lambda storage, loc: storage)
                netDCM.load_state_dict(state_dict)
                print('Load DCM from: ', cfg.TRAIN.NET_C)
            netDCM.cuda()
            netDCM.eval()

            for key in data_dic:
                save_dir = '%s/%s' % (s_tmp, key)
                mkdir_p(save_dir)
                captions, cap_lens, sorted_indices, imgs = data_dic[key]

                batch_size = captions.shape[0]
                nz = cfg.GAN.Z_DIM
                captions = Variable(torch.from_numpy(captions), volatile=True)
                cap_lens = Variable(torch.from_numpy(cap_lens), volatile=True)

                captions = captions.cuda()
                cap_lens = cap_lens.cuda()
                for i in range(1):
                    noise = Variable(torch.FloatTensor(batch_size, nz),
                                     volatile=True)
                    noise = noise.cuda()

                    #######################################################
                    # (1) Extract text and image embeddings
                    ######################################################
                    hidden = text_encoder.init_hidden(batch_size)

                    # The text embeddings
                    words_embs, sent_emb = text_encoder(
                        captions, cap_lens, hidden)

                    # The image embeddings
                    region_features, cnn_code = \
                      image_encoder(imgs[cfg.TREE.BRANCH_NUM - 1].unsqueeze(0))
                    mask = (captions == 0)

                    #######################################################
                    # (2) Modify real images
                    ######################################################
                    noise.data.normal_(0, 1)
                    fake_imgs, attention_maps, mu, logvar, h_code, c_code = netG(
                        noise, sent_emb, words_embs, mask, cnn_code,
                        region_features)

                    real_img = imgs[cfg.TREE.BRANCH_NUM - 1].unsqueeze(0)
                    real_features = VGG(real_img)[0]

                    fake_img = netDCM(h_code, real_features, sent_emb, words_embs, \
                                      mask, c_code)

                    cap_lens_np = cap_lens.cpu().data.numpy()
                    for j in range(batch_size):
                        save_name = '%s/%d_s_%d' % (save_dir, i,
                                                    sorted_indices[j])
                        for k in range(len(fake_imgs)):
                            im = fake_imgs[k][j].data.cpu().numpy()
                            im = (im + 1.0) * 127.5
                            im = im.astype(np.uint8)
                            im = np.transpose(im, (1, 2, 0))
                            im = Image.fromarray(im)
                            fullpath = '%s_g%d.png' % (save_name, k)
                            im.save(fullpath)

                        for k in range(len(attention_maps)):
                            if len(fake_imgs) > 1:
                                im = fake_imgs[k + 1].detach().cpu()
                            else:
                                im = fake_imgs[0].detach().cpu()
                            attn_maps = attention_maps[k]
                            att_sze = attn_maps.size(2)
                            img_set, sentences = \
                              build_super_images2(im[j].unsqueeze(0),
                                                  captions[j].unsqueeze(0),
                                                  [cap_lens_np[j]], self.ixtoword,
                                                  [attn_maps[j]], att_sze)
                            if img_set is not None:
                                im = Image.fromarray(img_set)
                                fullpath = '%s_a%d.png' % (save_name, k)
                                im.save(fullpath)

                        save_name = '%s/%d_sf_%d' % (save_dir, 1,
                                                     sorted_indices[j])
                        im = fake_img[j].data.cpu().numpy()
                        im = (im + 1.0) * 127.5
                        im = im.astype(np.uint8)
                        im = np.transpose(im, (1, 2, 0))
                        im = Image.fromarray(im)
                        fullpath = '%s_SF.png' % (save_name)
                        im.save(fullpath)

                    save_name = '%s/%d_s_%d' % (save_dir, 1, 9)
                    im = imgs[2].data.cpu().numpy()
                    im = (im + 1.0) * 127.5
                    im = im.astype(np.uint8)
                    im = np.transpose(im, (1, 2, 0))
                    im = Image.fromarray(im)
                    fullpath = '%s_SR.png' % (save_name)
                    im.save(fullpath)
예제 #3
0
    def sampling(self, split_dir):
        if cfg.TRAIN.NET_G == '' or cfg.TRAIN.NET_C == '':
            print('Error: the path for main module or DCM is not found!')
        else:
            if split_dir == 'test':
                split_dir = 'valid'

            if cfg.GAN.B_DCGAN:
                netG = G_DCGAN()
            else:
                netG = G_NET()
            netG.apply(weights_init)
            netG.cuda()
            netG.eval()
            # The text encoder
            text_encoder = RNN_ENCODER(self.n_words,
                                       nhidden=cfg.TEXT.EMBEDDING_DIM)
            state_dict = \
              torch.load(cfg.TRAIN.NET_E, map_location=lambda storage, loc: storage)
            text_encoder.load_state_dict(state_dict)
            print('Load text encoder from:', cfg.TRAIN.NET_E)
            text_encoder = text_encoder.cuda()
            text_encoder.eval()
            # The image encoder
            image_encoder = CNN_ENCODER(cfg.TEXT.EMBEDDING_DIM)
            img_encoder_path = cfg.TRAIN.NET_E.replace('text_encoder',
                                                       'image_encoder')
            state_dict = \
              torch.load(img_encoder_path, map_location=lambda storage, loc: storage)
            image_encoder.load_state_dict(state_dict)
            print('Load image encoder from:', img_encoder_path)
            image_encoder = image_encoder.cuda()
            image_encoder.eval()

            # The VGG network
            VGG = VGGNet()
            print("Load the VGG model")
            VGG.cuda()
            VGG.eval()

            batch_size = self.batch_size
            nz = cfg.GAN.Z_DIM
            noise = Variable(torch.FloatTensor(batch_size, nz), volatile=True)
            noise = noise.cuda()

            # The DCM
            netDCM = DCM_Net()
            if cfg.TRAIN.NET_C != '':
                state_dict = \
                  torch.load(cfg.TRAIN.NET_C, map_location=lambda storage, loc: storage)
                netDCM.load_state_dict(state_dict)
                print('Load DCM from: ', cfg.TRAIN.NET_C)
            netDCM.cuda()
            netDCM.eval()

            model_dir = cfg.TRAIN.NET_G
            state_dict = \
              torch.load(model_dir, map_location=lambda storage, loc: storage)
            netG.load_state_dict(state_dict)
            print('Load G from: ', model_dir)

            # the path to save modified images
            s_tmp = model_dir[:model_dir.rfind('.pth')]
            save_dir = '%s/%s' % (s_tmp, split_dir)
            mkdir_p(save_dir)

            cnt = 0
            idx = 0
            for _ in range(5):  # (cfg.TEXT.CAPTIONS_PER_IMAGE):
                for step, data in enumerate(self.data_loader, 0):
                    cnt += batch_size
                    if step % 100 == 0:
                        print('step: ', step)

                    imgs, captions, cap_lens, class_ids, keys, wrong_caps, \
                    wrong_caps_len, wrong_cls_id = prepare_data(data)

                    #######################################################
                    # (1) Extract text and image embeddings
                    ######################################################

                    hidden = text_encoder.init_hidden(batch_size)

                    words_embs, sent_emb = text_encoder(
                        wrong_caps, wrong_caps_len, hidden)
                    words_embs, sent_emb = words_embs.detach(
                    ), sent_emb.detach()

                    mask = (wrong_caps == 0)
                    num_words = words_embs.size(2)
                    if mask.size(1) > num_words:
                        mask = mask[:, :num_words]

                    region_features, cnn_code = \
                      image_encoder(imgs[cfg.TREE.BRANCH_NUM - 1])

                    #######################################################
                    # (2) Modify real images
                    ######################################################

                    noise.data.normal_(0, 1)
                    fake_imgs, attention_maps, mu, logvar, h_code, c_code = netG(
                        noise, sent_emb, words_embs, mask, cnn_code,
                        region_features)

                    real_img = imgs[cfg.TREE.BRANCH_NUM - 1]
                    real_features = VGG(real_img)[0]

                    fake_img = netDCM(h_code, real_features, sent_emb, words_embs, \
                                      mask, c_code)
                    for j in range(batch_size):
                        s_tmp = '%s/single' % (save_dir)
                        folder = s_tmp[:s_tmp.rfind('/')]
                        if not os.path.isdir(folder):
                            print('Make a new folder: ', folder)
                            mkdir_p(folder)
                        k = -1
                        im = fake_img[j].data.cpu().numpy()
                        im = (im + 1.0) * 127.5
                        im = im.astype(np.uint8)
                        im = np.transpose(im, (1, 2, 0))
                        im = Image.fromarray(im)
                        fullpath = '%s_s%d.png' % (s_tmp, idx)
                        idx = idx + 1
                        im.save(fullpath)
예제 #4
0
    def build_models(self):
        ################### Text and Image encoders ########################################
        # if cfg.TRAIN.NET_E == '':
        #   print('Error: no pretrained text-image encoders')
        #   return

        VGG = VGGNet()

        for p in VGG.parameters():
            p.requires_grad = False

        print("Load the VGG model")
        VGG.eval()

        image_encoder = CNN_ENCODER(cfg.TEXT.EMBEDDING_DIM)
        text_encoder = RNN_ENCODER(self.n_words,
                                   nhidden=cfg.TEXT.EMBEDDING_DIM)

        # when NET_E = '', train the image_encoder and text_encoder jointly
        if cfg.TRAIN.NET_E != '':
            state_dict = torch.load(
                cfg.TRAIN.NET_E,
                map_location=lambda storage, loc: storage).state_dict()
            text_encoder.load_state_dict(state_dict)
            for p in text_encoder.parameters():
                p.requires_grad = False
            print('Load text encoder from:', cfg.TRAIN.NET_E)
            text_encoder.eval()

            img_encoder_path = cfg.TRAIN.NET_E.replace('text_encoder',
                                                       'image_encoder')
            state_dict = torch.load(
                img_encoder_path,
                map_location=lambda storage, loc: storage).state_dict()
            image_encoder.load_state_dict(state_dict)
            for p in image_encoder.parameters():
                p.requires_grad = False
            print('Load image encoder from:', img_encoder_path)
            image_encoder.eval()

        ####################### Generator and Discriminators ##############
        netsD = []
        if cfg.GAN.B_DCGAN:
            if cfg.TREE.BRANCH_NUM == 1:
                from model import D_NET64 as D_NET
            elif cfg.TREE.BRANCH_NUM == 2:
                from model import D_NET128 as D_NET
            else:  # cfg.TREE.BRANCH_NUM == 3:
                from model import D_NET256 as D_NET
            netG = G_DCGAN()
            if cfg.TRAIN.W_GAN:
                netsD = [D_NET(b_jcu=False)]
        else:
            from model import D_NET64, D_NET128, D_NET256
            netG = G_NET()
            netG.apply(weights_init)
            if cfg.TRAIN.W_GAN:
                if cfg.TREE.BRANCH_NUM > 0:
                    netsD.append(D_NET64())
                if cfg.TREE.BRANCH_NUM > 1:
                    netsD.append(D_NET128())
                if cfg.TREE.BRANCH_NUM > 2:
                    netsD.append(D_NET256())
                for i in range(len(netsD)):
                    netsD[i].apply(weights_init)

        print('# of netsD', len(netsD))
        #
        epoch = 0
        if cfg.TRAIN.NET_G != '':
            state_dict = torch.load(cfg.TRAIN.NET_G,
                                    map_location=lambda storage, loc: storage)
            netG.load_state_dict(state_dict)
            print('Load G from: ', cfg.TRAIN.NET_G)
            istart = cfg.TRAIN.NET_G.rfind('_') + 1
            iend = cfg.TRAIN.NET_G.rfind('.')
            epoch = cfg.TRAIN.NET_G[istart:iend]
            epoch = int(epoch) + 1
            if cfg.TRAIN.B_NET_D:
                Gname = cfg.TRAIN.NET_G
                for i in range(len(netsD)):
                    s_tmp = Gname[:Gname.rfind('/')]
                    Dname = '%s/netD%d.pth' % (s_tmp, i)
                    print('Load D from: ', Dname)
                    state_dict = \
                      torch.load(Dname, map_location=lambda storage, loc: storage)
                    netsD[i].load_state_dict(state_dict)
        # ########################################################### #
        if cfg.CUDA:
            text_encoder = text_encoder.cuda()
            image_encoder = image_encoder.cuda()
            netG.cuda()
            VGG = VGG.cuda()
            for i in range(len(netsD)):
                netsD[i].cuda()
        return [text_encoder, image_encoder, netG, netsD, epoch, VGG]
예제 #5
0
    def build_models(self):
        # ################### models ######################################## #
        if cfg.TRAIN.NET_E == '':
            print('Error: no pretrained text-image encoders')
            return
        if cfg.TRAIN.NET_G == '':
            print('Error: no pretrained main module')
            return

        VGG = VGGNet()

        for p in VGG.parameters():
            p.requires_grad = False

        print("Load the VGG model")
        VGG.eval()

        image_encoder = CNN_ENCODER(cfg.TEXT.EMBEDDING_DIM)
        img_encoder_path = cfg.TRAIN.NET_E.replace('text_encoder',
                                                   'image_encoder')
        state_dict = \
            torch.load(img_encoder_path, map_location=lambda storage, loc: storage)
        image_encoder.load_state_dict(state_dict)
        for p in image_encoder.parameters():
            p.requires_grad = False
        print('Load image encoder from:', img_encoder_path)
        image_encoder.eval()

        text_encoder = \
            RNN_ENCODER(self.n_words, nhidden=cfg.TEXT.EMBEDDING_DIM)
        state_dict = \
            torch.load(cfg.TRAIN.NET_E,
                       map_location=lambda storage, loc: storage)
        text_encoder.load_state_dict(state_dict)
        for p in text_encoder.parameters():
            p.requires_grad = False
        print('Load text encoder from:', cfg.TRAIN.NET_E)
        text_encoder.eval()

        if cfg.GAN.B_DCGAN:
            netG = G_DCGAN()
            from model import D_NET256 as D_NET
            netD = D_NET(b_jcu=False)
        else:
            from model import D_NET256
            netG = G_NET()
            netD = D_NET256()

        netD.apply(weights_init)

        state_dict = \
            torch.load(cfg.TRAIN.NET_G, map_location=lambda storage, loc: storage)
        netG.load_state_dict(state_dict)
        netG.eval()
        print('Load G from: ', cfg.TRAIN.NET_G)

        epoch = 0
        netDCM = DCM_Net()
        if cfg.TRAIN.NET_C != '':
            state_dict = \
                torch.load(cfg.TRAIN.NET_C, map_location=lambda storage, loc: storage)
            netDCM.load_state_dict(state_dict)
            print('Load DCM from: ', cfg.TRAIN.NET_C)
            istart = cfg.TRAIN.NET_C.rfind('_') + 1
            iend = cfg.TRAIN.NET_C.rfind('.')
            epoch = cfg.TRAIN.NET_C[istart:iend]
            epoch = int(epoch) + 1

        if cfg.TRAIN.NET_D != '':
            state_dict = \
                torch.load(cfg.TRAIN.NET_D, map_location=lambda storage, loc: storage)
            netD.load_state_dict(state_dict)
            print('Load DCM Discriminator from: ', cfg.TRAIN.NET_D)

        if cfg.CUDA:
            text_encoder = text_encoder.cuda()
            image_encoder = image_encoder.cuda()
            netG.cuda()
            netDCM.cuda()
            VGG = VGG.cuda()
            netD.cuda()
        return [text_encoder, image_encoder, netG, netD, epoch, VGG, netDCM]