Beispiel #1
0
    def __init__(self, caption_file, saveable, cuda=False, profile=False):
        # flags
        self.cuda = cuda
        self.profile = profile

        if self.profile:
            print('Initializing Generator...')
            print('cuda={}\nprofile={}'.format(self.cuda, self.profile))

        # load caption indices
        x = pickle.load(open(caption_file, 'rb'))
        self.ixtoword = x[2]
        self.wordtoix = x[3]
        del x

        # load text encoder
        self.text_encoder = RNN_ENCODER(len(self.wordtoix), nhidden=cfg.TEXT.EMBEDDING_DIM)
        state_dict = torch.load(cfg.TRAIN.NET_E, map_location=lambda storage, loc: storage)
        self.text_encoder.load_state_dict(state_dict)
        if self.cuda:
            self.text_encoder.cuda()
            
        self.text_encoder.eval()

        # load generative model
        self.netG = G_NET()
        state_dict = torch.load(cfg.TRAIN.NET_G, map_location=lambda storage, loc: storage)
        self.netG.load_state_dict(state_dict)
        if self.cuda:
            self.netG.cuda()
            
        self.netG.eval()

        # saveable items -> push to storage
        self.saveable = saveable
Beispiel #2
0
def build_models():
    # build model ############################################################
    text_encoder = RNN_ENCODER(dataset.n_words, nhidden=cfg.TEXT.EMBEDDING_DIM)
    image_encoder = CNN_ENCODER(cfg.TEXT.EMBEDDING_DIM)
    labels = Variable(torch.LongTensor(range(batch_size)))
    start_epoch = 0
    if cfg.TRAIN.NET_E != '':
        state_dict = torch.load(cfg.TRAIN.NET_E)
        text_encoder.load_state_dict(state_dict)
        print('Load ', cfg.TRAIN.NET_E)
        #
        name = cfg.TRAIN.NET_E.replace('text_encoder', 'image_encoder')
        state_dict = torch.load(name)
        image_encoder.load_state_dict(state_dict)
        print('Load ', name)

        istart = cfg.TRAIN.NET_E.rfind('_') + 8
        iend = cfg.TRAIN.NET_E.rfind('.')
        start_epoch = cfg.TRAIN.NET_E[istart:iend]
        start_epoch = int(start_epoch) + 1
        print('start_epoch', start_epoch)
    if cfg.CUDA:
        text_encoder = text_encoder.cuda()
        image_encoder = image_encoder.cuda()
        labels = labels.cuda()

    return text_encoder, image_encoder, labels, start_epoch
Beispiel #3
0
def build_models():
    text_encoder = RNN_ENCODER(dataset.n_words, nhidden=cfg.TEXT.EMBEDDING_DIM)
    image_encoder = CNN_ENCODER(cfg.TEXT.EMBEDDING_DIM)
    labels = Variable(torch.LongTensor(range(batch_size)))
    start_epoch = 0
    lr = cfg.TRAIN.ENCODER_LR
    if cfg.TRAIN.NET_E != '':
      state_dict = torch.load(cfg.TRAIN.NET_E, map_location=lambda storage, loc: storage) 
      text_encoder.load_state_dict(state_dict)
      print('Load {}'.format(cfg.TRAIN.NET_E))
      name = cfg.TRAIN.NET_E.replace('text_encoder', 'image_encoder')
      state_dict = torch.load(name, map_location=lambda storage, loc: storage)
      image_encoder.load_state_dict(state_dict)
      print('Load {}'.format(name))
      istart = cfg.TRAIN.NET_E.rfind('_') + 8
      iend = cfg.TRAIN.NET_E.rfind('.')
      start_epoch = cfg.TRAIN.NET_E[istart:iend]
      start_epoch = int(start_epoch) + 1
      print('start_epoch', start_epoch)
      # initial lr with the right value
      # note that the turning point is always epoch 114
      if start_epoch < 114:
        lr = cfg.TRAIN.ENCODER_LR * (0.98 ** start_epoch)
      else:
        lr = cfg.TRAIN.ENCODER_LR / 10   
    if cfg.CUDA:
      text_encoder = text_encoder.cuda()
      image_encoder = image_encoder.cuda()
      labels = labels.cuda()
    return text_encoder, image_encoder, labels, start_epoch, lr
def build_text_encoder(ntokens):
    text_encoder = RNN_ENCODER(ntokens, nhidden=cfg.TEXT.EMBEDDING_DIM)
    if cfg.TRAIN.NET_E != '':
        state_dict = torch.load(cfg.TRAIN.NET_E,
                                map_location=lambda storage, loc: storage)
        text_encoder.load_state_dict(state_dict)
        print('Load ', cfg.TRAIN.NET_E)
    if cfg.CUDA:
        text_encoder = text_encoder.cuda()
    return text_encoder
Beispiel #5
0
def models(modelname, cfg, word_len):
    #print(word_len)
    text_encoder = cache.get(modelname + '_text_encoder', None)
    if text_encoder is None:
        #print("text_encoder not cached")
        text_encoder = RNN_ENCODER(word_len, nhidden=cfg.TEXT.EMBEDDING_DIM)
        state_dict = torch.load(cfg.TRAIN.NET_E,
                                map_location=lambda storage, loc: storage)
        text_encoder.load_state_dict(state_dict)
        if cfg.CUDA:
            text_encoder.cuda()
        text_encoder.eval()
        cache[modelname + '_text_encoder'] = text_encoder

    netG = cache.get(modelname + '_netG', None)
    if netG is None:
        #print("netG not cached")
        netG = G_NET()
        state_dict = torch.load(cfg.TRAIN.NET_G,
                                map_location=lambda storage, loc: storage)
        netG.load_state_dict(state_dict)
        if cfg.CUDA:
            netG.cuda()
        netG.eval()
        cache[modelname + '_netG'] = netG

    return text_encoder, netG
Beispiel #6
0
def models(word_len):
    print('Loading Model', word_len)
    text_encoder = cache.get('text_encoder')
    print('Text enconder', text_encoder)
    if text_encoder is None:
        print("text_encoder not cached")
        text_encoder = RNN_ENCODER(word_len, nhidden=256)
        state_dict = torch.load('../DAMSMencoders/coco/text_encoder100.pth', map_location=lambda storage, loc: storage)
        text_encoder.load_state_dict(state_dict)
        print('loaded text encoder')
        text_encoder.cuda()
        print('text encoder cuda')
        text_encoder.eval()
        print('text encoder eval')
        #cache.set('text_encoder', text_encoder, timeout=60 * 60 * 24)

    print('Got Text Encoder, moving to netG')
    netG = cache.get('netG')
    if netG is None:
        print("netG not cached")
        netG = G_NET()
        state_dict = torch.load('../models/coco_AttnGAN2.pth', map_location=lambda storage, loc: storage)
        netG.load_state_dict(state_dict)
        if cfg.CUDA:
            netG.cuda()
        netG.eval()
        #cache.set('netG', netG, timeout=60 * 60 * 24)
    print('Got NetG')
    return text_encoder, netG
Beispiel #7
0
def models(word_len):
    #print(word_len)
    text_encoder = cache.get('text_encoder')
    if text_encoder is None:
        #print("text_encoder not cached")
        text_encoder = RNN_ENCODER(word_len, nhidden=cfg.TEXT.EMBEDDING_DIM)
        state_dict = torch.load(cfg.TRAIN.NET_E, map_location=lambda storage, loc: storage)
        text_encoder.load_state_dict(state_dict)
        if cfg.CUDA:
            text_encoder.cuda()
        text_encoder.eval()
        cache.set('text_encoder', text_encoder, timeout=60 * 60 * 24)

    netG = cache.get('netG')
    if netG is None:
        #print("netG not cached")
        netG = G_NET()
        state_dict = torch.load(cfg.TRAIN.NET_G, map_location=lambda storage, loc: storage)
        netG.load_state_dict(state_dict)
        if cfg.CUDA:
            netG.cuda()
        netG.eval()
        cache.set('netG', netG, timeout=60 * 60 * 24)

    return text_encoder, netG
Beispiel #8
0
def build_models(text_encoder_type):
    # build model ############################################################
    text_encoder_type = text_encoder_type.casefold()
    if text_encoder_type not in ('rnn', 'transformer'):
        raise ValueError('Unsupported text_encoder_type')

    if text_encoder_type == 'rnn':
        text_encoder = RNN_ENCODER(dataset.n_words,
                                   nhidden=cfg.TEXT.EMBEDDING_DIM)
    image_encoder = CNN_ENCODER(cfg.TEXT.EMBEDDING_DIM)

    labels = Variable(torch.LongTensor(range(batch_size)))
    start_epoch = 0
    if cfg.TRAIN.NET_E:
        if text_encoder_type == 'rnn':
            state_dict = torch.load(cfg.TRAIN.NET_E)
            text_encoder.load_state_dict(state_dict)
        elif text_encoder_type == 'transformer':
            text_encoder = GPT2Model.from_pretrained(cfg.TRAIN.NET_E)
            # output_hidden_states = True )
        print('Load ', cfg.TRAIN.NET_E)
        #
        name = cfg.TRAIN.NET_E.replace('text_encoder', 'image_encoder')
        state_dict = torch.load(name)
        image_encoder.load_state_dict(state_dict)
        print('Load ', name)

        istart = cfg.TRAIN.NET_E.rfind('_') + 8
        iend = cfg.TRAIN.NET_E.rfind('.')
        start_epoch = cfg.TRAIN.NET_E[istart:iend]
        start_epoch = int(start_epoch) + 1
    else:
        if text_encoder_type == 'rnn':
            print('Training RNN from scratch')
        elif text_encoder_type == 'transformer':
            # don't initialize the weights of these huge models from scratch...
            print('Training Transformer starting from pretrained model')
            text_encoder = GPT2Model.from_pretrained(TRANSFORMER_ENCODER)
            # output_hidden_states = True )
        print('Training CNN starting from ImageNet pretrained Inception-v3')

    print('start_epoch', start_epoch)

    if cfg.CUDA:
        text_encoder = text_encoder.cuda()
        image_encoder = image_encoder.cuda()
        labels = labels.cuda()

    return text_encoder, image_encoder, labels, start_epoch
