def greedy_baseline(self, fc_feats, att_feats, att_masks, retrieval_loss,
                        _seqs, _sampleLogProbs, _masks):
        if att_masks is not None:
            wrapper = [fc_feats, att_feats, att_masks]
            _seqs_greedy, _sampleLogProbs_greedy = \
                self.caption_generator.sample(
                *utils.var_wrapper(wrapper, cuda=torch.cuda.is_available(),
                                   volatile=True), opt={
                    'sample_max': 1, 'temperature': 1})
        else:
            wrapper = [fc_feats, att_feats]
            _seqs_greedy, _sampleLogProbs_greedy = \
                self.caption_generator.sample(
                    *utils.var_wrapper(wrapper, cuda=torch.cuda.is_available(),
                                       volatile=True), None, opt={
                        'sample_max': 1, 'temperature': 1})
        greedy_res = _seqs_greedy

        if (_seqs_greedy > 0).float()[:, :-1].dim() > 1:
            _masks_greedy = torch.cat([
                Variable(
                    _seqs_greedy.data.new(_seqs.size(0), 2).fill_(1).float()),
                (_seqs_greedy > 0).float()[:, :-1]
            ], 1)
        else:
            _masks_greedy = torch.cat([
                Variable(
                    _seqs_greedy.data.new(_seqs.size(0), 2).fill_(1).float()),
                torch.unsqueeze((_seqs_greedy > 0).float()[:, :-1], 1)
            ], 1)

        _seqs_greedy = torch.cat([
            Variable(
                _seqs_greedy.data.new(
                    _seqs_greedy.size(0),
                    1).fill_(self.caption_generator.vocab_size + 1)),
            _seqs_greedy
        ], 1)

        baseline = self.vse(fc_feats,
                            att_feats,
                            _seqs_greedy,
                            _masks_greedy,
                            True,
                            only_one_retrieval=self.only_one_retrieval)

        sc_loss = _sampleLogProbs * (
            utils.var_wrapper(retrieval_loss, cuda=torch.cuda.is_available()) -
            utils.var_wrapper(baseline, cuda=torch.cuda.is_available())
        ).detach().unsqueeze(1) * (_masks[:, 1:].detach().float())
        return baseline, sc_loss, greedy_res
    def gt_baseline(self, fc_feats, att_feats, att_masks, retrieval_loss,
                    _seqs, _sampleLogProbs, _masks, seq, masks):

        baseline = self.vse(fc_feats,
                            att_feats,
                            seq,
                            masks,
                            True,
                            only_one_retrieval=self.only_one_retrieval)
        sc_loss = _sampleLogProbs * (utils.var_wrapper(
            retrieval_loss, cuda=torch.cuda.is_available
            ()) - utils.var_wrapper(baseline, cuda=
        torch.cuda.is_available())).detach() \
            .unsqueeze(1) * (_masks[:, 1:].detach().float())
        return baseline, sc_loss
    def traditional_cider(self, fc_feats, att_feats, att_masks, data, loss,
                          gen_result, greedy_res, sample_logprobs, gen_masks):

        # Use the differenced rewards
        if self.use_gen_cider_scores == 0:
            reward, cider_greedy = rewards.get_self_critical_reward(
                data, gen_result, greedy_res)
        else:  # use the original rewards
            reward, _, cider_greedy = \
                rewards.get_self_critical_reward(
                    data, gen_result, greedy_res,
                    return_gen_scores=True)
        self._loss['avg_reward'] = reward.mean()
        self._loss['cider_greedy'] = cider_greedy

        loss_cider = sample_logprobs * utils.var_wrapper(
            -reward.astype('float32'),
            cuda=torch.cuda.is_available()).unsqueeze(1) * (
                gen_masks[:, 1:].detach())

        loss_cider = loss_cider.sum() / \
                     gen_masks[:, 1:].data.float().sum()
        loss += self.cider_optimization * loss_cider
        self._loss['loss_cider'] = loss_cider.data[0]

        return loss
    def greedy_res_for_cider(self, fc_feats, att_feats, att_masks):
        if att_masks is not None:
            greedy_res, _ = self.caption_generator.sample(
                *utils.var_wrapper([fc_feats, att_feats, att_masks],
                                   cuda=torch.cuda.is_available(),
                                   volatile=True),
                opt={'sample_max': 1})
        else:
            greedy_res, _ = self.caption_generator.sample(
                *utils.var_wrapper([fc_feats, att_feats],
                                   cuda=torch.cuda.is_available(),
                                   volatile=True),
                att_masks,
                opt={'sample_max': 1})

        return greedy_res
