Ejemplo n.º 1
0
    def evaluate(self, split_dir):
        if cfg.TRAIN.NET_G == '':
            print('Error: the path for morels is not found!')
        else:
            # Build and load the generator
            netG = G_NET()
            netG.apply(weights_init)
            netG = torch.nn.DataParallel(netG, device_ids=self.gpus)
            print(netG)
            # state_dict = torch.load(cfg.TRAIN.NET_G)
            state_dict = \
                torch.load(cfg.TRAIN.NET_G,
                           map_location=lambda storage, loc: storage)
            netG.load_state_dict(state_dict)
            print('Load ', cfg.TRAIN.NET_G)

            # the path to save generated images
            s_tmp = cfg.TRAIN.NET_G
            istart = s_tmp.rfind('_') + 1
            iend = s_tmp.rfind('.')
            iteration = int(s_tmp[istart:iend])
            s_tmp = s_tmp[:s_tmp.rfind('/')]
            save_dir = '%s/iteration%d/%s' % (s_tmp, iteration, split_dir)
            if cfg.TEST.B_EXAMPLE:
                folder = '%s/super' % (save_dir)
            else:
                folder = '%s/single' % (save_dir)
            print('Make a new folder: ', folder)
            mkdir_p(folder)

            nz = cfg.GAN.Z_DIM
            noise = Variable(torch.FloatTensor(self.batch_size, nz))
            if cfg.CUDA:
                netG.cuda()
                noise = noise.cuda()

            # switch to evaluate mode
            netG.eval()
            num_batches = int(cfg.TEST.SAMPLE_NUM / self.batch_size)
            cnt = 0
            for step in xrange(num_batches):
                noise.data.normal_(0, 1)
                fake_imgs, _, _ = netG(noise)
                if cfg.TEST.B_EXAMPLE:
                    self.save_superimages(fake_imgs[-1], folder, cnt, 256)
                else:
                    self.save_singleimages(fake_imgs[-1], folder, cnt, 256)
                    # self.save_singleimages(fake_imgs[-2], folder, 128)
                    # self.save_singleimages(fake_imgs[-3], folder, 64)
                cnt += self.batch_size
Ejemplo n.º 2
0
def load_network(gpus):
    netG = G_NET()
    netG.apply(weights_init)
    netG = torch.nn.DataParallel(netG, device_ids=gpus)
    print(netG)

    netsD = []
    if cfg.TREE.BRANCH_NUM > 0:
        netsD.append(D_NET64())
    if cfg.TREE.BRANCH_NUM > 1:
        netsD.append(D_NET128())
    if cfg.TREE.BRANCH_NUM > 2:
        netsD.append(D_NET256())
    if cfg.TREE.BRANCH_NUM > 3:
        netsD.append(D_NET512())
    if cfg.TREE.BRANCH_NUM > 4:
        netsD.append(D_NET1024())
    # TODO: if cfg.TREE.BRANCH_NUM > 5:

    for i in range(len(netsD)):
        netsD[i].apply(weights_init)
        netsD[i] = torch.nn.DataParallel(netsD[i], device_ids=gpus)
        # print(netsD[i])
    print('# of netsD', len(netsD))

    count = 0
    if cfg.TRAIN.NET_G != '':
        state_dict = torch.load(cfg.TRAIN.NET_G)
        netG.load_state_dict(state_dict)
        print('Load ', cfg.TRAIN.NET_G)

        istart = cfg.TRAIN.NET_G.rfind('_') + 1
        iend = cfg.TRAIN.NET_G.rfind('.')
        count = cfg.TRAIN.NET_G[istart:iend]
        count = int(count) + 1

    if cfg.TRAIN.NET_D != '':
        for i in range(len(netsD)):
            print('Load %s_%d.pth' % (cfg.TRAIN.NET_D, i))
            state_dict = torch.load('%s%d.pth' % (cfg.TRAIN.NET_D, i))
            netsD[i].load_state_dict(state_dict)

    inception_model = INCEPTION_V3()

    if cfg.CUDA:
        netG.cuda()
        for i in range(len(netsD)):
            netsD[i].cuda()
        inception_model = inception_model.cuda()
    inception_model.eval()

    return netG, netsD, len(netsD), inception_model, count
Ejemplo n.º 3
0
    def build_models(self):
        # ###################encoders######################################## #
        if cfg.TRAIN.NET_E == '':
            print('Error: no pretrained text-image encoders')
            return

        ####################################################################
        image_encoder = CNN_ENCODER(cfg.TEXT.EMBEDDING_DIM)
        img_encoder_path = cfg.TRAIN.NET_E.replace('text_encoder',
                                                   'image_encoder')
        state_dict = \
            torch.load(img_encoder_path, map_location=lambda storage, loc: storage)
        image_encoder.load_state_dict(state_dict)
        for p in image_encoder.parameters():  # make image encoder grad on
            p.requires_grad = True
        for k, v in image_encoder.named_children(
        ):  # freeze the layer1-5 (set eval for BNlayer)
            if k in frozen_list_image_encoder:
                v.train(False)
                v.requires_grad_(False)
        print('Load image encoder from:', img_encoder_path)
        #         image_encoder.eval()

        ###################################################################
        text_encoder = TEXT_TRANSFORMER_ENCODERv2(
            emb=cfg.TEXT.EMBEDDING_DIM,
            heads=8,
            depth=1,
            seq_length=cfg.TEXT.WORDS_NUM,
            num_tokens=self.n_words)
        #         state_dict = torch.load(cfg.TRAIN.NET_E)
        #         text_encoder.load_state_dict(state_dict)
        #         print('Load ', cfg.TRAIN.NET_E)
        state_dict = torch.load(cfg.TRAIN.NET_E,
                                map_location=lambda storage, loc: storage)
        text_encoder.load_state_dict(state_dict)
        for p in text_encoder.parameters():
            p.requires_grad = True
        print('Load text encoder from:', cfg.TRAIN.NET_E)
        #         text_encoder.eval()

        # #######################generator and discriminators############## #
        netsD = []
        if cfg.GAN.B_DCGAN:
            if cfg.TREE.BRANCH_NUM == 1:
                from model import D_NET64 as D_NET
            elif cfg.TREE.BRANCH_NUM == 2:
                from model import D_NET128 as D_NET
            else:  # cfg.TREE.BRANCH_NUM == 3:
                from model import D_NET256 as D_NET
            # TODO: elif cfg.TREE.BRANCH_NUM > 3:
            netG = G_DCGAN()
            netsD = [D_NET(b_jcu=False)]
        else:
            from model import D_NET64, D_NET128, D_NET256
            netG = G_NET()
            if cfg.TREE.BRANCH_NUM > 0:
                netsD.append(D_NET64())
            if cfg.TREE.BRANCH_NUM > 1:
                netsD.append(D_NET128())
            if cfg.TREE.BRANCH_NUM > 2:
                netsD.append(D_NET256())
            # TODO: if cfg.TREE.BRANCH_NUM > 3:
        netG.apply(weights_init)
        # print(netG)
        for i in range(len(netsD)):
            netsD[i].apply(weights_init)
            # print(netsD[i])
        print('# of netsD', len(netsD))
        #
        epoch = 0
        if cfg.TRAIN.NET_G != '':
            state_dict = \
                torch.load(cfg.TRAIN.NET_G, map_location=lambda storage, loc: storage)
            netG.load_state_dict(state_dict)
            print('Load G from: ', cfg.TRAIN.NET_G)
            istart = cfg.TRAIN.NET_G.rfind('_') + 1
            iend = cfg.TRAIN.NET_G.rfind('.')
            epoch = cfg.TRAIN.NET_G[istart:iend]
            epoch = int(epoch) + 1
            if cfg.TRAIN.B_NET_D:
                Gname = cfg.TRAIN.NET_G
                for i in range(len(netsD)):
                    s_tmp = Gname[:Gname.rfind('/')]
                    Dname = '%s/netD%d.pth' % (s_tmp, i)
                    print('Load D from: ', Dname)
                    state_dict = \
                        torch.load(Dname, map_location=lambda storage, loc: storage)
                    netsD[i].load_state_dict(state_dict)
        # ########################################################## #
#         config = Config()
        cap_model = caption.build_model_v3(config)
        print("Initializing from Checkpoint...")
        cap_model_path = cfg.TRAIN.NET_E.replace('text_encoder', 'cap_model')

        if os.path.exists(cap_model_path):
            print('Load C from: {0}'.format(cap_model_path))
            state_dict = \
                torch.load(cap_model_path, map_location=lambda storage, loc: storage)
            cap_model.load_state_dict(state_dict['model'])
        else:
            base_line_path = 'catr/checkpoints/catr_damsm256_proj_coco2014_ep02.pth'
            print('Load C from: {0}'.format(base_line_path))
            checkv3 = torch.load(base_line_path, map_location='cpu')
            cap_model.load_state_dict(checkv3['model'], strict=False)

        # ########################################################### #
        if cfg.CUDA:
            text_encoder = text_encoder.cuda()
            image_encoder = image_encoder.cuda()
            cap_model = cap_model.cuda()  # caption transformer added
            netG.cuda()
            for i in range(len(netsD)):
                netsD[i].cuda()
        return [text_encoder, image_encoder, netG, netsD, epoch, cap_model]
    def sampling(self, split_dir, num_samples=30000):
        if cfg.TRAIN.NET_G == '':
            print('Error: the path for morels is not found!')
        else:
            if split_dir == 'test':
                split_dir = 'valid'
            # Build and load the generator
            if cfg.GAN.B_DCGAN:
                netG = G_DCGAN()
            else:
                netG = G_NET()
            netG.apply(weights_init)
            netG.cuda()
            netG.eval()
            #
            text_encoder = RNN_ENCODER(self.n_words,
                                       nhidden=cfg.TEXT.EMBEDDING_DIM)
            state_dict = torch.load(cfg.TRAIN.NET_E,
                                    map_location=lambda storage, loc: storage)
            text_encoder.load_state_dict(state_dict)
            text_encoder = text_encoder.cuda()
            text_encoder.eval()
            print('Loaded text encoder from:', cfg.TRAIN.NET_E)

            batch_size = self.batch_size[0]
            nz = cfg.GAN.Z_DIM
            noise = Variable(torch.FloatTensor(batch_size, nz)).cuda()
            local_noise = Variable(torch.FloatTensor(batch_size, 32)).cuda()

            model_dir = cfg.TRAIN.NET_G
            state_dict = torch.load(model_dir,
                                    map_location=lambda storage, loc: storage)
            netG.load_state_dict(state_dict["netG"])
            max_objects = 10
            print('Load G from: ', model_dir)

            # the path to save generated images
            s_tmp = model_dir[:model_dir.rfind('.pth')].split("/")[-1]
            save_dir = '%s/%s/%s' % ("../output", s_tmp, split_dir)
            mkdir_p(save_dir)
            print("Saving images to: {}".format(save_dir))

            number_batches = num_samples // batch_size
            if number_batches < 1:
                number_batches = 1

            data_iter = iter(self.data_loader)

            for step in tqdm(range(number_batches)):
                data = data_iter.next()

                imgs, captions, cap_lens, class_ids, keys, transformation_matrices, label_one_hot, _ = prepare_data(
                    data, eval=True)

                transf_matrices = transformation_matrices[0]
                transf_matrices_inv = transformation_matrices[1]

                hidden = text_encoder.init_hidden(batch_size)
                # words_embs: batch_size x nef x seq_len
                # sent_emb: batch_size x nef
                words_embs, sent_emb = text_encoder(captions, cap_lens, hidden)
                words_embs, sent_emb = words_embs.detach(), sent_emb.detach()
                mask = (captions == 0)
                num_words = words_embs.size(2)
                if mask.size(1) > num_words:
                    mask = mask[:, :num_words]

                #######################################################
                # (2) Generate fake images
                ######################################################
                noise.data.normal_(0, 1)
                local_noise.data.normal_(0, 1)
                inputs = (noise, local_noise, sent_emb, words_embs, mask,
                          transf_matrices, transf_matrices_inv, label_one_hot,
                          max_objects)
                with torch.no_grad():
                    fake_imgs, _, mu, logvar = nn.parallel.data_parallel(
                        netG, inputs, self.gpus)
                for batch_idx, j in enumerate(range(batch_size)):
                    s_tmp = '%s/%s' % (save_dir, keys[j])
                    folder = s_tmp[:s_tmp.rfind('/')]
                    if not os.path.isdir(folder):
                        print('Make a new folder: ', folder)
                        mkdir_p(folder)
                    k = -1
                    # for k in range(len(fake_imgs)):
                    im = fake_imgs[k][j].data.cpu().numpy()
                    # [-1, 1] --> [0, 255]
                    im = (im + 1.0) * 127.5
                    im = im.astype(np.uint8)
                    im = np.transpose(im, (1, 2, 0))
                    im = Image.fromarray(im)
                    fullpath = '%s_s%d.png' % (s_tmp,
                                               step * batch_size + batch_idx)
                    im.save(fullpath)
Ejemplo n.º 5
0
    def evaluate(self, split_dir):
        if cfg.TRAIN.NET_G == '':
            print('Error: the path for morels is not found!')
        else:
            # Build and load the generator
            if split_dir == 'test':
                split_dir = 'valid'
            netG = G_NET()
            netG.apply(weights_init)
            netG = torch.nn.DataParallel(netG, device_ids=self.gpus)
            print(netG)
            # state_dict = torch.load(cfg.TRAIN.NET_G)
            state_dict = \
                torch.load(cfg.TRAIN.NET_G,
                           map_location=lambda storage, loc: storage)
            netG.load_state_dict(state_dict)
            print('Load ', cfg.TRAIN.NET_G)

            # the path to save generated images
            s_tmp = cfg.TRAIN.NET_G
            istart = s_tmp.rfind('_') + 1
            iend = s_tmp.rfind('.')
            iteration = int(s_tmp[istart:iend])
            s_tmp = s_tmp[:s_tmp.rfind('/')]
            save_dir = '%s/iteration%d' % (s_tmp, iteration)

            nz = cfg.GAN.Z_DIM
            noise = Variable(torch.FloatTensor(self.batch_size, nz))
            if cfg.CUDA:
                netG.cuda()
                noise = noise.cuda()

            # switch to evaluate mode
            netG.eval()
            for step, data in enumerate(self.data_loader, 0):
                imgs, t_embeddings, filenames = data
                if cfg.CUDA:
                    t_embeddings = Variable(t_embeddings).cuda()
                else:
                    t_embeddings = Variable(t_embeddings)
                # print(t_embeddings[:, 0, :], t_embeddings.size(1))

                embedding_dim = t_embeddings.size(1)
                batch_size = imgs[0].size(0)
                noise.data.resize_(batch_size, nz)
                noise.data.normal_(0, 1)

                fake_img_list = []
                for i in range(embedding_dim):
                    fake_imgs, _, _ = netG(noise, t_embeddings[:, i, :])
                    if cfg.TEST.B_EXAMPLE:
                        # fake_img_list.append(fake_imgs[0].data.cpu())
                        # fake_img_list.append(fake_imgs[1].data.cpu())
                        fake_img_list.append(fake_imgs[2].data.cpu())
                    else:
                        self.save_singleimages(fake_imgs[-1], filenames,
                                               save_dir, split_dir, i, 256)
                        # self.save_singleimages(fake_imgs[-2], filenames,
                        #                        save_dir, split_dir, i, 128)
                        # self.save_singleimages(fake_imgs[-3], filenames,
                        #                        save_dir, split_dir, i, 64)
                    # break
                if cfg.TEST.B_EXAMPLE:
                    # self.save_superimages(fake_img_list, filenames,
                    #                       save_dir, split_dir, 64)
                    # self.save_superimages(fake_img_list, filenames,
                    #                       save_dir, split_dir, 128)
                    self.save_superimages(fake_img_list, filenames,
                                          save_dir, split_dir, 256)