Beispiel #9
0
def generate(captions, copies):
    x = pickle.load(open('data/captions.pickle', 'rb'))
    # print(x)
    ixtoword = x[2]
    wordtoix = x[3]
    del x
    word_len = len(wordtoix)

    text_encoder = RNN_ENCODER(word_len, nhidden=cfg.TEXT.EMBEDDING_DIM)
    state_dict = torch.load(cfg.TRAIN.NET_E,
                            map_location=lambda storage, loc: storage)
    text_encoder.load_state_dict(state_dict)
    text_encoder.eval()

    netG = G_NET()
    state_dict = torch.load(cfg.TRAIN.NET_G,
                            map_location=lambda storage, loc: storage)
    netG.load_state_dict(state_dict)
    # netG.eval()
    # seed = 100
    # random.seed(seed)
    # np.random.seed(seed)
    # torch.manual_seed(seed)
    # load word vector
    captions, cap_lens = vectorize_caption(wordtoix, captions, copies)
    n_words = len(wordtoix)

    # only one to generate
    batch_size = captions.shape[0]
    nz = cfg.GAN.Z_DIM
    captions = Variable(torch.from_numpy(captions), volatile=True)
    cap_lens = Variable(torch.from_numpy(cap_lens), volatile=True)
    noise = Variable(torch.FloatTensor(batch_size, nz), volatile=True)

    #######################################################
    # (1) Extract text embeddings
    #######################################################
    hidden = text_encoder.init_hidden(batch_size)
    words_embs, sent_emb = text_encoder(captions, cap_lens, hidden)
    mask = (captions == 0)

    #######################################################
    # (2) Generate fake images
    #######################################################
    noise.data.normal_(0, 1)
    fake_imgs, attention_maps, _, _ = netG(noise, sent_emb, words_embs, mask)

    # G attention
    cap_lens_np = cap_lens.cpu().data.numpy()

    #    # storing to blob storage
    # container_name = "images"
    # full_path = "https://attgan.blob.core.windows.net/images/%s"
    # prefix = datetime.now().strftime('%Y/%B/%d/%H_%M_%S_%f')
    imgs = []
    # only look at first one
    #j = 0
    return imglist(fake_imgs, batch_size)
def build_models():

    # build model ############################################################
    print('build_model(), model_type: ', model_type)
    if (model_type == 'bert'):
        #cfg.LOCAL_PRETRAINED = False
        if (cfg.LOCAL_PRETRAINED):
            tokenizer = tokenization.FullTokenizer(
                vocab_file=cfg.BERT_ENCODER.VOCAB, do_lower_case=True)
            vocab_size = len(tokenizer.vocab)
            #vocab_size = 3770
            #vocab_size = 4000
        else:
            tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
            vocab_size = len(tokenizer.vocab)
            #vocab_size = 30522
        text_encoder = BERT_RNN_ENCODER(vocab_size,
                                        nhidden=cfg.TEXT.EMBEDDING_DIM)
        image_encoder = BERT_CNN_ENCODER_RNN_DECODER(cfg.TEXT.EMBEDDING_DIM,
                                                     cfg.CNN_RNN.HIDDEN_DIM,
                                                     vocab_size,
                                                     rec_unit=cfg.RNN_TYPE)
    else:
        vocab_size = dataset_val.n_words
        text_encoder = RNN_ENCODER(vocab_size, nhidden=cfg.TEXT.EMBEDDING_DIM)
        image_encoder = CNN_ENCODER(cfg.TEXT.EMBEDDING_DIM)

    labels = Variable(torch.LongTensor(range(batch_size)))
    start_epoch = 0
    if cfg.TRAIN.NET_E != '':
        state_dict = torch.load(cfg.TRAIN.NET_E)
        text_encoder.load_state_dict(state_dict)
        print('Load ', cfg.TRAIN.NET_E)
        #
        name = cfg.TRAIN.NET_E.replace('text_encoder', 'image_encoder')
        state_dict = torch.load(name)
        image_encoder.load_state_dict(state_dict)
        print('Load ', name)

        istart = cfg.TRAIN.NET_E.rfind('_') + 8
        iend = cfg.TRAIN.NET_E.rfind('.')
        start_epoch = cfg.TRAIN.NET_E[istart:iend]
        start_epoch = int(start_epoch) + 1
        print('start_epoch', start_epoch)
    if cfg.CUDA:
        text_encoder = text_encoder.cuda()
        image_encoder = image_encoder.cuda()
        labels = labels.cuda()

    return text_encoder, image_encoder, labels, start_epoch
Beispiel #11
0
    def build_text_encoder(self):
        # Load trained text encoder model
        text_encoder = RNN_ENCODER(self.n_words, nhidden=self.embedding_dim)
        state_dict = torch.load(self.net_E,
                                map_location=lambda storage, loc: storage)
        text_encoder.load_state_dict(state_dict)
        text_encoder.eval()

        self.text_encoder = text_encoder
Beispiel #12
0
def models(word_len):
    text_encoder = cache.get('text_encoder')
    if text_encoder is None:
        text_encoder = RNN_ENCODER(word_len, nhidden=256)
        state_dict = torch.load('../DAMSMencoders/coco/text_encoder100.pth', map_location=lambda storage, loc: storage)
        text_encoder.load_state_dict(state_dict)
        text_encoder.cuda()
        text_encoder.eval()
        #cache.set('text_encoder', text_encoder, timeout=60 * 60 * 24)

    netG = cache.get('netG')
    if netG is None:
        netG = G_NET()
        state_dict = torch.load('../models/coco_AttnGAN2.pth', map_location=lambda storage, loc: storage)
        netG.load_state_dict(state_dict)
        if cfg.CUDA:
            netG.cuda()
        netG.eval()
        #cache.set('netG', netG, timeout=60 * 60 * 24)
    return text_encoder, netG
Beispiel #13
0
def models(word_len):
    print( word_len )
    text_encoder = cache.get('text_encoder')
    if text_encoder is None:
        print( "text_encoder not cached" )
        if sys.argv[1].casefold() == 'rnn':
            text_encoder = RNN_ENCODER(word_len, nhidden=cfg.TEXT.EMBEDDING_DIM)
        elif sys.argv[1].casefold() == 'transformer':
            text_encoder = GPT2Model.from_pretrained( TRANSFORMER_ENCODER )             
        state_dict = torch.load(cfg.TRAIN.NET_E, map_location=lambda storage, loc: storage)
        text_encoder.load_state_dict(state_dict)
        if cfg.CUDA:
            text_encoder.cuda()
        text_encoder.eval()
        cache.set('text_encoder', text_encoder, timeout=60 * 60 * 24)

    netG = cache.get('netG')
    if netG is None:
        print( "netG not cached" )
        if cfg.GAN.B_STYLEGEN:
            netG = G_NET_STYLED()
        else:
            netG = G_NET()
        checkpoint = torch.load(cfg.TRAIN.NET_G, map_location=lambda storage, loc: storage)
        if cfg.GAN.B_STYLEGEN:
            netG.w_ewma = checkpoint[ 'w_ewma' ]
            if cfg.CUDA:
                netG.w_ewma = netG.w_ewma.to( 'cuda:' + str( cfg.GPU_ID ) )
            netG.load_state_dict( checkpoint[ 'netG_state_dict' ] )
        else:
            netG.load_state_dict( checkpoint )
        if cfg.CUDA:
            netG.cuda()
        netG.eval()
        cache.set('netG', netG, timeout=60 * 60 * 24)

    return text_encoder, netG
def build_models(dataset, batch_size, audio_flag=False):
    # build model ############################################################
    if audio_flag:
        text_encoder = CNNRNN_Attn(40,
                                   nhidden=cfg.TEXT.EMBEDDING_DIM,
                                   nsent=cfg.TEXT.SENT_EMBEDDING_DIM)
    else:
        text_encoder = RNN_ENCODER(dataset.n_words,
                                   nhidden=cfg.TEXT.EMBEDDING_DIM,
                                   nsent=cfg.TEXT.SENT_EMBEDDING_DIM)
    image_encoder = CNN_ENCODER(cfg.TEXT.EMBEDDING_DIM,
                                condition=cfg.TRAIN.MASK_COND,
                                condition_channel=0)
    labels = torch.LongTensor(range(batch_size))
    start_epoch = 0
    if cfg.TRAIN.NET_E != '':
        state_dict = torch.load(cfg.TRAIN.NET_E,
                                map_location=lambda storage, loc: storage)
        text_encoder.load_state_dict(state_dict)
        print('Load ', cfg.TRAIN.NET_E)
        #
        name = cfg.TRAIN.NET_E.replace('text_encoder', 'image_encoder')
        state_dict = torch.load(name,
                                map_location=lambda storage, loc: storage)
        image_encoder.load_state_dict(state_dict)
        print('Load ', name)

        # istart = cfg.TRAIN.NET_E.rfind('encoder')
        # iend = cfg.TRAIN.NET_E.rfind('.')
        start_epoch = re.match(r'.*_encoder(\d+).*', cfg.TRAIN.NET_E).group(1)
        start_epoch = int(start_epoch) + 1
        print('start_epoch', start_epoch)
    if cfg.CUDA:
        text_encoder = text_encoder.cuda()
        image_encoder = image_encoder.cuda()
        labels = labels.cuda()

    return text_encoder, image_encoder, labels, start_epoch