示例#5
0
def encode_data(model, loader, eval_kwargs={}):
    num_images = eval_kwargs.get('num_images',
                                 eval_kwargs.get('val_images_use', -1))
    split = eval_kwargs.get('split', 'val')
    dataset = eval_kwargs.get('dataset', 'coco')

    # Make sure in the evaluation mode
    model.eval()

    loader_seq_per_img = loader.seq_per_img
    loader.seq_per_img = 5
    loader.reset_iterator(split)

    n = 0
    img_embs = []
    cap_embs = []
    while True:
        data = loader.get_batch(split)
        n = n + loader.batch_size

        tmp = [
            data['fc_feats'], data['att_feats'], data['labels'], data['masks']
        ]
        tmp = utils.var_wrapper(tmp)
        fc_feats, att_feats, labels, masks = tmp

        with torch.no_grad():
            img_emb = model.vse.img_enc(fc_feats)
            cap_emb = model.vse.txt_enc(labels, masks)

        # if we wrapped around the split or used up val imgs budget then bail
        ix0 = data['bounds']['it_pos_now']
        ix1 = data['bounds']['it_max']
        if num_images != -1:
            ix1 = min(ix1, num_images)

        if n > ix1:
            img_emb = img_emb[:(ix1 - n) * loader.seq_per_img]
            cap_emb = cap_emb[:(ix1 - n) * loader.seq_per_img]

        # preserve the embeddings by copying from gpu and converting to np
        img_embs.append(img_emb.data.cpu().numpy().copy())
        cap_embs.append(cap_emb.data.cpu().numpy().copy())

        if data['bounds']['wrapped']:
            break
        if num_images >= 0 and n >= num_images:
            break

        print("%d/%d" % (n, ix1))

    img_embs = np.vstack(img_embs)
    cap_embs = np.vstack(cap_embs)

    assert img_embs.shape[0] == ix1 * loader.seq_per_img

    loader.seq_per_img = loader_seq_per_img

    return img_embs, cap_embs
    def no_baseline(self, retrieval_loss, _sampleLogProbs, _masks):

        baseline = 0
        sc_loss = _sampleLogProbs * (utils.var_wrapper(
            retrieval_loss, torch.cuda.is_available())) \
            .detach().unsqueeze(1) * (_masks[:, 1:].
                                      detach().float())
        return baseline, sc_loss
