def _forward(self,
                 segs_feat,
                 input_seq,
                 gt_seq,
                 ppls,
                 gt_boxes,
                 mask_boxes,
                 num,
                 ppls_feat,
                 frm_mask,
                 sample_idx,
                 pnt_mask,
                 eval_obj_ground=False):

        seq = gt_seq[:, :self.seq_per_img, :].clone().view(
            -1, gt_seq.size(2))  # choose the first seq_per_img
        seq = torch.cat((Variable(seq.data.new(seq.size(0), 1).fill_(0)), seq),
                        1)
        input_seq = input_seq.view(
            -1, input_seq.size(2),
            input_seq.size(3))  # B*self.seq_per_img, self.seq_length+1, 5
        input_seq_update = input_seq.data.clone()

        batch_size = segs_feat.size(0)  # B
        seq_batch_size = seq.size(0)  # B*self.seq_per_img
        rois_num = ppls.size(1)  # max_num_proposal of the batch

        state = self.init_hidden(
            seq_batch_size
        )  # self.num_layers, B*self.seq_per_img, self.rnn_size
        rnn_output = []
        roi_labels = []  # store which proposal match the gt box
        att2_weights = []
        h_att_output = []
        max_grd_output = []
        frm_mask_output = []

        conv_feats = segs_feat
        sample_idx_mask = conv_feats.new(batch_size, conv_feats.size(1),
                                         1).fill_(1).byte()
        for i in range(batch_size):
            sample_idx_mask[i, sample_idx[i, 0]:sample_idx[i, 1]] = 0
        fc_feats = torch.mean(segs_feat, dim=1)
        fc_feats = torch.cat((F.layer_norm(fc_feats, [self.fc_feat_size-self.seg_info_size]), \
                              F.layer_norm(self.seg_info_embed(num[:, 3:7].float()), [self.seg_info_size])), dim=-1)

        # pooling the conv_feats
        pool_feats = ppls_feat
        pool_feats = self.ctx2pool_grd(pool_feats)
        g_pool_feats = pool_feats

        # calculate the overlaps between the rois/rois and rois/gt_bbox.
        # apply both frame mask and proposal mask
        overlaps = utils.bbox_overlaps(ppls.data, gt_boxes.data, \
                                      (frm_mask | pnt_mask[:, 1:].unsqueeze(-1)).data)

        # visual words embedding
        vis_word = Variable(
            torch.Tensor(range(0,
                               self.detect_size + 1)).type(input_seq.type()))
        vis_word_embed = self.vis_embed(vis_word)
        assert (vis_word_embed.size(0) == self.detect_size + 1)

        p_vis_word_embed = vis_word_embed.view(1, self.detect_size+1, self.vis_encoding_size) \
            .expand(batch_size, self.detect_size+1, self.vis_encoding_size).contiguous()

        if hasattr(self, 'vis_classifiers_bias'):
            bias = self.vis_classifiers_bias.type(p_vis_word_embed.type()) \
                                                  .view(1,-1,1).expand(p_vis_word_embed.size(0), \
                                                  p_vis_word_embed.size(1), g_pool_feats.size(1))
        else:
            bias = None

        # region-class similarity matrix
        sim_mat_static = self._grounder(p_vis_word_embed, g_pool_feats,
                                        pnt_mask[:, 1:], bias)
        sim_mat_static_update = sim_mat_static.view(batch_size, 1, self.detect_size+1, rois_num) \
            .expand(batch_size, self.seq_per_img, self.detect_size+1, rois_num).contiguous() \
            .view(seq_batch_size, self.detect_size+1, rois_num)
        sim_mat_static = F.softmax(sim_mat_static, dim=1)

        if self.test_mode:
            cls_pred = 0
        else:
            sim_target = utils.sim_mat_target(
                overlaps, gt_boxes[:, :, 5].data)  # B, num_box, num_rois
            sim_mask = (sim_target > 0)
            if not eval_obj_ground:
                masked_sim = torch.gather(sim_mat_static, 1, sim_target)
                masked_sim = torch.masked_select(masked_sim, sim_mask)
                cls_loss = F.binary_cross_entropy(
                    masked_sim,
                    masked_sim.new(masked_sim.size()).fill_(1))
            else:
                # region classification accuracy
                sim_target_masked = torch.masked_select(sim_target, sim_mask)
                sim_mat_masked = torch.masked_select(
                    torch.max(sim_mat_static,
                              dim=1)[1].unsqueeze(1).expand_as(sim_target),
                    sim_mask)
                cls_pred = torch.stack((sim_target_masked, sim_mat_masked),
                                       dim=1).data

        if not self.enable_BUTD:
            loc_input = ppls.data.new(batch_size, rois_num, 5)
            loc_input[:, :, :4] = ppls.data[:, :, :4] / 720.
            loc_input[:, :, 4] = ppls.data[:, :, 4] * 1. / self.num_sampled_frm
            loc_feats = self.loc_fc(
                Variable(loc_input))  # encode the locations
            label_feat = sim_mat_static.permute(0, 2, 1).contiguous()
            pool_feats = torch.cat((F.layer_norm(pool_feats, [pool_feats.size(-1)]), \
                F.layer_norm(loc_feats, [loc_feats.size(-1)]), F.layer_norm(label_feat, [label_feat.size(-1)])), 2)

        # replicate the feature to map the seq size.
        fc_feats = fc_feats.view(batch_size, 1, self.fc_feat_size)\
                .expand(batch_size, self.seq_per_img, self.fc_feat_size)\
                .contiguous().view(-1, self.fc_feat_size)
        pool_feats = pool_feats.view(batch_size, 1, rois_num, self.pool_feat_size)\
                .expand(batch_size, self.seq_per_img, rois_num, self.pool_feat_size)\
                .contiguous().view(-1, rois_num, self.pool_feat_size)
        g_pool_feats = g_pool_feats.view(batch_size, 1, rois_num, self.vis_encoding_size) \
                .expand(batch_size, self.seq_per_img, rois_num, self.vis_encoding_size) \
                .contiguous().view(-1, rois_num, self.vis_encoding_size)
        pnt_mask = pnt_mask.view(batch_size, 1, rois_num+1).expand(batch_size, self.seq_per_img, rois_num+1)\
                .contiguous().view(-1, rois_num+1)
        overlaps = overlaps.view(batch_size, 1, rois_num, overlaps.size(2)) \
                .expand(batch_size, self.seq_per_img, rois_num, overlaps.size(2)) \
                .contiguous().view(-1, rois_num, overlaps.size(2))

        # embed fc and att feats
        fc_feats = self.fc_embed(fc_feats)
        pool_feats = self.pool_embed(pool_feats)

        # object region interactions
        if hasattr(self, 'obj_interact'):
            pool_feats = self.obj_interact(pool_feats)

        # Project the attention feats first to reduce memory and computation comsumptions.
        p_pool_feats = self.ctx2pool(pool_feats)  # same here

        if self.att_input_mode in ('both', 'featmap'):
            conv_feats_splits = torch.split(conv_feats, 2048, 2)
            conv_feats = torch.cat(
                [m(c) for (m, c) in zip(self.att_embed, conv_feats_splits)],
                dim=2)
            conv_feats = conv_feats.permute(0, 2, 1).contiguous(
            )  # inconsistency between Torch TempConv and PyTorch Conv1d
            conv_feats = self.att_embed_aux(conv_feats)
            conv_feats = conv_feats.permute(0, 2, 1).contiguous(
            )  # inconsistency between Torch TempConv and PyTorch Conv1d
            conv_feats = self.context_enc(conv_feats)[0]

            conv_feats = conv_feats.masked_fill(sample_idx_mask, 0)
            conv_feats = conv_feats.view(batch_size, 1, self.t_attn_size, self.rnn_size)\
                .expand(batch_size, self.seq_per_img, self.t_attn_size, self.rnn_size)\
                .contiguous().view(-1, self.t_attn_size, self.rnn_size)
            p_conv_feats = self.ctx2att(
                conv_feats)  # self.rnn_size (1024) -> self.att_hid_size (512)
        else:
            # dummy
            conv_feats = pool_feats.new(1, 1).fill_(0)
            p_conv_feats = pool_feats.new(1, 1).fill_(0)

        if self.att_model == 'transformer':  # Masked Transformer does not support box supervision yet
            if self.att_input_mode == 'both':
                lm_loss = self.cap_model([conv_feats, pool_feats], seq)
            elif self.att_input_mode == 'featmap':
                lm_loss = self.cap_model([conv_feats, conv_feats], seq)
            elif self.att_input_mode == 'region':
                lm_loss = self.cap_model([pool_feats, pool_feats], seq)
            return lm_loss.unsqueeze(0), lm_loss.new(1).fill_(0), lm_loss.new(1).fill_(0), \
                lm_loss.new(1).fill_(0), lm_loss.new(1).fill_(0), lm_loss.new(1).fill_(0)
        elif self.att_model == 'topdown':
            for i in range(self.seq_length):
                it = seq[:, i].clone()

                # break if all the sequences end
                if i >= 1 and seq[:, i].data.sum() == 0:
                    break

                xt = self.embed(it)

                if not eval_obj_ground:
                    roi_label = utils.bbox_target(mask_boxes[:,:,:,i+1], overlaps, input_seq[:,i+1], \
                        input_seq_update[:,i+1], self.vocab_size) # roi_label if for the target seq
                    roi_labels.append(roi_label.view(seq_batch_size, -1))

                    # use frame mask during training
                    box_mask = mask_boxes[:, 0, :, i +
                                          1].contiguous().unsqueeze(1).expand(
                                              (batch_size, rois_num,
                                               mask_boxes.size(2)))
                    frm_mask_on_prop = (torch.sum(
                        (~(box_mask | frm_mask)), dim=2) <= 0)
                    frm_mask_on_prop = torch.cat((frm_mask_on_prop.new(batch_size, 1).fill_(0.), \
                        frm_mask_on_prop), dim=1) | pnt_mask
                    output, state, att2_weight, att_h, max_grd_val, grd_val = self.core(xt, fc_feats, \
                        conv_feats, p_conv_feats, pool_feats, p_pool_feats, pnt_mask, frm_mask_on_prop, \
                        state, sim_mat_static_update)
                    frm_mask_output.append(frm_mask_on_prop)
                else:
                    output, state, att2_weight, att_h, max_grd_val, grd_val = self.core(xt, fc_feats, \
                        conv_feats, p_conv_feats, pool_feats, p_pool_feats, pnt_mask, pnt_mask, \
                        state, sim_mat_static_update)

                att2_weights.append(att2_weight)
                h_att_output.append(
                    att_h)  # the hidden state of attention LSTM
                rnn_output.append(output)
                max_grd_output.append(max_grd_val)

            seq_cnt = len(rnn_output)
            rnn_output = torch.cat([_.unsqueeze(1) for _ in rnn_output],
                                   1)  # seq_batch_size, seq_cnt, vocab
            h_att_output = torch.cat([_.unsqueeze(1) for _ in h_att_output], 1)
            att2_weights = torch.cat([_.unsqueeze(1) for _ in att2_weights],
                                     1)  # seq_batch_size, seq_cnt, att_size
            max_grd_output = torch.cat(
                [_.unsqueeze(1) for _ in max_grd_output], 1)
            if not eval_obj_ground:
                frm_mask_output = torch.cat(
                    [_.unsqueeze(1) for _ in frm_mask_output], 1)
                roi_labels = torch.cat([_.unsqueeze(1) for _ in roi_labels], 1)

            decoded = F.log_softmax(self.beta * self.logit(rnn_output),
                                    dim=2)  # text word prob
            decoded = decoded.view((seq_cnt) * seq_batch_size, -1)

            # object grounding
            h_att_all = h_att_output  # hidden states from the Attention LSTM
            xt_clamp = torch.clamp(input_seq[:, 1:seq_cnt + 1, 0].clone() -
                                   self.vocab_size,
                                   min=0)
            xt_all = self.vis_embed(xt_clamp)

            if hasattr(self, 'vis_classifiers_bias'):
                bias = self.vis_classifiers_bias[xt_clamp].type(xt_all.type()) \
                                                .unsqueeze(2).expand(seq_batch_size, seq_cnt, rois_num)
            else:
                bias = 0

            if not eval_obj_ground:
                # att2_weights/ground_weights with both proposal mask and frame mask
                ground_weights = self._grounder(xt_all, g_pool_feats,
                                                frm_mask_output[:, :, 1:],
                                                bias + att2_weights)
                lm_loss, att2_loss, ground_loss = self.critLM(decoded, att2_weights, ground_weights, \
                    seq[:, 1:seq_cnt+1].clone(), roi_labels[:, :seq_cnt, :].clone(), input_seq[:, 1:seq_cnt+1, 0].clone())
                return lm_loss.unsqueeze(0), att2_loss.unsqueeze(
                    0), ground_loss.unsqueeze(0), cls_loss.unsqueeze(0)
            else:
                # att2_weights/ground_weights with proposal mask only
                ground_weights = self._grounder(xt_all, g_pool_feats,
                                                pnt_mask[:, 1:],
                                                bias + att2_weights)
                return cls_pred, torch.max(att2_weights.view(seq_batch_size, seq_cnt, self.num_sampled_frm, \
                    self.num_prop_per_frm), dim=-1)[1], torch.max(ground_weights.view(seq_batch_size, \
                    seq_cnt, self.num_sampled_frm, self.num_prop_per_frm), dim=-1)[1]
    def _sample(self, segs_feat, seq, proposals, gt_caption, num, mask_boxes,
                gt_boxes, region_feats, frm_mask, sample_idx, pnt_mask):
        """

        :param segs_feat:
        :param proposals: (batch_size, max num of proposals allowed (200), 6), 6-dim: (4 points for the box, proposal index to be used for self.itod to get detection, proposal score)
        :param seq: (batch_size, # of captions for this img, max caption length, 4), 4-dim: (category class, binary class, fine-grain class, from self.wtoi)
        :param num: (batch_size, 3), 3-dim: (# of captions for this img, # of proposals, # of GT bboxs)
        :param mask_boxes: (batch_size, # of captions for this img, max num of GT bbox allowed (100), max caption length)
        :return:
        """
        batch_size = segs_feat.size(0)
        decoder_state = self.init_hidden(batch_size, self.decoder_num_layers)

        # calculate the overlaps between the rois/rois and rois/gt_bbox.
        # apply both frame mask and proposal mask
        overlaps = utils.bbox_overlaps(proposals.data, gt_boxes.data,
                                       (frm_mask
                                        | pnt_mask[:, 1:].unsqueeze(-1)).data)

        fc_feats, conv_feats, p_conv_feats, pool_feats, p_pool_feats, g_pool_feats, \
            pnt_mask, overlaps_expanded, cls_pred, cls_loss = self.roi_feat_extractor(
                segs_feat, proposals, num, mask_boxes, region_feats, gt_boxes, overlaps, sample_idx)

        seq_output = []
        seq_output_logprobs = []
        att2_weights = []

        for word_idx in range(self.seq_length + 1):
            if word_idx == 0:  # input <bos>
                # word = fc_feats.data.new(batch_size).long().zero_()
                word = fc_feats.new_zeros(batch_size).long()
            else:
                sampleLogprobs_tmp, word_tmp = torch.topk(logprobs.data,
                                                          2,
                                                          dim=1)
                unk_mask = (word_tmp[:, 0] != self.unk_idx)  # mask on non-unk
                sampleLogprobs = unk_mask.float() * sampleLogprobs_tmp[:, 0] + (
                    1 - unk_mask.float()) * sampleLogprobs_tmp[:, 1]
                word = unk_mask.long() * \
                    word_tmp[:, 0] + (1 - unk_mask.long()) * word_tmp[:, 1]
                word = word.view(-1).long()

            embedded_word = self.embed(word)

            if word_idx >= 1:
                # seq_output[t] the input of t+2 time step
                seq_output.append(word)
                seq_output_logprobs.append(sampleLogprobs.view(-1))

            if word_idx < self.seq_length:
                output, decoder_state, roi_attn, _, weighted_pool_feat = self.decoder_core(
                    embedded_word,
                    fc_feats,
                    conv_feats,
                    p_conv_feats,
                    pool_feats,
                    p_pool_feats,
                    pnt_mask[:, 1:],
                    decoder_state,
                    with_sentinel=False)

                logprobs = F.log_softmax(self.logit(output), dim=1)
                att2_weights.append(roi_attn)

        att2_weights = torch.cat([_.unsqueeze(1) for _ in att2_weights], 1)

        return torch.cat([_.unsqueeze(1) for _ in seq_output], 1), \
            att2_weights, None