Beispiel #15
0
def build_models():
    # build model ############################################################
    text_encoder = RNN_ENCODER(dataset.n_words, nhidden=cfg.TEXT.EMBEDDING_DIM)
    image_encoder = CNN_ENCODER(cfg.TEXT.EMBEDDING_DIM)
    labels = Variable(torch.LongTensor(range(batch_size)))
    start_epoch = 0
    # MODIFIED
    if cfg.PRETRAINED_RNN:
        text_encoder_params = torch.load(
            cfg.PRETRAINED_RNN, map_location=lambda storage, loc: storage)
        text_encoder.rnn.load_state_dict(text_encoder_params['encoder'])
        pad_idx = text_encoder_params['vocab']['word2id']['<pad>']
        n_words, embed_size = text_encoder.encoder.weight.size()
        text_encoder.encoder = nn.Embedding(n_words, embed_size, pad_idx)
        text_encoder.encoder.load_state_dict(text_encoder_params['embedding'])
    if cfg.PRETRAINED_CNN:
        image_encoder_params = torch.load(
            cfg.PRETRAINED_CNN, map_location=lambda storage, loc: storage)
        image_encoder.load_state_dict(image_encoder_params)
    if cfg.TRAIN.NET_E != '':
        state_dict = torch.load(cfg.TRAIN.NET_E)
        text_encoder.load_state_dict(state_dict)
        print('Load ', cfg.TRAIN.NET_E)
        #
        name = cfg.TRAIN.NET_E.replace('text_encoder', 'image_encoder')
        state_dict = torch.load(name)
        image_encoder.load_state_dict(state_dict)
        print('Load ', name)

        istart = cfg.TRAIN.NET_E.rfind('_') + 8
        iend = cfg.TRAIN.NET_E.rfind('.')
        start_epoch = cfg.TRAIN.NET_E[istart:iend]
        start_epoch = int(start_epoch) + 1
        print('start_epoch', start_epoch)
    if cfg.CUDA:
        text_encoder = text_encoder.cuda()
        image_encoder = image_encoder.cuda()
        labels = labels.cuda()

    return text_encoder, image_encoder, labels, start_epoch
    def sample(self, split_dir, num_samples=25, draw_bbox=False):
        from PIL import Image, ImageDraw, ImageFont
        import cPickle as pickle
        import torchvision
        import torchvision.utils as vutils

        if cfg.TRAIN.NET_G == '':
            print('Error: the path for model NET_G is not found!')
        else:
            if split_dir == 'test':
                split_dir = 'valid'
            # Build and load the generator
            text_encoder = RNN_ENCODER(self.n_words, nhidden=cfg.TEXT.EMBEDDING_DIM)
            state_dict = \
                torch.load(cfg.TRAIN.NET_E, map_location=lambda storage, loc: storage)
            text_encoder.load_state_dict(state_dict)
            print('Load text encoder from:', cfg.TRAIN.NET_E)
            text_encoder = text_encoder.cuda()
            text_encoder.eval()

            batch_size = cfg.TRAIN.BATCH_SIZE
            nz = cfg.GAN.Z_DIM

            model_dir = cfg.TRAIN.NET_G
            state_dict = torch.load(model_dir, map_location=lambda storage, loc: storage)
            # state_dict = torch.load(cfg.TRAIN.NET_G)
            netG = G_NET()
            print('Load G from: ', model_dir)
            netG.apply(weights_init)

            netG.load_state_dict(state_dict["netG"])
            netG.cuda()
            netG.eval()

            # the path to save generated images
            s_tmp = model_dir[:model_dir.rfind('.pth')]
            save_dir = '%s_%s' % (s_tmp, split_dir)
            mkdir_p(save_dir)
            #######################################
            noise = Variable(torch.FloatTensor(9, nz))

            imsize = 256

            for step, data in enumerate(self.data_loader, 0):
                if step >= num_samples:
                    break

                imgs, captions, cap_lens, class_ids, keys, transformation_matrices, label_one_hot, bbox = \
                    prepare_data(data, eval=True)
                transf_matrices_inv = transformation_matrices[1][0].unsqueeze(0)
                label_one_hot = label_one_hot[0].unsqueeze(0)

                img = imgs[-1][0]
                val_image = img.view(1, 3, imsize, imsize)

                hidden = text_encoder.init_hidden(batch_size)
                # words_embs: batch_size x nef x seq_len
                # sent_emb: batch_size x nef
                words_embs, sent_emb = text_encoder(captions, cap_lens, hidden)
                words_embs, sent_emb = words_embs[0].unsqueeze(0).detach(), sent_emb[0].unsqueeze(0).detach()
                words_embs = words_embs.repeat(9, 1, 1)
                sent_emb = sent_emb.repeat(9, 1)
                mask = (captions == 0)
                mask = mask[0].unsqueeze(0)
                num_words = words_embs.size(2)
                if mask.size(1) > num_words:
                    mask = mask[:, :num_words]
                mask = mask.repeat(9, 1)
                transf_matrices_inv = transf_matrices_inv.repeat(9, 1, 1, 1)
                label_one_hot = label_one_hot.repeat(9, 1, 1)

                #######################################################
                # (2) Generate fake images
                ######################################################
                noise.data.normal_(0, 1)
                inputs = (noise, sent_emb, words_embs, mask, transf_matrices_inv, label_one_hot)
                with torch.no_grad():
                    fake_imgs, _, mu, logvar = nn.parallel.data_parallel(netG, inputs, self.gpus)

                data_img = torch.FloatTensor(10, 3, imsize, imsize).fill_(0)
                data_img[0] = val_image
                data_img[1:10] = fake_imgs[-1]

                if draw_bbox:
                    for idx in range(3):
                        x, y, w, h = tuple([int(imsize*x) for x in bbox[0, idx]])
                        w = imsize-1 if w > imsize-1 else w
                        h = imsize-1 if h > imsize-1 else h
                        if x <= -1:
                            break
                        data_img[:10, :, y, x:x + w] = 1
                        data_img[:10, :, y:y + h, x] = 1
                        data_img[:10, :, y+h, x:x + w] = 1
                        data_img[:10, :, y:y + h, x + w] = 1

                # get caption
                cap = captions[0].data.cpu().numpy()
                sentence = ""
                for j in range(len(cap)):
                    if cap[j] == 0:
                        break
                    word = self.ixtoword[cap[j]].encode('ascii', 'ignore').decode('ascii')
                    sentence += word + " "
                sentence = sentence[:-1]
                vutils.save_image(data_img, '{}/{}_{}.png'.format(save_dir, sentence, step), normalize=True, nrow=10)
            print("Saved {} files to {}".format(step, save_dir))
    def sampling(self, split_dir, num_samples=30000):
        if cfg.TRAIN.NET_G == '':
            print('Error: the path for morels is not found!')
        else:
            if split_dir == 'test':
                split_dir = 'valid'
            # Build and load the generator
            if cfg.GAN.B_DCGAN:
                netG = G_DCGAN()
            else:
                netG = G_NET()
            netG.apply(weights_init)
            netG.cuda()
            netG.eval()
            #
            text_encoder = RNN_ENCODER(self.n_words, nhidden=cfg.TEXT.EMBEDDING_DIM)
            state_dict = \
                torch.load(cfg.TRAIN.NET_E, map_location=lambda storage, loc: storage)
            text_encoder.load_state_dict(state_dict)
            print('Load text encoder from:', cfg.TRAIN.NET_E)
            text_encoder = text_encoder.cuda()
            text_encoder.eval()

            batch_size = self.batch_size
            nz = cfg.GAN.Z_DIM
            noise = Variable(torch.FloatTensor(batch_size, nz))
            noise = noise.cuda()

            model_dir = cfg.TRAIN.NET_G
            state_dict = \
                torch.load(model_dir, map_location=lambda storage, loc: storage)
            # state_dict = torch.load(cfg.TRAIN.NET_G)
            netG.load_state_dict(state_dict["netG"])
            print('Load G from: ', model_dir)

            # the path to save generated images
            s_tmp = model_dir[:model_dir.rfind('.pth')]
            save_dir = '%s/%s' % (s_tmp, split_dir)
            mkdir_p(save_dir)

            cnt = 0

            for _ in range(1):  # (cfg.TEXT.CAPTIONS_PER_IMAGE):
                for step, data in enumerate(self.data_loader, 0):
                    cnt += batch_size
                    if step % 10000 == 0:
                        print('step: ', step)
                    if step >= num_samples:
                        break

                    imgs, captions, cap_lens, class_ids, keys, transformation_matrices, label_one_hot = prepare_data(data)
                    transf_matrices_inv = transformation_matrices[1]

                    hidden = text_encoder.init_hidden(batch_size)
                    # words_embs: batch_size x nef x seq_len
                    # sent_emb: batch_size x nef
                    words_embs, sent_emb = text_encoder(captions, cap_lens, hidden)
                    words_embs, sent_emb = words_embs.detach(), sent_emb.detach()
                    mask = (captions == 0)
                    num_words = words_embs.size(2)
                    if mask.size(1) > num_words:
                        mask = mask[:, :num_words]

                    #######################################################
                    # (2) Generate fake images
                    ######################################################
                    noise.data.normal_(0, 1)
                    inputs = (noise, sent_emb, words_embs, mask, transf_matrices_inv, label_one_hot)
                    with torch.no_grad():
                        fake_imgs, _, mu, logvar = nn.parallel.data_parallel(netG, inputs, self.gpus)
                    for j in range(batch_size):
                        s_tmp = '%s/single/%s' % (save_dir, keys[j])
                        folder = s_tmp[:s_tmp.rfind('/')]
                        if not os.path.isdir(folder):
                            print('Make a new folder: ', folder)
                            mkdir_p(folder)
                        k = -1
                        # for k in range(len(fake_imgs)):
                        im = fake_imgs[k][j].data.cpu().numpy()
                        # [-1, 1] --> [0, 255]
                        im = (im + 1.0) * 127.5
                        im = im.astype(np.uint8)
                        im = np.transpose(im, (1, 2, 0))
                        im = Image.fromarray(im)
                        fullpath = '%s_s%d.png' % (s_tmp, k)
                        im.save(fullpath)
def gen_example(n_words, wordtoix, ixtoword, model_dir):
    '''generate images from example sentences'''
    # filepath = 'example_captions.txt'
    filepath = 'caption.txt'
    data_dic = {}
    with open(filepath, "r") as f:
        filenames = f.read().split('\n')

        captions = []
        cap_lens = []

        for sent in filenames:
            if len(sent) == 0:
                continue
            sent = sent.replace("\ufffd\ufffd", " ")
            tokenizer = RegexpTokenizer(r'\w+')
            tokens = tokenizer.tokenize(sent.lower())
            if len(tokens) == 0:
                print('sentence token == 0 !')
                continue

            rev = []
            for t in tokens:
                t = t.encode('ascii', 'ignore').decode('ascii')
                if len(t) > 0 and t in wordtoix:
                    rev.append(wordtoix[t])
            captions.append(rev)
            cap_lens.append(len(rev))

        max_len = np.max(cap_lens)
        sorted_indices = np.argsort(cap_lens)[::-1]
        cap_lens = np.asarray(cap_lens)
        cap_lens = cap_lens[sorted_indices]
        cap_array = np.zeros((len(captions), max_len), dtype='int64')

        for i in range(len(captions)):
            idx = sorted_indices[i]
            cap = captions[idx]
            c_len = len(cap)
            cap_array[i, :c_len] = cap
        # key = name[(name.rfind('/') + 1):]
        key = 0
        data_dic[key] = [cap_array, cap_lens, sorted_indices]

    # algo.gen_example(data_dic)
    text_encoder = RNN_ENCODER(n_words, nhidden=cfg.TEXT.EMBEDDING_DIM)
    state_dict = torch.load(cfg.TRAIN.NET_E,
                            map_location=lambda storage, loc: storage)
    text_encoder.load_state_dict(state_dict)
    print('Load text encoder from:', cfg.TRAIN.NET_E)
    text_encoder.eval()

    netG = G_NET()
    netG.apply(weights_init)
    # netG.cuda()
    netG.eval()
    state_dict = torch.load(model_dir,
                            map_location=lambda storage, loc: storage)
    netG.load_state_dict(state_dict)
    print('Load G from: ', model_dir)

    save_dir = 'results'
    mkdir_p(save_dir)
    for key in data_dic:
        captions, cap_lens, sorted_indices = data_dic[key]

        batch_size = captions.shape[0]
        nz = cfg.GAN.Z_DIM

        with torch.no_grad():
            captions = Variable(torch.from_numpy(captions))
            cap_lens = Variable(torch.from_numpy(cap_lens))

            # captions = captions.cuda()
            # cap_lens = cap_lens.cuda()

        for i in range(image_per_caption):  # 16
            with torch.no_grad():
                noise = Variable(torch.FloatTensor(batch_size, nz))
                # noise = noise.cuda()

            # (1) Extract text embeddings
            hidden = text_encoder.init_hidden(batch_size)
            words_embs, sent_emb = text_encoder(captions, cap_lens, hidden)
            mask = (captions == 0)
            # (2) Generate fake images
            noise.data.normal_(0, 1)
            fake_imgs, attention_maps, _, _ = netG(noise, sent_emb, words_embs,
                                                   mask)

            cap_lens_np = cap_lens.data.numpy()

            for j in range(batch_size):
                save_name = '%s/%d_%d' % (save_dir, i, sorted_indices[j])
                for k in range(len(fake_imgs)):
                    im = fake_imgs[k][j].data.cpu().numpy()
                    im = (im + 1.0) * 127.5
                    im = im.astype(np.uint8)
                    # print('im', im.shape)
                    im = np.transpose(im, (1, 2, 0))
                    # print('im', im.shape)
                    im = Image.fromarray(im)
                    fullpath = '%s_g%d.png' % (save_name, k)
                    im.save(fullpath)

                for k in range(len(attention_maps)):
                    if len(fake_imgs) > 1:
                        im = fake_imgs[k + 1]
                    else:
                        im = fake_imgs[0]
                    attn_maps = attention_maps[k]
                    att_sze = attn_maps.size(2)
                    img_set, sentences = \
                        build_super_images2(im[j].unsqueeze(0),
                                            captions[j].unsqueeze(0),
                                            [cap_lens_np[j]], ixtoword,
                                            [attn_maps[j]], att_sze)
                    if img_set is not None:
                        im = Image.fromarray(img_set)
                        fullpath = '%s_a%d_attention.png' % (save_name, k)
                        im.save(fullpath)