Ejemplo n.º 6
0
    def gen_example(self, data_dic):
        if cfg.TRAIN.NET_G == '':
            print('Error: the path for morels is not found!')
        else:
            # Build and load the generator
            text_encoder = \
                RNN_ENCODER(self.n_words, nhidden=cfg.TEXT.EMBEDDING_DIM)
            state_dict = \
                torch.load(cfg.TRAIN.NET_E, map_location=lambda storage, loc: storage)
            text_encoder.load_state_dict(state_dict)
            print('Load text encoder from:', cfg.TRAIN.NET_E)
            text_encoder = text_encoder.cuda()
            text_encoder.eval()

            # the path to save generated images
            if cfg.GAN.B_DCGAN:
                netG = G_DCGAN()
            else:
                netG = G_NET()
            s_tmp = cfg.TRAIN.NET_G[:cfg.TRAIN.NET_G.rfind('.pth')]
            model_dir = cfg.TRAIN.NET_G
            state_dict = \
                torch.load(model_dir, map_location=lambda storage, loc: storage)
            netG.load_state_dict(state_dict)
            print('Load G from: ', model_dir)
            netG.cuda()
            netG.eval()
            for key in data_dic:
                save_dir = '%s/%s' % (s_tmp, key)
                mkdir_p(save_dir)
                captions, cap_lens, sorted_indices = data_dic[key]

                batch_size = captions.shape[0]
                nz = cfg.GAN.Z_DIM
                captions = Variable(torch.from_numpy(captions), volatile=True)
                cap_lens = Variable(torch.from_numpy(cap_lens), volatile=True)

                captions = captions.cuda()
                cap_lens = cap_lens.cuda()
                for i in range(1):  # 16
                    noise = Variable(torch.FloatTensor(batch_size, nz), volatile=True)
                    noise = noise.cuda()
                    # (1) Extract text embeddings
                    hidden = text_encoder.init_hidden(batch_size)
                    # words_embs: batch_size x nef x seq_len
                    # sent_emb: batch_size x nef
                    words_embs, sent_emb = text_encoder(captions, cap_lens, hidden)
                    mask = (captions == 0)
                    # (2) Generate fake images
                    noise.data.normal_(0, 1)
                    fake_imgs, attention_maps, _, _ = netG(noise, sent_emb, words_embs, mask)
                    # G attention
                    cap_lens_np = cap_lens.cpu().data.numpy()
                    for j in range(batch_size):
                        save_name = '%s/%d_s_%d' % (save_dir, i, sorted_indices[j])
                        for k in range(len(fake_imgs)):
                            im = fake_imgs[k][j].data.cpu().numpy()
                            im = (im + 1.0) * 127.5
                            im = im.astype(np.uint8)
                            # print('im', im.shape)
                            im = np.transpose(im, (1, 2, 0))
                            # print('im', im.shape)
                            im = Image.fromarray(im)
                            fullpath = '%s_g%d.png' % (save_name, k)
                            im.save(fullpath)

                        for k in range(len(attention_maps)):
                            if len(fake_imgs) > 1:
                                im = fake_imgs[k + 1].detach().cpu()
                            else:
                                im = fake_imgs[0].detach().cpu()
                            attn_maps = attention_maps[k]
                            att_sze = attn_maps.size(2)
                            img_set, sentences = \
                                build_super_images2(im[j].unsqueeze(0),
                                                    captions[j].unsqueeze(0),
                                                    [cap_lens_np[j]], self.ixtoword,
                                                    [attn_maps[j]], att_sze)
                            if img_set is not None:
                                im = Image.fromarray(img_set)
                                fullpath = '%s_a%d.png' % (save_name, k)
                                im.save(fullpath)
Ejemplo n.º 7
0
    def build_models(self):

        if cfg.TRAIN.NET_E == '':
            print('Error: no pretrained text-image encoders')
            return

        # vgg16 network
        style_loss = VGGNet()

        for p in style_loss.parameters():
            p.requires_grad = False

        print("Load the style loss model")
        style_loss.eval()

        image_encoder = CNN_ENCODER(cfg.TEXT.EMBEDDING_DIM)
        img_encoder_path = cfg.TRAIN.NET_E.replace('text_encoder',
                                                   'image_encoder')
        state_dict = \
            torch.load(img_encoder_path, map_location=lambda storage, loc: storage)
        image_encoder.load_state_dict(state_dict)
        for p in image_encoder.parameters():
            p.requires_grad = False
        print('Load image encoder from:', img_encoder_path)
        image_encoder.eval()

        text_encoder = \
            RNN_ENCODER(self.n_words, nhidden=cfg.TEXT.EMBEDDING_DIM)
        state_dict = \
            torch.load(cfg.TRAIN.NET_E,
                       map_location=lambda storage, loc: storage)
        text_encoder.load_state_dict(state_dict)
        for p in text_encoder.parameters():
            p.requires_grad = False
        print('Load text encoder from:', cfg.TRAIN.NET_E)
        text_encoder.eval()

        netsD = []
        if cfg.GAN.B_DCGAN:
            if cfg.TREE.BRANCH_NUM == 1:
                from model import D_NET64 as D_NET
            elif cfg.TREE.BRANCH_NUM == 2:
                from model import D_NET128 as D_NET
            else:  # cfg.TREE.BRANCH_NUM == 3:
                from model import D_NET256 as D_NET
            netG = G_DCGAN()
            netsD = [D_NET(b_jcu=False)]
        else:
            from model import D_NET64, D_NET128, D_NET256
            netG = G_NET()
            if cfg.TREE.BRANCH_NUM > 0:
                netsD.append(D_NET64())
            if cfg.TREE.BRANCH_NUM > 1:
                netsD.append(D_NET128())
            if cfg.TREE.BRANCH_NUM > 2:
                netsD.append(D_NET256())
        netG.apply(weights_init)

        for i in range(len(netsD)):
            netsD[i].apply(weights_init)
        print('# of netsD', len(netsD))
        #
        epoch = 0
        if cfg.TRAIN.NET_G != '':
            state_dict = \
                torch.load(cfg.TRAIN.NET_G, map_location=lambda storage, loc: storage)
            netG.load_state_dict(state_dict)
            print('Load G from: ', cfg.TRAIN.NET_G)
            istart = cfg.TRAIN.NET_G.rfind('_') + 1
            iend = cfg.TRAIN.NET_G.rfind('.')
            epoch = cfg.TRAIN.NET_G[istart:iend]
            epoch = int(epoch) + 1
            if cfg.TRAIN.B_NET_D:
                Gname = cfg.TRAIN.NET_G
                for i in range(len(netsD)):
                    s_tmp = Gname[:Gname.rfind('/')]
                    Dname = '%s/netD%d.pth' % (s_tmp, i)
                    print('Load D from: ', Dname)
                    state_dict = \
                        torch.load(Dname, map_location=lambda storage, loc: storage)
                    netsD[i].load_state_dict(state_dict)

        if cfg.CUDA:
            text_encoder = text_encoder.cuda()
            image_encoder = image_encoder.cuda()
            netG.cuda()
            style_loss = style_loss.cuda()
            for i in range(len(netsD)):
                netsD[i].cuda()

        return [text_encoder, image_encoder, netG, netsD, epoch, style_loss]
Ejemplo n.º 8
0
    def build_models(self):
        # ###################encoders######################################## #

        image_encoder = CNN_ENCODER(cfg.TEXT.EMBEDDING_DIM)
        image_encoder.train()

        # #######################generator and discriminators############## #
        netsD = []
        if cfg.GAN.B_DCGAN:
            if cfg.TREE.BRANCH_NUM == 1:
                from model import D_NET64 as D_NET
            elif cfg.TREE.BRANCH_NUM == 2:
                from model import D_NET128 as D_NET
            else:  # cfg.TREE.BRANCH_NUM == 3:
                from model import D_NET256 as D_NET
            # TODO: elif cfg.TREE.BRANCH_NUM > 3:
            netG = G_DCGAN()
            netsD = [D_NET(b_jcu=False)]
        else:
            from model import D_NET64, D_NET128, D_NET256
            netG = G_NET()
            if cfg.TREE.BRANCH_NUM > 0:
                netsD.append(D_NET64())
            if cfg.TREE.BRANCH_NUM > 1:
                netsD.append(D_NET128())
            if cfg.TREE.BRANCH_NUM > 2:
                netsD.append(D_NET256())
            # TODO: if cfg.TREE.BRANCH_NUM > 3:
        netG.apply(weights_init)
        # print(netG)
        for i in range(len(netsD)):
            netsD[i].apply(weights_init)
            # print(netsD[i])
        print('# of netsD', len(netsD))
        #
        epoch = 0

        if cfg.PRETRAINED_CNN:
            image_encoder_params = torch.load(
                cfg.PRETRAINED_CNN, map_location=lambda storage, loc: storage)
            image_encoder.load_state_dict(image_encoder_params)

        if cfg.PRETRAINED_G != '':
            state_dict = torch.load(cfg.PRETRAINED_G,
                                    map_location=lambda storage, loc: storage)
            netG.load_state_dict(state_dict)
            print('Load G from: ', cfg.PRETRAINED_G)
            if cfg.TRAIN.B_NET_D:
                Gname = cfg.PRETRAINED_G
                s_tmp = Gname[:Gname.rfind('/')]
                for i in range(len(netsD)):
                    Dname = '%s/netD%d.pth' % (
                        s_tmp, i
                    )  # the name of Ds should be consistent and differ from each other in i
                    print('Load D from: ', Dname)
                    state_dict = torch.load(
                        Dname, map_location=lambda storage, loc: storage)
                    netsD[i].load_state_dict(state_dict)
        # ########################################################### #
        if cfg.CUDA:
            image_encoder = image_encoder.cuda()
            netG.cuda()
            for i in range(len(netsD)):
                netsD[i].cuda()
        return [image_encoder, netG, netsD, epoch]
Ejemplo n.º 9
0
    def gen_example(self, data_dic):
        if cfg.TRAIN.NET_G == '' or cfg.TRAIN.NET_C == '':
            print('Error: the path for main module or DCM is not found!')
        else:
            # The text encoder
            text_encoder = \
                RNN_ENCODER(self.n_words, nhidden=cfg.TEXT.EMBEDDING_DIM)
            state_dict = \
                torch.load(cfg.TRAIN.NET_E, map_location=lambda storage, loc: storage)
            text_encoder.load_state_dict(state_dict)
            print('Load text encoder from:', cfg.TRAIN.NET_E)
            text_encoder = text_encoder.cuda()
            text_encoder.eval()

            # The image encoder
            image_encoder = CNN_ENCODER(cfg.TEXT.EMBEDDING_DIM)
            img_encoder_path = cfg.TRAIN.NET_E.replace('text_encoder',
                                                       'image_encoder')
            state_dict = \
                torch.load(img_encoder_path, map_location=lambda storage, loc: storage)
            image_encoder.load_state_dict(state_dict)
            print('Load image encoder from:', img_encoder_path)
            image_encoder = image_encoder.cuda()
            image_encoder.eval()

            # The VGG network
            VGG = VGGNet()
            print("Load the VGG model")
            VGG.cuda()
            VGG.eval()

            # The main module
            if cfg.GAN.B_DCGAN:
                netG = G_DCGAN()
            else:
                netG = G_NET()
            s_tmp = cfg.TRAIN.NET_G[:cfg.TRAIN.NET_G.rfind('.pth')]
            model_dir = cfg.TRAIN.NET_G
            state_dict = \
                torch.load(model_dir, map_location=lambda storage, loc: storage)
            netG.load_state_dict(state_dict)
            print('Load G from: ', model_dir)
            netG.cuda()
            netG.eval()

            # The DCM
            netDCM = DCM_Net()
            if cfg.TRAIN.NET_C != '':
                state_dict = \
                    torch.load(cfg.TRAIN.NET_C, map_location=lambda storage, loc: storage)
                netDCM.load_state_dict(state_dict)
                print('Load DCM from: ', cfg.TRAIN.NET_C)
            netDCM.cuda()
            netDCM.eval()

            for key in data_dic:
                save_dir = '%s/%s' % (s_tmp, key)
                mkdir_p(save_dir)
                captions, cap_lens, sorted_indices, imgs = data_dic[key]

                batch_size = captions.shape[0]
                nz = cfg.GAN.Z_DIM
                captions = Variable(torch.from_numpy(captions), volatile=True)
                cap_lens = Variable(torch.from_numpy(cap_lens), volatile=True)

                captions = captions.cuda()
                cap_lens = cap_lens.cuda()
                for i in range(1):
                    noise = Variable(torch.FloatTensor(batch_size, nz),
                                     volatile=True)
                    noise = noise.cuda()

                    #######################################################
                    # (1) Extract text and image embeddings
                    ######################################################
                    hidden = text_encoder.init_hidden(batch_size)

                    # The text embeddings
                    words_embs, sent_emb = text_encoder(
                        captions, cap_lens, hidden)

                    # The image embeddings
                    region_features, cnn_code = \
                                    image_encoder(imgs[cfg.TREE.BRANCH_NUM - 1].unsqueeze(0))
                    mask = (captions == 0)

                    #######################################################
                    # (2) Generate fake images
                    ######################################################
                    noise.data.normal_(0, 1)
                    fake_imgs, attention_maps, mu, logvar, h_code, c_code = netG(
                        noise, sent_emb, words_embs, mask, cnn_code,
                        region_features)

                    real_img = imgs[cfg.TREE.BRANCH_NUM - 1].unsqueeze(0)
                    real_features = VGG(real_img)[0]

                    fake_img = netDCM(h_code, real_features, sent_emb, words_embs,\
                                         mask, c_code)

                    cap_lens_np = cap_lens.cpu().data.numpy()
                    for j in range(batch_size):
                        save_name = '%s/%d_s_%d' % (save_dir, i,
                                                    sorted_indices[j])
                        for k in range(len(fake_imgs)):
                            im = fake_imgs[k][j].data.cpu().numpy()
                            im = (im + 1.0) * 127.5
                            im = im.astype(np.uint8)
                            im = np.transpose(im, (1, 2, 0))
                            im = Image.fromarray(im)
                            fullpath = '%s_g%d.png' % (save_name, k)
                            im.save(fullpath)

                        for k in range(len(attention_maps)):
                            if len(fake_imgs) > 1:
                                im = fake_imgs[k + 1].detach().cpu()
                            else:
                                im = fake_imgs[0].detach().cpu()
                            attn_maps = attention_maps[k]
                            att_sze = attn_maps.size(2)
                            img_set, sentences = \
                                build_super_images2(im[j].unsqueeze(0),
                                                    captions[j].unsqueeze(0),
                                                    [cap_lens_np[j]], self.ixtoword,
                                                    [attn_maps[j]], att_sze)
                            if img_set is not None:
                                im = Image.fromarray(img_set)
                                fullpath = '%s_a%d.png' % (save_name, k)
                                im.save(fullpath)

                        save_name = '%s/%d_sf_%d' % (save_dir, 1,
                                                     sorted_indices[j])
                        im = fake_img[j].data.cpu().numpy()
                        im = (im + 1.0) * 127.5
                        im = im.astype(np.uint8)
                        im = np.transpose(im, (1, 2, 0))
                        im = Image.fromarray(im)
                        fullpath = '%s_SF.png' % (save_name)
                        im.save(fullpath)

                    save_name = '%s/%d_s_%d' % (save_dir, 1, 9)
                    im = imgs[2].data.cpu().numpy()
                    im = (im + 1.0) * 127.5
                    im = im.astype(np.uint8)
                    im = np.transpose(im, (1, 2, 0))
                    im = Image.fromarray(im)
                    fullpath = '%s_SR.png' % (save_name)
                    im.save(fullpath)
