Exemplo n.º 1
0
    def forward(self, kwargs):
        if self.rl_stage == False:
            logit = self.model(**kwargs)
            loss, loss_info = self.xe_criterion(logit,
                                                kwargs[cfg.PARAM.TARGET_SENT])
        else:
            ids = kwargs[cfg.PARAM.INDICES]
            gv_feat = kwargs[cfg.PARAM.GLOBAL_FEAT]
            att_feats = kwargs[cfg.PARAM.ATT_FEATS]
            att_mask = kwargs[cfg.PARAM.ATT_FEATS_MASK]

            # max
            kwargs['BEAM_SIZE'] = 1
            kwargs['GREEDY_DECODE'] = True
            kwargs[cfg.PARAM.GLOBAL_FEAT] = gv_feat
            kwargs[cfg.PARAM.ATT_FEATS] = att_feats
            kwargs[cfg.PARAM.ATT_FEATS_MASK] = att_mask

            self.model.eval()
            with torch.no_grad():
                seq_max, logP_max = self.model.module.decode(**kwargs)
            self.model.train()
            rewards_max, rewards_info_max = self.scorer(
                ids,
                seq_max.data.cpu().numpy().tolist())
            rewards_max = utils.expand_numpy(rewards_max)

            ids = utils.expand_numpy(ids)
            gv_feat = utils.expand_tensor(gv_feat,
                                          cfg.COCO_DATA_LOADER.SEQ_PER_IMG)
            att_feats = utils.expand_tensor(att_feats,
                                            cfg.COCO_DATA_LOADER.SEQ_PER_IMG)
            att_mask = utils.expand_tensor(att_mask,
                                           cfg.COCO_DATA_LOADER.SEQ_PER_IMG)

            # sample
            kwargs['BEAM_SIZE'] = 1
            kwargs['GREEDY_DECODE'] = False
            kwargs[cfg.PARAM.GLOBAL_FEAT] = gv_feat
            kwargs[cfg.PARAM.ATT_FEATS] = att_feats
            kwargs[cfg.PARAM.ATT_FEATS_MASK] = att_mask

            seq_sample, logP_sample = self.model.module.decode(**kwargs)
            rewards_sample, rewards_info_sample = self.scorer(
                ids,
                seq_sample.data.cpu().numpy().tolist())

            rewards = rewards_sample - rewards_max
            rewards = torch.from_numpy(rewards).float().cuda()
            loss = self.rl_criterion(seq_sample, logP_sample, rewards)

            loss_info = {}
            for key in rewards_info_sample:
                loss_info[key + '_sample'] = rewards_info_sample[key]
            for key in rewards_info_max:
                loss_info[key + '_max'] = rewards_info_max[key]

        return loss, loss_info
Exemplo n.º 2
0
    def init_gx_encoder_out_p_att_feats_att_mask(self, **kwargs):
        with torch.no_grad():
            att_feats = kwargs[cfg.PARAM.ATT_FEATS]
            att_mask = kwargs[cfg.PARAM.ATT_FEATS_MASK]

            repeat_factor = kwargs["REPEAT_FACTOR"]
            att_feats = self.att_embed(att_feats)
            gx, encoder_out = self.encoder(att_feats, att_mask)

            gx = utils.expand_tensor(gx, repeat_factor)
            encoder_out = utils.expand_tensor(encoder_out, repeat_factor)
            att_mask = utils.expand_tensor(att_mask, repeat_factor)

            p_att_feats = self.decoder.precompute(encoder_out)
            return (gx, encoder_out, p_att_feats, att_mask)
Exemplo n.º 3
0
 def _prefix_rewards(self, kwargs, seq_prefix):
     kwargs[cfg.PARAM.MAX_GEN_LEN] = seq_prefix.shape[-1]
     kwargs[cfg.PARAM.GEN_RESULT] = utils.expand_tensor(seq_prefix, 5)
     with torch.no_grad():
         seq_sample = self.predictor.extend_trajectory(
             **kwargs).detach().cpu().numpy().tolist()
     return seq_sample