示例#7
0
    def forward(self, fc_feats, att_feats, att_masks, seq, masks, data):
        if self.caption_loss_weight > 0 and not self.cider_optimization:
            loss_cap = self.caption_generator(fc_feats, att_feats, att_masks,
                                              seq, masks)
        else:
            loss_cap = Variable(torch.cuda.FloatTensor([0]))
        if self.vse_loss_weight > 0:
            loss_vse = self.vse(fc_feats,
                                att_feats,
                                seq,
                                masks,
                                only_one_retrieval=self.only_one_retrieval)
        else:
            loss_vse = Variable(torch.cuda.FloatTensor([0]))

        loss = self.caption_loss_weight * loss_cap + self.vse_loss_weight * loss_vse

        if self.retrieval_reward_weight > 0:
            if True:
                _seqs, _sampleLogProbs = self.caption_generator.sample(
                    fc_feats, att_feats, att_masks, {
                        'sample_max': 0,
                        'temperature': 1
                    })
                gen_result, sample_logprobs = _seqs, _sampleLogProbs
                _masks = torch.cat([
                    Variable(
                        _seqs.data.new(_seqs.size(0), 2).fill_(1).float()),
                    (_seqs > 0).float()[:, :-1]
                ], 1)

                gen_masks = _masks

                _seqs = torch.cat([
                    Variable(
                        _seqs.data.new(
                            _seqs.size(0),
                            1).fill_(self.caption_generator.vocab_size + 1)),
                    _seqs
                ], 1)

                if True:
                    retrieval_loss = self.vse(
                        fc_feats,
                        att_feats,
                        _seqs,
                        _masks,
                        True,
                        only_one_retrieval=self.only_one_retrieval)
                    if self.reinforce_baseline_type == 'greedy':
                        _seqs_greedy, _sampleLogProbs_greedy = self.caption_generator.sample(
                            *utils.var_wrapper(
                                [fc_feats, att_feats, att_masks],
                                volatile=True),
                            opt={
                                'sample_max': 1,
                                'temperature': 1
                            })
                        greedy_res = _seqs_greedy
                        # Do we need weights here???
                        if True:  #not self.use_word_weights:
                            _masks_greedy = torch.cat([
                                Variable(
                                    _seqs_greedy.data.new(_seqs.size(0),
                                                          2).fill_(1).float()),
                                (_seqs_greedy > 0).float()[:, :-1]
                            ], 1)
                        else:
                            _masks_greedy = self.get_word_weights_mask(
                                _seqs_greedy)

                        _seqs_greedy = torch.cat([
                            Variable(
                                _seqs_greedy.data.new(_seqs_greedy.size(0), 1).
                                fill_(self.caption_generator.vocab_size + 1)),
                            _seqs_greedy
                        ], 1)

                        baseline = self.vse(
                            fc_feats,
                            att_feats,
                            _seqs_greedy,
                            _masks_greedy,
                            True,
                            only_one_retrieval=self.only_one_retrieval)
                    elif self.reinforce_baseline_type == 'gt':
                        baseline = self.vse(
                            fc_feats,
                            att_feats,
                            seq,
                            masks,
                            True,
                            only_one_retrieval=self.only_one_retrieval)
                    else:
                        baseline = 0

                sc_loss = _sampleLogProbs * (
                    utils.var_wrapper(retrieval_loss) -
                    utils.var_wrapper(baseline)).detach().unsqueeze(1) * (
                        _masks[:, 1:].detach().float())
                sc_loss = sc_loss.sum() / _masks[:, 1:].data.float().sum()

                loss += self.retrieval_reward_weight * sc_loss

                self._loss['retrieval_sc_loss'] = sc_loss.data[0]

                self._loss['retrieval_loss'] = retrieval_loss.sum().data[0]
                self._loss['retrieval_loss_greedy'] = baseline.sum(
                ).data[0] if isinstance(baseline, Variable) else baseline

        if self.cider_optimization:
            if 'gen_result' not in locals():
                gen_result, sample_logprobs = self.caption_generator.sample(
                    fc_feats, att_feats, att_masks, opt={'sample_max': 0})
                gen_masks = torch.cat([
                    Variable(
                        gen_result.data.new(gen_result.size(0),
                                            2).fill_(1).float()),
                    (gen_result > 0).float()[:, :-1]
                ], 1)
            if 'greedy_res' not in locals():
                greedy_res, _ = self.caption_generator.sample(
                    *utils.var_wrapper([fc_feats, att_feats, att_masks],
                                       volatile=True),
                    opt={'sample_max': 1})
            reward, cider_greedy = rewards.get_self_critical_reward(
                data, gen_result, greedy_res)
            self._loss['avg_reward'] = reward.mean()
            self._loss['cider_greedy'] = cider_greedy
            loss_cap = sample_logprobs * utils.var_wrapper(-reward.astype(
                'float32')).unsqueeze(1) * (gen_masks[:, 1:].detach())
            loss_cap = loss_cap.sum() / gen_masks[:, 1:].data.float().sum()

            loss += self.caption_loss_weight * loss_cap

        self._loss['loss_cap'] = loss_cap.item()
        self._loss['loss_vse'] = loss_vse.item()
        self._loss['loss'] = loss.item()

        return loss