Ejemplo n.º 10
0
    def gen_example(self, data_dic):        
        if cfg.TRAIN.NET_G == '':
            print('Error: the path for morels is not found!')
        else:
            # Build and load the generator
            batch_size = 16
            text_encoder = \
                RNN_ENCODER(self.n_words, nhidden=cfg.TEXT.EMBEDDING_DIM)
            print("=======self.n_words: %d", self.n_words)
            state_dict = \
                torch.load(cfg.TRAIN.NET_E, map_location=lambda storage, loc: storage)
            # customed restore text encoder parameters
            # ext_encoder.load_state_dict(state_dict)
            own_state = text_encoder.state_dict()
            for name, param in state_dict.items():
                if name not in own_state:
                    continue
                own_state[name] = param
            print('Load text encoder from:', cfg.TRAIN.NET_E)
            text_encoder = text_encoder.cuda()
            text_encoder.eval()

            # the path to save generated images
            if cfg.GAN.B_DCGAN:
                netG = G_DCGAN()
            else:
                netG = G_NET(text_encoder)
            s_tmp = cfg.TRAIN.NET_G[:cfg.TRAIN.NET_G.rfind('.pth')]
            model_dir = cfg.TRAIN.NET_G
            state_dict = \
                torch.load(model_dir, map_location=lambda storage, loc: storage)
            netG.load_state_dict(state_dict)
            print('Load G from: ', model_dir)
            netG.cuda()
            netG.eval()
            for key in data_dic:
                save_dir = '%s/%s' % (s_tmp, key)
                mkdir_p(save_dir)
                captions, cap_lens, sorted_indices = data_dic[key]

                # batch_size = captions.shape[0]
                total_time = len(captions)//batch_size
                nz = cfg.GAN.Z_DIM
                # captions = Variable(torch.from_numpy(captions), volatile=True)
                # cap_lens = Variable(torch.from_numpy(cap_lens), volatile=True)

                # captions = captions.cuda()
                # cap_lens = cap_lens.cuda()
                with torch.no_grad():
                    for i in range(total_time):  # 16
                        noise = Variable(torch.FloatTensor(batch_size, nz))
                        noise = noise.cuda()
                        caption_tmp = Variable(torch.from_numpy(captions[i*batch_size:(i+1)*batch_size]))
                        if i < 3:
                            print(caption_tmp.data)
                        cap_len_tmp = Variable(torch.from_numpy(cap_lens[i*batch_size:(i+1)*batch_size]))
                        caption_tmp = caption_tmp.cuda()
                        cap_len_tmp = cap_len_tmp.cuda()
                        #######################################################
                        # (1) Extract text embeddings
                        ######################################################
                        hidden = text_encoder.init_hidden(batch_size)
                        # words_embs: batch_size x nef x seq_len
                        # sent_emb: batch_size x nef
                        words_embs, sent_emb, _ = text_encoder(caption_tmp, cap_len_tmp, None)
                        words_embs, sent_emb = words_embs.detach(), sent_emb.detach()
                        mask = (caption_tmp == 0)
                        #######################################################
                        # (2) Generate fake images
                        ######################################################
                        random.seed(datetime.now())
                        rnd= random.randint(0,1000)
                        torch.cuda.manual_seed(rnd)
                        noise.data.normal_(0, 1)                                          
                        fake_imgs, attention_maps, _, _, _ = netG(noise, sent_emb, words_embs, mask, caption_tmp, cap_len_tmp)                    
                        # G attention
                        # cap_lens_np = cap_lens.cpu().data.numpy()
                        cap_lens_np = cap_len_tmp.cpu().data.numpy()
                        for j in range(batch_size):
                            save_name = '%s/s_%d' % (save_dir, sorted_indices[i*batch_size+j])
                            for k in range(len(fake_imgs)):
                                im = fake_imgs[k][j].data.cpu().numpy()
                                im = ((im + 1.0) / 2)* 255.0
                                im = im.astype(np.uint8)
                                # print('im', im.shape)
                                im = np.transpose(im, (1, 2, 0))
                                # print('im', im.shape)
                                im = Image.fromarray(im)
                                fullpath = '%s_g%d.png' % (save_name, k)
                                im.save(fullpath)
                                # save to seperate directory                            
                                save_dir2 = '%s/stage_%d' % (save_dir, k)
                                mkdir_p(save_dir2)
                                fullpath = '%s/%d_g%d.png' % (save_dir2, sorted_indices[i*batch_size+j], k)
                                im.save(fullpath)

                            for k in range(len(attention_maps)):
                                if len(fake_imgs) > 1:
                                    im = fake_imgs[k + 1].detach().cpu()
                                else:
                                    im = fake_imgs[0].detach().cpu()
                                attn_maps = attention_maps[k]
                                att_sze = attn_maps.size(2)
                                img_set, sentences = \
                                    build_super_images2(im[j].unsqueeze(0),
                                                        caption_tmp[j].unsqueeze(0),
                                                        [cap_len_tmp[j]], self.ixtoword,
                                                        [attn_maps[j]], att_sze)
                                if img_set is not None:
                                    im = Image.fromarray(img_set)
                                    fullpath = '%s_a%d.png' % (save_name, k)
                                    im.save(fullpath)                            
Ejemplo n.º 11
0
    def generate_fake_im(self, data_dic):

        global text_encoder_path, net_G_path

        # Build and load the generator
        #####################################
        ## load the encoder                 #
        #####################################
        text_encoder = \
            BERT_RNN_ENCODER(self.n_words, nhidden=cfg.TEXT.EMBEDDING_DIM)
        state_dict = \
            torch.load(text_encoder_path,
                        map_location=lambda storage, loc: storage)
        text_encoder.load_state_dict(state_dict)

        print('Loaded text encoder from:', text_encoder_path)
        text_encoder.eval()
        text_encoder = text_encoder.cuda()

        netG = G_NET()
        ######################################
        ## load the generator                #
        ######################################

        state_dict = \
                        torch.load(net_G_path, map_location=lambda storage, loc: storage)
        netG.load_state_dict(state_dict)
        print('Load Generator from: ', net_G_path)
        s_tmp = net_G_path[:net_G_path.rfind('.pth')]

        netG.cuda()
        netG.eval()
        for key in data_dic:
            save_dir = '%s/%s' % (s_tmp, key)
            mkdir_p(save_dir)
            captions, cap_lens, sorted_indices = data_dic[key]

            batch_size = captions.shape[0]
            nz = cfg.GAN.Z_DIM
            captions = Variable(torch.from_numpy(captions))
            cap_lens = Variable(torch.from_numpy(cap_lens))

            captions = captions.cuda()
            cap_lens = cap_lens.cuda()
            for i in range(1):  # 16
                noise = Variable(torch.FloatTensor(batch_size, nz))
                noise = noise.cuda()
                #######################################################
                # (1) Extract text embeddings
                ######################################################
                hidden = text_encoder.init_hidden(batch_size)
                # words_embs: batch_size x nef x seq_len
                # sent_emb: batch_size x nef
                words_embs, sent_emb = text_encoder(captions, cap_lens, hidden)
                mask = (captions == 0)
                #######################################################
                # (2) Generate fake images
                ######################################################
                noise.data.normal_(0, 1)
                fake_imgs, attention_maps, _, _ = netG(noise, sent_emb,
                                                       words_embs, mask)

                return fake_imgs, attention_maps
Ejemplo n.º 12
0
    def evaluate(self):
        if cfg.TRAIN.NET_G == '':
            print('Error: the path for morels is not found!')
        else:
            # Build and load the generator

            self.num_Ds = cfg.TREE.BRANCH_NUM
            self.base_num = 135
            netG = G_NET()
            netG.apply(weights_init)
            netG = torch.nn.DataParallel(netG, device_ids=self.gpus)
            print(netG)
            # state_dict = torch.load(cfg.TRAIN.NET_G)
            state_dict = \
                torch.load(cfg.TRAIN.NET_G,
                           map_location=lambda storage, loc: storage)
            netG.load_state_dict(state_dict)
            print('Load ', cfg.TRAIN.NET_G)

            # the path to save generated images
            s_tmp = cfg.TRAIN.NET_G
            istart = s_tmp.rfind('_') + 1
            iend = s_tmp.rfind('.')
            iteration = int(s_tmp[istart:iend])
            s_tmp = s_tmp[:s_tmp.rfind('/')]
            save_dir = '%s/iteration%d' % (s_tmp, iteration)

            nz = cfg.GAN.Z_DIM
            noise = Variable(torch.FloatTensor(self.batch_size, nz))
            if cfg.CUDA:
                netG.cuda()
                noise = noise.cuda()

            # switch to evaluate mode
            netG.eval()
            for step, data in enumerate(self.data_loader, 0):
                imgs, t_embeddings, filenames, _ = data
                embedding_dim = t_embeddings.size(1)
                batch_size = imgs[0].size(0)
                noise.data.resize_(batch_size, nz)
                noise.data.normal_(0, 1)

                crop_vbase = []

                crop_base_imgs = torch.zeros(batch_size, 3, self.img_size,
                                             self.img_size)
                for step, (base_img_list) in enumerate(data[3]):
                    if cfg.DATASET_NAME.find('flower') != -1:
                        base_ix = random.randint(1, self.base_num)
                        base_img_name = '%s/%s.jpg' % (base_img_list,
                                                       str(base_ix))
                    else:
                        temp_base_list = os.listdir(base_img_list)
                        base_ix = random.randint(0, len(temp_base_list) - 1)
                        base_img_name = '%s/%s.jpg' % (base_img_list,
                                                       str(base_ix))

                    base_img = Image.open(base_img_name).convert('RGB')
                    crop_base = base_img.resize([self.img_size, self.img_size])

                    crop_base = Torchtransform(crop_base)
                    crop_base_imgs[step, :] = crop_base

                if cfg.CUDA:
                    crop_vbase.append(Variable(crop_base_imgs).cuda())
                else:
                    crop_vbase.append(Variable(crop_base_imgs))

                if cfg.CUDA:
                    t_embeddings = Variable(t_embeddings).cuda()
                else:
                    t_embeddings = Variable(t_embeddings)
                for i in range(embedding_dim):
                    fake_imgs, fake_segs, _, _ = netG(noise,
                                                      t_embeddings[:, i, :],
                                                      crop_vbase)
                    self.save_singleimages(fake_imgs, fake_segs[-1],
                                           crop_vbase[0], filenames, save_dir,
                                           i, self.img_size)
Ejemplo n.º 13
0
    def build_models(self):
        def count_parameters(model):
            total_param = 0
            for name, param in model.named_parameters():
                if param.requires_grad:
                    num_param = np.prod(param.size())
                    if param.dim() > 1:
                        print(name, ':', 'x'.join(str(x) for x in list(param.size())), '=', num_param)
                    else:
                        print(name, ':', num_param)
                    total_param += num_param
            return total_param

        # ###################encoders######################################## #
        if cfg.TRAIN.NET_E == '':
            print('Error: no pretrained text-image encoders')
            return

        image_encoder = CNN_ENCODER(cfg.TEXT.EMBEDDING_DIM)
        img_encoder_path = cfg.TRAIN.NET_E.replace('text_encoder', 'image_encoder')
        state_dict = \
            torch.load(img_encoder_path, map_location=lambda storage, loc: storage)
        image_encoder.load_state_dict(state_dict)
        for p in image_encoder.parameters():
            p.requires_grad = False
        print('Load image encoder from:', img_encoder_path)
        image_encoder.eval()

        text_encoder = \
            RNN_ENCODER(self.n_words, nhidden=cfg.TEXT.EMBEDDING_DIM)
        state_dict = \
            torch.load(cfg.TRAIN.NET_E,
                       map_location=lambda storage, loc: storage)
        text_encoder.load_state_dict(state_dict)
        for p in text_encoder.parameters():
            p.requires_grad = False
        print('Load text encoder from:', cfg.TRAIN.NET_E)
        text_encoder.eval()

        # #######################generator and discriminators############## #
        netsD = []
        if cfg.GAN.B_DCGAN:
            if cfg.TREE.BRANCH_NUM ==1:
                from model import D_NET64 as D_NET
            elif cfg.TREE.BRANCH_NUM == 2:
                from model import D_NET128 as D_NET
            else:  # cfg.TREE.BRANCH_NUM == 3:
                from model import D_NET256 as D_NET
            # TODO: elif cfg.TREE.BRANCH_NUM > 3:
            netG = G_DCGAN()
            netsD = [D_NET(b_jcu=False)]
        else:
            from model import D_NET64, D_NET128, D_NET256
            netG = G_NET()
            if cfg.TREE.BRANCH_NUM > 0:
                netsD.append(D_NET64())
            if cfg.TREE.BRANCH_NUM > 1:
                netsD.append(D_NET128())
            if cfg.TREE.BRANCH_NUM > 2:
                netsD.append(D_NET256())
            # TODO: if cfg.TREE.BRANCH_NUM > 3:

        print('number of trainable parameters =', count_parameters(netG))
        print('number of trainable parameters =', count_parameters(netsD[-1]))

        netG.apply(weights_init)
        # print(netG)
        for i in range(len(netsD)):
            netsD[i].apply(weights_init)
            # print(netsD[i])
        print('# of netsD', len(netsD))
        #
        epoch = 0
        if cfg.TRAIN.NET_G != '':
            state_dict = \
                torch.load(cfg.TRAIN.NET_G, map_location=lambda storage, loc: storage)
            netG.load_state_dict(state_dict)
            print('Load G from: ', cfg.TRAIN.NET_G)
            istart = cfg.TRAIN.NET_G.rfind('_') + 1
            iend = cfg.TRAIN.NET_G.rfind('.')
            epoch = cfg.TRAIN.NET_G[istart:iend]
            epoch = int(epoch) + 1
            if cfg.TRAIN.B_NET_D:
                Gname = cfg.TRAIN.NET_G
                for i in range(len(netsD)):
                    s_tmp = Gname[:Gname.rfind('/')]
                    Dname = '%s/netD%d.pth' % (s_tmp, i)
                    print('Load D from: ', Dname)
                    state_dict = \
                        torch.load(Dname, map_location=lambda storage, loc: storage)
                    netsD[i].load_state_dict(state_dict)
        # ########################################################### #
        
        # Create a target network.
        target_netG = deepcopy(netG)
                    
        if cfg.CUDA:
            text_encoder = text_encoder.cuda()
            image_encoder = image_encoder.cuda()
            netG.cuda()
            
            # The target network is stored on the scondary GPU.---------------------------------
            target_netG.cuda() 
            #target_netG.ca_net.device = secondary_device
            #-----------------------------------------------------------------------------------

            for i in range(len(netsD)):
                netsD[i].cuda()
                
        # Disable training in the target network:
        for p in target_netG.parameters():
            p.requires_grad = False  
            
        return [text_encoder, image_encoder, netG, target_netG, netsD, epoch]