Exemplo n.º 4
0
    def forward(self, **kwargs):
        # forward entry

        att_feats = kwargs[cfg.PARAM.ATT_FEATS]
        seq = kwargs[cfg.PARAM.INPUT_SENT]
        att_mask = kwargs[cfg.PARAM.ATT_FEATS_MASK]
        # att_mask = torch.ones(16,70).to(device)
        att_mask = utils.expand_tensor(att_mask, cfg.DATA_LOADER.SEQ_PER_IMG)
        att_feats = utils.expand_tensor(att_feats, cfg.DATA_LOADER.SEQ_PER_IMG)

        ##############################################
        seq_mask = (seq > 0).type(torch.cuda.IntTensor)
        seq_mask[:, 0] += 1
        seq_mask = seq_mask.unsqueeze(-2)
        seq_mask = seq_mask & subsequent_mask(seq.size(-1)).to(seq_mask)
        seq_mask = seq_mask.type(torch.cuda.FloatTensor)
        ##############################################
        # print('att_feats.shape b4 pretrain',att_feats.shape)
        # if len(att_feats.shape) == 5:
        #   batch_size = att_feats.shape[0]
        #   img_size = att_feats.shape[1]
        #   channel_size = att_feats.shape[2]
        #   hidden_size1 = att_feats.shape[3]
        #   hidden_size2 = att_feats.shape[4]
        #   att_feats = att_feats.view(batch_size, -1, hidden_size1, hidden_size2)

        # FEED IUXRAY: two images
        # att_feats_0 = self.image_pretrained_models(att_feats[:, 0])
        # att_feats_1 = self.image_pretrained_models(att_feats[:, 1])
        # att_feats = torch.cat((att_feats_0, att_feats_1), dim=1)  # shape (bs, 2048, 7, 7)
        att_feats = self.get_visual_features(att_feats)
        batch_size, feat_size, _, _ = att_feats.shape
        att_feats = att_feats.reshape(batch_size, feat_size,
                                      -1).permute(0, 2, 1)
        # print('att_feats.shape after pretrain', att_feats.shape)
        att_feats = self.att_embed(att_feats)  # forward entry
        # print('att_feats.shape after pretrain', att_feats.shape)

        gx, encoder_out = self.encoder(att_feats, att_mask)
        # print(gx.shape, encoder_out.shape)

        decoder_out = self.decoder(gx, seq, encoder_out, att_mask, seq_mask)
        # print('decoder_out.shape',decoder_out.shape) # 4, 41, 761
        # raise Exception('lol')
        return decoder_out
Exemplo n.º 5
0
    def forward(self, **kwargs):
        seq = kwargs[cfg.PARAM.INPUT_SENT]
        gv_feat, att_feats, att_mask, p_att_feats = self.preprocess(**kwargs)
        gv_feat = utils.expand_tensor(gv_feat, cfg.DATA_LOADER.SEQ_PER_IMG)
        att_feats = utils.expand_tensor(att_feats, cfg.DATA_LOADER.SEQ_PER_IMG)
        att_mask = utils.expand_tensor(att_mask, cfg.DATA_LOADER.SEQ_PER_IMG)
        p_att_feats = utils.expand_tensor(p_att_feats,
                                          cfg.DATA_LOADER.SEQ_PER_IMG)

        batch_size = gv_feat.size(0)
        state = self.init_hidden(batch_size)

        outputs = Variable(
            torch.zeros(batch_size, seq.size(1), self.vocab_size).cuda())
        for t in range(seq.size(1)):
            if self.training and t >= 1 and self.ss_prob > 0:
                prob = torch.empty(batch_size).cuda().uniform_(0, 1)
                mask = prob < self.ss_prob
                if mask.sum() == 0:
                    wt = seq[:, t].clone()
                else:
                    ind = mask.nonzero().view(-1)
                    wt = seq[:, t].data.clone()
                    prob_prev = torch.exp(outputs[:, t - 1].detach())
                    wt.index_copy_(
                        0, ind,
                        torch.multinomial(prob_prev,
                                          1).view(-1).index_select(0, ind))
            else:
                wt = seq[:, t].clone()

            if t >= 1 and seq[:, t].max() == 0:
                break

            kwargs = self.make_kwargs(wt, gv_feat, att_feats, att_mask,
                                      p_att_feats, state)
            output, state = self.Forward(**kwargs)
            if self.dropout_lm is not None:
                output = self.dropout_lm(output)

            logit = self.logit(output)
            outputs[:, t] = logit

        return outputs