Beispiel #19
0
    def sampling(self, split_dir, num_samples=30000):
        if cfg.TRAIN.NET_G == '':
            logger.error('Error: the path for morels is not found!')
        else:
            if split_dir == 'test':
                split_dir = 'valid'
            # Build and load the generator
            if cfg.GAN.B_DCGAN:
                netG = G_DCGAN()
            else:
                netG = G_NET()
            netG.apply(weights_init)
            netG.to(cfg.DEVICE)
            netG.eval()
            #
            text_encoder = RNN_ENCODER(self.n_words, nhidden=cfg.TEXT.EMBEDDING_DIM)
            state_dict = torch.load(cfg.TRAIN.NET_E, map_location=lambda storage, loc: storage)
            text_encoder.load_state_dict(state_dict)
            text_encoder = text_encoder.to(cfg.DEVICE)
            text_encoder.eval()
            logger.info('Loaded text encoder from: %s', cfg.TRAIN.NET_E)

            batch_size = self.batch_size[0]
            nz = cfg.GAN.GLOBAL_Z_DIM
            noise = Variable(torch.FloatTensor(batch_size, nz)).to(cfg.DEVICE)
            local_noise = Variable(torch.FloatTensor(batch_size, cfg.GAN.LOCAL_Z_DIM)).to(cfg.DEVICE)

            model_dir = cfg.TRAIN.NET_G
            state_dict = torch.load(model_dir, map_location=lambda storage, loc: storage)
            netG.load_state_dict(state_dict["netG"])
            max_objects = 10
            logger.info('Load G from: %s', model_dir)

            # the path to save generated images
            s_tmp = model_dir[:model_dir.rfind('.pth')].split("/")[-1]
            save_dir = '%s/%s/%s' % ("../output", s_tmp, split_dir)
            mkdir_p(save_dir)
            logger.info("Saving images to: {}".format(save_dir))

            number_batches = num_samples // batch_size
            if number_batches < 1:
                number_batches = 1

            data_iter = iter(self.data_loader)

            for step in tqdm(range(number_batches)):
                data = data_iter.next()

                imgs, captions, cap_lens, class_ids, keys, transformation_matrices, label_one_hot, _ = prepare_data(
                    data, eval=True)

                transf_matrices = transformation_matrices[0]
                transf_matrices_inv = transformation_matrices[1]

                hidden = text_encoder.init_hidden(batch_size)
                # words_embs: batch_size x nef x seq_len
                # sent_emb: batch_size x nef
                words_embs, sent_emb = text_encoder(captions, cap_lens, hidden)
                words_embs, sent_emb = words_embs.detach(), sent_emb.detach()
                mask = (captions == 0)
                num_words = words_embs.size(2)
                if mask.size(1) > num_words:
                    mask = mask[:, :num_words]

                #######################################################
                # (2) Generate fake images
                ######################################################
                noise.data.normal_(0, 1)
                local_noise.data.normal_(0, 1)
                inputs = (noise, local_noise, sent_emb, words_embs, mask, transf_matrices, transf_matrices_inv, label_one_hot, max_objects)
                inputs = tuple((inp.to(cfg.DEVICE) if isinstance(inp, torch.Tensor) else inp) for inp in inputs)

                with torch.no_grad():
                    fake_imgs, _, mu, logvar = netG(*inputs)
                for batch_idx, j in enumerate(range(batch_size)):
                    s_tmp = '%s/%s' % (save_dir, keys[j])
                    folder = s_tmp[:s_tmp.rfind('/')]
                    if not os.path.isdir(folder):
                        logger.info('Make a new folder: %s', folder)
                        mkdir_p(folder)
                    k = -1
                    # for k in range(len(fake_imgs)):
                    im = fake_imgs[k][j].data.cpu().numpy()
                    # [-1, 1] --> [0, 255]
                    im = (im + 1.0) * 127.5
                    im = im.astype(np.uint8)
                    im = np.transpose(im, (1, 2, 0))
                    im = Image.fromarray(im)
                    fullpath = '%s_s%d.png' % (s_tmp, step*batch_size+batch_idx)
                    im.save(fullpath)
Beispiel #20
0
    def build_models(self):

        if cfg.TRAIN.NET_E == '':
            print('Error: no pretrained text-image encoders')
            return

        # vgg16 network
        style_loss = VGGNet()

        for p in style_loss.parameters():
            p.requires_grad = False

        print("Load the style loss model")
        style_loss.eval()

        image_encoder = CNN_ENCODER(cfg.TEXT.EMBEDDING_DIM)
        img_encoder_path = cfg.TRAIN.NET_E.replace('text_encoder',
                                                   'image_encoder')
        state_dict = \
            torch.load(img_encoder_path, map_location=lambda storage, loc: storage)
        image_encoder.load_state_dict(state_dict)
        for p in image_encoder.parameters():
            p.requires_grad = False
        print('Load image encoder from:', img_encoder_path)
        image_encoder.eval()

        text_encoder = \
            RNN_ENCODER(self.n_words, nhidden=cfg.TEXT.EMBEDDING_DIM)
        state_dict = \
            torch.load(cfg.TRAIN.NET_E,
                       map_location=lambda storage, loc: storage)
        text_encoder.load_state_dict(state_dict)
        for p in text_encoder.parameters():
            p.requires_grad = False
        print('Load text encoder from:', cfg.TRAIN.NET_E)
        text_encoder.eval()

        netsD = []
        if cfg.GAN.B_DCGAN:
            if cfg.TREE.BRANCH_NUM == 1:
                from model import D_NET64 as D_NET
            elif cfg.TREE.BRANCH_NUM == 2:
                from model import D_NET128 as D_NET
            else:  # cfg.TREE.BRANCH_NUM == 3:
                from model import D_NET256 as D_NET
            netG = G_DCGAN()
            netsD = [D_NET(b_jcu=False)]
        else:
            from model import D_NET64, D_NET128, D_NET256
            netG = G_NET()

            if cfg.TREE.BRANCH_NUM > 0:
                netsD.append(D_NET64())
            if cfg.TREE.BRANCH_NUM > 1:
                netsD.append(D_NET128())
            if cfg.TREE.BRANCH_NUM > 2:
                netsD.append(D_NET256())
        netG.apply(weights_init)

        for i in range(len(netsD)):
            netsD[i].apply(weights_init)
        print('# of netsD', len(netsD))
        #
        epoch = 0
        if cfg.TRAIN.NET_G != '':
            state_dict = \
                torch.load(cfg.TRAIN.NET_G, map_location=lambda storage, loc: storage)
            netG.load_state_dict(state_dict)

            print('Load G from: ', cfg.TRAIN.NET_G)
            istart = cfg.TRAIN.NET_G.rfind('_') + 1
            iend = cfg.TRAIN.NET_G.rfind('.')
            epoch = cfg.TRAIN.NET_G[istart:iend]
            epoch = int(epoch) + 1
            if cfg.TRAIN.B_NET_D:
                Gname = cfg.TRAIN.NET_G
                for i in range(len(netsD)):
                    s_tmp = Gname[:Gname.rfind('/')]
                    Dname = '%s/netD%d.pth' % (s_tmp, i)
                    print('Load D from: ', Dname)
                    state_dict = \
                        torch.load(Dname, map_location=lambda storage, loc: storage)
                    netsD[i].load_state_dict(state_dict)

        # Create a target network.
        target_netG = deepcopy(netG)

        if cfg.CUDA:
            text_encoder = text_encoder.cuda()
            image_encoder = image_encoder.cuda()
            style_loss = style_loss.cuda()

            # The target network is stored on the scondary GPU.---------------------------------
            target_netG.cuda(secondary_device)
            target_netG.ca_net.device = secondary_device
            #-----------------------------------------------------------------------------------

            netG.cuda()
            for i in range(len(netsD)):
                netsD[i] = netsD[i].cuda()

        # Disable training in the target network:
        for p in target_netG.parameters():
            p.requires_grad = False

        return [
            text_encoder, image_encoder, netG, target_netG, netsD, epoch,
            style_loss
        ]
Beispiel #21
0
    def build_models(self):
        # ###################encoders######################################## #
        if cfg.TRAIN.NET_E == '':
            print('Error: no pretrained text-image encoders')
            return

        image_encoder = CNN_ENCODER(cfg.TEXT.EMBEDDING_DIM)
        img_encoder_path = cfg.TRAIN.NET_E.replace('text_encoder',
                                                   'image_encoder')
        state_dict = \
            torch.load(img_encoder_path, map_location=lambda storage, loc: storage)
        image_encoder.load_state_dict(state_dict)
        for p in image_encoder.parameters():
            p.requires_grad = False
        print('Load image encoder from:', img_encoder_path)
        image_encoder.eval()

        text_encoder = \
            RNN_ENCODER(self.n_words, nhidden=cfg.TEXT.EMBEDDING_DIM)
        state_dict = \
            torch.load(cfg.TRAIN.NET_E,
                       map_location=lambda storage, loc: storage)
        text_encoder.load_state_dict(state_dict)
        for p in text_encoder.parameters():
            p.requires_grad = False
        print('Load text encoder from:', cfg.TRAIN.NET_E)
        text_encoder.eval()

        # #######################generator and discriminators############## #
        netsD = []
        if cfg.GAN.B_DCGAN:
            if cfg.TREE.BRANCH_NUM == 1:
                from model import D_NET64 as D_NET
            elif cfg.TREE.BRANCH_NUM == 2:
                from model import D_NET128 as D_NET
            else:  # cfg.TREE.BRANCH_NUM == 3:
                from model import D_NET256 as D_NET
            # TODO: elif cfg.TREE.BRANCH_NUM > 3:
            netG = G_DCGAN()
            netsD = [D_NET(b_jcu=False)]
        else:
            from model import D_NET64, D_NET128, D_NET256
            netG = G_NET()
            if cfg.TREE.BRANCH_NUM > 0:
                netsD.append(D_NET64())
            if cfg.TREE.BRANCH_NUM > 1:
                netsD.append(D_NET128())
            if cfg.TREE.BRANCH_NUM > 2:
                netsD.append(D_NET256())
            # TODO: if cfg.TREE.BRANCH_NUM > 3:
        netG.apply(weights_init)
        # print(netG)
        for i in range(len(netsD)):
            netsD[i].apply(weights_init)
            # print(netsD[i])
        print('# of netsD', len(netsD))
        #
        epoch = 0
        # MODIFIED
        if cfg.PRETRAINED_G != '':
            state_dict = torch.load(cfg.PRETRAINED_G,
                                    map_location=lambda storage, loc: storage)
            netG.load_state_dict(state_dict)
            print('Load G from: ', cfg.PRETRAINED_G)
            if cfg.TRAIN.B_NET_D:
                Gname = cfg.PRETRAINED_G
                s_tmp = Gname[:Gname.rfind('/')]
                for i in range(len(netsD)):
                    Dname = '%s/netD%d.pth' % (
                        s_tmp, i
                    )  # the name of Ds should be consistent and differ from each other in i
                    print('Load D from: ', Dname)
                    state_dict = torch.load(
                        Dname, map_location=lambda storage, loc: storage)
                    netsD[i].load_state_dict(state_dict)

        if cfg.TRAIN.NET_G != '':
            state_dict = \
                torch.load(cfg.TRAIN.NET_G, map_location=lambda storage, loc: storage)
            netG.load_state_dict(state_dict)
            print('Load G from: ', cfg.TRAIN.NET_G)
            istart = cfg.TRAIN.NET_G.rfind('_') + 1
            iend = cfg.TRAIN.NET_G.rfind('.')
            epoch = cfg.TRAIN.NET_G[istart:iend]
            epoch = int(epoch) + 1
            if cfg.TRAIN.B_NET_D:
                Gname = cfg.TRAIN.NET_G
                for i in range(len(netsD)):
                    s_tmp = Gname[:Gname.rfind('/')]
                    Dname = '%s/netD%d.pth' % (s_tmp, i)
                    print('Load D from: ', Dname)
                    state_dict = \
                        torch.load(Dname, map_location=lambda storage, loc: storage)
                    netsD[i].load_state_dict(state_dict)
        # ########################################################### #
        if cfg.CUDA:
            text_encoder = text_encoder.cuda()
            image_encoder = image_encoder.cuda()
            netG.cuda()
            for i in range(len(netsD)):
                netsD[i].cuda()
        return [text_encoder, image_encoder, netG, netsD, epoch]