示例#8
0
def train(opt):
    opt.use_att = utils.if_use_att(opt)
    loader = DataLoader(opt)
    opt.vocab_size = loader.vocab_size
    opt.seq_length = loader.seq_length

    tf_summary_writer = tf and SummaryWriter(opt.checkpoint_path)

    infos = {}
    histories = {}
    if opt.start_from is not None:
        # open old infos and check if models are compatible
        with open(os.path.join(opt.start_from,
                               'infos_' + opt.id + '.pkl')) as f:
            infos = cPickle.load(f)
            saved_model_opt = infos['opt']
            need_be_same = [
                "caption_model", "rnn_type", "rnn_size", "num_layers"
            ]
            for checkme in need_be_same:
                assert vars(saved_model_opt)[checkme] == vars(
                    opt
                )[checkme], "Command line argument and saved model disagree on '%s' " % checkme

        if os.path.isfile(
                os.path.join(opt.start_from, 'histories_' + opt.id + '.pkl')):
            with open(
                    os.path.join(opt.start_from,
                                 'histories_' + opt.id + '.pkl')) as f:
                histories = cPickle.load(f)

    iteration = infos.get('iter', 0)
    epoch = infos.get('epoch', 0)

    val_result_history = histories.get('val_result_history', {})
    loss_history = histories.get('loss_history', {})
    lr_history = histories.get('lr_history', {})
    ss_prob_history = histories.get('ss_prob_history', {})

    loader.iterators = infos.get('iterators', loader.iterators)
    loader.split_ix = infos.get('split_ix', loader.split_ix)
    if opt.load_best_score == 1:
        best_val_score = infos.get('best_val_score', None)
        best_val_score_vse = infos.get('best_val_score_vse', None)

    model = models.JointModel(opt)
    model.cuda()

    update_lr_flag = True
    # Assure in training mode
    model.train()

    optimizer = optim.Adam([p for p in model.parameters() if p.requires_grad],
                           lr=opt.learning_rate,
                           weight_decay=opt.weight_decay)

    # Load the optimizer
    if vars(opt).get('start_from', None) is not None and os.path.isfile(
            os.path.join(opt.start_from, 'optimizer.pth')):
        state_dict = torch.load(os.path.join(opt.start_from, 'optimizer.pth'))
        if len(state_dict['state']) == len(optimizer.state_dict()['state']):
            optimizer.load_state_dict(state_dict)
        else:
            print(
                'Optimizer param group number not matched? There must be new parameters. Reinit the optimizer.'
            )

    init_scorer(opt.cached_tokens)
    while True:
        if update_lr_flag:
            # Assign the learning rate
            if epoch > opt.learning_rate_decay_start and opt.learning_rate_decay_start >= 0:
                frac = (epoch - opt.learning_rate_decay_start
                        ) // opt.learning_rate_decay_every
                decay_factor = opt.learning_rate_decay_rate**frac
                opt.current_lr = opt.learning_rate * decay_factor
                utils.set_lr(optimizer, opt.current_lr)  # set the decayed rate
            else:
                opt.current_lr = opt.learning_rate
            # Assign the scheduled sampling prob
            if epoch > opt.scheduled_sampling_start and opt.scheduled_sampling_start >= 0:
                frac = (epoch - opt.scheduled_sampling_start
                        ) // opt.scheduled_sampling_increase_every
                opt.ss_prob = min(opt.scheduled_sampling_increase_prob * frac,
                                  opt.scheduled_sampling_max_prob)
                model.caption_generator.ss_prob = opt.ss_prob
            # Assign retrieval loss weight
            if epoch > opt.retrieval_reward_weight_decay_start and opt.retrieval_reward_weight_decay_start >= 0:
                frac = (epoch - opt.retrieval_reward_weight_decay_start
                        ) // opt.retrieval_reward_weight_decay_every
                model.retrieval_reward_weight = opt.retrieval_reward_weight * (
                    opt.retrieval_reward_weight_decay_rate**frac)
            update_lr_flag = False

        start = time.time()
        # Load data from train split (0)
        data = loader.get_batch('train')
        print('Read data:', time.time() - start)

        torch.cuda.synchronize()
        start = time.time()

        tmp = [
            data['fc_feats'], data['att_feats'], data['att_masks'],
            data['labels'], data['masks']
        ]
        tmp = utils.var_wrapper(tmp)
        fc_feats, att_feats, att_masks, labels, masks = tmp

        optimizer.zero_grad()

        loss = model(fc_feats, att_feats, att_masks, labels, masks, data)
        loss.backward()
        utils.clip_gradient(optimizer, opt.grad_clip)
        optimizer.step()
        train_loss = loss.data[0]
        torch.cuda.synchronize()
        end = time.time()
        print("iter {} (epoch {}), train_loss = {:.3f}, time/batch = {:.3f}" \
            .format(iteration, epoch, train_loss, end - start))
        prt_str = ""
        for k, v in model.loss().items():
            prt_str += "{} = {:.3f} ".format(k, v)
        print(prt_str)

        # Update the iteration and epoch
        iteration += 1
        if data['bounds']['wrapped']:
            epoch += 1
            update_lr_flag = True

        # Write the training loss summary
        if (iteration % opt.losses_log_every == 0):
            if tf is not None:
                tf_summary_writer.add_scalar('train_loss', train_loss,
                                             iteration)
                for k, v in model.loss().items():
                    tf_summary_writer.add_scalar(k, v, iteration)
                tf_summary_writer.add_scalar('learning_rate', opt.current_lr,
                                             iteration)
                tf_summary_writer.add_scalar('scheduled_sampling_prob',
                                             model.caption_generator.ss_prob,
                                             iteration)
                tf_summary_writer.add_scalar('retrieval_reward_weight',
                                             model.retrieval_reward_weight,
                                             iteration)
                tf_summary_writer.file_writer.flush()

            loss_history[iteration] = train_loss
            lr_history[iteration] = opt.current_lr
            ss_prob_history[iteration] = model.caption_generator.ss_prob

        # make evaluation on validation set, and save model
        if (iteration % opt.save_checkpoint_every == 0):
            # eval model
            eval_kwargs = {'split': 'val', 'dataset': opt.input_json}
            eval_kwargs.update(vars(opt))
            # Load the retrieval model for evaluation
            val_loss, predictions, lang_stats = eval_utils.eval_split(
                model, loader, eval_kwargs)

            # Write validation result into summary
            if tf is not None:
                for k, v in val_loss.items():
                    tf_summary_writer.add_scalar('validation ' + k, v,
                                                 iteration)
                for k, v in lang_stats.items():
                    tf_summary_writer.add_scalar(k, v, iteration)
                tf_summary_writer.add_text(
                    'Captions',
                    '.\n\n'.join([_['caption'] for _ in predictions[:100]]),
                    iteration)
                #tf_summary_writer.add_image('images', utils.make_summary_image(), iteration)
                #utils.make_html(opt.id, iteration)
                tf_summary_writer.file_writer.flush()

            val_result_history[iteration] = {
                'loss': val_loss,
                'lang_stats': lang_stats,
                'predictions': predictions
            }

            # Save model if is improving on validation result
            if opt.language_eval == 1:
                current_score = lang_stats['SPICE'] * 100
            else:
                current_score = -val_loss['loss_cap']
            current_score_vse = val_loss.get(opt.vse_eval_criterion, 0) * 100

            best_flag = False
            best_flag_vse = False
            if True:  # if true
                if best_val_score is None or current_score > best_val_score:
                    best_val_score = current_score
                    best_flag = True
                if best_val_score_vse is None or current_score_vse > best_val_score_vse:
                    best_val_score_vse = current_score_vse
                    best_flag_vse = True
                checkpoint_path = os.path.join(opt.checkpoint_path,
                                               'model.pth')
                torch.save(model.state_dict(), checkpoint_path)
                print("model saved to {}".format(checkpoint_path))
                checkpoint_path = os.path.join(opt.checkpoint_path,
                                               'model-%d.pth' % (iteration))
                torch.save(model.state_dict(), checkpoint_path)
                print("model saved to {}".format(checkpoint_path))
                optimizer_path = os.path.join(opt.checkpoint_path,
                                              'optimizer.pth')
                torch.save(optimizer.state_dict(), optimizer_path)

                # Dump miscalleous informations
                infos['iter'] = iteration
                infos['epoch'] = epoch
                infos['iterators'] = loader.iterators
                infos['split_ix'] = loader.split_ix
                infos['best_val_score'] = best_val_score
                infos['best_val_score_vse'] = best_val_score_vse
                infos['opt'] = opt
                infos['vocab'] = loader.get_vocab()

                histories['val_result_history'] = val_result_history
                histories['loss_history'] = loss_history
                histories['lr_history'] = lr_history
                histories['ss_prob_history'] = ss_prob_history
                with open(
                        os.path.join(opt.checkpoint_path,
                                     'infos_' + opt.id + '.pkl'), 'wb') as f:
                    cPickle.dump(infos, f)
                with open(
                        os.path.join(
                            opt.checkpoint_path,
                            'infos_' + opt.id + '-%d.pkl' % (iteration)),
                        'wb') as f:
                    cPickle.dump(infos, f)
                with open(
                        os.path.join(opt.checkpoint_path,
                                     'histories_' + opt.id + '.pkl'),
                        'wb') as f:
                    cPickle.dump(histories, f)

                if best_flag:
                    checkpoint_path = os.path.join(opt.checkpoint_path,
                                                   'model-best.pth')
                    torch.save(model.state_dict(), checkpoint_path)
                    print("model saved to {}".format(checkpoint_path))
                    with open(
                            os.path.join(opt.checkpoint_path,
                                         'infos_' + opt.id + '-best.pkl'),
                            'wb') as f:
                        cPickle.dump(infos, f)
                if best_flag_vse:
                    checkpoint_path = os.path.join(opt.checkpoint_path,
                                                   'model_vse-best.pth')
                    torch.save(model.state_dict(), checkpoint_path)
                    print("model saved to {}".format(checkpoint_path))
                    with open(
                            os.path.join(opt.checkpoint_path,
                                         'infos_vse_' + opt.id + '-best.pkl'),
                            'wb') as f:
                        cPickle.dump(infos, f)

        # Stop if reaching max epochs
        if epoch >= opt.max_epochs and opt.max_epochs != -1:
            break