Exemplo n.º 6
0
    def forward(self, **kwargs):
        att_feats = kwargs[cfg.PARAM.ATT_FEATS]
        seq = kwargs[cfg.PARAM.INPUT_SENT]
        att_mask = kwargs[cfg.PARAM.ATT_FEATS_MASK]
        att_mask = utils.expand_tensor(att_mask, cfg.DATA_LOADER.SEQ_PER_IMG)
        att_feats = utils.expand_tensor(att_feats, cfg.DATA_LOADER.SEQ_PER_IMG)

        ##############################################
        seq_mask = (seq > 0).type(torch.cuda.IntTensor)
        seq_mask[:, 0] += 1
        seq_mask = seq_mask.unsqueeze(-2)
        seq_mask = seq_mask & subsequent_mask(seq.size(-1)).to(seq_mask)
        seq_mask = seq_mask.type(torch.cuda.FloatTensor)
        ##############################################

        att_feats = self.att_embed(att_feats)
        gx, encoder_out = self.encoder(att_feats, att_mask)
        decoder_out = self.decoder(gx, seq, encoder_out, att_mask, seq_mask)
        return decoder_out
Exemplo n.º 7
0
    def forward(self, **kwargs):
        # forward entry

        att_feats = kwargs[cfg.PARAM.ATT_FEATS]
        seq = kwargs[cfg.PARAM.INPUT_SENT]
        att_mask = kwargs[cfg.PARAM.ATT_FEATS_MASK]
        # att_mask = torch.ones(16,70).to(device)
        att_mask = utils.expand_tensor(att_mask, cfg.DATA_LOADER.SEQ_PER_IMG)
        att_feats = utils.expand_tensor(att_feats, cfg.DATA_LOADER.SEQ_PER_IMG)
        # HARDCODE: att_mask = None
        # Regenerate later
        att_mask = None

        ##############################################
        seq_mask = (seq > 0).type(torch.cuda.IntTensor)
        seq_mask[:, 0] += 1
        seq_mask = seq_mask.unsqueeze(-2)
        seq_mask = seq_mask & subsequent_mask(seq.size(-1)).to(seq_mask)
        seq_mask = seq_mask.type(torch.cuda.FloatTensor)
        ##############################################

        cnn_feats, gcn_feats = self.get_visual_features(att_feats)
        all_feats = torch.cat([cnn_feats, gcn_feats], dim=1)

        att_mask = makeMask(all_feats)
        cnn_mask = makeMask(cnn_feats)
        gcn_mask = makeMask(gcn_feats)

        cnn_feats = self.cnn_embed(cnn_feats)  # forward entry
        gcn_feats = self.gcn_embed(gcn_feats)

        gx, encoder_out = self.encoder(cnn_feats, gcn_feats, cnn_mask,
                                       gcn_mask)

        decoder_out = self.decoder(gx, seq, encoder_out, att_mask, seq_mask)
        # print('decoder_out.shape',decoder_out.shape) # 4, 41, 761
        # raise Exception('lol')
        return decoder_out