Beispiel #22
0
    def embedding(self, split_dir, model):
        if cfg.TRAIN.NET_G == '':
            print('Error: the path for morels is not found!')
        else:
            if split_dir == 'test':
                split_dir = 'valid'
            # Build and load the generator
            if cfg.GAN.B_DCGAN:
                netG = G_DCGAN()
            else:
                netG = G_NET()
            netG.apply(weights_init)
            if cfg.GPU_ID != -1:
                netG.cuda()
            netG.eval()
            #
            model_dir = cfg.TRAIN.NET_G
            state_dict = \
                torch.load(model_dir, map_location=lambda storage, loc: storage)
            netG.load_state_dict(state_dict)
            print('Load G from: ', model_dir)

            image_encoder = CNN_ENCODER(cfg.TEXT.EMBEDDING_DIM)
            img_encoder_path = cfg.TRAIN.NET_E.replace('text_encoder',
                                                       'image_encoder')
            print(img_encoder_path)
            print('Load image encoder from:', img_encoder_path)
            state_dict = \
                torch.load(img_encoder_path, map_location=lambda storage, loc: storage)
            image_encoder.load_state_dict(state_dict)
            if cfg.GPU_ID != -1:
                image_encoder = image_encoder.cuda()
            image_encoder.eval()

            print('Load text encoder from:', cfg.TRAIN.NET_E)
            text_encoder = RNN_ENCODER(self.n_words,
                                       nhidden=cfg.TEXT.EMBEDDING_DIM)
            state_dict = \
                torch.load(cfg.TRAIN.NET_E, map_location=lambda storage, loc: storage)
            text_encoder.load_state_dict(state_dict)
            if cfg.GPU_ID != -1:
                text_encoder = text_encoder.cuda()
            text_encoder.eval()

            batch_size = self.batch_size
            nz = cfg.GAN.Z_DIM

            with torch.no_grad():
                noise = Variable(torch.FloatTensor(batch_size, nz))
                if cfg.GPU_ID != -1:
                    noise = noise.cuda()

            # the path to save generated images
            save_dir = model_dir[:model_dir.rfind('.pth')]

            cnt = 0

            # new
            if cfg.TRAIN.CLIP_SENTENCODER:
                print("Use CLIP SentEncoder for sampling")
            img_features = dict()
            txt_features = dict()

            with torch.no_grad():
                for _ in range(1):  # (cfg.TEXT.CAPTIONS_PER_IMAGE):
                    for step, data in enumerate(self.data_loader, 0):
                        cnt += batch_size
                        if step % 100 == 0:
                            print('step: ', step)

                        imgs, captions, cap_lens, class_ids, keys, texts = prepare_data(
                            data)

                        hidden = text_encoder.init_hidden(batch_size)
                        # words_embs: batch_size x nef x seq_len
                        # sent_emb: batch_size x nef
                        words_embs, sent_emb = text_encoder(
                            captions, cap_lens, hidden)
                        words_embs, sent_emb = words_embs.detach(
                        ), sent_emb.detach()
                        mask = (captions == 0)
                        num_words = words_embs.size(2)
                        if mask.size(1) > num_words:
                            mask = mask[:, :num_words]

                        if cfg.TRAIN.CLIP_SENTENCODER:

                            # random select one paragraph for each training example
                            sents = []
                            for idx in range(len(texts)):
                                sents_per_image = texts[idx].split(
                                    '\n')  # new 3/11
                                if len(sents_per_image) > 1:
                                    sent_ix = np.random.randint(
                                        0,
                                        len(sents_per_image) - 1)
                                else:
                                    sent_ix = 0
                                sents.append(sents_per_image[0])
                            # print('sents: ', sents)

                            sent = clip.tokenize(sents)  # .to(device)

                            # load clip
                            #model = torch.jit.load("model.pt").cuda().eval()
                            sent_input = sent
                            if cfg.GPU_ID != -1:
                                sent_input = sent.cuda()
                            # print("text input", sent_input)
                            sent_emb_clip = model.encode_text(
                                sent_input).float()
                            if CLIP:
                                sent_emb = sent_emb_clip
                        #######################################################
                        # (2) Generate fake images
                        ######################################################
                        noise.data.normal_(0, 1)
                        fake_imgs, _, _, _ = netG(noise, sent_emb, words_embs,
                                                  mask)
                        if CLIP:
                            images = []
                            for j in range(fake_imgs[-1].shape[0]):
                                image = fake_imgs[-1][j].cpu().clone()
                                image = image.squeeze(0)
                                unloader = transforms.ToPILImage()
                                image = unloader(image)

                                image = preprocess(
                                    image.convert("RGB"))  # 256*256 -> 224*224
                                images.append(image)

                            image_mean = torch.tensor(
                                [0.48145466, 0.4578275, 0.40821073]).cuda()
                            image_std = torch.tensor(
                                [0.26862954, 0.26130258, 0.27577711]).cuda()

                            image_input = torch.tensor(np.stack(images)).cuda()
                            image_input -= image_mean[:, None, None]
                            image_input /= image_std[:, None, None]
                            cnn_codes = model.encode_image(image_input).float()
                        else:
                            region_features, cnn_codes = image_encoder(
                                fake_imgs[-1])
                        for j in range(batch_size):
                            cnn_code = cnn_codes[j]

                            temp = keys[j].replace('b', '').replace("'", '')
                            img_features[temp] = cnn_code.cpu().numpy()
                            txt_features[temp] = sent_emb[j].cpu().numpy()
            with open(save_dir + ".pkl", 'wb') as f:
                pickle.dump(img_features, f)
            with open(save_dir + "_text.pkl", 'wb') as f:
                pickle.dump(txt_features, f)
Beispiel #23
0
    def sampling(self, split_dir):
        if cfg.TRAIN.NET_G == '':
            print('Error: the path for models is not found!')
        else:
            if split_dir == 'test':
                split_dir = 'valid'
            if cfg.GAN.B_DCGAN:
                netG = G_DCGAN()
            else:
                netG = G_NET()
            netG.apply(weights_init)
            netG.cuda()
            netG.eval()
            #
            text_encoder = RNN_ENCODER(self.n_words,
                                       nhidden=cfg.TEXT.EMBEDDING_DIM)
            state_dict = \
                torch.load(cfg.TRAIN.NET_E, map_location=lambda storage, loc: storage)
            text_encoder.load_state_dict(state_dict)
            print('Load text encoder from:', cfg.TRAIN.NET_E)
            text_encoder = text_encoder.cuda()
            text_encoder.eval()

            batch_size = self.batch_size
            nz = cfg.GAN.Z_DIM
            noise = Variable(torch.FloatTensor(batch_size, nz), volatile=True)
            noise = noise.cuda()

            model_dir = cfg.TRAIN.NET_G
            state_dict = \
                torch.load(model_dir, map_location=lambda storage, loc: storage)
            netG.load_state_dict(state_dict)
            print('Load G from: ', model_dir)

            # the path to save generated images
            s_tmp = model_dir[:model_dir.rfind('.pth')]
            save_dir = '%s/%s' % (s_tmp, split_dir)
            mkdir_p(save_dir)

            cnt = 0
            idx = 0  ###

            avg_ddva = 0
            for _ in range(1):
                for step, data in enumerate(self.data_loader, 0):
                    cnt += batch_size
                    if step % 100 == 0:
                        print('step: ', step)

                    captions, cap_lens, imperfect_captions, imperfect_cap_lens, misc = data

                    # Generate images for human-text ----------------------------------------------------------------
                    data_human = [captions, cap_lens, misc]

                    imgs, captions, cap_lens, class_ids, keys, wrong_caps,\
                                wrong_caps_len, wrong_cls_id= prepare_data(data_human)

                    hidden = text_encoder.init_hidden(batch_size)
                    words_embs, sent_emb = text_encoder(
                        captions, cap_lens, hidden)
                    words_embs, sent_emb = words_embs.detach(
                    ), sent_emb.detach()
                    mask = (captions == 0)
                    num_words = words_embs.size(2)
                    if mask.size(1) > num_words:
                        mask = mask[:, :num_words]

                    noise.data.normal_(0, 1)
                    fake_imgs, _, _, _ = netG(noise, sent_emb, words_embs,
                                              mask)

                    # Generate images for imperfect caption-text-------------------------------------------------------
                    data_imperfect = [
                        imperfect_captions, imperfect_cap_lens, misc
                    ]

                    imgs, imperfect_captions, imperfect_cap_lens, class_ids, imperfect_keys, wrong_caps,\
                                wrong_caps_len, wrong_cls_id = prepare_data(data_imperfect)

                    hidden = text_encoder.init_hidden(batch_size)
                    words_embs, sent_emb = text_encoder(
                        imperfect_captions, imperfect_cap_lens, hidden)
                    words_embs, sent_emb = words_embs.detach(
                    ), sent_emb.detach()
                    mask = (imperfect_captions == 0)
                    num_words = words_embs.size(2)
                    if mask.size(1) > num_words:
                        mask = mask[:, :num_words]

                    noise.data.normal_(0, 1)
                    imperfect_fake_imgs, _, _, _ = netG(
                        noise, sent_emb, words_embs, mask)

                    # Sort the results by keys to align ----------------------------------------------------------------
                    keys, captions, cap_lens, fake_imgs, _, _ = sort_by_keys(
                        keys, captions, cap_lens, fake_imgs, None, None)

                    imperfect_keys, imperfect_captions, imperfect_cap_lens, imperfect_fake_imgs, true_imgs, _ = \
                                sort_by_keys(imperfect_keys, imperfect_captions, imperfect_cap_lens, imperfect_fake_imgs,\
                                             imgs, None)

                    # Shift device for the imgs, target_imgs and imperfect_imgs------------------------------------------------
                    for i in range(len(imgs)):
                        imgs[i] = imgs[i].to(secondary_device)
                        imperfect_fake_imgs[i] = imperfect_fake_imgs[i].to(
                            secondary_device)
                        fake_imgs[i] = fake_imgs[i].to(secondary_device)

                    for j in range(batch_size):
                        s_tmp = '%s/single' % (save_dir)
                        folder = s_tmp[:s_tmp.rfind('/')]
                        if not os.path.isdir(folder):
                            print('Make a new folder: ', folder)
                            mkdir_p(folder)
                        k = -1
                        im = fake_imgs[k][j].data.cpu().numpy()
                        im = (im + 1.0) * 127.5
                        im = im.astype(np.uint8)
                        im = np.transpose(im, (1, 2, 0))

                        cap_im = imperfect_fake_imgs[k][j].data.cpu().numpy()
                        cap_im = (cap_im + 1.0) * 127.5
                        cap_im = cap_im.astype(np.uint8)
                        cap_im = np.transpose(cap_im, (1, 2, 0))

                        # Uncomment to scale true image
                        true_im = true_imgs[k][j].data.cpu().numpy()
                        true_im = (true_im + 1.0) * 127.5
                        true_im = true_im.astype(np.uint8)
                        true_im = np.transpose(true_im, (1, 2, 0))

                        # Uncomment to save images.
                        #true_im = Image.fromarray(true_im)
                        #fullpath = '%s_true_s%d.png' % (s_tmp, idx)
                        #true_im.save(fullpath)
                        im = Image.fromarray(im)
                        fullpath = '%s_s%d.png' % (s_tmp, idx)
                        im.save(fullpath)
                        #cap_im = Image.fromarray(cap_im)
                        #fullpath = '%s_imperfect_s%d.png' % (s_tmp, idx)
                        idx = idx + 1
                        #cap_im.save(fullpath)

                    neg_ddva = negative_ddva(
                        imperfect_fake_imgs,
                        imgs,
                        fake_imgs,
                        reduce='mean',
                        final_only=True).data.cpu().numpy()
                    avg_ddva += neg_ddva * (-1)

                    #text_caps = [[self.ixtoword[word] for word in sent if word!=0] for sent in captions.tolist()]

                    #imperfect_text_caps = [[self.ixtoword[word] for word in sent if word!=0] for sent in
                    #                       imperfect_captions.tolist()]

                    print(step)
            avg_ddva = avg_ddva / (step + 1)
            print('\n\nAvg_DDVA: ', avg_ddva)