Пример #3
0
    def _forward(self, img, seq, ppls, gt_boxes, mask_boxes, num):

        seq = seq.view(-1, seq.size(2), seq.size(3))
        seq_update = seq.data.clone()

        batch_size = img.size(0)
        seq_batch_size = seq.size(0)
        rois_num = ppls.size(1)
        # constructing the mask.

        pnt_mask = mask_boxes.data.new(batch_size,
                                       rois_num + 1).byte().fill_(1)
        for i in range(batch_size):
            pnt_mask[i, :num.data[i, 1] + 1] = 0
        pnt_mask = Variable(pnt_mask)

        state = self.init_hidden(seq_batch_size)
        rnn_output = []
        roi_labels = []
        det_output = []

        if self.finetune_cnn:
            conv_feats, fc_feats = self.cnn(img)
        else:
            # with torch.no_grad():
            conv_feats, fc_feats = self.cnn(Variable(img.data, volatile=True))
            conv_feats = Variable(conv_feats.data)
            fc_feats = Variable(fc_feats.data)
        # pooling the conv_feats
        rois = ppls.data.new(batch_size, rois_num, 5)
        rois[:, :, 1:] = ppls.data[:, :, :4]

        for i in range(batch_size):
            rois[i, :, 0] = i
        pool_feats = self.roi_align(conv_feats, Variable(rois.view(-1, 5)))
        pool_feats = pool_feats.view(batch_size, rois_num, self.att_feat_size)

        loc_input = ppls.data.new(batch_size, rois_num, 5)
        loc_input[:, :, :4] = ppls.data[:, :, :4] / self.image_crop_size
        loc_input[:, :, 4] = ppls.data[:, :, 5]
        loc_feats = self.loc_fc(Variable(loc_input))

        label_input = seq.data.new(batch_size, rois_num)
        label_input[:, :] = ppls.data[:, :, 4]
        label_feat = self.det_fc(Variable(label_input))

        # pool_feats = pool_feats + label_feat
        pool_feats = torch.cat((pool_feats, loc_feats, label_feat), 2)
        # transpose the conv_feats
        conv_feats = conv_feats.view(batch_size, self.att_feat_size,
                                     -1).transpose(1, 2).contiguous()

        # replicate the feature to map the seq size.
        fc_feats = fc_feats.view(batch_size, 1, self.fc_feat_size)\
                .expand(batch_size, self.seq_per_img, self.fc_feat_size)\
                .contiguous().view(-1, self.fc_feat_size)
        conv_feats = conv_feats.view(batch_size, 1, self.att_size*self.att_size, self.att_feat_size)\
                .expand(batch_size, self.seq_per_img, self.att_size*self.att_size, self.att_feat_size)\
                .contiguous().view(-1, self.att_size*self.att_size, self.att_feat_size)
        pool_feats = pool_feats.view(batch_size, 1, rois_num, self.pool_feat_size)\
                .expand(batch_size, self.seq_per_img, rois_num, self.pool_feat_size)\
                .contiguous().view(-1, rois_num, self.pool_feat_size)
        pnt_mask = pnt_mask.view(batch_size, 1, rois_num+1).expand(batch_size, self.seq_per_img, rois_num+1)\
                    .contiguous().view(-1, rois_num+1)

        # embed fc and att feats
        fc_feats = self.fc_embed(fc_feats)
        conv_feats = self.att_embed(conv_feats)
        pool_feats = self.pool_embed(pool_feats)

        # Project the attention feats first to reduce memory and computation comsumptions.
        p_conv_feats = self.ctx2att(conv_feats)
        p_pool_feats = self.ctx2pool(pool_feats)

        # calculate the overlaps between the rois and gt_bbox.
        overlaps = utils.bbox_overlaps(ppls.data, gt_boxes.data)
        for i in range(self.seq_length):
            if self.training and i >= 1 and self.ss_prob > 0.0:  # otherwiste no need to sample
                sample_prob = fc_feats.data.new(batch_size).uniform_(0, 1)
                sample_mask = sample_prob < self.ss_prob
                if sample_mask.sum() == 0:
                    it = seq[:, i, 0].clone()
                else:
                    sample_ind = sample_mask.nonzero().view(-1)
                    it = seq[:, i, 0].data.clone()
                    #prob_prev = torch.exp(outputs[-1].data.index_select(0, sample_ind)) # fetch prev distribution: shape Nx(M+1)
                    #it.index_copy_(0, sample_ind, torch.multinomial(prob_prev, 1).view(-1))
                    prob_prev = torch.exp(
                        outputs[-1].data
                    )  # fetch prev distribution: shape Nx(M+1)
                    it.index_copy_(
                        0, sample_ind,
                        torch.multinomial(prob_prev, 1).view(-1).index_select(
                            0, sample_ind))
                    it = Variable(it)
            else:
                it = seq[:, i, 0].clone()
            # break if all the sequences end
            if i >= 1 and seq[:, i, 0].data.sum() == 0:
                break

            roi_label = utils.bbox_target(
                ppls, mask_boxes[:, :, :, i + 1], overlaps, seq[:, i + 1],
                seq_update[:, i + 1],
                self.vocab_size)  # roi_label if for the target seq
            # pdb.set_trace()

            roi_labels.append(roi_label.view(seq_batch_size, -1))

            xt = self.embed(it)
            output, det_prob, state = self.core(xt, fc_feats, conv_feats,
                                                p_conv_feats, pool_feats,
                                                p_pool_feats, pnt_mask,
                                                pnt_mask, state)

            rnn_output.append(output)
            det_output.append(det_prob)

        seq_cnt = len(det_output)
        rnn_output = torch.cat([_.unsqueeze(1) for _ in rnn_output], 1)
        det_output = torch.cat([_.unsqueeze(1) for _ in det_output], 1)
        roi_labels = torch.cat([_.unsqueeze(1) for _ in roi_labels], 1)

        det_output = F.log_softmax(det_output, dim=2)
        decoded = F.log_softmax(self.beta * self.logit(rnn_output), dim=2)
        lambda_v = det_output[:, :, 0].contiguous()
        prob_det = det_output[:, :, 1:].contiguous()

        decoded = decoded + lambda_v.view(seq_batch_size, seq_cnt,
                                          1).expand_as(decoded)
        decoded = decoded.view((seq_cnt) * seq_batch_size, -1)

        roi_labels = Variable(roi_labels)
        prob_det = prob_det * roi_labels

        roi_cnt = roi_labels.sum(2)
        roi_cnt[roi_cnt == 0] = 1
        vis_prob = prob_det.sum(2) / roi_cnt
        vis_prob = vis_prob.view(-1, 1)

        seq_update = Variable(seq_update)
        lm_loss = self.critLM(decoded, vis_prob, seq_update[:, 1:seq_cnt + 1,
                                                            0].clone())
        # do the cascade object recognition.
        vis_idx = seq_update[:, 1:seq_cnt + 1, 0].data - self.vocab_size
        vis_idx[vis_idx < 0] = 0
        vis_idx = Variable(vis_idx.view(-1))

        # roi_labels = Variable(roi_labels)
        bn_logprob, fg_logprob = self.ccr_core(vis_idx, pool_feats, rnn_output,
                                               roi_labels, seq_batch_size,
                                               seq_cnt)

        bn_loss = self.critBN(bn_logprob, seq_update[:, 1:seq_cnt + 1,
                                                     1].clone())
        fg_loss = self.critFG(fg_logprob, seq_update[:, 1:seq_cnt + 1,
                                                     2].clone())

        return lm_loss, bn_loss, fg_loss
    def _forward_3_loops(self, segs_feat, input_seq, proposals, gt_caption,
                         num, mask_boxes, gt_boxes, region_feats, frm_mask,
                         sample_idx, pnt_mask):
        """

        :param segs_feat: 
        :param proposals: (batch_size, max num of proposals allowed (200), 6), 6-dim: (4 points for the box, proposal index to be used for self.itod to get detection, proposal score)
        :param seq: (batch_size, # of captions for this img, max caption length, 4), 4-dim: (category class, binary class, fine-grain class, from self.wtoi)
        :param num: (batch_size, 3), 3-dim: (# of captions for this img, # of proposals, # of GT bboxs)
        :param mask_boxes: (batch_size, # of captions for this img, max num of GT bbox allowed (100), max caption length)
        :param gt_boxes: (batch_size, max num of GT bbox allowed (100), 5), 5-dim: (4 points for the box, bbox index)
        :return:
        """
        batch_size = segs_feat.size(0)
        num_rois = proposals.size(1)
        gt_caption = gt_caption[:, :self.seq_per_img, :].clone().view(
            -1, gt_caption.size(2))  # choose the first seq_per_img
        gt_caption = torch.cat(
            (gt_caption.new(gt_caption.size(0), 1).fill_(0), gt_caption), 1)

        # B*self.seq_per_img, self.seq_length+1, 5
        input_seq = input_seq.view(-1, input_seq.size(2), input_seq.size(3))
        input_seq_update = input_seq.data.clone()

        gt_caption_batch_size = gt_caption.size(0)

        decoder_state = self.init_hidden(gt_caption_batch_size,
                                         self.decoder_num_layers)

        # calculate the overlaps between the rois and gt_bbox.
        # overlaps, overlaps_no_replicate = utils.bbox_overlaps(proposals.data, gt_boxes.data)
        # calculate the overlaps between the rois/rois and rois/gt_bbox.
        # apply both frame mask and proposal mask
        overlaps = utils.bbox_overlaps(proposals.data, gt_boxes.data,
                                       (frm_mask
                                        | pnt_mask[:, 1:].unsqueeze(-1)).data)

        fc_feats, conv_feats, p_conv_feats, pool_feats, \
            p_pool_feats, g_pool_feats, pnt_mask, overlaps_expanded, cls_pred, cls_loss = self.roi_feat_extractor(
                segs_feat, proposals, num, mask_boxes, region_feats, gt_boxes, overlaps, sample_idx)

        # ====================================================================================
        # regular decoder
        # ====================================================================================
        lang_outputs = []
        decoder_roi_attn, decoder_masked_attn, roi_labels = [], [], []
        frm_mask_output = []
        # unfold the GT caption and localize the ROIs word by word
        for word_idx in range(self.seq_length):
            word = gt_caption[:, word_idx].clone()
            embedded_word = self.embed(word)

            roi_label = utils.bbox_target(
                mask_boxes[:, :, :, word_idx + 1], overlaps,
                input_seq[:, word_idx + 1], input_seq_update[:, word_idx + 1],
                self.vocab_size)  # roi_label for the target seq
            roi_labels.append(roi_label.view(gt_caption_batch_size, -1))

            # use frame mask during training
            box_mask = mask_boxes[:, 0, :, word_idx +
                                  1].contiguous().unsqueeze(1).expand(
                                      (batch_size, num_rois,
                                       mask_boxes.size(2)))
            # frm_mask_on_prop = (
            #     torch.sum((1 - (box_mask | frm_mask)), dim=2) <= 0)
            frm_mask_on_prop = (torch.sum(
                (~(box_mask | frm_mask)), dim=2) <= 0)

            frm_mask_on_prop = torch.cat((frm_mask_on_prop.new(
                batch_size, 1).fill_(0.), frm_mask_on_prop),
                                         dim=1) | pnt_mask.bool()
            frm_mask_output.append(frm_mask_on_prop)

            output, decoder_state, roi_attn, frame_masked_attn, weighted_pool_feat = self.decoder_core(
                embedded_word,
                fc_feats,
                conv_feats,
                p_conv_feats,
                pool_feats,
                p_pool_feats,
                pnt_mask[:, 1:],
                decoder_state,
                proposal_frame_mask=frm_mask_on_prop[:, 1:],
                with_sentinel=False)

            output = F.log_softmax(self.logit(output), dim=1)
            decoder_roi_attn.append(roi_attn)
            decoder_masked_attn.append(frame_masked_attn)

            lang_outputs.append(output)

        # (batch_size, seq_length - 1, max # of ROIs)
        att2_weights = torch.stack(decoder_masked_attn, dim=1)

        # decoder output for captioning task
        lang_outputs = torch.cat([_.unsqueeze(1) for _ in lang_outputs], dim=1)
        roi_labels = torch.cat([_.unsqueeze(1) for _ in roi_labels], dim=1)
        frm_mask_output = torch.cat([_.unsqueeze(1) for _ in frm_mask_output],
                                    1)

        # object grounding
        xt_clamp = torch.clamp(input_seq[:, 1:self.seq_length + 1, 0].clone() -
                               self.vocab_size,
                               min=0)
        xt_all = self.roi_feat_extractor.vis_embed(xt_clamp)

        if hasattr(self.roi_feat_extractor, 'vis_classifiers_bias'):
            bias = self.roi_feat_extractor.vis_classifiers_bias[xt_clamp].type(xt_all.type()) \
                .unsqueeze(2).expand(gt_caption_batch_size, self.seq_length, num_rois)
        else:
            bias = 0

        # att2_weights/ground_weights with both proposal mask and frame mask
        ground_weights = self._grounder(xt_all, g_pool_feats,
                                        frm_mask_output[:, :, 1:],
                                        bias + att2_weights)

        # if we are only training for the decoder, directly return language output
        if self.opts.train_decoder_only:
            lang_outputs = lang_outputs.view(-1, lang_outputs.size(2))

            lm_loss, att2_loss, ground_loss = self.critLM(
                lang_outputs, att2_weights, ground_weights,
                gt_caption[:, 1:self.seq_length + 1].clone(),
                roi_labels[:, :self.seq_length, :].clone(),
                input_seq[:, 1:self.seq_length + 1, 0].clone())

            return lm_loss.unsqueeze(0), att2_loss.unsqueeze(
                0), ground_loss.unsqueeze(0), cls_loss.unsqueeze(0)

        # ====================================================================================
        # greedy select words with highest probability and use them for localization
        # ====================================================================================
        # this will turn the output_seq to be the leaf on the computational graph
        _, output_seq = lang_outputs.max(2)
        localizer_state = self.init_hidden(gt_caption_batch_size,
                                           self.localizer_num_layers)

        localized_weighted_feats = []
        localized_conv_feats = []
        localized_roi_probs = []
        for word_idx in range(self.seq_length):
            word = output_seq[:, word_idx].clone()
            embedded_word = self.embed(word)

            localized_weighted_feat, localized_conv_feat, localized_roi_prob, localizer_state = self.localizer_core(
                embedded_word,
                fc_feats,
                conv_feats,
                p_conv_feats,
                pool_feats,
                p_pool_feats,
                pnt_mask[:, 1:],
                localizer_state,
                None,
                proposal_frame_mask=frm_mask_output[:, word_idx, 1:],
                with_sentinel=False)

            # if self.opts.localizer_only_groundable:
            #     visually_groundable_mask = np.in1d(np.array(word.cpu()), self.i_to_visually_groundable).astype('float')
            #     visually_groundable_mask = torch.from_numpy(visually_groundable_mask).float().to(localized_weighted_feat.device)
            #     localized_weighted_feat = visually_groundable_mask.unsqueeze(1) * localized_weighted_feat
            #     localized_conv_feat = visually_groundable_mask.unsqueeze(1) * localized_conv_feat
            #     localized_roi_prob = visually_groundable_mask.unsqueeze(1) * localized_roi_prob

            localized_weighted_feats.append(localized_weighted_feat)
            localized_conv_feats.append(localized_conv_feat)
            localized_roi_probs.append(localized_roi_prob)

        # localized_roi_probs = torch.stack(localized_roi_probs, dim=1)  # (batch_size, seq_length - 1, max # of ROIs)

        # ====================================================================================
        # decoder use the localized ROIs to generate caption again
        # ====================================================================================
        consistent_outputs = []
        consistent_decoder_state = self.init_hidden(gt_caption_batch_size,
                                                    self.decoder_num_layers)
        for word_idx in range(self.seq_length):
            word = gt_caption[:, word_idx].clone()
            embedded_word = self.embed(word)

            localized_weighted_feat = localized_weighted_feats[word_idx]
            localized_conv_feat = localized_conv_feats[word_idx]

            output, consistent_decoder_state = self.attended_roi_decoder_core(
                embedded_word,
                fc_feats,
                localized_weighted_feat,
                localized_conv_feat,
                consistent_decoder_state,
                with_sentinel=False)

            output = F.log_softmax(self.logit(output), dim=1)
            consistent_outputs.append(output)

        consistent_outputs = torch.cat(
            [_.unsqueeze(1) for _ in consistent_outputs], dim=1)

        lang_outputs = lang_outputs.view(-1, lang_outputs.size(2))
        lm_loss, att2_loss, ground_loss = self.critLM(
            lang_outputs, att2_weights, ground_weights,
            gt_caption[:, 1:self.seq_length + 1].clone(),
            roi_labels[:, :self.seq_length, :].clone(),
            input_seq[:, 1:self.seq_length + 1, 0].clone())

        # compute another loss for the reconstructor
        consistent_outputs = consistent_outputs.view(
            -1, consistent_outputs.size(2))
        lm_recon_loss = self.xe_criterion(
            consistent_outputs, gt_caption[:, 1:self.seq_length + 1].clone())

        return lm_loss.unsqueeze(0), att2_loss.unsqueeze(0), ground_loss.unsqueeze(0), \
            cls_loss.unsqueeze(0), lm_recon_loss.unsqueeze(0)