Exemplo n.º 8
0
    def decode_beam(self, **kwargs):
        att_feats = kwargs[cfg.PARAM.ATT_FEATS]
        att_mask = kwargs[cfg.PARAM.ATT_FEATS_MASK]
        beam_size = kwargs['BEAM_SIZE']
        batch_size = att_feats.size(0)
        seq_logprob = torch.zeros((batch_size, 1, 1)).cuda()
        log_probs = []
        selected_words = None
        seq_mask = torch.ones((batch_size, beam_size, 1)).cuda()

        att_feats = self.att_embed(att_feats)
        gx, encoder_out = self.encoder(att_feats, att_mask)
        p_att_feats = self.decoder.precompute(encoder_out)

        state = None
        wt = Variable(torch.zeros(batch_size, dtype=torch.long).cuda())
        kwargs[cfg.PARAM.ATT_FEATS] = encoder_out
        kwargs[cfg.PARAM.GLOBAL_FEAT] = gx
        kwargs[cfg.PARAM.P_ATT_FEATS] = p_att_feats

        outputs = []
        self.decoder.init_buffer(batch_size)
        for t in range(cfg.MODEL.SEQ_LEN):
            cur_beam_size = 1 if t == 0 else beam_size

            kwargs[cfg.PARAM.WT] = wt
            kwargs[cfg.PARAM.STATE] = state
            word_logprob, state = self.get_logprobs_state(**kwargs)
            word_logprob = word_logprob.view(batch_size, cur_beam_size, -1)
            candidate_logprob = seq_logprob + word_logprob

            # Mask sequence if it reaches EOS
            if t > 0:
                mask = (selected_words.view(batch_size, cur_beam_size) !=
                        0).float().unsqueeze(-1)
                seq_mask = seq_mask * mask
                word_logprob = word_logprob * seq_mask.expand_as(word_logprob)
                old_seq_logprob = seq_logprob.expand_as(
                    candidate_logprob).contiguous()
                old_seq_logprob[:, :, 1:] = -999
                candidate_logprob = seq_mask * candidate_logprob + old_seq_logprob * (
                    1 - seq_mask)

            selected_idx, selected_logprob = self.select(
                batch_size, beam_size, t, candidate_logprob)
            selected_beam = selected_idx / candidate_logprob.shape[-1]
            selected_words = selected_idx - selected_beam * candidate_logprob.shape[
                -1]

            self.decoder.apply_to_states(
                self._expand_state(batch_size, beam_size, cur_beam_size,
                                   selected_beam))
            seq_logprob = selected_logprob.unsqueeze(-1)
            seq_mask = torch.gather(seq_mask, 1, selected_beam.unsqueeze(-1))
            outputs = list(
                torch.gather(o, 1, selected_beam.unsqueeze(-1))
                for o in outputs)
            outputs.append(selected_words.unsqueeze(-1))

            this_word_logprob = torch.gather(
                word_logprob, 1,
                selected_beam.unsqueeze(-1).expand(batch_size, beam_size,
                                                   word_logprob.shape[-1]))
            this_word_logprob = torch.gather(this_word_logprob, 2,
                                             selected_words.unsqueeze(-1))
            log_probs = list(
                torch.gather(
                    o, 1,
                    selected_beam.unsqueeze(-1).expand(batch_size, beam_size,
                                                       1)) for o in log_probs)
            log_probs.append(this_word_logprob)
            selected_words = selected_words.view(-1, 1)
            wt = selected_words.squeeze(-1)

            if t == 0:
                encoder_out = utils.expand_tensor(encoder_out, beam_size)
                gx = utils.expand_tensor(gx, beam_size)
                att_mask = utils.expand_tensor(att_mask, beam_size)
                state[0] = state[0].squeeze(0)
                state[0] = utils.expand_tensor(state[0], beam_size)
                state[0] = state[0].unsqueeze(0)

                p_att_feats_tmp = []
                for p_feat in p_att_feats:
                    p_key, p_value2 = p_feat
                    p_key = utils.expand_tensor(p_key, beam_size)
                    p_value2 = utils.expand_tensor(p_value2, beam_size)
                    p_att_feats_tmp.append((p_key, p_value2))

                kwargs[cfg.PARAM.ATT_FEATS] = encoder_out
                kwargs[cfg.PARAM.GLOBAL_FEAT] = gx
                kwargs[cfg.PARAM.ATT_FEATS_MASK] = att_mask
                kwargs[cfg.PARAM.P_ATT_FEATS] = p_att_feats_tmp

        seq_logprob, sort_idxs = torch.sort(seq_logprob, 1, descending=True)
        outputs = torch.cat(outputs, -1)
        outputs = torch.gather(
            outputs, 1,
            sort_idxs.expand(batch_size, beam_size, cfg.MODEL.SEQ_LEN))
        log_probs = torch.cat(log_probs, -1)
        log_probs = torch.gather(
            log_probs, 1,
            sort_idxs.expand(batch_size, beam_size, cfg.MODEL.SEQ_LEN))

        outputs = outputs.contiguous()[:, 0]
        log_probs = log_probs.contiguous()[:, 0]

        self.decoder.clear_buffer()
        return outputs, log_probs