Ejemplo n.º 14
0
    def sampling(self, split_dir):
        if cfg.TRAIN.NET_G == '':
            print('Error: the path for morels is not found!')
        else:
            if split_dir == 'test':
                split_dir = 'valid'
            # Build and load the generator
            if cfg.GAN.B_DCGAN:
                netG = G_DCGAN()
            else:
                netG = G_NET()
            netG.apply(weights_init)
            netG.cuda()
            netG.eval()

            # load text encoder
            text_encoder = RNN_ENCODER(self.n_words, nhidden=cfg.TEXT.EMBEDDING_DIM)
            state_dict = torch.load(cfg.TRAIN.NET_E, map_location=lambda storage, loc: storage)
            text_encoder.load_state_dict(state_dict)
            print('Load text encoder from:', cfg.TRAIN.NET_E)
            text_encoder = text_encoder.cuda()
            text_encoder.eval()

            #load image encoder
            image_encoder = CNN_ENCODER(cfg.TEXT.EMBEDDING_DIM)
            img_encoder_path = cfg.TRAIN.NET_E.replace('text_encoder', 'image_encoder')
            state_dict = torch.load(img_encoder_path, map_location=lambda storage, loc: storage)
            image_encoder.load_state_dict(state_dict)
            print('Load image encoder from:', img_encoder_path)
            image_encoder = image_encoder.cuda()
            image_encoder.eval()

            batch_size = self.batch_size
            nz = cfg.GAN.Z_DIM
            noise = Variable(torch.FloatTensor(batch_size, nz), volatile=True)
            noise = noise.cuda()

            model_dir = cfg.TRAIN.NET_G
            state_dict = torch.load(model_dir, map_location=lambda storage, loc: storage)
            # state_dict = torch.load(cfg.TRAIN.NET_G)
            netG.load_state_dict(state_dict)
            print('Load G from: ', model_dir)

            # the path to save generated images
            s_tmp = model_dir[:model_dir.rfind('.pth')]
            save_dir = '%s/%s' % (s_tmp, split_dir)
            mkdir_p(save_dir)

            cnt = 0
            R_count = 0
            R = np.zeros(30000)
            cont = True
            for ii in range(11):  # (cfg.TEXT.CAPTIONS_PER_IMAGE):
                if (cont == False):
                    break
                for step, data in enumerate(self.data_loader, 0):
                    cnt += batch_size
                    if (cont == False):
                        break
                    if step % 100 == 0:
                       print('cnt: ', cnt)
                    # if step > 50:
                    #     break

                    imgs, captions, cap_lens, class_ids, keys = prepare_data(data)

                    hidden = text_encoder.init_hidden(batch_size)
                    # words_embs: batch_size x nef x seq_len
                    # sent_emb: batch_size x nef
                    words_embs, sent_emb = text_encoder(captions, cap_lens, hidden)
                    words_embs, sent_emb = words_embs.detach(), sent_emb.detach()
                    mask = (captions == 0)
                    num_words = words_embs.size(2)
                    if mask.size(1) > num_words:
                        mask = mask[:, :num_words]

                    #######################################################
                    # (2) Generate fake images
                    ######################################################
                    noise.data.normal_(0, 1)
                    fake_imgs, _, _, _ = netG(noise, sent_emb, words_embs, mask, cap_lens)
                    for j in range(batch_size):
                        s_tmp = '%s/single/%s' % (save_dir, keys[j])
                        folder = s_tmp[:s_tmp.rfind('/')]
                        if not os.path.isdir(folder):
                            #print('Make a new folder: ', folder)
                            mkdir_p(folder)
                        k = -1
                        # for k in range(len(fake_imgs)):
                        im = fake_imgs[k][j].data.cpu().numpy()
                        # [-1, 1] --> [0, 255]
                        im = (im + 1.0) * 127.5
                        im = im.astype(np.uint8)
                        im = np.transpose(im, (1, 2, 0))
                        im = Image.fromarray(im)
                        fullpath = '%s_s%d_%d.png' % (s_tmp, k, ii)
                        im.save(fullpath)

                    _, cnn_code = image_encoder(fake_imgs[-1])

                    for i in range(batch_size):
                        mis_captions, mis_captions_len = self.dataset.get_mis_caption(class_ids[i])
                        hidden = text_encoder.init_hidden(99)
                        _, sent_emb_t = text_encoder(mis_captions, mis_captions_len, hidden)
                        rnn_code = torch.cat((sent_emb[i, :].unsqueeze(0), sent_emb_t), 0)
                        ### cnn_code = 1 * nef
                        ### rnn_code = 100 * nef
                        scores = torch.mm(cnn_code[i].unsqueeze(0), rnn_code.transpose(0, 1))  # 1* 100
                        cnn_code_norm = torch.norm(cnn_code[i].unsqueeze(0), 2, dim=1, keepdim=True)
                        rnn_code_norm = torch.norm(rnn_code, 2, dim=1, keepdim=True)
                        norm = torch.mm(cnn_code_norm, rnn_code_norm.transpose(0, 1))
                        scores0 = scores / norm.clamp(min=1e-8)
                        if torch.argmax(scores0) == 0:
                            R[R_count] = 1
                        R_count += 1

                    if R_count >= 30000:
                        sum = np.zeros(10)
                        np.random.shuffle(R)
                        for i in range(10):
                            sum[i] = np.average(R[i * 3000:(i + 1) * 3000 - 1])
                        R_mean = np.average(sum)
                        R_std = np.std(sum)
                        print("R mean:{:.4f} std:{:.4f}".format(R_mean, R_std))
                        cont = False
Ejemplo n.º 15
0
def loading_model(dataset_name='bird'):
  #IMPORTANT ARGUMENTS
  if (dataset_name=='bird') :
    cfg_file=os.path.join(current_dir,"cfg/eval_bird.yml")
  else :
    cfg_file=os.path.join(current_dir,"cfg/eval_coco.yml")
  
  gpu_id=-1 #change it to 0 or more when using gpu
  data_dir=''
  manualSeed = 100

  #cfg file set
  if cfg_file is not None:
    cfg_from_file(cfg_file)

  if gpu_id != -1:
    cfg.GPU_ID = gpu_id
  else:
    cfg.CUDA = False

  if data_dir != '':
    cfg.DATA_DIR = data_dir


  now = datetime.datetime.now(dateutil.tz.tzlocal())
  timestamp = now.strftime('%Y_%m_%d_%H_%M_%S')
  output_dir = '../output/%s_%s_%s' % \
    (cfg.DATASET_NAME, cfg.CONFIG_NAME, timestamp)

  split_dir, bshuffle = 'train', True
  if not cfg.TRAIN.FLAG:
    # bshuffle = False
    split_dir = 'test'


  # Get data loader
  imsize = cfg.TREE.BASE_SIZE * (2 ** (cfg.TREE.BRANCH_NUM - 1))
  image_transform = transforms.Compose([
      transforms.Scale(int(imsize * 76 / 64)),
      transforms.RandomCrop(imsize),
      transforms.RandomHorizontalFlip()])
  dataset = TextDataset(cfg.DATA_DIR, split_dir,
                        base_size=cfg.TREE.BASE_SIZE,
                        transform=image_transform)
  assert dataset
  dataloader = torch.utils.data.DataLoader(
          dataset, batch_size=cfg.TRAIN.BATCH_SIZE,
          drop_last=True, shuffle=bshuffle, num_workers=int(cfg.WORKERS))


  ###setting up ALGO
  # Define models and go to train/evaluate
  algo = trainer(output_dir, dataloader, dataset.n_words, dataset.ixtoword)

  #loading text ENCODER
  text_encoder = RNN_ENCODER(algo.n_words, nhidden=cfg.TEXT.EMBEDDING_DIM)
  state_dict = torch.load(cfg.TRAIN.NET_E, map_location=lambda storage, loc: storage) #TRAIN.NET_E path can be given directly
  text_encoder.load_state_dict(state_dict)
  # print('Load text encoder from:', cfg.TRAIN.NET_E) ###edited here
  if cfg.CUDA:
    text_encoder = text_encoder.cuda()
  text_encoder.eval()


  #LOADING Generator
  netG = G_NET()
  model_dir = cfg.TRAIN.NET_G #directory for model can be given directly as well
  state_dict = torch.load(model_dir, map_location=lambda storage, loc: storage)
  netG.load_state_dict(state_dict)
  # print('Load G from: ', model_dir)  ###edited here
  if cfg.CUDA:
    netG.cuda()
  netG.eval()

  return [algo,text_encoder,netG,dataset]
    def build_models(self):
        # ###################encoders######################################## #
        if cfg.TRAIN.NET_E == '':
            print('Error: no pretrained text-image encoders')
            return

        image_encoder = CNN_ENCODER(cfg.TEXT.EMBEDDING_DIM)
        img_encoder_path = cfg.TRAIN.NET_E.replace('text_encoder',
                                                   'image_encoder')
        state_dict = \
            torch.load(img_encoder_path, map_location=lambda storage, loc: storage)
        image_encoder.load_state_dict(state_dict)
        for p in image_encoder.parameters():
            p.requires_grad = False
        print('Load image encoder from:', img_encoder_path)
        image_encoder.eval()

        text_encoder = \
            RNN_ENCODER(self.n_words, nhidden=cfg.TEXT.EMBEDDING_DIM)
        state_dict = \
            torch.load(cfg.TRAIN.NET_E,
                       map_location=lambda storage, loc: storage)
        text_encoder.load_state_dict(state_dict)
        for p in text_encoder.parameters():
            p.requires_grad = False
        print('Load text encoder from:', cfg.TRAIN.NET_E)
        text_encoder.eval()

        # #######################generator and discriminators############## #
        netsD = []
        if cfg.GAN.B_DCGAN:
            if cfg.TREE.BRANCH_NUM == 1:
                from model import D_NET64 as D_NET
            elif cfg.TREE.BRANCH_NUM == 2:
                from model import D_NET128 as D_NET
            else:  # cfg.TREE.BRANCH_NUM == 3:
                from model import D_NET256 as D_NET
            # TODO: elif cfg.TREE.BRANCH_NUM > 3:
            netG = G_DCGAN()
            netsD = [D_NET(b_jcu=False)]
        else:
            from model import D_NET64, D_NET128, D_NET256
            netG = G_NET()
            if cfg.TREE.BRANCH_NUM > 0:
                netsD.append(D_NET64())
            if cfg.TREE.BRANCH_NUM > 1:
                netsD.append(D_NET128())
            if cfg.TREE.BRANCH_NUM > 2:
                netsD.append(D_NET256())
            # TODO: if cfg.TREE.BRANCH_NUM > 3:
        netG.apply(weights_init)
        # print(netG)
        for i in range(len(netsD)):
            netsD[i].apply(weights_init)
            # print(netsD[i])
        print('# of netsD', len(netsD))
        #
        epoch = 0
        if cfg.TRAIN.NET_G != '':
            state_dict = \
                torch.load(cfg.TRAIN.NET_G, map_location=lambda storage, loc: storage)
            netG.load_state_dict(state_dict)
            print('Load G from: ', cfg.TRAIN.NET_G)
            istart = cfg.TRAIN.NET_G.rfind('_') + 1
            iend = cfg.TRAIN.NET_G.rfind('.')
            epoch = cfg.TRAIN.NET_G[istart:iend]
            epoch = int(epoch) + 1
            if cfg.TRAIN.B_NET_D:
                Gname = cfg.TRAIN.NET_G
                for i in range(len(netsD)):
                    s_tmp = Gname[:Gname.rfind('/')]
                    Dname = '%s/netD%d.pth' % (s_tmp, i)
                    print('Load D from: ', Dname)
                    state_dict = \
                        torch.load(Dname, map_location=lambda storage, loc: storage)
                    netsD[i].load_state_dict(state_dict)
        # ########################################################### #
        if torch.cuda.device_count() > 1:
            print("Let's use", torch.cuda.device_count(), "GPUs!")
            # dim = 0 [30, xxx] -> [10, ...], [10, ...], [10, ...] on 3 GPUs
            text_encoder = nn.DataParallel(text_encoder)
            image_encoder = nn.DataParallel(image_encoder)
            netG = nn.DataParallel(netG)
            for i in range(len(netsD)):
                netsD[i] = nn.DataParallel(netsD[i])

        image_encoder.to(self.device)
        text_encoder.to(self.device)
        netG.to(self.device)
        for i in range(len(netsD)):
            netsD[i].to(self.device)

        # if cfg.CUDA and torch.cuda.is_available():
        #     text_encoder = text_encoder.cuda()
        #     image_encoder = image_encoder.cuda()
        #     netG.cuda()
        #     for i in range(len(netsD)):
        #         netsD[i].cuda()

        # if cfg.PARALLEL:
        #     netG = torch.nn.DataParallel(netG, device_ids=[0, 1, 2])
        #     text_encoder = torch.nn.DataParallel(text_encoder, device_ids=[0, 1, 2])
        #     image_encoder = torch.nn.DataParallel(image_encoder, device_ids=[0, 1, 2])
        #     for i in range(len(netsD)):
        #         netsD[i] = torch.nn.DataParallel(netsD[i], device_ids=[0, 1, 2])

        return [text_encoder, image_encoder, netG, netsD, epoch]
Ejemplo n.º 17
0
    def build_models(self):
        # ###################encoders######################################## #
        if cfg.TRAIN.NET_E == '':
            print('Error: no pretrained text-image encoders')
            return

        image_encoder = CNN_ENCODER(cfg.TEXT.EMBEDDING_DIM)
        img_encoder_path = cfg.TRAIN.NET_E.replace('text_encoder', 'image_encoder')
        state_dict = \
            torch.load(img_encoder_path, map_location=lambda storage, loc: storage)
        image_encoder.load_state_dict(state_dict)
        for p in image_encoder.parameters():
            p.requires_grad = False
        print('Load image encoder from:', img_encoder_path)
        image_encoder.eval()

        text_encoder = \
            RNN_ENCODER(self.n_words, nhidden=cfg.TEXT.EMBEDDING_DIM)
        state_dict = \
            torch.load(cfg.TRAIN.NET_E,
                       map_location=lambda storage, loc: storage)
        # customed restore text encoder parameters
        own_state = text_encoder.state_dict()
        for name, param in state_dict.items():
            if name not in own_state:
                continue
            own_state[name] = param
        # text_encoder.load_state_dict(state_dict)
        # customed restore text encoder parameteres end
        for p in text_encoder.parameters():
            p.requires_grad = False
        print('Load text encoder from:', cfg.TRAIN.NET_E)
        text_encoder.train()

        # #######################generator and discriminators############## #
        netsD = []
        if cfg.GAN.B_DCGAN:
            if cfg.TREE.BRANCH_NUM ==1:
                from model import D_NET64 as D_NET
            elif cfg.TREE.BRANCH_NUM == 2:
                from model import D_NET128 as D_NET
            else:  # cfg.TREE.BRANCH_NUM == 3:
                from model import D_NET256 as D_NET
            # TODO: elif cfg.TREE.BRANCH_NUM > 3:
            netG = G_DCGAN()
            netsD = [D_NET(b_jcu=False)]
        else:
            from model import D_NET64, D_NET128, D_NET256
            netG = G_NET(text_encoder)
            if cfg.TREE.BRANCH_NUM > 0:
                netsD.append(D_NET64())
            if cfg.TREE.BRANCH_NUM > 1:
                netsD.append(D_NET128())
            if cfg.TREE.BRANCH_NUM > 2:
                netsD.append(D_NET256())
            # TODO: if cfg.TREE.BRANCH_NUM > 3:
        netG.apply(weights_init)
        # print(netG)
        for i in range(len(netsD)):
            netsD[i].apply(weights_init)
            # print(netsD[i])
        print('# of netsD', len(netsD))
        #
        epoch = 0
        if cfg.TRAIN.NET_G != '':
            state_dict = \
                torch.load(cfg.TRAIN.NET_G, map_location=lambda storage, loc: storage)
            netG.load_state_dict(state_dict)
            print('Load G from: ', cfg.TRAIN.NET_G)
            istart = cfg.TRAIN.NET_G.rfind('_') + 1
            iend = cfg.TRAIN.NET_G.rfind('.')
            epoch = cfg.TRAIN.NET_G[istart:iend]
            epoch = int(epoch) + 1
            if cfg.TRAIN.B_NET_D:
                Gname = cfg.TRAIN.NET_G
                for i in range(len(netsD)):
                    s_tmp = Gname[:Gname.rfind('/')]
                    Dname = '%s/netD%d.pth' % (s_tmp, i)
                    print('Load D from: ', Dname)
                    state_dict = \
                        torch.load(Dname, map_location=lambda storage, loc: storage)
                    netsD[i].load_state_dict(state_dict)
        # ########################################################### #
        if cfg.CUDA:
            text_encoder = text_encoder.cuda()
            image_encoder = image_encoder.cuda()
            netG.cuda()
            for i in range(len(netsD)):
                netsD[i].cuda()
        return [text_encoder, image_encoder, netG, netsD, epoch]