Beispiel #24
0
    def sampling(self, split_dir, model):
        if cfg.TRAIN.NET_G == '':
            print('Error: the path for morels is not found!')
        else:
            if split_dir == 'test':
                split_dir = 'valid'
            # Build and load the generator
            if cfg.GAN.B_DCGAN:
                netG = G_DCGAN()
            else:
                netG = G_NET()
            netG.apply(weights_init)
            if cfg.GPU_ID != -1:
                netG.cuda()
            netG.eval()
            #
            text_encoder = RNN_ENCODER(self.n_words,
                                       nhidden=cfg.TEXT.EMBEDDING_DIM)
            state_dict = \
                torch.load(cfg.TRAIN.NET_E, map_location=lambda storage, loc: storage)
            text_encoder.load_state_dict(state_dict)
            print('Load text encoder from:', cfg.TRAIN.NET_E)
            if cfg.GPU_ID != -1:
                text_encoder = text_encoder.cuda()
            text_encoder.eval()

            batch_size = self.batch_size
            nz = cfg.GAN.Z_DIM

            with torch.no_grad():
                noise = Variable(torch.FloatTensor(batch_size, nz))
                if cfg.GPU_ID != -1:
                    noise = noise.cuda()

            model_dir = cfg.TRAIN.NET_G
            state_dict = \
                torch.load(model_dir, map_location=lambda storage, loc: storage)
            # state_dict = torch.load(cfg.TRAIN.NET_G)
            netG.load_state_dict(state_dict)
            print('Load G from: ', model_dir)

            # the path to save generated images
            s_tmp = model_dir[:model_dir.rfind('.pth')]
            save_dir = '%s/%s' % (s_tmp, split_dir)
            mkdir_p(save_dir)

            cnt = 0

            #new
            if cfg.TRAIN.CLIP_SENTENCODER:
                print("Use CLIP SentEncoder for sampling")

            for _ in range(1):  # (cfg.TEXT.CAPTIONS_PER_IMAGE):
                for step, data in enumerate(self.data_loader, 0):
                    cnt += batch_size
                    if step % 100 == 0:
                        print('step: ', step)
                    # if step > 50:
                    #     break

                    #imgs, captions, cap_lens, class_ids, keys = prepare_data(data)
                    #new
                    imgs, captions, cap_lens, class_ids, keys, texts = prepare_data(
                        data)

                    hidden = text_encoder.init_hidden(batch_size)
                    # words_embs: batch_size x nef x seq_len
                    # sent_emb: batch_size x nef
                    words_embs, sent_emb = text_encoder(
                        captions, cap_lens, hidden)
                    words_embs, sent_emb = words_embs.detach(
                    ), sent_emb.detach()
                    mask = (captions == 0)
                    num_words = words_embs.size(2)
                    if mask.size(1) > num_words:
                        mask = mask[:, :num_words]

                    # new
                    if cfg.TRAIN.CLIP_SENTENCODER:

                        # random select one paragraph for each training example
                        sents = []
                        for idx in range(len(texts)):
                            sents_per_image = texts[idx].split(
                                '\n')  # new 3/11
                            if len(sents_per_image) > 1:
                                sent_ix = np.random.randint(
                                    0,
                                    len(sents_per_image) - 1)
                            else:
                                sent_ix = 0
                            sents.append(sents_per_image[sent_ix])
                            with open('%s/%s' % (save_dir, 'eval_sents.txt'),
                                      'a+') as f:
                                f.write(sents_per_image[sent_ix] + '\n')
                        # print('sents: ', sents)

                        sent = clip.tokenize(sents)  # .to(device)

                        # load clip
                        #model = torch.jit.load("model.pt").cuda().eval()
                        sent_input = sent
                        if cfg.GPU_ID != -1:
                            sent_input = sent.cuda()
                        # print("text input", sent_input)
                        with torch.no_grad():
                            sent_emb = model.encode_text(sent_input).float()

                    #######################################################
                    # (2) Generate fake images
                    ######################################################
                    noise.data.normal_(0, 1)
                    fake_imgs, _, _, _ = netG(noise, sent_emb, words_embs,
                                              mask)
                    for j in range(batch_size):
                        s_tmp = '%s/fake/%s' % (save_dir, keys[j])
                        folder = s_tmp[:s_tmp.rfind('/')]
                        if not os.path.isdir(folder):
                            print('Make a new folder: ', folder)
                            mkdir_p(folder)
                            print('Make a new folder: ', f'{save_dir}/real')
                            mkdir_p(f'{save_dir}/real')
                            print('Make a new folder: ', f'{save_dir}/text')
                            mkdir_p(f'{save_dir}/text')
                        k = -1
                        # for k in range(len(fake_imgs)):
                        im = fake_imgs[k][j].data.cpu().numpy()
                        # [-1, 1] --> [0, 255]
                        im = (im + 1.0) * 127.5
                        im = im.astype(np.uint8)
                        im = np.transpose(im, (1, 2, 0))
                        im = Image.fromarray(im)
                        fullpath = '%s_s%d.png' % (s_tmp, k)
                        im.save(fullpath)
                        temp = keys[j].replace('b', '').replace("'", '')
                        shutil.copy(f"../data/Face/images/{temp}.jpg",
                                    f"{save_dir}/real/")
                        shutil.copy(f"../data/Face/text/{temp}.txt",
                                    f"{save_dir}/text/")