Exemplo n.º 9
0
    def decode(self, **kwargs):
        greedy_decode = kwargs['GREEDY_DECODE']
        att_feats = kwargs[cfg.PARAM.ATT_FEATS]
        att_mask = kwargs[cfg.PARAM.ATT_FEATS_MASK]

        att_feats = self.att_embed(att_feats)
        gx, encoder_out = self.encoder(att_feats, att_mask)

        repeat_factor = 1
        if "REPEAT_FACTOR" in kwargs:
            repeat_factor = kwargs["REPEAT_FACTOR"]
        gx = utils.expand_tensor(gx, repeat_factor)
        encoder_out = utils.expand_tensor(encoder_out, repeat_factor)

        p_att_feats = self.decoder.precompute(encoder_out)
        batch_size = att_feats.size(0)
        self.decoder.init_buffer(batch_size)

        state = None
        sents = torch.zeros((batch_size, cfg.MODEL.SEQ_LEN),
                            dtype=torch.long).cuda()
        if cfg.PARAM.GEN_RESULT in kwargs:
            gen_result = kwargs[cfg.PARAM.GEN_RESULT].view(
                batch_size, cfg.MODEL.SEQ_LEN)
            logprobs = torch.zeros(batch_size, cfg.MODEL.SEQ_LEN,
                                   self.vocab_size).cuda()
        else:
            gen_result = None
            if "NEED_PD" in kwargs:
                logprobs = torch.zeros(batch_size, cfg.MODEL.SEQ_LEN,
                                       self.vocab_size).cuda()
            else:
                logprobs = torch.zeros(batch_size, cfg.MODEL.SEQ_LEN).cuda()
        wt = torch.zeros(batch_size, dtype=torch.long).cuda()
        unfinished = wt.eq(wt)
        kwargs[cfg.PARAM.ATT_FEATS] = encoder_out
        kwargs[cfg.PARAM.GLOBAL_FEAT] = gx
        kwargs[cfg.PARAM.P_ATT_FEATS] = p_att_feats

        for t in range(cfg.MODEL.SEQ_LEN):
            kwargs[cfg.PARAM.WT] = wt
            kwargs[cfg.PARAM.STATE] = state
            logprobs_t, state = self.get_logprobs_state(**kwargs)

            logprobs_t = logprobs_t.reshape(batch_size, self.vocab_size)
            if greedy_decode:
                logP_t, wt = torch.max(logprobs_t, 1)
            else:
                if gen_result is None:
                    probs_t = torch.exp(logprobs_t)
                    wt = torch.multinomial(probs_t, 1)
                    if "NEED_PD" not in kwargs:
                        logP_t = logprobs_t.gather(1, wt).view(-1)
                    else:
                        logP_t = logprobs_t
                else:
                    wt = gen_result[:, t]
                    logP_t = logprobs_t

            wt = wt.view(-1).long()
            unfinished = unfinished * (wt > 0)
            wt = wt * unfinished.type_as(wt)
            sents[:, t] = wt
            logprobs[:, t] = logP_t

            if unfinished.sum() == 0:
                break
        self.decoder.clear_buffer()
        return sents, logprobs