Ejemplo n.º 18
0
    def gen_example(self, data_dic):
        if cfg.TRAIN.NET_G == '':
            print('Error: the path for models is not found!')
        else:
            # Build and load the generator
            image_encoder = CNN_ENCODER(cfg.TEXT.EMBEDDING_DIM)
            img_encoder_path = cfg.TRAIN.NET_E.replace('text_encoder',
                                                       'image_encoder')
            state_dict = \
                torch.load(img_encoder_path, map_location=lambda storage, loc: storage)
            image_encoder.load_state_dict(state_dict)
            print('Load image encoder from:', img_encoder_path)
            image_encoder = image_encoder.cuda()
            image_encoder.eval()

            # the path to save generated images
            if cfg.GAN.B_DCGAN:
                netG = G_DCGAN()
            else:
                netG = G_NET()
            s_tmp = cfg.TRAIN.NET_G[:cfg.TRAIN.NET_G.rfind('.pth')]
            model_dir = cfg.TRAIN.NET_G
            state_dict = \
                torch.load(model_dir, map_location=lambda storage, loc: storage)
            netG.load_state_dict(state_dict)
            print('Load G from: ', model_dir)
            netG.cuda()
            netG.eval()

            words_embs = Variable(
                torch.zeros(1, cfg.TEXT.EMBEDDING_DIM, cfg.TEXT.WORDS_NUM))
            mask = Variable(torch.zeros(1, cfg.TEXT.WORDS_NUM))
            noise = Variable(torch.FloatTensor(1, cfg.GAN.Z_DIM),
                             volatile=True)
            words_embs, mask, noise = words_embs.cuda(), mask.cuda(
            ), noise.cuda()

            for key in data_dic:
                save_path = '%s/custom/%s' % (s_tmp, key)
                img = data_dic[key]
                img = Variable(img).unsqueeze(0).cuda()

                #######################################################
                # (1) Extract image embeddings
                ######################################################
                _, sent_emb = image_encoder(img)
                #######################################################
                # (2) Generate fake images
                ######################################################
                noise.data.normal_(0, 1)
                fake_imgs, attention_maps, _, _ = netG(noise, sent_emb,
                                                       words_embs, mask)
                # G attention
                for k in range(len(fake_imgs)):
                    im = fake_imgs[k][-1].data.cpu().numpy()
                    im = (im + 1.0) * 127.5
                    im = im.astype(np.uint8)
                    # print('im', im.shape)
                    im = np.transpose(im, (1, 2, 0))
                    # print('im', im.shape)
                    im = Image.fromarray(im)
                    fullpath = '%s_g%d.png' % (save_path, k)
                    im.save(fullpath)
Ejemplo n.º 19
0
    def sample(self, split_dir, num_samples=25, draw_bbox=False):
        from PIL import Image, ImageDraw, ImageFont
        import cPickle as pickle
        import torchvision
        import torchvision.utils as vutils

        if cfg.TRAIN.NET_G == '':
            print('Error: the path for model NET_G is not found!')
        else:
            if split_dir == 'test':
                split_dir = 'valid'
            # Build and load the generator
            text_encoder = RNN_ENCODER(self.n_words,
                                       nhidden=cfg.TEXT.EMBEDDING_DIM)
            state_dict = \
                torch.load(cfg.TRAIN.NET_E, map_location=lambda storage, loc: storage)
            text_encoder.load_state_dict(state_dict)
            print('Load text encoder from:', cfg.TRAIN.NET_E)
            text_encoder = text_encoder.cuda()
            text_encoder.eval()

            batch_size = cfg.TRAIN.BATCH_SIZE
            nz = cfg.GAN.Z_DIM

            model_dir = cfg.TRAIN.NET_G
            state_dict = torch.load(model_dir,
                                    map_location=lambda storage, loc: storage)
            # state_dict = torch.load(cfg.TRAIN.NET_G)
            netG = G_NET()
            print('Load G from: ', model_dir)
            netG.apply(weights_init)

            netG.load_state_dict(state_dict["netG"])
            netG.cuda()
            netG.eval()

            # the path to save generated images
            s_tmp = model_dir[:model_dir.rfind('.pth')]
            save_dir = '%s_%s' % (s_tmp, split_dir)
            mkdir_p(save_dir)
            #######################################
            noise = Variable(torch.FloatTensor(9, nz))

            imsize = 256

            for step, data in enumerate(self.data_loader, 0):
                if step >= num_samples:
                    break

                imgs, captions, cap_lens, class_ids, keys, transformation_matrices, label_one_hot, bbox = \
                    self.prepare_data(data, eval=True)
                transf_matrices_inv = transformation_matrices[1][0].unsqueeze(
                    0)
                label_one_hot = label_one_hot[0].unsqueeze(0)

                img = imgs[-1][0]
                val_image = img.view(1, 3, imsize, imsize)

                hidden = text_encoder.init_hidden(batch_size)
                # words_embs: batch_size x nef x seq_len
                # sent_emb: batch_size x nef
                words_embs, sent_emb = text_encoder(captions, cap_lens, hidden)
                words_embs, sent_emb = words_embs[0].unsqueeze(
                    0).detach(), sent_emb[0].unsqueeze(0).detach()
                words_embs = words_embs.repeat(9, 1, 1)
                sent_emb = sent_emb.repeat(9, 1)
                mask = (captions == 0)
                mask = mask[0].unsqueeze(0)
                num_words = words_embs.size(2)
                if mask.size(1) > num_words:
                    mask = mask[:, :num_words]
                mask = mask.repeat(9, 1)
                transf_matrices_inv = transf_matrices_inv.repeat(9, 1, 1, 1)
                label_one_hot = label_one_hot.repeat(9, 1, 1)

                #######################################################
                # (2) Generate fake images
                ######################################################
                noise.data.normal_(0, 1)
                inputs = (noise, sent_emb, words_embs, mask,
                          transf_matrices_inv, label_one_hot)
                with torch.no_grad():
                    fake_imgs, _, mu, logvar = nn.parallel.data_parallel(
                        netG, inputs, self.gpus)

                data_img = torch.FloatTensor(10, 3, imsize, imsize).fill_(0)
                data_img[0] = val_image
                data_img[1:10] = fake_imgs[-1]

                if draw_bbox:
                    for idx in range(3):
                        x, y, w, h = tuple(
                            [int(imsize * x) for x in bbox[0, idx]])
                        w = imsize - 1 if w > imsize - 1 else w
                        h = imsize - 1 if h > imsize - 1 else h
                        if x <= -1:
                            break
                        data_img[:10, :, y, x:x + w] = 1
                        data_img[:10, :, y:y + h, x] = 1
                        data_img[:10, :, y + h, x:x + w] = 1
                        data_img[:10, :, y:y + h, x + w] = 1

                # get caption
                cap = captions[0].data.cpu().numpy()
                sentence = ""
                for j in range(len(cap)):
                    if cap[j] == 0:
                        break
                    word = self.ixtoword[cap[j]].encode(
                        'ascii', 'ignore').decode('ascii')
                    sentence += word + " "
                sentence = sentence[:-1]
                vutils.save_image(data_img,
                                  '{}/{}_{}.png'.format(
                                      save_dir, sentence, step),
                                  normalize=True,
                                  nrow=10)
            print("Saved {} files to {}".format(step, save_dir))
Ejemplo n.º 20
0
    def evaluate(self, split_dir):
        if cfg.TRAIN.NET_G == '':
            print('Error: the path for morels is not found!')
        else:
            # Build and load the generator
            if split_dir == 'test':
                split_dir = 'valid'
            netG = G_NET()
            netG.apply(weights_init)
            netG = torch.nn.DataParallel(netG, device_ids=self.gpus)
            print(netG)
            # state_dict = torch.load(cfg.TRAIN.NET_G)
            state_dict = \
                torch.load(cfg.TRAIN.NET_G,
                           map_location=lambda storage, loc: storage)
            netG.load_state_dict(state_dict)
            print('Load ', cfg.TRAIN.NET_G)

            netE = load_embedding_model(self.data_loader.dataset.dictionary)
            print(netE)

            nz = cfg.GAN.Z_DIM
            sample_size = cfg.TEST.NUM_IMAGES
            noise = Variable(torch.FloatTensor(sample_size, nz))
            if cfg.CUDA:
                netG.cuda()
                netE.cuda()
                noise = noise.cuda()

            # switch to evaluate mode
            netG.eval()
            count = 0
            output_dir = os.path.join(cfg.OUTPUT_DIR, cfg.EXPERIMENT_NAME)
            for step, data in enumerate(
                    tqdm(self.data_loader, desc='evaluate'), 0):
                imgs, txt_ids, txts = data

                if cfg.CUDA:
                    txt_ids = Variable(txt_ids).cuda()
                else:
                    txt_ids = Variable(txt_ids)

                txts_embeddings = netE(txt_ids)

                batch_size = imgs[0].size(0)

                imgs64, imgs128, imgs256 = [], [], []
                for i in range(0, batch_size):
                    noise.data.normal_(0, 1)
                    txt_embedding = txts_embeddings[i].repeat(sample_size, 1)

                    fake_imgs, _, _ = netG(noise, txt_embedding)

                    imgs64.append(normalize_(fake_imgs[0]))
                    imgs128.append(normalize_(fake_imgs[1]))
                    imgs256.append(normalize_(fake_imgs[2]))

                save_images_with_text(imgs64, imgs128, imgs256, imgs, txts,
                                      batch_size, cfg.TEXT.MAX_LEN, count,
                                      output_dir)

                count = count + batch_size + 1
Ejemplo n.º 21
0
    def generate_fake_images_with_incremental_noise(self, data_dic, sizeim):

        global text_encoder_path, net_G_path
        print(os.getcwd(), os.path.join(os.getcwd(), text_encoder_path))
        text_encoder_path = os.path.join(os.getcwd(), text_encoder_path)
        net_G_path = os.path.join(os.getcwd(), net_G_path)

        # Build and load the generator
        #####################################
        ## load the encoder                 #
        #####################################

        print('Loading text encoder from:', text_encoder_path)


        text_encoder = \
            BERT_RNN_ENCODER(self.n_words, nhidden=cfg.TEXT.EMBEDDING_DIM)
        state_dict = \
            torch.load(text_encoder_path,
                        map_location=lambda storage, loc: storage)
        text_encoder.load_state_dict(state_dict)

        print('Loaded text encoder from:', text_encoder_path)
        text_encoder.eval()
        text_encoder = text_encoder.cuda()

        netG = G_NET()
        ######################################
        ## load the generator                #
        ######################################

        state_dict = \
                        torch.load(net_G_path, map_location=lambda storage, loc: storage)
        netG.load_state_dict(state_dict)
        print('Load Generator from: ', net_G_path)
        s_tmp = net_G_path[:net_G_path.rfind('.pth')]

        netG.cuda()
        netG.eval()
        for key in data_dic:
            save_dir = '%s/%s' % ('res', key)
            mkdir_p(save_dir)
            captions, cap_lens, sorted_indices = data_dic[key]

            batch_size = captions.shape[0]
            nz = cfg.GAN.Z_DIM
            captions = Variable(torch.from_numpy(captions))
            cap_lens = Variable(torch.from_numpy(cap_lens))

            captions = captions.cuda()
            cap_lens = cap_lens.cuda()
            base_noise = Variable(torch.FloatTensor(batch_size, nz))
            base_noise = base_noise.cuda()

            for i in range(sizeim):  # number of images to be created
                noise = base_noise.clone()
                noise[0][i %
                         100] = base_noise[0][i % 100] + torch.mean(base_noise)
                #######################################################
                # (1) Extract text embeddings
                ######################################################
                hidden = text_encoder.init_hidden(batch_size)
                # words_embs: batch_size x nef x seq_len
                # sent_emb: batch_size x nef
                words_embs, sent_emb = text_encoder(captions, cap_lens, hidden)
                mask = (captions == 0)
                #######################################################
                # (2) Generate fake images
                ######################################################
                noise.data.normal_(0, 1)
                fake_imgs, attention_maps, _, _ = netG(noise, sent_emb,
                                                       words_embs, mask)

                im = fake_imgs[2].squeeze(0).data.cpu().numpy()
                im = (im + 1.0) * 127.5
                im = im.astype(np.uint8)
                # print('im', im.shape)
                im = np.transpose(im, (1, 2, 0))
                # print('im', im.shape)
                im = Image.fromarray(im)

                fullpath = os.path.join(save_dir, '{0}.png'.format(i))

                im.save(fullpath)