Beispiel #25
0
    def genDiscOutputs(self, split_dir, num_samples=57140):
        if cfg.TRAIN.NET_G == '':
            logger.error('Error: the path for morels is not found!')
        else:
            if split_dir == 'test':
                split_dir = 'valid'
            # Build and load the generator
            if cfg.GAN.B_DCGAN:
                netG = G_DCGAN()
            else:
                netG = G_NET()
            netG.apply(weights_init)
            netG.to(cfg.DEVICE)
            netG.eval()
            #
            text_encoder = RNN_ENCODER(self.n_words,
                                       nhidden=cfg.TEXT.EMBEDDING_DIM)  ###HACK
            state_dict = torch.load(cfg.TRAIN.NET_E,
                                    map_location=lambda storage, loc: storage)
            text_encoder.load_state_dict(state_dict)
            text_encoder = text_encoder.to(cfg.DEVICE)
            text_encoder.eval()
            logger.info('Loaded text encoder from: %s', cfg.TRAIN.NET_E)

            batch_size = self.batch_size[0]
            nz = cfg.GAN.GLOBAL_Z_DIM
            noise = Variable(torch.FloatTensor(batch_size, nz)).to(cfg.DEVICE)
            local_noise = Variable(
                torch.FloatTensor(batch_size,
                                  cfg.GAN.LOCAL_Z_DIM)).to(cfg.DEVICE)

            model_dir = cfg.TRAIN.NET_G
            state_dict = torch.load(model_dir,
                                    map_location=lambda storage, loc: storage)
            netG.load_state_dict(state_dict["netG"])
            for keys in state_dict.keys():
                print(keys)
            logger.info('Load G from: %s', model_dir)
            max_objects = 3
            from model import D_NET256
            netD = D_NET256()
            netD.load_state_dict(state_dict["netD"][2])

            netD.eval()

            # the path to save generated images
            s_tmp = model_dir[:model_dir.rfind('.pth')].split("/")[-1]
            save_dir = '%s/%s/%s' % ("../output", s_tmp, split_dir)
            mkdir_p(save_dir)
            logger.info("Saving images to: {}".format(save_dir))

            number_batches = num_samples // batch_size
            if number_batches < 1:
                number_batches = 1

            data_iter = iter(self.data_loader)
            real_labels, fake_labels, match_labels = self.prepare_labels()

            for step in tqdm(range(number_batches)):
                data = data_iter.next()

                imgs, captions, cap_lens, class_ids, keys, transformation_matrices, label_one_hot, _ = prepare_data(
                    data, eval=True)

                transf_matrices = transformation_matrices[0]
                transf_matrices_inv = transformation_matrices[1]

                hidden = text_encoder.init_hidden(batch_size)
                # words_embs: batch_size x nef x seq_len
                # sent_emb: batch_size x nef
                words_embs, sent_emb = text_encoder(captions, cap_lens, hidden)
                words_embs, sent_emb = words_embs.detach(), sent_emb.detach()
                mask = (captions == 0)
                num_words = words_embs.size(2)
                if mask.size(1) > num_words:
                    mask = mask[:, :num_words]

                #######################################################
                # (2) Generate fake images
                ######################################################
                noise.data.normal_(0, 1)
                local_noise.data.normal_(0, 1)
                inputs = (noise, local_noise, sent_emb, words_embs, mask,
                          transf_matrices, transf_matrices_inv, label_one_hot,
                          max_objects)
                inputs = tuple(
                    (inp.to(cfg.DEVICE) if isinstance(inp, torch.Tensor
                                                      ) else inp)
                    for inp in inputs)
                with torch.no_grad():
                    fake_imgs, _, mu, logvar = netG(*inputs)
                    inputs = (fake_imgs, fake_labels, transf_matrices,
                              transf_matrices_inv, max_objects)
                    codes = netsD[-1].partial_forward(*inputs)
    def build_models(self):
        def count_parameters(model):
            total_param = 0
            for name, param in model.named_parameters():
                if param.requires_grad:
                    num_param = np.prod(param.size())
                    if param.dim() > 1:
                        print(name, ':',
                              'x'.join(str(x) for x in list(param.size())),
                              '=', num_param)
                    else:
                        print(name, ':', num_param)
                    total_param += num_param
            return total_param

        # ###################encoders######################################## #
        if cfg.TRAIN.NET_E == '':
            print('Error: no pretrained text-image encoders')
            return

        image_encoder = CNN_ENCODER(cfg.TEXT.EMBEDDING_DIM)
        img_encoder_path = cfg.TRAIN.NET_E.replace('text_encoder',
                                                   'image_encoder')
        state_dict = \
            torch.load(img_encoder_path, map_location=lambda storage, loc: storage)
        image_encoder.load_state_dict(state_dict)
        for p in image_encoder.parameters():
            p.requires_grad = False
        print('Load image encoder from:', img_encoder_path)
        image_encoder.eval()

        text_encoder = \
            RNN_ENCODER(self.n_words, nhidden=cfg.TEXT.EMBEDDING_DIM)
        state_dict = \
            torch.load(cfg.TRAIN.NET_E,
                       map_location=lambda storage, loc: storage)
        text_encoder.load_state_dict(state_dict)
        for p in text_encoder.parameters():
            p.requires_grad = False
        print('Load text encoder from:', cfg.TRAIN.NET_E)
        text_encoder.eval()

        # #######################generator and discriminators############## #
        netsD = []
        if cfg.GAN.B_DCGAN:
            if cfg.TREE.BRANCH_NUM == 1:
                from model import D_NET64 as D_NET
            elif cfg.TREE.BRANCH_NUM == 2:
                from model import D_NET128 as D_NET
            else:  # cfg.TREE.BRANCH_NUM == 3:
                from model import D_NET256 as D_NET
            # TODO: elif cfg.TREE.BRANCH_NUM > 3:
            netG = G_DCGAN()
            netsD = [D_NET(b_jcu=False)]
        else:
            from model import D_NET64, D_NET128, D_NET256
            netG = G_NET()
            if cfg.TREE.BRANCH_NUM > 0:
                netsD.append(D_NET64())
            if cfg.TREE.BRANCH_NUM > 1:
                netsD.append(D_NET128())
            if cfg.TREE.BRANCH_NUM > 2:
                netsD.append(D_NET256())
            # TODO: if cfg.TREE.BRANCH_NUM > 3:

        print('number of trainable parameters =', count_parameters(netG))
        print('number of trainable parameters =', count_parameters(netsD[-1]))

        netG.apply(weights_init)
        # print(netG)
        for i in range(len(netsD)):
            netsD[i].apply(weights_init)
            # print(netsD[i])
        print('# of netsD', len(netsD))
        #
        epoch = 0
        if cfg.TRAIN.NET_G != '':
            state_dict = \
                torch.load(cfg.TRAIN.NET_G, map_location=lambda storage, loc: storage)
            netG.load_state_dict(state_dict)
            print('Load G from: ', cfg.TRAIN.NET_G)
            istart = cfg.TRAIN.NET_G.rfind('_') + 1
            iend = cfg.TRAIN.NET_G.rfind('.')
            epoch = cfg.TRAIN.NET_G[istart:iend]
            epoch = int(epoch) + 1
            if cfg.TRAIN.B_NET_D:
                Gname = cfg.TRAIN.NET_G
                for i in range(len(netsD)):
                    s_tmp = Gname[:Gname.rfind('/')]
                    Dname = '%s/netD%d.pth' % (s_tmp, i)
                    print('Load D from: ', Dname)
                    state_dict = \
                        torch.load(Dname, map_location=lambda storage, loc: storage)
                    netsD[i].load_state_dict(state_dict)
        # ########################################################### #
        if cfg.CUDA:
            text_encoder = text_encoder.cuda()
            image_encoder = image_encoder.cuda()
            netG.cuda()
            for i in range(len(netsD)):
                netsD[i].cuda()
        return [text_encoder, image_encoder, netG, netsD, epoch]