Exemplo n.º 10
0
    def decode_beam(self, **kwargs):
        gv_feat, att_feats, att_mask, p_att_feats = self.preprocess(**kwargs)

        beam_size = kwargs['BEAM_SIZE']
        batch_size = att_feats.size(0)
        seq_logprob = torch.zeros((batch_size, 1, 1)).cuda()
        log_probs = []
        selected_words = None
        seq_mask = torch.ones((batch_size, beam_size, 1)).cuda()

        state = self.init_hidden(batch_size)
        wt = Variable(torch.zeros(batch_size, dtype=torch.long).cuda())

        kwargs[cfg.PARAM.ATT_FEATS] = att_feats
        kwargs[cfg.PARAM.GLOBAL_FEAT] = gv_feat
        kwargs[cfg.PARAM.P_ATT_FEATS] = p_att_feats

        outputs = []
        for t in range(cfg.MODEL.SEQ_LEN):
            cur_beam_size = 1 if t == 0 else beam_size

            kwargs[cfg.PARAM.WT] = wt
            kwargs[cfg.PARAM.STATE] = state
            word_logprob, state = self.get_logprobs_state(**kwargs)
            word_logprob = word_logprob.view(batch_size, cur_beam_size, -1)
            candidate_logprob = seq_logprob + word_logprob

            # Mask sequence if it reaches EOS
            if t > 0:
                mask = (selected_words.view(batch_size, cur_beam_size) !=
                        0).float().unsqueeze(-1)
                seq_mask = seq_mask * mask
                word_logprob = word_logprob * seq_mask.expand_as(word_logprob)
                old_seq_logprob = seq_logprob.expand_as(
                    candidate_logprob).contiguous()
                old_seq_logprob[:, :, 1:] = -999
                candidate_logprob = seq_mask * candidate_logprob + old_seq_logprob * (
                    1 - seq_mask)

            selected_idx, selected_logprob = self.select(
                batch_size, beam_size, t, candidate_logprob)
            #             selected_beam = selected_idx / candidate_logprob.shape[-1]
            selected_beam = torch.floor_divide(selected_idx,
                                               candidate_logprob.shape[-1])
            selected_words = selected_idx - selected_beam * candidate_logprob.shape[
                -1]

            for s in range(len(state)):
                state[s] = self._expand_state(batch_size, beam_size,
                                              cur_beam_size, state[s],
                                              selected_beam)

            seq_logprob = selected_logprob.unsqueeze(-1)
            seq_mask = torch.gather(seq_mask, 1, selected_beam.unsqueeze(-1))
            outputs = list(
                torch.gather(o, 1, selected_beam.unsqueeze(-1))
                for o in outputs)
            outputs.append(selected_words.unsqueeze(-1))

            this_word_logprob = torch.gather(
                word_logprob, 1,
                selected_beam.unsqueeze(-1).expand(batch_size, beam_size,
                                                   word_logprob.shape[-1]))
            this_word_logprob = torch.gather(this_word_logprob, 2,
                                             selected_words.unsqueeze(-1))
            log_probs = list(
                torch.gather(
                    o, 1,
                    selected_beam.unsqueeze(-1).expand(batch_size, beam_size,
                                                       1)) for o in log_probs)
            log_probs.append(this_word_logprob)
            selected_words = selected_words.view(-1, 1)
            wt = selected_words.squeeze(-1)

            if t == 0:
                att_feats = utils.expand_tensor(att_feats, beam_size)
                gv_feat = utils.expand_tensor(gv_feat, beam_size)
                att_mask = utils.expand_tensor(att_mask, beam_size)
                p_att_feats = utils.expand_tensor(p_att_feats, beam_size)

                kwargs[cfg.PARAM.ATT_FEATS] = att_feats
                kwargs[cfg.PARAM.GLOBAL_FEAT] = gv_feat
                kwargs[cfg.PARAM.ATT_FEATS_MASK] = att_mask
                kwargs[cfg.PARAM.P_ATT_FEATS] = p_att_feats

        seq_logprob, sort_idxs = torch.sort(seq_logprob, 1, descending=True)
        outputs = torch.cat(outputs, -1)
        outputs = torch.gather(
            outputs, 1,
            sort_idxs.expand(batch_size, beam_size, cfg.MODEL.SEQ_LEN))
        log_probs = torch.cat(log_probs, -1)
        log_probs = torch.gather(
            log_probs, 1,
            sort_idxs.expand(batch_size, beam_size, cfg.MODEL.SEQ_LEN))

        outputs = outputs.contiguous()[:, 0]
        log_probs = log_probs.contiguous()[:, 0]

        return outputs, log_probs