示例#9
0
def eval_split(model, loader, eval_kwargs={}):
    verbose = eval_kwargs.get('verbose', True)
    num_images = eval_kwargs.get('num_images',
                                 eval_kwargs.get('val_images_use', -1))
    split = eval_kwargs.get('split', 'val')
    lang_eval = eval_kwargs.get('language_eval', 0)
    rank_eval = eval_kwargs.get('rank_eval', 0)
    dataset = eval_kwargs.get('dataset', 'coco')
    beam_size = eval_kwargs.get('beam_size', 1)

    # Make sure in the evaluation mode
    model.eval()

    np.random.seed(123)
    loader.reset_iterator(split)

    n = 0
    losses = {}
    loss_evals = 1e-8
    predictions = [
    ]  # Save the discriminative results. Used for further html visualization.
    while True:
        data = loader.get_batch(split)
        n = n + loader.batch_size

        if data.get('labels', None) is not None:
            # forward the model to get loss
            tmp = [
                data['fc_feats'], data['att_feats'], data['labels'],
                data['masks'], data['att_masks']
            ]
            tmp = [
                Variable(torch.from_numpy(_), volatile=True).cuda()
                for _ in tmp
            ]
            fc_feats, att_feats, labels, masks, att_masks = tmp

            loss = model(fc_feats, att_feats, att_masks, labels, masks, data)
            loss = loss.data[0]
            for k, v in model.loss().items():
                if k not in losses:
                    losses[k] = 0
                losses[k] += v

            loss_evals = loss_evals + 1

        # forward the model to also get generated samples for each image
        # Only leave one feature for each image, in case duplicate sample
        tmp = [
            data['fc_feats'][np.arange(loader.batch_size) *
                             loader.seq_per_img],
            data['att_feats'][np.arange(loader.batch_size) *
                              loader.seq_per_img],
            data['att_masks'][np.arange(loader.batch_size) *
                              loader.seq_per_img]
        ]
        tmp = utils.var_wrapper(tmp, volatile=True)
        fc_feats, att_feats, att_masks = tmp
        # forward the model to also get generated samples for each image
        seq, _ = model.sample(fc_feats, att_feats, att_masks, opt=eval_kwargs)

        sents = utils.decode_sequence(loader.get_vocab(), seq.data)

        for k, sent in enumerate(sents):
            entry = {'image_id': data['infos'][k]['id'], 'caption': sent}
            if eval_kwargs.get('dump_path', 0) == 1:
                entry['file_name'] = data['infos'][k]['file_path']
            predictions.append(entry)
            if eval_kwargs.get('dump_images', 0) == 1:
                # dump the raw image to vis/ folder
                cmd = 'cp "' + os.path.join(
                    eval_kwargs['image_root'],
                    data['infos'][k]['file_path']) + '" vis/imgs/img' + str(
                        len(predictions)) + '.jpg'  # bit gross
                print(cmd)
                os.system(cmd)

            if verbose:
                print('image %s: %s' % (entry['image_id'], entry['caption']))

        # if we wrapped around the split or used up val imgs budget then bail
        ix0 = data['bounds']['it_pos_now']
        ix1 = data['bounds']['it_max']
        if num_images != -1:
            ix1 = min(ix1, num_images)
        for i in range(n - ix1):
            predictions.pop()

        if verbose:
            print('evaluating validation preformance... %d/%d (%f)' %
                  (ix0 - 1, ix1, loss))

        if data['bounds']['wrapped']:
            break
        if num_images >= 0 and n >= num_images:
            break

    lang_stats = None
    if lang_eval == 1:
        lang_stats = language_eval(dataset, predictions, eval_kwargs['id'],
                                   split)
    else:
        lang_stats = {}

    ranks = evalrank(model, loader, eval_kwargs) if rank_eval else {}

    # Switch back to training mode
    model.train()
    losses = {k: v / loss_evals for k, v in losses.items()}
    losses.update(ranks)
    return losses, predictions, lang_stats