Beispiel #27
0
    def build_models(self):
        # text encoders
        if cfg.TRAIN.NET_E == '':
            print('Error: no pretrained text-image encoders')
            return

        image_encoder = CNN_ENCODER(cfg.TEXT.EMBEDDING_DIM)
        img_encoder_path = cfg.TRAIN.NET_E.replace('text_encoder', 'image_encoder')
        state_dict = \
            torch.load(img_encoder_path, map_location=lambda storage, loc: storage)
        image_encoder.load_state_dict(state_dict)
        for p in image_encoder.parameters():
            p.requires_grad = False
        print('Load image encoder from:', img_encoder_path)
        image_encoder.eval()

        # self.n_words = 156
        text_encoder = RNN_ENCODER(self.n_words, nhidden=cfg.TEXT.EMBEDDING_DIM)
        state_dict = torch.load(cfg.TRAIN.NET_E, map_location=lambda storage, loc: storage)
        text_encoder.load_state_dict(state_dict)
        for p in text_encoder.parameters():
            p.requires_grad = False
        print('Load text encoder from:', cfg.TRAIN.NET_E)
        text_encoder.eval()

        # Caption models - cnn_encoder and rnn_decoder
        caption_cnn = CAPTION_CNN(cfg.CAP.embed_size)
        caption_cnn.load_state_dict(torch.load(cfg.CAP.caption_cnn_path, map_location=lambda storage, loc: storage))
        for p in caption_cnn.parameters():
            p.requires_grad = False
        print('Load caption model from:', cfg.CAP.caption_cnn_path)
        caption_cnn.eval()

        # self.n_words = 9
        caption_rnn = CAPTION_RNN(cfg.CAP.embed_size, cfg.CAP.hidden_size * 2, self.n_words, cfg.CAP.num_layers)
        # caption_rnn = CAPTION_RNN(cfg.CAP.embed_size, cfg.CAP.hidden_size * 2, self.n_words, cfg.CAP.num_layers)
        caption_rnn.load_state_dict(torch.load(cfg.CAP.caption_rnn_path, map_location=lambda storage, loc: storage))
        for p in caption_rnn.parameters():
            p.requires_grad = False
        print('Load caption model from:', cfg.CAP.caption_rnn_path)

        # Generator and Discriminator:
        netsD = []
        if cfg.GAN.B_DCGAN:
            if cfg.TREE.BRANCH_NUM == 1:
                from model import D_NET64 as D_NET
            elif cfg.TREE.BRANCH_NUM == 2:
                from model import D_NET128 as D_NET
            else:  # cfg.TREE.BRANCH_NUM == 3:
                from model import D_NET256 as D_NET

            netG = G_DCGAN()
            netsD = [D_NET(b_jcu=False)]
        else:
            from model import D_NET64, D_NET128, D_NET256
            netG = G_NET()
            if cfg.TREE.BRANCH_NUM > 0:
                netsD.append(D_NET64())
            if cfg.TREE.BRANCH_NUM > 1:
                netsD.append(D_NET128())
            if cfg.TREE.BRANCH_NUM > 2:
                netsD.append(D_NET256())
        netG.apply(weights_init)
        # print(netG)
        for i in range(len(netsD)):
            netsD[i].apply(weights_init)
            # print(netsD[i])
        print('# of netsD', len(netsD))

        epoch = 0
        if cfg.TRAIN.NET_G != '':
            state_dict = \
                torch.load(cfg.TRAIN.NET_G, map_location=lambda storage, loc: storage)
            netG.load_state_dict(state_dict)
            print('Load G from: ', cfg.TRAIN.NET_G)
            istart = cfg.TRAIN.NET_G.rfind('_') + 1
            iend = cfg.TRAIN.NET_G.rfind('.')
            epoch = cfg.TRAIN.NET_G[istart:iend]

            # print(epoch)
            # print(state_dict.keys())
            # print(netG.keys())
            # epoch = state_dict['epoch']
            epoch = int(epoch) + 1
            # epoch = 187
            if cfg.TRAIN.B_NET_D:
                Gname = cfg.TRAIN.NET_G
                for i in range(len(netsD)):
                    s_tmp = Gname[:Gname.rfind('/')]
                    Dname = '%s/netD%d.pth' % (s_tmp, i)
                    print('Load D from: ', Dname)
                    state_dict = \
                        torch.load(Dname, map_location=lambda storage, loc: storage)
                    netsD[i].load_state_dict(state_dict)

        if cfg.CUDA:
            text_encoder = text_encoder.cuda()
            image_encoder = image_encoder.cuda()
            caption_cnn = caption_cnn.cuda()
            caption_rnn = caption_rnn.cuda()
            netG.cuda()
            for i in range(len(netsD)):
                netsD[i].cuda()
        return [text_encoder, image_encoder, caption_cnn, caption_rnn, netG, netsD, epoch]
    def build_models(self):
        # ###################encoders######################################## #
        if cfg.TRAIN.NET_E == '':
            print('Error: no pretrained text-image encoders')
            return

        image_encoder = CNN_ENCODER(cfg.TEXT.EMBEDDING_DIM)
        img_encoder_path = cfg.TRAIN.NET_E.replace('text_encoder', 'image_encoder')
        state_dict = \
            torch.load(img_encoder_path, map_location=lambda storage, loc: storage)
        image_encoder.load_state_dict(state_dict)
        for p in image_encoder.parameters():
            p.requires_grad = False
        print('Load image encoder from:', img_encoder_path)
        image_encoder.eval()

        text_encoder = \
            RNN_ENCODER(self.n_words, nhidden=cfg.TEXT.EMBEDDING_DIM)
        state_dict = \
            torch.load(cfg.TRAIN.NET_E,
                       map_location=lambda storage, loc: storage)
        text_encoder.load_state_dict(state_dict)
        for p in text_encoder.parameters():
            p.requires_grad = False
        print('Load text encoder from:', cfg.TRAIN.NET_E)
        text_encoder.eval()

        # #######################generator and discriminators############## #
        netsD = []
        from model import D_NET64, D_NET128, D_NET256
        netG = G_NET()
        if cfg.TREE.BRANCH_NUM > 0:
            netsD.append(D_NET64())
        if cfg.TREE.BRANCH_NUM > 1:
            netsD.append(D_NET128())
        if cfg.TREE.BRANCH_NUM > 2:
            netsD.append(D_NET256())

        netG.apply(weights_init)
        # print(netG)
        for i in range(len(netsD)):
            netsD[i].apply(weights_init)
            # print(netsD[i])
        print('# of netsD', len(netsD))
        epoch = 0

        if self.resume:
            checkpoint_list = sorted([ckpt for ckpt in glob.glob(self.model_dir + "/" + '*.pth')])
            latest_checkpoint = checkpoint_list[-1]
            state_dict = torch.load(latest_checkpoint, map_location=lambda storage, loc: storage)
            netG.load_state_dict(state_dict["netG"])
            for i in range(len(netsD)):
                netsD[i].load_state_dict(state_dict["netD"][i])
            epoch = int(latest_checkpoint[-8:-4]) + 1
            print("Resuming training from checkpoint {} at epoch {}.".format(latest_checkpoint, epoch))

        #
        if cfg.TRAIN.NET_G != '':
            state_dict = \
                torch.load(cfg.TRAIN.NET_G, map_location=lambda storage, loc: storage)
            netG.load_state_dict(state_dict)
            print('Load G from: ', cfg.TRAIN.NET_G)
            istart = cfg.TRAIN.NET_G.rfind('_') + 1
            iend = cfg.TRAIN.NET_G.rfind('.')
            epoch = cfg.TRAIN.NET_G[istart:iend]
            epoch = int(epoch) + 1
            if cfg.TRAIN.B_NET_D:
                Gname = cfg.TRAIN.NET_G
                for i in range(len(netsD)):
                    s_tmp = Gname[:Gname.rfind('/')]
                    Dname = '%s/netD%d.pth' % (s_tmp, i)
                    print('Load D from: ', Dname)
                    state_dict = \
                        torch.load(Dname, map_location=lambda storage, loc: storage)
                    netsD[i].load_state_dict(state_dict)
        # ########################################################### #
        if cfg.CUDA:
            text_encoder = text_encoder.cuda()
            image_encoder = image_encoder.cuda()
            netG.cuda()
            for i in range(len(netsD)):
                netsD[i].cuda()
        return [text_encoder, image_encoder, netG, netsD, epoch]
    def gen_example(self, data_dic):
        if cfg.TRAIN.NET_G == '':
            print('Error: the path for morels is not found!')
        else:
            # Build and load the generator
            text_encoder = \
                RNN_ENCODER(self.n_words, nhidden=cfg.TEXT.EMBEDDING_DIM)
            state_dict = \
                torch.load(cfg.TRAIN.NET_E, map_location=lambda storage, loc: storage)
            text_encoder.load_state_dict(state_dict)
            print('Load text encoder from:', cfg.TRAIN.NET_E)
            text_encoder = text_encoder.cuda()
            text_encoder.eval()

            # the path to save generated images
            if cfg.GAN.B_DCGAN:
                netG = G_DCGAN()
            else:
                netG = G_NET()
            s_tmp = cfg.TRAIN.NET_G[:cfg.TRAIN.NET_G.rfind('.pth')]
            model_dir = cfg.TRAIN.NET_G
            state_dict = \
                torch.load(model_dir, map_location=lambda storage, loc: storage)
            netG.load_state_dict(state_dict)
            print('Load G from: ', model_dir)
            netG.cuda()
            netG.eval()
            for key in data_dic:
                save_dir = '%s/%s' % (s_tmp, key)
                mkdir_p(save_dir)
                captions, cap_lens, sorted_indices = data_dic[key]

                batch_size = captions.shape[0]
                nz = cfg.GAN.Z_DIM
                captions = Variable(torch.from_numpy(captions), volatile=True)
                cap_lens = Variable(torch.from_numpy(cap_lens), volatile=True)

                captions = captions.cuda()
                cap_lens = cap_lens.cuda()
                for i in range(1):  # 16
                    noise = Variable(torch.FloatTensor(batch_size, nz), volatile=True)
                    noise = noise.cuda()
                    #######################################################
                    # (1) Extract text embeddings
                    ######################################################
                    hidden = text_encoder.init_hidden(batch_size)
                    # words_embs: batch_size x nef x seq_len
                    # sent_emb: batch_size x nef
                    words_embs, sent_emb = text_encoder(captions, cap_lens, hidden)
                    mask = (captions == 0)
                    #######################################################
                    # (2) Generate fake images
                    ######################################################
                    noise.data.normal_(0, 1)
                    with torch.no_grad():
                        fake_imgs, attention_maps, _, _ = netG(noise, sent_emb, words_embs, mask)
                    # G attention
                    cap_lens_np = cap_lens.cpu().data.numpy()
                    for j in range(batch_size):
                        save_name = '%s/%d_s_%d' % (save_dir, i, sorted_indices[j])
                        for k in range(len(fake_imgs)):
                            im = fake_imgs[k][j].data.cpu().numpy()
                            im = (im + 1.0) * 127.5
                            im = im.astype(np.uint8)
                            # print('im', im.shape)
                            im = np.transpose(im, (1, 2, 0))
                            # print('im', im.shape)
                            im = Image.fromarray(im)
                            fullpath = '%s_g%d.png' % (save_name, k)
                            im.save(fullpath)

                        for k in range(len(attention_maps)):
                            if len(fake_imgs) > 1:
                                im = fake_imgs[k + 1].detach().cpu()
                            else:
                                im = fake_imgs[0].detach().cpu()
                            attn_maps = attention_maps[k]
                            att_sze = attn_maps.size(2)
                            img_set, sentences = \
                                build_super_images2(im[j].unsqueeze(0),
                                                    captions[j].unsqueeze(0),
                                                    [cap_lens_np[j]], self.ixtoword,
                                                    [attn_maps[j]], att_sze)
                            if img_set is not None:
                                im = Image.fromarray(img_set)
                                fullpath = '%s_a%d.png' % (save_name, k)
                                im.save(fullpath)
    def sampling(self, split_dir):
        if cfg.TRAIN.NET_G == '':
            print('Error: the path for morels is not found!')
        else:
            if split_dir == 'test':
                split_dir = 'valid'
            # Build and load the generator
            if cfg.GAN.B_DCGAN:
                netG = G_DCGAN()
            else:
                netG = G_NET()
            netG.apply(weights_init)
            netG.cuda()
            netG.eval()

            # load text encoder
            text_encoder = RNN_ENCODER(self.n_words,
                                       nhidden=cfg.TEXT.EMBEDDING_DIM)
            state_dict = torch.load(cfg.TRAIN.NET_E,
                                    map_location=lambda storage, loc: storage)
            text_encoder.load_state_dict(state_dict)
            print('Load text encoder from:', cfg.TRAIN.NET_E)
            text_encoder = text_encoder.cuda()
            text_encoder.eval()

            #load image encoder
            image_encoder = CNN_ENCODER(cfg.TEXT.EMBEDDING_DIM)
            img_encoder_path = cfg.TRAIN.NET_E.replace('text_encoder',
                                                       'image_encoder')
            state_dict = torch.load(img_encoder_path,
                                    map_location=lambda storage, loc: storage)
            image_encoder.load_state_dict(state_dict)
            print('Load image encoder from:', img_encoder_path)
            image_encoder = image_encoder.cuda()
            image_encoder.eval()

            batch_size = self.batch_size
            nz = cfg.GAN.Z_DIM
            noise = Variable(torch.FloatTensor(batch_size, nz), volatile=True)
            noise = noise.cuda()

            model_dir = cfg.TRAIN.NET_G
            state_dict = torch.load(model_dir,
                                    map_location=lambda storage, loc: storage)
            # state_dict = torch.load(cfg.TRAIN.NET_G)
            netG.load_state_dict(state_dict)
            print('Load G from: ', model_dir)

            # the path to save generated images
            s_tmp = model_dir[:model_dir.rfind('.pth')]
            save_dir = '%s/%s' % (s_tmp, split_dir)
            mkdir_p(save_dir)

            cnt = 0
            R_count = 0
            R = np.zeros(30000)
            cont = True
            for ii in range(11):  # (cfg.TEXT.CAPTIONS_PER_IMAGE):
                if (cont == False):
                    break
                for step, data in enumerate(self.data_loader, 0):
                    cnt += batch_size
                    if (cont == False):
                        break
                    if step % 100 == 0:
                        print('cnt: ', cnt)
                    # if step > 50:
                    #     break

                    imgs, captions, cap_lens, class_ids, keys = prepare_data(
                        data)

                    hidden = text_encoder.init_hidden(batch_size)
                    # words_embs: batch_size x nef x seq_len
                    # sent_emb: batch_size x nef
                    words_embs, sent_emb = text_encoder(
                        captions, cap_lens, hidden)
                    words_embs, sent_emb = words_embs.detach(
                    ), sent_emb.detach()
                    mask = (captions == 0)
                    num_words = words_embs.size(2)
                    if mask.size(1) > num_words:
                        mask = mask[:, :num_words]

                    #######################################################
                    # (2) Generate fake images
                    ######################################################
                    noise.data.normal_(0, 1)
                    fake_imgs, _, _, _ = netG(noise, sent_emb, words_embs,
                                              mask, cap_lens)
                    for j in range(batch_size):
                        s_tmp = '%s/single/%s' % (save_dir, keys[j])
                        folder = s_tmp[:s_tmp.rfind('/')]
                        if not os.path.isdir(folder):
                            #print('Make a new folder: ', folder)
                            mkdir_p(folder)
                        k = -1
                        # for k in range(len(fake_imgs)):
                        im = fake_imgs[k][j].data.cpu().numpy()
                        # [-1, 1] --> [0, 255]
                        im = (im + 1.0) * 127.5
                        im = im.astype(np.uint8)
                        im = np.transpose(im, (1, 2, 0))
                        im = Image.fromarray(im)
                        fullpath = '%s_s%d_%d.png' % (s_tmp, k, ii)
                        im.save(fullpath)

                    _, cnn_code = image_encoder(fake_imgs[-1])

                    for i in range(batch_size):
                        mis_captions, mis_captions_len = self.dataset.get_mis_caption(
                            class_ids[i])
                        hidden = text_encoder.init_hidden(99)
                        _, sent_emb_t = text_encoder(mis_captions,
                                                     mis_captions_len, hidden)
                        rnn_code = torch.cat(
                            (sent_emb[i, :].unsqueeze(0), sent_emb_t), 0)
                        ### cnn_code = 1 * nef
                        ### rnn_code = 100 * nef
                        scores = torch.mm(cnn_code[i].unsqueeze(0),
                                          rnn_code.transpose(0, 1))  # 1* 100
                        cnn_code_norm = torch.norm(cnn_code[i].unsqueeze(0),
                                                   2,
                                                   dim=1,
                                                   keepdim=True)
                        rnn_code_norm = torch.norm(rnn_code,
                                                   2,
                                                   dim=1,
                                                   keepdim=True)
                        norm = torch.mm(cnn_code_norm,
                                        rnn_code_norm.transpose(0, 1))
                        scores0 = scores / norm.clamp(min=1e-8)
                        if torch.argmax(scores0) == 0:
                            R[R_count] = 1
                        R_count += 1

                    if R_count >= 30000:
                        sum = np.zeros(10)
                        np.random.shuffle(R)
                        for i in range(10):
                            sum[i] = np.average(R[i * 3000:(i + 1) * 3000 - 1])
                        R_mean = np.average(sum)
                        R_std = np.std(sum)
                        print("R mean:{:.4f} std:{:.4f}".format(R_mean, R_std))
                        cont = False