Exemplo n.º 11
0
    def decode_beam(self, **kwargs):
        #print('decode beam!')
        gv_feat, att_feats, att_mask, p_att_feats = self.preprocess(**kwargs)
        att_mask0 = att_mask
        beam_size = kwargs['BEAM_SIZE']
        output_attention = kwargs['output_attention']
        batch_size = att_feats.size(0)
        seq_logprob = torch.zeros((batch_size, 1, 1)).cuda()
        log_probs, attention_scores = [], []
        selected_words = None
        seq_mask = torch.ones((batch_size, beam_size, 1)).cuda()

        state = self.init_hidden(batch_size)
        wt = Variable(torch.zeros(batch_size, dtype=torch.long).cuda())

        kwargs[cfg.PARAM.ATT_FEATS] = att_feats
        kwargs[cfg.PARAM.GLOBAL_FEAT] = gv_feat
        kwargs[cfg.PARAM.P_ATT_FEATS] = p_att_feats

        outputs = []
        for t in range(cfg.MODEL.SEQ_LEN):
            cur_beam_size = 1 if t == 0 else beam_size  #???

            kwargs[cfg.PARAM.WT] = wt
            kwargs[cfg.PARAM.STATE] = state
            logprobs_state_output = self.get_logprobs_state(
                **kwargs)  # word_logprob B, #V
            #print(len(logprobs_state_output))
            #print(logprobs_state_output[0].shape, logprobs_state_output[1].shape)
            word_logprob, state = logprobs_state_output[
                0], logprobs_state_output[1]
            if output_attention:
                attention_score = logprobs_state_output[2][
                    -1]  #[(36,8,67)] the last layer
                attention_score = attention_score.view(
                    batch_size, -1, attention_score.shape[1],
                    attention_score.shape[2])
            word_logprob = word_logprob.view(batch_size, cur_beam_size, -1)
            candidate_logprob = seq_logprob + word_logprob  # B,1,1 + B,1,#V

            # Mask sequence if it reaches EOS
            if t > 0:
                mask = (selected_words.view(batch_size, cur_beam_size) !=
                        0).float().unsqueeze(-1)
                #B, cur_beam_size !=0  unended True ended False    -> B, cur_beam_size, 1
                seq_mask = seq_mask * mask  #B, beam_size, 1  update seq_mask
                word_logprob = word_logprob * seq_mask.expand_as(word_logprob)
                #word_logprob B, beam_size #V
                #seq_mask B, beam_size, 1
                old_seq_logprob = seq_logprob.expand_as(
                    candidate_logprob).contiguous()
                #seq_logprob B,1,1 -> B, beam_size, #V
                old_seq_logprob[:, :, 1:] = -999  #B, beam_size,
                candidate_logprob = seq_mask * candidate_logprob + old_seq_logprob * (
                    1 - seq_mask)

            selected_idx, selected_logprob = self.select(
                batch_size, beam_size, t, candidate_logprob)
            selected_beam = selected_idx / candidate_logprob.shape[-1]  #B, bs
            #print('selected_beam', selected_beam.shape, selected_beam[0])
            selected_words = selected_idx - selected_beam * candidate_logprob.shape[
                -1]
            #print('selected_words', selected_words.shape, selected_words[0]) #B, bs
            #input()
            for s in range(len(state)):
                #print('state shape before expand {}'.format(state[s].shape))
                state[s] = self._expand_state(batch_size, beam_size,
                                              cur_beam_size, state[s],
                                              selected_beam)
                # print('state shape after expand {}'.format(state[s].shape))
                # input()
            #for a in range(len(attention_score)):
            if output_attention:
                selected_beam_ex = selected_beam.unsqueeze(-1).unsqueeze(
                    -1).expand(batch_size, beam_size,
                               attention_score.shape[-2],
                               attention_score.shape[-1])
                this_word_attention_score = torch.gather(
                    attention_score, 1, selected_beam_ex)
                #print('this word attention score shape {}'.format(this_word_attention_score.shape))
                attention_scores = list(
                    torch.gather(a, 1, selected_beam_ex)
                    for a in attention_scores)
                attention_scores.append(this_word_attention_score)  #B, bs, H,

            # def debug(attention_scores, selected_beam):
            #     b = 0
            #     print('selected_beam {}'.format(selected_beam[b]))
            #     print('attention_scores \n {}'.format( \
            #         [a[b, :, 0, :3] \
            #             for a in attention_scores]))
            # debug(attention_scores, selected_beam)
            # input()

            seq_logprob = selected_logprob.unsqueeze(-1)  # B,beam_size,1
            seq_mask = torch.gather(
                seq_mask, 1, selected_beam.unsqueeze(-1))  #B, beam_size, 1
            outputs = list(
                torch.gather(o, 1, selected_beam.unsqueeze(-1))
                for o in outputs)
            outputs.append(selected_words.unsqueeze(-1))  #B, beam_size, 1

            this_word_logprob = torch.gather(
                word_logprob, 1,
                selected_beam.unsqueeze(-1).expand(batch_size, beam_size,
                                                   word_logprob.shape[-1]))
            this_word_logprob = torch.gather(this_word_logprob, 2,
                                             selected_words.unsqueeze(-1))
            log_probs = list(
                torch.gather(
                    o, 1,
                    selected_beam.unsqueeze(-1).expand(batch_size, beam_size,
                                                       1)) for o in log_probs)
            log_probs.append(this_word_logprob)
            selected_words = selected_words.view(-1, 1)  #B*beam_size 1
            wt = selected_words.squeeze(-1)  #B*beam_size

            if t == 0:
                att_feats = utils.expand_tensor(att_feats, beam_size)
                gv_feat = utils.expand_tensor(gv_feat, beam_size)
                att_mask = utils.expand_tensor(att_mask, beam_size)
                p_att_feats = utils.expand_tensor(p_att_feats, beam_size)

                kwargs[cfg.PARAM.ATT_FEATS] = att_feats
                kwargs[cfg.PARAM.GLOBAL_FEAT] = gv_feat
                kwargs[cfg.PARAM.ATT_FEATS_MASK] = att_mask
                kwargs[cfg.PARAM.P_ATT_FEATS] = p_att_feats

        seq_logprob, sort_idxs = torch.sort(seq_logprob, 1, descending=True)
        outputs = torch.cat(outputs, -1)
        outputs = torch.gather(
            outputs, 1,
            sort_idxs.expand(batch_size, beam_size, cfg.MODEL.SEQ_LEN))
        log_probs = torch.cat(log_probs, -1)
        log_probs = torch.gather(
            log_probs, 1,
            sort_idxs.expand(batch_size, beam_size, cfg.MODEL.SEQ_LEN))

        outputs = outputs.contiguous()[:, 0]
        log_probs = log_probs.contiguous()[:, 0]

        if output_attention:
            sort_idx_ex = sort_idxs.unsqueeze(-1).unsqueeze(
                -1)  #B, bs, 1 -> B, bs, 1, 1, 1
            attention_scores = torch.stack(attention_scores, -1)  #B,bs,H,N,T
            sort_idx_ex = sort_idx_ex.expand(batch_size, beam_size,
                                             attention_scores.shape[-3],
                                             attention_scores.shape[-2],
                                             cfg.MODEL.SEQ_LEN)
            #print(attention_scores.shape, sort_idx_ex.shape)
            attention_scores = torch.gather(attention_scores, 1,
                                            sort_idx_ex)  #B, bs, H,N,T
            attention_scores = attention_scores.contiguous()[:, 0]  # B,H,N,T
            #print(attention_scores.shape, att_mask0.shape)
            #input()
            return outputs, log_probs, attention_scores, att_mask0
        else:
            return outputs, log_probs