Ejemplo n.º 22
0
    def sampling(self, split_dir):
        if cfg.TRAIN.NET_G == '':
            print('Error: the path for model is not found!')
        else:
            if split_dir == 'test':
                split_dir = 'valid'
            # Build and load the generator
            if cfg.GAN.B_DCGAN:
                netG = G_DCGAN()
            else:
                netG = G_NET()
            netG.apply(weights_init)
            netG.cuda()
            netG.eval()
            #
            text_encoder = RNN_ENCODER(self.n_words, nhidden=cfg.TEXT.EMBEDDING_DIM)
            state_dict = \
                torch.load(cfg.TRAIN.NET_E, map_location=lambda storage, loc: storage)
            text_encoder.load_state_dict(state_dict)
            print('Load text encoder from:', cfg.TRAIN.NET_E)
            text_encoder = text_encoder.cuda()
            text_encoder.eval()

            batch_size = self.batch_size
            nz = cfg.GAN.Z_DIM
            noise = Variable(torch.FloatTensor(batch_size, nz), volatile=True)
            noise = noise.cuda()

            model_dir = cfg.TRAIN.NET_G
            state_dict = \
                torch.load(model_dir, map_location=lambda storage, loc: storage)
            netG.load_state_dict(state_dict)
            print('Load G from: ', model_dir)

            # the path to save generated images
            s_tmp = model_dir[:model_dir.rfind('.pth')]
            save_dir = '%s/%s' % (s_tmp, split_dir)
            mkdir_p(save_dir)

            cnt = 0

            for _ in range(1):  # (cfg.TEXT.CAPTIONS_PER_IMAGE):
                for step, data in enumerate(self.data_loader, 0):
                    cnt += batch_size
                    if step % 100 == 0:
                        print('step: ', step)
                    # if step > 50:
                    #     break

                    imgs, captions, cap_lens, class_ids, keys = prepare_data(data)

                    hidden = text_encoder.init_hidden(batch_size)
                    # words_embs: batch_size x nef x seq_len
                    # sent_emb: batch_size x nef
                    words_embs, sent_emb = text_encoder(captions, cap_lens, hidden)
                    words_embs, sent_emb = words_embs.detach(), sent_emb.detach()
                    mask = (captions == 0)
                    num_words = words_embs.size(2)
                    if mask.size(1) > num_words:
                        mask = mask[:, :num_words]

                    # (2) Generate fake images
                    noise.data.normal_(0, 1)
                    fake_imgs, _, _, _ = netG(noise, sent_emb, words_embs, mask)
                    for j in range(batch_size):
                        s_tmp = '%s/single/%s' % (save_dir, keys[j])
                        folder = s_tmp[:s_tmp.rfind('/')]
                        if not os.path.isdir(folder):
                            print('Make a new folder: ', folder)
                            mkdir_p(folder)
                        k = -1
                        # for k in range(len(fake_imgs)):
                        im = fake_imgs[k][j].data.cpu().numpy()
                        # [-1, 1] --> [0, 255]
                        im = (im + 1.0) * 127.5
                        im = im.astype(np.uint8)
                        im = np.transpose(im, (1, 2, 0))
                        im = Image.fromarray(im)
                        fullpath = '%s_s%d.png' % (s_tmp, k)
                        im.save(fullpath)
Ejemplo n.º 23
0
    def evaluate_finegan(self):
        self.save_dir = os.path.join(cfg.SAVE_DIR, 'images')
        mkdir_p(self.save_dir)
        random.seed(datetime.now())
        depth = cfg.TEST_DEPTH
        res = 32 * 2**depth
        if cfg.TRAIN.NET_G == '':
            print('Error: the path for model not found!')
        else:
            # Build and load the generator
            netG = G_NET(depth)
            netG.apply(weights_init)
            netG = torch.nn.DataParallel(netG, device_ids=self.gpus)
            model_dict = netG.state_dict()

            state_dict = \
                torch.load(cfg.TRAIN.NET_G,
                           map_location=lambda storage, loc: storage)

            state_dict = {
                k: v
                for k, v in state_dict.items() if k in model_dict
            }

            model_dict.update(state_dict)
            netG.load_state_dict(model_dict)
            print('Load ', cfg.TRAIN.NET_G)

            # Uncomment this to print Generator layers
            # print(netG)

            nrow = 6
            ncol = 4
            z_std = 0.1
            p_vs_c = False
            reprod = False

            if not reprod:
                torch.manual_seed(random.randint(-9999, 9999))

            bg_li = []
            pf_li = []
            cf_li = []
            pk_li = []
            ck_li = []
            pfg_li = []
            cfg_li = []
            pfgmk_li = []
            cfgmk_li = []
            b = random.randint(0, cfg.FINE_GRAINED_CATEGORIES - 1)

            nz = cfg.GAN.Z_DIM
            noise = torch.FloatTensor(1, nz)

            noise.data.normal_(0, z_std)
            # noise = noise.repeat(self.batch_size, 1)

            if cfg.CUDA:
                netG.cuda()
                noise = noise.cuda()

            netG.eval()

            c_li = np.random.randint(0,
                                     cfg.FINE_GRAINED_CATEGORIES - 1,
                                     size=nrow)
            p_li = np.random.randint(0, cfg.SUPER_CATEGORIES - 1, size=nrow)
            for k in range(ncol):
                p = p_li[k]
                # p = random.randint(0, cfg.SUPER_CATEGORIES-1)
                for i in range(nrow):
                    bg_code = torch.zeros(
                        [self.batch_size, cfg.FINE_GRAINED_CATEGORIES])
                    p_code = torch.zeros(
                        [self.batch_size, cfg.SUPER_CATEGORIES])
                    c_code = torch.zeros(
                        [self.batch_size, cfg.FINE_GRAINED_CATEGORIES])
                    c = c_li[i]
                    for j in range(self.batch_size):
                        bg_code[j][b] = 1
                        p_code[j][p] = 1
                        c_code[j][c] = 1

                    fake_imgs, fg_imgs, mk_imgs, fgmk_imgs = netG(
                        noise, c_code, None, p_code,
                        bg_code)  # Forward pass through the generator
                    bg_li.append(fake_imgs[3 * depth][0])
                    pf_li.append(fake_imgs[3 * depth + 1][0])
                    cf_li.append(fake_imgs[3 * depth + 2][0])
                    pk_li.append(mk_imgs[2 * depth][0])
                    ck_li.append(mk_imgs[2 * depth + 1][0])
                    pfg_li.append(fg_imgs[2 * depth][0])
                    cfg_li.append(fg_imgs[2 * depth + 1][0])
                    pfgmk_li.append(fgmk_imgs[2 * depth][0])
                    cfgmk_li.append(fgmk_imgs[2 * depth + 1][0])

            save_image(bg_li, self.save_dir, 'background_pvc', nrow, res)
            save_image(pf_li, self.save_dir, 'parent_final_pvc', nrow, res)
            save_image(cf_li, self.save_dir, 'child_final_pvc', nrow, res)
            save_image(pfg_li, self.save_dir, 'parent_foreground_pvc', nrow,
                       res)
            save_image(cfg_li, self.save_dir, 'child_foreground_pvc', nrow,
                       res)
            save_image(pk_li, self.save_dir, 'parent_mask_pvc', nrow, res)
            save_image(ck_li, self.save_dir, 'child_mask_pvc', nrow, res)
            save_image(pfgmk_li, self.save_dir, 'parent_foreground_masked_pvc',
                       nrow, res)
            save_image(cfgmk_li, self.save_dir, 'child_foreground_masked_pvc',
                       nrow, res)

            bg_li = []
            pf_li = []
            cf_li = []
            pk_li = []
            ck_li = []
            pfg_li = []
            cfg_li = []
            pfgmk_li = []
            cfgmk_li = []
            for _ in range(ncol):
                noise.data.normal_(0, z_std)
                for i in range(nrow):
                    bg_code = torch.zeros(
                        [self.batch_size, cfg.FINE_GRAINED_CATEGORIES])
                    p_code = torch.zeros(
                        [self.batch_size, cfg.SUPER_CATEGORIES])
                    c_code = torch.zeros(
                        [self.batch_size, cfg.FINE_GRAINED_CATEGORIES])
                    c = c_li[i]
                    p = p_li[i]
                    for j in range(self.batch_size):
                        bg_code[j][b] = 1
                        p_code[j][p] = 1
                        c_code[j][c] = 1

                    fake_imgs, fg_imgs, mk_imgs, fgmk_imgs = netG(
                        noise, c_code, None, p_code,
                        bg_code)  # Forward pass through the generator
                    bg_li.append(fake_imgs[3 * depth][0])
                    pf_li.append(fake_imgs[3 * depth + 1][0])
                    cf_li.append(fake_imgs[3 * depth + 2][0])
                    pk_li.append(mk_imgs[2 * depth][0])
                    ck_li.append(mk_imgs[2 * depth + 1][0])
                    pfg_li.append(fg_imgs[2 * depth][0])
                    cfg_li.append(fg_imgs[2 * depth + 1][0])
                    pfgmk_li.append(fgmk_imgs[2 * depth][0])
                    cfgmk_li.append(fgmk_imgs[2 * depth + 1][0])

            save_image(bg_li, self.save_dir, 'background_zvpc', nrow, res)
            save_image(pf_li, self.save_dir, 'parent_final_zvpc', nrow, res)
            save_image(cf_li, self.save_dir, 'child_final_zvpc', nrow, res)
            save_image(pfg_li, self.save_dir, 'parent_foreground_zvpc', nrow,
                       res)
            save_image(cfg_li, self.save_dir, 'child_foreground_zvpc', nrow,
                       res)
            save_image(pk_li, self.save_dir, 'parent_mask_zvpc', nrow, res)
            save_image(ck_li, self.save_dir, 'child_mask_zvpc', nrow, res)
            save_image(pfgmk_li, self.save_dir,
                       'parent_foreground_masked_zvpc', nrow, res)
            save_image(cfgmk_li, self.save_dir, 'child_foreground_masked_zvpc',
                       nrow, res)
Ejemplo n.º 24
0
    def build_models(self):
        # text encoders
        if cfg.TRAIN.NET_E == '':
            print('Error: no pretrained text-image encoders')
            return

        image_encoder = CNN_ENCODER(cfg.TEXT.EMBEDDING_DIM)
        img_encoder_path = cfg.TRAIN.NET_E.replace('text_encoder', 'image_encoder')
        state_dict = torch.load(img_encoder_path, map_location=lambda storage, loc: storage)
        image_encoder.load_state_dict(state_dict)
        for p in image_encoder.parameters():
            p.requires_grad = False
        print('Load image encoder from:', img_encoder_path)
        image_encoder.eval()

        text_encoder = RNN_ENCODER(self.n_words, nhidden=cfg.TEXT.EMBEDDING_DIM)
        state_dict = torch.load(cfg.TRAIN.NET_E, map_location=lambda storage, loc: storage)
        text_encoder.load_state_dict(state_dict)
        for p in text_encoder.parameters():
            p.requires_grad = False
        print('Load text encoder from:', cfg.TRAIN.NET_E)
        text_encoder.eval()

        # Caption models - cnn_encoder and rnn_decoder
        caption_cnn = CAPTION_CNN(cfg.CAP.embed_size)
        caption_cnn.load_state_dict(torch.load(cfg.CAP.caption_cnn_path, map_location=lambda storage, loc: storage))
        for p in caption_cnn.parameters():
            p.requires_grad = False
        print('Load caption model from:', cfg.CAP.caption_cnn_path)
        caption_cnn.eval()

        caption_rnn = CAPTION_RNN(cfg.CAP.embed_size, cfg.CAP.hidden_size * 2, self.n_words, cfg.CAP.num_layers)
        caption_rnn.load_state_dict(torch.load(cfg.CAP.caption_rnn_path, map_location=lambda storage, loc: storage))
        for p in caption_rnn.parameters():
            p.requires_grad = False
        print('Load caption model from:', cfg.CAP.caption_rnn_path)

        # Generator and Discriminator:
        netsD = []
        if cfg.GAN.B_DCGAN:
            if cfg.TREE.BRANCH_NUM == 1:
                from model import D_NET64 as D_NET
            elif cfg.TREE.BRANCH_NUM == 2:
                from model import D_NET128 as D_NET
            else:  # cfg.TREE.BRANCH_NUM == 3:
                from model import D_NET256 as D_NET

            netG = G_DCGAN()
            netsD = [D_NET(b_jcu=False)]
        else:
            from model import D_NET64, D_NET128, D_NET256
            netG = G_NET()
            if cfg.TREE.BRANCH_NUM > 0:
                netsD.append(D_NET64())
            if cfg.TREE.BRANCH_NUM > 1:
                netsD.append(D_NET128())
            if cfg.TREE.BRANCH_NUM > 2:
                netsD.append(D_NET256())
        netG.apply(weights_init)
        # print(netG)
        for i in range(len(netsD)):
            netsD[i].apply(weights_init)
            # print(netsD[i])
        print('# of netsD', len(netsD))

        epoch = 0
        if cfg.TRAIN.NET_G != '':
            state_dict = \
                torch.load(cfg.TRAIN.NET_G, map_location=lambda storage, loc: storage)
            netG.load_state_dict(state_dict)
            print('Load G from: ', cfg.TRAIN.NET_G)
            istart = cfg.TRAIN.NET_G.rfind('_') + 1
            iend = cfg.TRAIN.NET_G.rfind('.')
            epoch = cfg.TRAIN.NET_G[istart:iend]
            epoch = int(epoch) + 1
            if cfg.TRAIN.B_NET_D:
                Gname = cfg.TRAIN.NET_G
                for i in range(len(netsD)):
                    s_tmp = Gname[:Gname.rfind('/')]
                    Dname = '%s/netD%d.pth' % (s_tmp, i)
                    print('Load D from: ', Dname)
                    state_dict = \
                        torch.load(Dname, map_location=lambda storage, loc: storage)
                    netsD[i].load_state_dict(state_dict)

        if cfg.CUDA:
            text_encoder = text_encoder.cuda()
            image_encoder = image_encoder.cuda()
            caption_cnn = caption_cnn.cuda()
            caption_rnn = caption_rnn.cuda()
            netG.cuda()
            for i in range(len(netsD)):
                netsD[i].cuda()
        return [text_encoder, image_encoder, caption_cnn, caption_rnn, netG, netsD, epoch]
Ejemplo n.º 25
0
    def sample_images(self):
        sample_size = 24
        save_dir = '../sample_images/'
        save_final = '../sample_finals/'

        if not os.path.exists(save_dir):
            os.makedirs(save_dir)

        if not os.path.exists(save_final):
            os.makedirs(save_final)

        random.seed(datetime.now())
        depth = cfg.TEST_DEPTH
        res = 32 * 2**depth
        if cfg.TRAIN.NET_G == '':
            print('Error: the path for model not found!')
        else:
            # Build and load the generator
            netG = G_NET(depth)
            netG.apply(weights_init)
            netG = torch.nn.DataParallel(netG, device_ids=self.gpus)
            model_dict = netG.state_dict()

            state_dict = \
                torch.load(cfg.TRAIN.NET_G,
                           map_location=lambda storage, loc: storage)

            state_dict = {
                k: v
                for k, v in state_dict.items() if k in model_dict
            }

            model_dict.update(state_dict)
            netG.load_state_dict(model_dict)
            print('Load ', cfg.TRAIN.NET_G)

            # Uncomment this to print Generator layers
            # print(netG)

            nz = cfg.GAN.Z_DIM
            noise = torch.FloatTensor(1, nz)
            # noise.data.normal_(0, 1)
            # noise = noise.repeat(1, 1)

            if cfg.CUDA:
                netG.cuda()
                noise = noise.cuda()

            netG.eval()

            for i in tqdm(range(sample_size)):
                noise.data.normal_(0, 1)
                bg_code = torch.zeros([1, cfg.FINE_GRAINED_CATEGORIES]).cuda()
                p_code = torch.zeros([1, cfg.SUPER_CATEGORIES]).cuda()
                c_code = torch.zeros([1, cfg.FINE_GRAINED_CATEGORIES]).cuda()
                b = random.randint(0, cfg.FINE_GRAINED_CATEGORIES - 1)
                p = random.randint(0, cfg.SUPER_CATEGORIES - 1)
                c = random.randint(0, cfg.FINE_GRAINED_CATEGORIES - 1)
                bg_code[0][b] = 1
                p_code[0][p] = 1
                c_code[0][c] = 1

                fake_imgs, fg_imgs, mk_imgs, fgmk_imgs = netG(
                    noise, c_code, 1, p_code,
                    bg_code)  # Forward pass through the generator

                self.save_image(fake_imgs[3 * depth + 0][0], save_dir,
                                '%d_bg' % i)
                self.save_image(fake_imgs[3 * depth + 1][0], save_dir,
                                '%d_pf' % i)
                self.save_image(fake_imgs[3 * depth + 2][0], save_dir,
                                '%d_cf' % i)
                self.save_image(fake_imgs[3 * depth + 2][0], save_final,
                                '%d' % i)
                # self.save_image(fg_imgs[2 * depth + 0][0], save_dir, 'parent_foreground')
                # self.save_image(fg_imgs[2 * depth + 1][0], save_dir, 'child_foreground')
                self.save_image(mk_imgs[2 * depth + 0][0], save_dir,
                                '%d_pmk' % i)
                self.save_image(mk_imgs[2 * depth + 1][0], save_dir,
                                '%d_cmk' % i)