def encode_data_generated(model, loader, captions, eval_kwargs={}):
    num_images = eval_kwargs.get('num_images',
                                 eval_kwargs.get('val_images_use', -1))
    split = eval_kwargs.get('split', 'val')
    dataset = eval_kwargs.get('dataset', 'coco')

    # Make sure in the evaluation mode
    model.eval()

    loader_seq_per_img = loader.seq_per_img
    loader.seq_per_img = 5
    loader.reset_iterator(split)
    print('num_images', num_images)
    print(captions.size())
    print(len(loader))
    n = 0
    img_embs = []
    cap_embs = []
    while True:
        data = loader.get_batch(split)
        labels = captions[n:(n + loader.batch_size)]
        masks = (labels > 0).float()
        for i in range(labels.size(0)):
            for j in range(labels.size(1) - 1):
                if labels[i, j].item() > 0.5 and labels[i, j + 1].item() < 0.5:
                    masks[i, j + 1] = 1.0
                    break

        n = n + loader.batch_size

        tmp = [
            data['fc_feats'], data['att_feats'], data['labels'], data['masks']
        ]
        tmp = utils.var_wrapper(tmp, volatile=True)
        fc_feats, att_feats, _, __ = tmp
        fc_feats = fc_feats.cuda()
        labels = labels.cuda()
        masks = masks.cuda()
        img_emb = model.vse.img_enc(fc_feats)
        cap_emb = model.vse.txt_enc(labels, masks)

        # if we wrapped around the split or used up val imgs budget then bail
        ix0 = data['bounds']['it_pos_now']
        ix1 = data['bounds']['it_max']
        if num_images != -1:
            ix1 = min(ix1, num_images)

        print(img_emb.size())
        img_embs.append(img_emb.data.cpu().numpy().copy())
        cap_embs.append(cap_emb.data.cpu().numpy().copy())
        if n > ix1:
            img_emb = img_emb[:(ix1 - n) * loader.seq_per_img]
            cap_emb = cap_emb[:(ix1 - n) * loader.seq_per_img]
        # preserve the embeddings by copying from gpu and converting to np

        print(cap_emb.size())
        if data['bounds']['wrapped']:
            break
        if num_images >= 0 and n >= num_images:
            break

        print("%d/%d" % (n, ix1))
    print('start stack')
    img_embs = np.vstack(img_embs)[:num_images * 5]
    cap_embs = np.vstack(cap_embs)[:num_images]
    print(img_embs.shape)
    print(cap_embs.shape)
    print('stack')

    #assert img_embs.shape[0] == ix1 * loader.seq_per_img

    loader.seq_per_img = loader_seq_per_img

    return img_embs, cap_embs