def save_img_results(self, netG, noise, sent_emb, words_embs, w_sent_emb,
                         w_words_embs, caption, w_caption, gen_iterations,
                         real_imgs, mask, VGG):

        if not os.path.isdir(self.image_dir):
            mkdir_p(self.image_dir)
        # Save images
        enc_features = VGG(real_imgs[-1])
        fake_img, _, _ = nn.parallel.data_parallel(
            netG,
            (real_imgs[-1], sent_emb, words_embs, noise, mask, enc_features),
            self.gpus)

        img_set = build_images(real_imgs[-1], fake_img, caption, w_caption,
                               self.ixtoword)
        fullpath = '%s/average_%d.png' % (self.image_dir, gen_iterations)
        img = Image.fromarray(img_set)
        img.save(fullpath)
        """
        for i in range(len(attention_maps)):
            if len(fake_imgs) > 1:
                img = fake_imgs[i + 1].detach().cpu()
                lr_img = fake_imgs[i].detach().cpu()
            else:
                img = fake_imgs[0].detach().cpu()
                lr_img = None
            attn_maps = attention_maps[i]
            att_sze = attn_maps.size(2)
            img_set, _ = \
                build_super_images(img, captions, self.ixtoword,
                                   attn_maps, att_sze, lr_imgs=lr_img)
            if img_set is not None:
                im = Image.fromarray(img_set)
                fullpath = '%s/G_%s_%d_%d.png'\
                    % (self.image_dir, name, gen_iterations, i)
                im.save(fullpath)

        i = -1
        img = fake_imgs[i].detach()
        region_features, _ = image_encoder(img)
        att_sze = region_features.size(2)
        _, _, att_maps = words_loss(region_features.detach(),
                                    words_embs.detach(),
                                    None, cap_lens,
                                    None, self.batch_size)
        img_set, _ = \
            build_super_images(fake_imgs[i].detach().cpu(),
                               captions, self.ixtoword, att_maps, att_sze)
        if img_set is not None:
            im = Image.fromarray(img_set)
            fullpath = '%s/D_%s_%d.png'\
                % (self.image_dir, name, gen_iterations)
            im.save(fullpath)
        """
        '''
    def sampling(self, split_dir):
        if cfg.TRAIN.NET_G == '':
            print('Error: the path for main module is not found!')
        else:
            if split_dir == 'test':
                split_dir = 'valid'

            if cfg.GAN.B_DCGAN:
                netG = G_DCGAN()
            else:
                netG = EncDecNet()
            netG.apply(weights_init)
            netG.cuda()
            netG.eval()
            # The text encoder
            text_encoder = RNN_ENCODER(self.n_words,
                                       nhidden=cfg.TEXT.EMBEDDING_DIM)
            state_dict = \
                torch.load(cfg.TRAIN.NET_E, map_location=lambda storage, loc: storage)
            text_encoder.load_state_dict(state_dict)
            print('Load text encoder from:', cfg.TRAIN.NET_E)
            text_encoder = text_encoder.cuda()
            text_encoder.eval()
            # The image encoder
            """
            image_encoder = CNN_ENCODER(cfg.TEXT.EMBEDDING_DIM)
            img_encoder_path = cfg.TRAIN.NET_E.replace('text_encoder', 'image_encoder')
            state_dict = \
                torch.load(img_encoder_path, map_location=lambda storage, loc: storage)
            image_encoder.load_state_dict(state_dict)
            print('Load image encoder from:', img_encoder_path)
            image_encoder = image_encoder.cuda()
            image_encoder.eval()
            """

            # The VGG network
            VGG = VGG16()
            print("Load the VGG model")
            VGG.cuda()
            VGG.eval()

            batch_size = self.batch_size
            nz = cfg.GAN.Z_DIM
            noise = Variable(torch.FloatTensor(batch_size, nz), volatile=True)
            noise = noise.cuda()

            model_dir = os.path.join(cfg.DATA_DIR, 'output', self.args.netG,
                                     'Model/netG_epoch_600.pth')
            state_dict = \
                torch.load(model_dir, map_location=lambda storage, loc: storage)
            netG.load_state_dict(state_dict)
            print('Load G from: ', model_dir)

            # the path to save modified images
            save_dir_valid = os.path.join(cfg.DATA_DIR, 'output',
                                          self.args.netG, 'valid')
            #mkdir_p(save_dir)

            cnt = 0
            idx = 0
            for i in range(5):  # (cfg.TEXT.CAPTIONS_PER_IMAGE):
                # the path to save modified images
                save_dir = os.path.join(save_dir_valid, 'valid_%d' % i)
                save_dir_super = os.path.join(save_dir, 'super')
                save_dir_single = os.path.join(save_dir, 'single')
                mkdir_p(save_dir_super)
                mkdir_p(save_dir_single)
                for step, data in enumerate(self.data_loader, 0):
                    cnt += batch_size
                    if step % 100 == 0:
                        print('step: ', step)

                    imgs, w_imgs, captions, cap_lens, class_ids, keys, wrong_caps, \
                                wrong_caps_len, wrong_cls_id = prepare_data(data)

                    #######################################################
                    # (1) Extract text and image embeddings
                    ######################################################

                    hidden = text_encoder.init_hidden(batch_size)

                    words_embs, sent_emb = text_encoder(
                        wrong_caps, wrong_caps_len, hidden)
                    words_embs, sent_emb = words_embs.detach(
                    ), sent_emb.detach()

                    mask = (wrong_caps == 0)
                    num_words = words_embs.size(2)
                    if mask.size(1) > num_words:
                        mask = mask[:, :num_words]

                    #######################################################
                    # (2) Modify real images
                    ######################################################

                    noise.data.normal_(0, 1)

                    fake_img, mu, logvar = netG(imgs[-1], sent_emb, words_embs,
                                                noise, mask, VGG)

                    img_set = build_images(imgs[-1], fake_img, captions,
                                           wrong_caps, self.ixtoword)
                    img = Image.fromarray(img_set)
                    full_path = '%s/super_step%d.png' % (save_dir_super, step)
                    img.save(full_path)

                    for j in range(batch_size):
                        s_tmp = '%s/single' % (save_dir_single)
                        folder = s_tmp[:s_tmp.rfind('/')]
                        if not os.path.isdir(folder):
                            print('Make a new folder: ', folder)
                            mkdir_p(folder)
                        k = -1
                        im = fake_img[j].data.cpu().numpy()
                        #im = (im + 1.0) * 127.5
                        im = im.astype(np.uint8)
                        im = np.transpose(im, (1, 2, 0))
                        im = Image.fromarray(im)
                        fullpath = '%s_s%d.png' % (s_tmp, idx)
                        idx = idx + 1
                        im.save(fullpath)
    def save_img_results(self,
                         netsG,
                         noise,
                         atts,
                         image_atts,
                         inception_model,
                         classifiers,
                         real_imgs,
                         gen_iterations,
                         name='current'):
        mkdir_p(self.image_dir)

        # Save images
        """
        if self.args.kl_loss:
            fake_imgs, _, _ = netG(noise, atts) ##
        else:
            fake_imgs, _  = netG(noise, atts, image_att, inception_model, classifiers, imgs)
        """

        fake_imgs = []
        C_losses = None
        if not self.args.kl_loss:
            if cfg.TREE.BRANCH_NUM > 0:
                fake_img1, h_code1 = nn.parallel.data_parallel(
                    netsG[0], (noise, atts, image_atts), self.gpus)
                fake_imgs.append(fake_img1)
                if self.args.split == 'train':  ##for train: real_imgsの特徴量を使う。
                    att_embeddings, C_losses = classifier_loss(
                        classifiers, inception_model, real_imgs[0], image_atts,
                        C_losses)
                    _, C_losses = classifier_loss(classifiers, inception_model,
                                                  fake_img1, image_atts,
                                                  C_losses)
                else:  ##いらない
                    att_embeddings, _ = classifier_loss(
                        classifiers, inception_model, fake_img1, image_atts)

            if cfg.TREE.BRANCH_NUM > 1:
                fake_img2, h_code2 = nn.parallel.data_parallel(
                    netsG[1], (h_code1, att_embeddings), self.gpus)
                fake_imgs.append(fake_img2)
                if self.args.split == 'train':
                    att_embeddings, C_losses = classifier_loss(
                        classifiers, inception_model, real_imgs[1], image_atts,
                        C_losses)
                    _, C_losses = classifier_loss(classifiers, inception_model,
                                                  fake_img1, image_atts,
                                                  C_losses)
                else:
                    att_embeddings, _ = classifier_loss(
                        classifiers, inception_model, fake_img1, image_atts)

            if cfg.TREE.BRANCH_NUM > 2:
                fake_img3 = nn.parallel.data_parallel(
                    netsG[2], (h_code2, att_embeddings), self.gpus)
                fake_imgs.append(fake_img3)

        ##make image set
        img_set = build_images(fake_imgs)  ##
        img = Image.fromarray(img_set)
        full_path = '%s/G_%s.png' % (self.image_dir, gen_iterations)
        img.save(full_path)
    def sampling(self):
        if self.args.netG == '':
            print('Error: the path for models is not found!')
        else:
            data_dir = cfg.DATA_DIR
            if self.args.split == "test_unseen":
                filepath = os.path.join(data_dir,
                                        "test_unseen/class_data.pickle")
            else:  #test_seen
                filepath = os.path.join(data_dir,
                                        "test_seen/class_data.pickle")
            if os.path.isfile(filepath):
                with open(filepath, "rb") as f:
                    data_dic = pkl.load(f)
            class_names = data_dic['classes']
            class_ids = data_dic['class_info']

            att_dir = os.path.join(data_dir, "CUB_200_2011/attributes")
            att_np = np.zeros((312, 200))  #for CUB
            with open(att_dir + "/class_attribute_labels_continuous.txt",
                      "r") as f:
                for ind, line in enumerate(f.readlines()):
                    line = line.strip("\n")
                    line = list(map(float, line.split()))
                    att_np[:, ind] = line

            if self.args.kl_loss:
                netG = G_NET()
            else:
                netG = G_NET_not_CA()
            test_model = "netG_epoch_600.pth"
            model_path = os.path.join(self.args.netG, "Model", test_model)  ##
            state_dic = torch.load(model_path,
                                   map_location=lambda storage, loc: storage)
            netG.load_state_dict(state_dic)
            netG.cuda()
            netG.eval()

            noise = torch.FloatTensor(self.batch_size, cfg.GAN.Z_DIM)

            for class_name, class_id in zip(class_names, class_ids):
                print("now generating, ", class_name)
                class_dir = os.path.join(self.args.netG, 'valid',
                                         test_model[:test_model.rfind(".")],
                                         self.args.split, class_name)
                atts = att_np[:, class_id - 1]
                atts = np.expand_dims(atts, axis=0)
                atts = atts.repeat(self.batch_size, axis=0)
                assert atts.shape == (self.batch_size, 312)

                if cfg.CUDA:
                    noise = noise.cuda()
                    atts = torch.cuda.FloatTensor(atts)
                else:
                    atts = torch.FloatTensor(atts)

                for i in range(self.sample_num):
                    noise.normal_(0, 1)
                    if self.args.kl_loss:
                        fake_imgs, _, _ = nn.parallel.data_parallel(
                            netG, (noise, atts), self.gpus)
                    else:
                        fake_imgs = nn.parallel.data_parallel(
                            netG, (noise, atts), self.gpus)
                    for stage in range(len(fake_imgs)):
                        for num, im in enumerate(fake_imgs[stage]):
                            im = im.detach().cpu()
                            im = im.add_(1).div_(2).mul_(255)
                            im = im.numpy().astype(np.uint8)
                            im = np.transpose(im, (1, 2, 0))
                            im = Image.fromarray(im)
                            stage_dir = os.path.join(class_dir,
                                                     "stage_%d" % stage)
                            mkdir_p(stage_dir)
                            img_path = os.path.join(stage_dir,
                                                    "single_%d.png" % num)
                            im.save(img_path)
                        for j in range(int(self.batch_size /
                                           20)):  ## cfg.batch_size==100
                            one_set = [
                                fake_imgs[0][j * 20:(j + 1) * 20],
                                fake_imgs[1][j * 20:(j + 1) * 20],
                                fake_imgs[2][j * 20:(j + 1) * 20]
                            ]
                            img_set = build_images(one_set)
                            img_set = Image.fromarray(img_set)
                            super_dir = os.path.join(class_dir, "super")
                            mkdir_p(super_dir)
                            img_path = os.path.join(super_dir,
                                                    "super_%d.png" % j)
                            img_set.save(img_path)
Beispiel #5
0
    def train(self):
        text_encoder, image_encoder = self.build_models()
        netG, netD = self.netG, self.netD
        optimizerG = self.optimizerG
        optimizerD = self.optimizerD
        device = 'cuda' if torch.cuda.is_available() else 'cpu'
        start_epoch = 0

        nz = cfg.GAN.Z_DIM
        noise = Variable(torch.FloatTensor(self.batch_size, nz)).to(device)
        fixed_noise = Variable(torch.FloatTensor(self.batch_size, nz).normal_(0, 1)).to(device)

        # gen_iterations = start_epoch * self.num_batches
        for epoch in range(start_epoch, self.max_epoch):
            epoch = epoch + self.start_epoch
            for step, data in enumerate(self.data_loader):
                start_t = time.time()

                shape, cap, cap_len, cls_id, key = data

                sorted_cap_lens, sorted_cap_indices = torch.sort(cap_len, 0, True)

                #sort
                shapes = shape[sorted_cap_indices].squeeze().to(device)
                captions = cap[sorted_cap_indices].squeeze().to(device)
                class_ids = cls_id[sorted_cap_indices].squeeze().numpy()

                hidden = text_encoder.init_hidden(self.batch_size)
                # words_embs: batch_size x nef x seq_len
                # sent_emb: batch_size x nef
                words_embs, sent_emb = text_encoder(captions, sorted_cap_lens, hidden)
                # words_embs, sent_emb = words_embs.detach(), sent_emb.detach()
                mask = (captions == 0)
                num_words = words_embs.size(2)
                if mask.size(1) > num_words:
                    mask = mask[:, :num_words]
                #######################################################
                # (2) Generate fake images
                ######################################################
                noise.data.normal_(0, 1)
                fake_shapes, mu, logvar = netG(noise, sent_emb)

                #######################################################
                # (3) Update D network
                ######################################################

                real_labels = torch.FloatTensor(self.batch_size).fill_(1).to(device)
                fake_labels = torch.FloatTensor(self.batch_size).fill_(0).to(device)

                netD.zero_grad()

                real_features = netD(shapes).to(device)
                cond_real_errD = nn.BCELoss()(real_features, real_labels)
                fake_features = netD(fake_shapes).to(device)
                cond_fake_errD = nn.BCELoss()(fake_features, fake_labels)



                errD_total = cond_real_errD + cond_fake_errD / 2.

                d_real_acu = torch.ge(real_features.squeeze(), 0.5).float()
                d_fake_acu = torch.le(fake_features.squeeze(), 0.5).float()
                d_total_acu = torch.mean(torch.cat((d_real_acu, d_fake_acu),0))

                if d_total_acu < 0.85:
                    errD_total.backward(retain_graph=True)
                    optimizerD.step()

                # #######################################################
                # # (4) Update G network: maximize log(D(G(z)))
                # ######################################################
                # # compute total loss for training G
                # step += 1
                # gen_iterations += 1

                # # do not need to compute gradient for Ds
                # # self.set_requires_grad_value(netsD, False)
                netG.zero_grad()
                # errG_total, G_logs = \
                #     generator_loss(netsD, image_encoder, fake_imgs, real_labels,
                #                    words_embs, sent_emb, match_labels, cap_lens, class_ids)

                labels = Variable(torch.LongTensor(range(self.batch_size)))
                real_labels = torch.FloatTensor(self.batch_size).fill_(1).to(device)

                real_features = netD(fake_shapes)
                cond_real_errG = nn.BCELoss()(real_features, real_labels)

                kl_loss = KL_loss(mu, logvar)
                errG_total = kl_loss + cond_real_errG



                if step % 10 == 0:
                    region_features, cnn_code = image_encoder(fake_shapes)

                    w_loss0, w_loss1, _ = words_loss(region_features, words_embs,
                                                    labels, sorted_cap_lens,
                                                    class_ids, self.batch_size)
                    w_loss = (w_loss0 + w_loss1) * \
                        cfg.TRAIN.SMOOTH.LAMBDA

                    s_loss0, s_loss1 = sent_loss(cnn_code, sent_emb,
                                                labels, class_ids, self.batch_size)
                    s_loss = (s_loss0 + s_loss1) * \
                        cfg.TRAIN.SMOOTH.LAMBDA

                    errG_total += s_loss + w_loss
                    self.exp.metric('s_loss', s_loss.item())
                    self.exp.metric('w_loss', w_loss.item())

                # print('kl: %.2f w s, %.2f %.2f, cond %.2f' % (kl_loss.item(), w_loss.item(), s_loss.item(), cond_real_errG.item()))
                # # backward and update parameters
                errG_total.backward()


                optimizerG.step()

                end_t = time.time()

                self.exp.metric('d_loss', errD_total.item())
                self.exp.metric('g_loss', errG_total.item())
                self.exp.metric('act', d_total_acu.item())

                print('''[%d/%d][%d/%d] Loss_D: %.2f Loss_G: %.2f Time: %.2fs'''
                    % (epoch, self.max_epoch, step, self.num_batches,
                        errD_total.item(), errG_total.item(),
                        end_t - start_t))

                if step % 500 == 0:
                    fullpath = '%s/lean_%d_%d.png' % (self.image_dir,epoch, step)
                    build_images(fake_shapes, captions, self.ixtoword, fullpath)

            torch.save(netG.state_dict(),'%s/netG_epoch_%d.pth' % (self.model_dir, epoch))
            torch.save(netD.state_dict(),'%s/netD_epoch_%d.pth' % (self.model_dir, epoch))
            print('Save G/Ds models.')