Ejemplo n.º 26
0
    def evaluate(self,
                 split_dir,
                 n_samples=4,
                 extractor='googlenet',
                 save_dir=None):
        if cfg.TRAIN.NET_G == '':
            print('Error: the path for morels is not found!')
        else:
            # Build and load the generator
            if split_dir == 'test':
                split_dir = 'valid'
            netG = G_NET()
            netG.apply(weights_init)
            netG = torch.nn.DataParallel(netG, device_ids=self.gpus)
            mapper = EXTRACTOR_MAPPING[extractor]()
            mapper = torch.nn.DataParallel(mapper, device_ids=self.gpus)
            set_parameter_requires_grad(netG, False)
            set_parameter_requires_grad(mapper, False)

            print(netG)
            # state_dict = torch.load(cfg.TRAIN.NET_G)
            state_dict = \
                torch.load(cfg.TRAIN.NET_G,
                           map_location=lambda storage, loc: storage)
            netG.load_state_dict(state_dict)
            print('Load ', cfg.TRAIN.NET_G)

            if save_dir is None:
                # the path to save generated images
                s_tmp = cfg.TRAIN.NET_G
                istart = s_tmp.rfind('_') + 1
                iend = s_tmp.rfind('.')
                iteration = int(s_tmp[istart:iend])
                s_tmp = s_tmp[:s_tmp.rfind('/')]
                save_dir = '%s/iteration%d' % (s_tmp, iteration)

            nz = cfg.GAN.Z_DIM

            if cfg.CUDA:
                netG.cuda()
                mapper.cuda()

            # switch to evaluate mode
            netG.eval()
            mapper.eval()

            synthetic_ds = SyntheticDataset(save_dir)

            for class_embeddings, synthetic_id in self.data_loader.dataset.embeddings_by_class(
            ):

                if cfg.CUDA:
                    class_embeddings = class_embeddings.cuda()

                class_embeddings = class_embeddings.mean(
                    dim=1)  # mean of 10 captions per image
                for i in range(class_embeddings.size(0)):
                    image_embeddings = class_embeddings[i].repeat(n_samples, 1)
                    noise = torch.randn(n_samples, nz)
                    if cfg.CUDA:
                        noise = noise.cuda()

                    imgs, _, _ = netG(noise, image_embeddings)
                    imgs = imgs[-1]
                    samples = mapper(imgs)

                    synthetic_ds.save_pairs(samples, synthetic_id)
Ejemplo n.º 27
0
    def build_models(self):
        # ###################encoders######################################## #
        if cfg.TRAIN.NET_E == '':
            raise FileNotFoundError(
                'No pretrained text encoder found in directory DAMSMencoders/. \n'
                +
                'Please train the DAMSM first before training the GAN (see README for details).'
            )

        image_encoder = CNN_ENCODER(cfg.TEXT.EMBEDDING_DIM)
        img_encoder_path = cfg.TRAIN.NET_E.replace('text_encoder',
                                                   'image_encoder')
        state_dict = \
            torch.load(img_encoder_path, map_location=lambda storage, loc: storage)
        image_encoder.load_state_dict(state_dict)
        for p in image_encoder.parameters():
            p.requires_grad = False
        print('Load image encoder from:', img_encoder_path)
        image_encoder.eval()

        if self.text_encoder_type == 'rnn':
            text_encoder = \
                RNN_ENCODER(self.n_words, nhidden=cfg.TEXT.EMBEDDING_DIM)
        elif self.text_encoder_type == 'transformer':
            text_encoder = GPT2Model.from_pretrained(TRANSFORMER_ENCODER)
        state_dict = \
            torch.load(cfg.TRAIN.NET_E,
                       map_location=lambda storage, loc: storage)
        text_encoder.load_state_dict(state_dict)
        for p in text_encoder.parameters():
            p.requires_grad = False
        print('Load text encoder from:', cfg.TRAIN.NET_E)
        text_encoder.eval()

        # #######################generator and discriminators############## #
        netsD = []
        if cfg.GAN.B_DCGAN:
            if cfg.TREE.BRANCH_NUM == 1:
                from model import D_NET64 as D_NET
            elif cfg.TREE.BRANCH_NUM == 2:
                from model import D_NET128 as D_NET
            else:  # cfg.TREE.BRANCH_NUM == 3:
                from model import D_NET256 as D_NET
            # TODO: elif cfg.TREE.BRANCH_NUM > 3:
            netG = G_DCGAN()
            netsD = [D_NET(b_jcu=False)]
        elif cfg.GAN.B_STYLEGEN:
            netG = G_NET_STYLED()
            if cfg.GAN.B_STYLEDISC:
                from model import D_NET_STYLED64, D_NET_STYLED128, D_NET_STYLED256
                if cfg.TREE.BRANCH_NUM > 0:
                    netsD.append(D_NET_STYLED64())
                if cfg.TREE.BRANCH_NUM > 1:
                    netsD.append(D_NET_STYLED128())
                if cfg.TREE.BRANCH_NUM > 2:
                    netsD.append(D_NET_STYLED256())
                # TODO: if cfg.TREE.BRANCH_NUM > 3:
            else:
                from model import D_NET64, D_NET128, D_NET256
                if cfg.TREE.BRANCH_NUM > 0:
                    netsD.append(D_NET64())
                if cfg.TREE.BRANCH_NUM > 1:
                    netsD.append(D_NET128())
                if cfg.TREE.BRANCH_NUM > 2:
                    netsD.append(D_NET256())
                # TODO: if cfg.TREE.BRANCH_NUM > 3:
        else:
            from model import D_NET64, D_NET128, D_NET256
            netG = G_NET()
            if cfg.TREE.BRANCH_NUM > 0:
                netsD.append(D_NET64())
            if cfg.TREE.BRANCH_NUM > 1:
                netsD.append(D_NET128())
            if cfg.TREE.BRANCH_NUM > 2:
                netsD.append(D_NET256())
            # TODO: if cfg.TREE.BRANCH_NUM > 3:
            netG.apply(weights_init)
            # print(netG)
            for i in range(len(netsD)):
                netsD[i].apply(weights_init)
                # print(netsD[i])
        print(netG.__class__)
        for i in netsD:
            print(i.__class__)
        print('# of netsD', len(netsD))
        #
        epoch = 0
        if cfg.TRAIN.NET_G != '':
            state_dict = torch.load(cfg.TRAIN.NET_G,
                                    map_location=lambda storage, loc: storage)
            if cfg.GAN.B_STYLEGEN:
                netG.w_ewma = state_dict['w_ewma']
                if cfg.CUDA:
                    netG.w_ewma = netG.w_ewma.to('cuda:' + str(cfg.GPU_ID))
                netG.load_state_dict(state_dict['netG_state_dict'])
            else:
                netG.load_state_dict(state_dict)
            print('Load G from: ', cfg.TRAIN.NET_G)
            istart = cfg.TRAIN.NET_G.rfind('_') + 1
            iend = cfg.TRAIN.NET_G.rfind('.')
            epoch = cfg.TRAIN.NET_G[istart:iend]
            epoch = int(epoch) + 1
            if cfg.TRAIN.B_NET_D:
                Gname = cfg.TRAIN.NET_G
                for i in range(len(netsD)):
                    s_tmp = Gname[:Gname.rfind('/')]
                    Dname = '%s/netD%d.pth' % (s_tmp, i)
                    print('Load D from: ', Dname)
                    state_dict = \
                        torch.load(Dname, map_location=lambda storage, loc: storage)
                    netsD[i].load_state_dict(state_dict)
        # ########################################################### #
        if cfg.CUDA:
            text_encoder = text_encoder.cuda()
            image_encoder = image_encoder.cuda()
            netG.cuda()
            for i in range(len(netsD)):
                netsD[i].cuda()
        return [text_encoder, image_encoder, netG, netsD, epoch]
    def build_models(self):
        # ###################encoders######################################## #
        if cfg.TRAIN.NET_E == '':
            print('Error: no pretrained text-image encoders')
            return

        image_encoder = CNN_ENCODER(cfg.TEXT.EMBEDDING_DIM)
        img_encoder_path = cfg.TRAIN.NET_E.replace('text_encoder',
                                                   'image_encoder')
        state_dict = torch.load(img_encoder_path,
                                map_location=lambda storage, loc: storage)
        image_encoder.load_state_dict(state_dict)
        for p in image_encoder.parameters():
            p.requires_grad = False
        print('Load image encoder from:', img_encoder_path)
        image_encoder.eval()

        text_encoder = RNN_ENCODER(self.n_words,
                                   nhidden=cfg.TEXT.EMBEDDING_DIM)
        state_dict = torch.load(cfg.TRAIN.NET_E,
                                map_location=lambda storage, loc: storage)
        text_encoder.load_state_dict(state_dict)
        for p in text_encoder.parameters():
            p.requires_grad = False
        print('Load text encoder from:', cfg.TRAIN.NET_E)
        text_encoder.eval()

        # #######################generator and discriminators############## #
        netsD = []
        from model import D_NET64, D_NET128, D_NET256
        netG = G_NET()
        if cfg.TREE.BRANCH_NUM > 0:
            netsD.append(D_NET64())
        if cfg.TREE.BRANCH_NUM > 1:
            netsD.append(D_NET128())
        if cfg.TREE.BRANCH_NUM > 2:
            netsD.append(D_NET256())

        netG.apply(weights_init)
        for i in range(len(netsD)):
            netsD[i].apply(weights_init)
        print('# of netsD', len(netsD))
        epoch = 0

        if self.resume:
            checkpoint_list = sorted(
                [ckpt for ckpt in glob.glob(self.model_dir + "/" + '*.pth')])
            latest_checkpoint = checkpoint_list[-1]
            state_dict = torch.load(latest_checkpoint,
                                    map_location=lambda storage, loc: storage)

            netG.load_state_dict(state_dict["netG"])
            for i in range(len(netsD)):
                netsD[i].load_state_dict(state_dict["netD"][i])
            epoch = int(latest_checkpoint[-8:-4]) + 1
            print("Resuming training from checkpoint {} at epoch {}.".format(
                latest_checkpoint, epoch))

        #
        if cfg.TRAIN.NET_G != '':
            state_dict = torch.load(cfg.TRAIN.NET_G,
                                    map_location=lambda storage, loc: storage)
            netG.load_state_dict(state_dict)
            print('Load G from: ', cfg.TRAIN.NET_G)
            istart = cfg.TRAIN.NET_G.rfind('_') + 1
            iend = cfg.TRAIN.NET_G.rfind('.')
            epoch = cfg.TRAIN.NET_G[istart:iend]
            epoch = int(epoch) + 1
            if cfg.TRAIN.B_NET_D:
                Gname = cfg.TRAIN.NET_G
                for i in range(len(netsD)):
                    s_tmp = Gname[:Gname.rfind('/')]
                    Dname = '%s/netD%d.pth' % (s_tmp, i)
                    print('Load D from: ', Dname)
                    state_dict = torch.load(
                        Dname, map_location=lambda storage, loc: storage)
                    netsD[i].load_state_dict(state_dict)
        # ########################################################### #
        if cfg.CUDA:
            text_encoder = text_encoder.cuda()
            image_encoder = image_encoder.cuda()
            netG.cuda()
            for i in range(len(netsD)):
                netsD[i].cuda()
        return [text_encoder, image_encoder, netG, netsD, epoch]
Ejemplo n.º 29
0
    def evaluate_finegan(self):
        if cfg.TRAIN.NET_G == '':
            print('Error: the path for model not found!')
        else:
            # Build and load the generator
            netG = G_NET()
            netG.apply(weights_init)
            netG = torch.nn.DataParallel(netG, device_ids=self.gpus)
            model_dict = netG.state_dict()

            state_dict = \
                torch.load(cfg.TRAIN.NET_G,
                           map_location=lambda storage, loc: storage)

            state_dict = {
                k: v
                for k, v in state_dict.items() if k in model_dict
            }

            model_dict.update(state_dict)
            netG.load_state_dict(model_dict)
            print('Load ', cfg.TRAIN.NET_G)

            # Uncomment this to print Generator layers
            # print(netG)

            nz = cfg.GAN.Z_DIM
            noise = torch.FloatTensor(self.batch_size, nz)
            noise.data.normal_(0, 1)

            if cfg.CUDA:
                netG.cuda()
                noise = noise.cuda()

            netG.eval()

            background_class = cfg.TEST_BACKGROUND_CLASS
            parent_class = cfg.TEST_PARENT_CLASS
            child_class = cfg.TEST_CHILD_CLASS
            bg_code = torch.zeros(
                [self.batch_size, cfg.FINE_GRAINED_CATEGORIES])
            p_code = torch.zeros([self.batch_size, cfg.SUPER_CATEGORIES])
            c_code = torch.zeros(
                [self.batch_size, cfg.FINE_GRAINED_CATEGORIES])

            for j in range(self.batch_size):
                bg_code[j][background_class] = 1
                p_code[j][parent_class] = 1
                c_code[j][child_class] = 1

            fake_imgs, fg_imgs, mk_imgs, fgmk_imgs = netG(
                noise, c_code, p_code,
                bg_code)  # Forward pass through the generator

            self.save_image(fake_imgs[0][0], self.save_dir, 'background')
            self.save_image(fake_imgs[1][0], self.save_dir, 'parent_final')
            self.save_image(fake_imgs[2][0], self.save_dir, 'child_final')
            self.save_image(fg_imgs[0][0], self.save_dir, 'parent_foreground')
            self.save_image(fg_imgs[1][0], self.save_dir, 'child_foreground')
            self.save_image(mk_imgs[0][0], self.save_dir, 'parent_mask')
            self.save_image(mk_imgs[1][0], self.save_dir, 'child_mask')
            self.save_image(fgmk_imgs[0][0], self.save_dir,
                            'parent_foreground_masked')
            self.save_image(fgmk_imgs[1][0], self.save_dir,
                            'child_foreground_masked')
Ejemplo n.º 30
0
    def sampling(self, split_dir):
        if cfg.TRAIN.NET_G == '' or cfg.TRAIN.NET_C == '':
            print('Error: the path for main module or DCM is not found!')
        else:
            if split_dir == 'test':
                split_dir = 'valid'

            if cfg.GAN.B_DCGAN:
                netG = G_DCGAN()
            else:
                netG = G_NET()
            netG.apply(weights_init)
            netG.cuda()
            netG.eval()
            # The text encoder
            text_encoder = RNN_ENCODER(self.n_words,
                                       nhidden=cfg.TEXT.EMBEDDING_DIM)
            state_dict = \
                torch.load(cfg.TRAIN.NET_E, map_location=lambda storage, loc: storage)
            text_encoder.load_state_dict(state_dict)
            print('Load text encoder from:', cfg.TRAIN.NET_E)
            text_encoder = text_encoder.cuda()
            text_encoder.eval()
            # The image encoder
            image_encoder = CNN_ENCODER(cfg.TEXT.EMBEDDING_DIM)
            img_encoder_path = cfg.TRAIN.NET_E.replace('text_encoder',
                                                       'image_encoder')
            state_dict = \
                torch.load(img_encoder_path, map_location=lambda storage, loc: storage)
            image_encoder.load_state_dict(state_dict)
            print('Load image encoder from:', img_encoder_path)
            image_encoder = image_encoder.cuda()
            image_encoder.eval()

            # The VGG network
            VGG = VGGNet()
            print("Load the VGG model")
            VGG.cuda()
            VGG.eval()

            batch_size = self.batch_size
            nz = cfg.GAN.Z_DIM
            noise = Variable(torch.FloatTensor(batch_size, nz), volatile=True)
            noise = noise.cuda()

            # The DCM
            netDCM = DCM_Net()
            if cfg.TRAIN.NET_C != '':
                state_dict = \
                    torch.load(cfg.TRAIN.NET_C, map_location=lambda storage, loc: storage)
                netDCM.load_state_dict(state_dict)
                print('Load DCM from: ', cfg.TRAIN.NET_C)
            netDCM.cuda()
            netDCM.eval()

            model_dir = cfg.TRAIN.NET_G
            state_dict = \
                torch.load(model_dir, map_location=lambda storage, loc: storage)
            netG.load_state_dict(state_dict)
            print('Load G from: ', model_dir)

            # the path to save generated images
            s_tmp = model_dir[:model_dir.rfind('.pth')]
            save_dir = '%s/%s' % (s_tmp, split_dir)
            mkdir_p(save_dir)

            cnt = 0
            idx = 0
            for _ in range(5):  # (cfg.TEXT.CAPTIONS_PER_IMAGE):
                for step, data in enumerate(self.data_loader, 0):
                    cnt += batch_size
                    if step % 100 == 0:
                        print('step: ', step)

                    imgs, captions, cap_lens, class_ids, keys, wrong_caps, \
                                wrong_caps_len, wrong_cls_id = prepare_data(data)

                    #######################################################
                    # (1) Extract text and image embeddings
                    ######################################################

                    hidden = text_encoder.init_hidden(batch_size)

                    words_embs, sent_emb = text_encoder(
                        wrong_caps, wrong_caps_len, hidden)
                    words_embs, sent_emb = words_embs.detach(
                    ), sent_emb.detach()

                    mask = (wrong_caps == 0)
                    num_words = words_embs.size(2)
                    if mask.size(1) > num_words:
                        mask = mask[:, :num_words]

                    region_features, cnn_code = \
                                    image_encoder(imgs[cfg.TREE.BRANCH_NUM - 1])

                    #######################################################
                    # (2) Modify real images
                    ######################################################

                    noise.data.normal_(0, 1)
                    fake_imgs, attention_maps, mu, logvar, h_code, c_code = netG(
                        noise, sent_emb, words_embs, mask, cnn_code,
                        region_features)

                    real_img = imgs[cfg.TREE.BRANCH_NUM - 1]
                    real_features = VGG(real_img)[0]

                    fake_img = netDCM(h_code, real_features, sent_emb, words_embs,\
                                         mask, c_code)
                    for j in range(batch_size):
                        s_tmp = '%s/single' % (save_dir)
                        folder = s_tmp[:s_tmp.rfind('/')]
                        if not os.path.isdir(folder):
                            print('Make a new folder: ', folder)
                            mkdir_p(folder)
                        k = -1
                        im = fake_img[j].data.cpu().numpy()
                        im = (im + 1.0) * 127.5
                        im = im.astype(np.uint8)
                        im = np.transpose(im, (1, 2, 0))
                        im = Image.fromarray(im)
                        fullpath = '%s_s%d.png' % (s_tmp, idx)
                        idx = idx + 1
                        im.save(fullpath)
Ejemplo n.º 31
0
    def build_models(self):
        # ################### models ######################################## #
        if cfg.TRAIN.NET_E == '':
            print('Error: no pretrained text-image encoders')
            return
        if cfg.TRAIN.NET_G == '':
            print('Error: no pretrained main module')
            return

        VGG = VGGNet()

        for p in VGG.parameters():
            p.requires_grad = False

        print("Load the VGG model")
        VGG.eval()

        image_encoder = CNN_ENCODER(cfg.TEXT.EMBEDDING_DIM)
        img_encoder_path = cfg.TRAIN.NET_E.replace('text_encoder',
                                                   'image_encoder')
        state_dict = \
            torch.load(img_encoder_path, map_location=lambda storage, loc: storage)
        image_encoder.load_state_dict(state_dict)
        for p in image_encoder.parameters():
            p.requires_grad = False
        print('Load image encoder from:', img_encoder_path)
        image_encoder.eval()

        text_encoder = \
            RNN_ENCODER(self.n_words, nhidden=cfg.TEXT.EMBEDDING_DIM)
        state_dict = \
            torch.load(cfg.TRAIN.NET_E,
                       map_location=lambda storage, loc: storage)
        text_encoder.load_state_dict(state_dict)
        for p in text_encoder.parameters():
            p.requires_grad = False
        print('Load text encoder from:', cfg.TRAIN.NET_E)
        text_encoder.eval()

        if cfg.GAN.B_DCGAN:
            netG = G_DCGAN()
            from model import D_NET256 as D_NET
            netD = D_NET(b_jcu=False)
        else:
            from model import D_NET256
            netG = G_NET()
            netD = D_NET256()

        netD.apply(weights_init)

        state_dict = \
            torch.load(cfg.TRAIN.NET_G, map_location=lambda storage, loc: storage)
        netG.load_state_dict(state_dict)
        netG.eval()
        print('Load G from: ', cfg.TRAIN.NET_G)

        epoch = 0
        netDCM = DCM_Net()
        if cfg.TRAIN.NET_C != '':
            state_dict = \
                torch.load(cfg.TRAIN.NET_C, map_location=lambda storage, loc: storage)
            netDCM.load_state_dict(state_dict)
            print('Load DCM from: ', cfg.TRAIN.NET_C)
            istart = cfg.TRAIN.NET_C.rfind('_') + 1
            iend = cfg.TRAIN.NET_C.rfind('.')
            epoch = cfg.TRAIN.NET_C[istart:iend]
            epoch = int(epoch) + 1

        if cfg.TRAIN.NET_D != '':
            state_dict = \
                torch.load(cfg.TRAIN.NET_D, map_location=lambda storage, loc: storage)
            netD.load_state_dict(state_dict)
            print('Load DCM Discriminator from: ', cfg.TRAIN.NET_D)

        if cfg.CUDA:
            text_encoder = text_encoder.cuda()
            image_encoder = image_encoder.cuda()
            netG.cuda()
            netDCM.cuda()
            VGG = VGG.cuda()
            netD.cuda()
        return [text_encoder, image_encoder, netG, netD, epoch, VGG, netDCM]
Ejemplo n.º 32
0
    def evaluate(self, split_dir):
        if cfg.TRAIN.NET_G == '':
            print('Error: the path for morels is not found!')
        else:
            # Build and load the generator
            if split_dir == 'test':
                split_dir = 'valid'
            netG = G_NET()
            netG.apply(weights_init)
            netG = torch.nn.DataParallel(netG, device_ids=self.gpus)
            print(netG)
            # state_dict = torch.load(cfg.TRAIN.NET_G)
            state_dict = \
                torch.load(cfg.TRAIN.NET_G,
                           map_location=lambda storage, loc: storage)
            netG.load_state_dict(state_dict)
            print('Load ', cfg.TRAIN.NET_G)

            # the path to save generated images
            s_tmp = cfg.TRAIN.NET_G
            istart = s_tmp.rfind('_') + 1
            iend = s_tmp.rfind('.')
            iteration = int(s_tmp[istart:iend])
            s_tmp = s_tmp[:s_tmp.rfind('/')]
            save_dir = '%s/iteration%d' % (s_tmp, iteration)

            nz = cfg.GAN.Z_DIM
            noise = Variable(torch.FloatTensor(self.batch_size, nz))
            if cfg.CUDA:
                netG.cuda()
                noise = noise.cuda()

            # switch to evaluate mode
            netG.eval()
            for step, data in enumerate(self.data_loader, 0):
                imgs, t_embeddings, filenames = data
                if cfg.CUDA:
                    t_embeddings = Variable(t_embeddings).cuda()
                else:
                    t_embeddings = Variable(t_embeddings)
                # print(t_embeddings[:, 0, :], t_embeddings.size(1))

                embedding_dim = t_embeddings.size(1)
                batch_size = imgs[0].size(0)
                noise.data.resize_(batch_size, nz)
                noise.data.normal_(0, 1)

                fake_img_list = []
                for i in range(embedding_dim):
                    fake_imgs, _, _ = netG(noise, t_embeddings[:, i, :])
                    if cfg.TEST.B_EXAMPLE:
                        # fake_img_list.append(fake_imgs[0].data.cpu())
                        # fake_img_list.append(fake_imgs[1].data.cpu())
                        fake_img_list.append(fake_imgs[2].data.cpu())
                    else:
                        self.save_singleimages(fake_imgs[-1], filenames,
                                               save_dir, split_dir, i, 256)
                        # self.save_singleimages(fake_imgs[-2], filenames,
                        #                        save_dir, split_dir, i, 128)
                        # self.save_singleimages(fake_imgs[-3], filenames,
                        #                        save_dir, split_dir, i, 64)
                    # break
                if cfg.TEST.B_EXAMPLE:
                    # self.save_superimages(fake_img_list, filenames,
                    #                       save_dir, split_dir, 64)
                    # self.save_superimages(fake_img_list, filenames,
                    #                       save_dir, split_dir, 128)
                    self.save_superimages(fake_img_list, filenames,
                                          save_dir, split_dir, 256)
Ejemplo n.º 33
0
    def build_models(self):
        print('Building models...')
        print('N_words: ', self.n_words)

        #####################
        ##  TEXT ENCODERS  ##
        #####################
        if cfg.TRAIN.NET_E == '':
            print('Error: no pretrained text-image encoders')
            return

        image_encoder = CNN_ENCODER(cfg.TEXT.EMBEDDING_DIM)
        img_encoder_path = cfg.TRAIN.NET_E.replace('text_encoder',
                                                   'image_encoder')
        state_dict = \
            torch.load(img_encoder_path, map_location=lambda storage, loc: storage)
        image_encoder.load_state_dict(state_dict)
        print('Built image encoder: ', image_encoder)
        for p in image_encoder.parameters():
            p.requires_grad = False
        print('Load image encoder from:', img_encoder_path)
        image_encoder.eval()

        text_encoder = \
            RNN_ENCODER(self.n_words, nhidden=cfg.TEXT.EMBEDDING_DIM)
        state_dict = \
            torch.load(cfg.TRAIN.NET_E,
                       map_location=lambda storage, loc: storage)
        text_encoder.load_state_dict(state_dict)
        print('Built text encoder: ', text_encoder)
        for p in text_encoder.parameters():
            p.requires_grad = False
        print('Load text encoder from:', cfg.TRAIN.NET_E)
        text_encoder.eval()

        ######################
        ##  CAPTION MODELS  ##
        ######################

        # cnn_encoder and rnn_encoder
        if cfg.CAP.USE_ORIGINAL:
            caption_cnn = CAPTION_CNN(embed_size=cfg.TEXT.EMBEDDING_DIM)
            caption_rnn = CAPTION_RNN(embed_size=cfg.TEXT.EMBEDDING_DIM,
                                      hidden_size=cfg.CAP.HIDDEN_SIZE,
                                      vocab_size=self.n_words,
                                      num_layers=cfg.CAP.NUM_LAYERS)
        else:
            caption_cnn = Encoder()
            caption_rnn = Decoder(idx2word=self.ixtoword)

        caption_cnn_checkpoint = torch.load(
            cfg.CAP.CAPTION_CNN_PATH,
            map_location=lambda storage, loc: storage)
        caption_rnn_checkpoint = torch.load(
            cfg.CAP.CAPTION_RNN_PATH,
            map_location=lambda storage, loc: storage)
        caption_cnn.load_state_dict(caption_cnn_checkpoint['model_state_dict'])
        caption_rnn.load_state_dict(caption_rnn_checkpoint['model_state_dict'])

        for p in caption_cnn.parameters():
            p.requires_grad = False
        print('Load caption model from: ', cfg.CAP.CAPTION_CNN_PATH)
        caption_cnn.eval()

        for p in caption_rnn.parameters():
            p.requires_grad = False
        print('Load caption model from: ', cfg.CAP.CAPTION_RNN_PATH)

        #################################
        ##  GENERATOR & DISCRIMINATOR  ##
        #################################
        netsD = []
        if cfg.GAN.B_DCGAN:
            if cfg.TREE.BRANCH_NUM == 1:
                from model import D_NET64 as D_NET
            elif cfg.TREE.BRANCH_NUM == 2:
                from model import D_NET128 as D_NET
            else:  # cfg.TREE.BRANCH_NUM == 3:
                from model import D_NET256 as D_NET

            netG = G_DCGAN()
            netsD = [D_NET(b_jcu=False)]
        else:
            from model import D_NET64, D_NET128, D_NET256
            netG = G_NET()
            if cfg.TREE.BRANCH_NUM > 0:
                netsD.append(D_NET64())
            if cfg.TREE.BRANCH_NUM > 1:
                netsD.append(D_NET128())
            if cfg.TREE.BRANCH_NUM > 2:
                netsD.append(D_NET256())
        netG.apply(weights_init)
        # print(netG)
        for i in range(len(netsD)):
            netsD[i].apply(weights_init)
            # print(netsD[i])
        print('# of netsD', len(netsD))

        epoch = 0
        if cfg.TRAIN.NET_G != '':
            state_dict = \
                torch.load(cfg.TRAIN.NET_G, map_location=lambda storage, loc: storage)
            netG.load_state_dict(state_dict)
            print('Load G from: ', cfg.TRAIN.NET_G)
            istart = cfg.TRAIN.NET_G.rfind('_') + 1
            iend = cfg.TRAIN.NET_G.rfind('.')
            epoch = cfg.TRAIN.NET_G[istart:iend]
            epoch = int(epoch) + 1
            if cfg.TRAIN.B_NET_D:
                Gname = cfg.TRAIN.NET_G
                for i in range(len(netsD)):
                    s_tmp = Gname[:Gname.rfind('/')]
                    Dname = '%s/netD%d.pth' % (s_tmp, i)
                    print('Load D from: ', Dname)
                    state_dict = \
                        torch.load(Dname, map_location=lambda storage, loc: storage)
                    netsD[i].load_state_dict(state_dict)

        text_encoder = text_encoder.to(cfg.DEVICE)
        image_encoder = image_encoder.to(cfg.DEVICE)
        caption_cnn = caption_cnn.to(cfg.DEVICE)
        caption_rnn = caption_rnn.to(cfg.DEVICE)
        netG.to(cfg.DEVICE)
        for i in range(len(netsD)):
            netsD[i].to(cfg.DEVICE)
        return [
            text_encoder, image_encoder, caption_cnn, caption_rnn, netG, netsD,
            epoch
        ]