def forward(self, kwargs): if self.rl_stage == False: logit = self.model(**kwargs) loss, loss_info = self.xe_criterion(logit, kwargs[cfg.PARAM.TARGET_SENT]) else: ids = kwargs[cfg.PARAM.INDICES] gv_feat = kwargs[cfg.PARAM.GLOBAL_FEAT] att_feats = kwargs[cfg.PARAM.ATT_FEATS] att_mask = kwargs[cfg.PARAM.ATT_FEATS_MASK] # max kwargs['BEAM_SIZE'] = 1 kwargs['GREEDY_DECODE'] = True kwargs[cfg.PARAM.GLOBAL_FEAT] = gv_feat kwargs[cfg.PARAM.ATT_FEATS] = att_feats kwargs[cfg.PARAM.ATT_FEATS_MASK] = att_mask self.model.eval() with torch.no_grad(): seq_max, logP_max = self.model.module.decode(**kwargs) self.model.train() rewards_max, rewards_info_max = self.scorer( ids, seq_max.data.cpu().numpy().tolist()) rewards_max = utils.expand_numpy(rewards_max) ids = utils.expand_numpy(ids) gv_feat = utils.expand_tensor(gv_feat, cfg.COCO_DATA_LOADER.SEQ_PER_IMG) att_feats = utils.expand_tensor(att_feats, cfg.COCO_DATA_LOADER.SEQ_PER_IMG) att_mask = utils.expand_tensor(att_mask, cfg.COCO_DATA_LOADER.SEQ_PER_IMG) # sample kwargs['BEAM_SIZE'] = 1 kwargs['GREEDY_DECODE'] = False kwargs[cfg.PARAM.GLOBAL_FEAT] = gv_feat kwargs[cfg.PARAM.ATT_FEATS] = att_feats kwargs[cfg.PARAM.ATT_FEATS_MASK] = att_mask seq_sample, logP_sample = self.model.module.decode(**kwargs) rewards_sample, rewards_info_sample = self.scorer( ids, seq_sample.data.cpu().numpy().tolist()) rewards = rewards_sample - rewards_max rewards = torch.from_numpy(rewards).float().cuda() loss = self.rl_criterion(seq_sample, logP_sample, rewards) loss_info = {} for key in rewards_info_sample: loss_info[key + '_sample'] = rewards_info_sample[key] for key in rewards_info_max: loss_info[key + '_max'] = rewards_info_max[key] return loss, loss_info
def init_gx_encoder_out_p_att_feats_att_mask(self, **kwargs): with torch.no_grad(): att_feats = kwargs[cfg.PARAM.ATT_FEATS] att_mask = kwargs[cfg.PARAM.ATT_FEATS_MASK] repeat_factor = kwargs["REPEAT_FACTOR"] att_feats = self.att_embed(att_feats) gx, encoder_out = self.encoder(att_feats, att_mask) gx = utils.expand_tensor(gx, repeat_factor) encoder_out = utils.expand_tensor(encoder_out, repeat_factor) att_mask = utils.expand_tensor(att_mask, repeat_factor) p_att_feats = self.decoder.precompute(encoder_out) return (gx, encoder_out, p_att_feats, att_mask)
def _prefix_rewards(self, kwargs, seq_prefix): kwargs[cfg.PARAM.MAX_GEN_LEN] = seq_prefix.shape[-1] kwargs[cfg.PARAM.GEN_RESULT] = utils.expand_tensor(seq_prefix, 5) with torch.no_grad(): seq_sample = self.predictor.extend_trajectory( **kwargs).detach().cpu().numpy().tolist() return seq_sample
def forward(self, **kwargs): # forward entry att_feats = kwargs[cfg.PARAM.ATT_FEATS] seq = kwargs[cfg.PARAM.INPUT_SENT] att_mask = kwargs[cfg.PARAM.ATT_FEATS_MASK] # att_mask = torch.ones(16,70).to(device) att_mask = utils.expand_tensor(att_mask, cfg.DATA_LOADER.SEQ_PER_IMG) att_feats = utils.expand_tensor(att_feats, cfg.DATA_LOADER.SEQ_PER_IMG) ############################################## seq_mask = (seq > 0).type(torch.cuda.IntTensor) seq_mask[:, 0] += 1 seq_mask = seq_mask.unsqueeze(-2) seq_mask = seq_mask & subsequent_mask(seq.size(-1)).to(seq_mask) seq_mask = seq_mask.type(torch.cuda.FloatTensor) ############################################## # print('att_feats.shape b4 pretrain',att_feats.shape) # if len(att_feats.shape) == 5: # batch_size = att_feats.shape[0] # img_size = att_feats.shape[1] # channel_size = att_feats.shape[2] # hidden_size1 = att_feats.shape[3] # hidden_size2 = att_feats.shape[4] # att_feats = att_feats.view(batch_size, -1, hidden_size1, hidden_size2) # FEED IUXRAY: two images # att_feats_0 = self.image_pretrained_models(att_feats[:, 0]) # att_feats_1 = self.image_pretrained_models(att_feats[:, 1]) # att_feats = torch.cat((att_feats_0, att_feats_1), dim=1) # shape (bs, 2048, 7, 7) att_feats = self.get_visual_features(att_feats) batch_size, feat_size, _, _ = att_feats.shape att_feats = att_feats.reshape(batch_size, feat_size, -1).permute(0, 2, 1) # print('att_feats.shape after pretrain', att_feats.shape) att_feats = self.att_embed(att_feats) # forward entry # print('att_feats.shape after pretrain', att_feats.shape) gx, encoder_out = self.encoder(att_feats, att_mask) # print(gx.shape, encoder_out.shape) decoder_out = self.decoder(gx, seq, encoder_out, att_mask, seq_mask) # print('decoder_out.shape',decoder_out.shape) # 4, 41, 761 # raise Exception('lol') return decoder_out
def forward(self, **kwargs): seq = kwargs[cfg.PARAM.INPUT_SENT] gv_feat, att_feats, att_mask, p_att_feats = self.preprocess(**kwargs) gv_feat = utils.expand_tensor(gv_feat, cfg.DATA_LOADER.SEQ_PER_IMG) att_feats = utils.expand_tensor(att_feats, cfg.DATA_LOADER.SEQ_PER_IMG) att_mask = utils.expand_tensor(att_mask, cfg.DATA_LOADER.SEQ_PER_IMG) p_att_feats = utils.expand_tensor(p_att_feats, cfg.DATA_LOADER.SEQ_PER_IMG) batch_size = gv_feat.size(0) state = self.init_hidden(batch_size) outputs = Variable( torch.zeros(batch_size, seq.size(1), self.vocab_size).cuda()) for t in range(seq.size(1)): if self.training and t >= 1 and self.ss_prob > 0: prob = torch.empty(batch_size).cuda().uniform_(0, 1) mask = prob < self.ss_prob if mask.sum() == 0: wt = seq[:, t].clone() else: ind = mask.nonzero().view(-1) wt = seq[:, t].data.clone() prob_prev = torch.exp(outputs[:, t - 1].detach()) wt.index_copy_( 0, ind, torch.multinomial(prob_prev, 1).view(-1).index_select(0, ind)) else: wt = seq[:, t].clone() if t >= 1 and seq[:, t].max() == 0: break kwargs = self.make_kwargs(wt, gv_feat, att_feats, att_mask, p_att_feats, state) output, state = self.Forward(**kwargs) if self.dropout_lm is not None: output = self.dropout_lm(output) logit = self.logit(output) outputs[:, t] = logit return outputs
def forward(self, **kwargs): att_feats = kwargs[cfg.PARAM.ATT_FEATS] seq = kwargs[cfg.PARAM.INPUT_SENT] att_mask = kwargs[cfg.PARAM.ATT_FEATS_MASK] att_mask = utils.expand_tensor(att_mask, cfg.DATA_LOADER.SEQ_PER_IMG) att_feats = utils.expand_tensor(att_feats, cfg.DATA_LOADER.SEQ_PER_IMG) ############################################## seq_mask = (seq > 0).type(torch.cuda.IntTensor) seq_mask[:, 0] += 1 seq_mask = seq_mask.unsqueeze(-2) seq_mask = seq_mask & subsequent_mask(seq.size(-1)).to(seq_mask) seq_mask = seq_mask.type(torch.cuda.FloatTensor) ############################################## att_feats = self.att_embed(att_feats) gx, encoder_out = self.encoder(att_feats, att_mask) decoder_out = self.decoder(gx, seq, encoder_out, att_mask, seq_mask) return decoder_out
def forward(self, **kwargs): # forward entry att_feats = kwargs[cfg.PARAM.ATT_FEATS] seq = kwargs[cfg.PARAM.INPUT_SENT] att_mask = kwargs[cfg.PARAM.ATT_FEATS_MASK] # att_mask = torch.ones(16,70).to(device) att_mask = utils.expand_tensor(att_mask, cfg.DATA_LOADER.SEQ_PER_IMG) att_feats = utils.expand_tensor(att_feats, cfg.DATA_LOADER.SEQ_PER_IMG) # HARDCODE: att_mask = None # Regenerate later att_mask = None ############################################## seq_mask = (seq > 0).type(torch.cuda.IntTensor) seq_mask[:, 0] += 1 seq_mask = seq_mask.unsqueeze(-2) seq_mask = seq_mask & subsequent_mask(seq.size(-1)).to(seq_mask) seq_mask = seq_mask.type(torch.cuda.FloatTensor) ############################################## cnn_feats, gcn_feats = self.get_visual_features(att_feats) all_feats = torch.cat([cnn_feats, gcn_feats], dim=1) att_mask = makeMask(all_feats) cnn_mask = makeMask(cnn_feats) gcn_mask = makeMask(gcn_feats) cnn_feats = self.cnn_embed(cnn_feats) # forward entry gcn_feats = self.gcn_embed(gcn_feats) gx, encoder_out = self.encoder(cnn_feats, gcn_feats, cnn_mask, gcn_mask) decoder_out = self.decoder(gx, seq, encoder_out, att_mask, seq_mask) # print('decoder_out.shape',decoder_out.shape) # 4, 41, 761 # raise Exception('lol') return decoder_out
def decode_beam(self, **kwargs): att_feats = kwargs[cfg.PARAM.ATT_FEATS] att_mask = kwargs[cfg.PARAM.ATT_FEATS_MASK] beam_size = kwargs['BEAM_SIZE'] batch_size = att_feats.size(0) seq_logprob = torch.zeros((batch_size, 1, 1)).cuda() log_probs = [] selected_words = None seq_mask = torch.ones((batch_size, beam_size, 1)).cuda() att_feats = self.att_embed(att_feats) gx, encoder_out = self.encoder(att_feats, att_mask) p_att_feats = self.decoder.precompute(encoder_out) state = None wt = Variable(torch.zeros(batch_size, dtype=torch.long).cuda()) kwargs[cfg.PARAM.ATT_FEATS] = encoder_out kwargs[cfg.PARAM.GLOBAL_FEAT] = gx kwargs[cfg.PARAM.P_ATT_FEATS] = p_att_feats outputs = [] self.decoder.init_buffer(batch_size) for t in range(cfg.MODEL.SEQ_LEN): cur_beam_size = 1 if t == 0 else beam_size kwargs[cfg.PARAM.WT] = wt kwargs[cfg.PARAM.STATE] = state word_logprob, state = self.get_logprobs_state(**kwargs) word_logprob = word_logprob.view(batch_size, cur_beam_size, -1) candidate_logprob = seq_logprob + word_logprob # Mask sequence if it reaches EOS if t > 0: mask = (selected_words.view(batch_size, cur_beam_size) != 0).float().unsqueeze(-1) seq_mask = seq_mask * mask word_logprob = word_logprob * seq_mask.expand_as(word_logprob) old_seq_logprob = seq_logprob.expand_as( candidate_logprob).contiguous() old_seq_logprob[:, :, 1:] = -999 candidate_logprob = seq_mask * candidate_logprob + old_seq_logprob * ( 1 - seq_mask) selected_idx, selected_logprob = self.select( batch_size, beam_size, t, candidate_logprob) selected_beam = selected_idx / candidate_logprob.shape[-1] selected_words = selected_idx - selected_beam * candidate_logprob.shape[ -1] self.decoder.apply_to_states( self._expand_state(batch_size, beam_size, cur_beam_size, selected_beam)) seq_logprob = selected_logprob.unsqueeze(-1) seq_mask = torch.gather(seq_mask, 1, selected_beam.unsqueeze(-1)) outputs = list( torch.gather(o, 1, selected_beam.unsqueeze(-1)) for o in outputs) outputs.append(selected_words.unsqueeze(-1)) this_word_logprob = torch.gather( word_logprob, 1, selected_beam.unsqueeze(-1).expand(batch_size, beam_size, word_logprob.shape[-1])) this_word_logprob = torch.gather(this_word_logprob, 2, selected_words.unsqueeze(-1)) log_probs = list( torch.gather( o, 1, selected_beam.unsqueeze(-1).expand(batch_size, beam_size, 1)) for o in log_probs) log_probs.append(this_word_logprob) selected_words = selected_words.view(-1, 1) wt = selected_words.squeeze(-1) if t == 0: encoder_out = utils.expand_tensor(encoder_out, beam_size) gx = utils.expand_tensor(gx, beam_size) att_mask = utils.expand_tensor(att_mask, beam_size) state[0] = state[0].squeeze(0) state[0] = utils.expand_tensor(state[0], beam_size) state[0] = state[0].unsqueeze(0) p_att_feats_tmp = [] for p_feat in p_att_feats: p_key, p_value2 = p_feat p_key = utils.expand_tensor(p_key, beam_size) p_value2 = utils.expand_tensor(p_value2, beam_size) p_att_feats_tmp.append((p_key, p_value2)) kwargs[cfg.PARAM.ATT_FEATS] = encoder_out kwargs[cfg.PARAM.GLOBAL_FEAT] = gx kwargs[cfg.PARAM.ATT_FEATS_MASK] = att_mask kwargs[cfg.PARAM.P_ATT_FEATS] = p_att_feats_tmp seq_logprob, sort_idxs = torch.sort(seq_logprob, 1, descending=True) outputs = torch.cat(outputs, -1) outputs = torch.gather( outputs, 1, sort_idxs.expand(batch_size, beam_size, cfg.MODEL.SEQ_LEN)) log_probs = torch.cat(log_probs, -1) log_probs = torch.gather( log_probs, 1, sort_idxs.expand(batch_size, beam_size, cfg.MODEL.SEQ_LEN)) outputs = outputs.contiguous()[:, 0] log_probs = log_probs.contiguous()[:, 0] self.decoder.clear_buffer() return outputs, log_probs
def decode(self, **kwargs): greedy_decode = kwargs['GREEDY_DECODE'] att_feats = kwargs[cfg.PARAM.ATT_FEATS] att_mask = kwargs[cfg.PARAM.ATT_FEATS_MASK] att_feats = self.att_embed(att_feats) gx, encoder_out = self.encoder(att_feats, att_mask) repeat_factor = 1 if "REPEAT_FACTOR" in kwargs: repeat_factor = kwargs["REPEAT_FACTOR"] gx = utils.expand_tensor(gx, repeat_factor) encoder_out = utils.expand_tensor(encoder_out, repeat_factor) p_att_feats = self.decoder.precompute(encoder_out) batch_size = att_feats.size(0) self.decoder.init_buffer(batch_size) state = None sents = torch.zeros((batch_size, cfg.MODEL.SEQ_LEN), dtype=torch.long).cuda() if cfg.PARAM.GEN_RESULT in kwargs: gen_result = kwargs[cfg.PARAM.GEN_RESULT].view( batch_size, cfg.MODEL.SEQ_LEN) logprobs = torch.zeros(batch_size, cfg.MODEL.SEQ_LEN, self.vocab_size).cuda() else: gen_result = None if "NEED_PD" in kwargs: logprobs = torch.zeros(batch_size, cfg.MODEL.SEQ_LEN, self.vocab_size).cuda() else: logprobs = torch.zeros(batch_size, cfg.MODEL.SEQ_LEN).cuda() wt = torch.zeros(batch_size, dtype=torch.long).cuda() unfinished = wt.eq(wt) kwargs[cfg.PARAM.ATT_FEATS] = encoder_out kwargs[cfg.PARAM.GLOBAL_FEAT] = gx kwargs[cfg.PARAM.P_ATT_FEATS] = p_att_feats for t in range(cfg.MODEL.SEQ_LEN): kwargs[cfg.PARAM.WT] = wt kwargs[cfg.PARAM.STATE] = state logprobs_t, state = self.get_logprobs_state(**kwargs) logprobs_t = logprobs_t.reshape(batch_size, self.vocab_size) if greedy_decode: logP_t, wt = torch.max(logprobs_t, 1) else: if gen_result is None: probs_t = torch.exp(logprobs_t) wt = torch.multinomial(probs_t, 1) if "NEED_PD" not in kwargs: logP_t = logprobs_t.gather(1, wt).view(-1) else: logP_t = logprobs_t else: wt = gen_result[:, t] logP_t = logprobs_t wt = wt.view(-1).long() unfinished = unfinished * (wt > 0) wt = wt * unfinished.type_as(wt) sents[:, t] = wt logprobs[:, t] = logP_t if unfinished.sum() == 0: break self.decoder.clear_buffer() return sents, logprobs
def decode_beam(self, **kwargs): gv_feat, att_feats, att_mask, p_att_feats = self.preprocess(**kwargs) beam_size = kwargs['BEAM_SIZE'] batch_size = att_feats.size(0) seq_logprob = torch.zeros((batch_size, 1, 1)).cuda() log_probs = [] selected_words = None seq_mask = torch.ones((batch_size, beam_size, 1)).cuda() state = self.init_hidden(batch_size) wt = Variable(torch.zeros(batch_size, dtype=torch.long).cuda()) kwargs[cfg.PARAM.ATT_FEATS] = att_feats kwargs[cfg.PARAM.GLOBAL_FEAT] = gv_feat kwargs[cfg.PARAM.P_ATT_FEATS] = p_att_feats outputs = [] for t in range(cfg.MODEL.SEQ_LEN): cur_beam_size = 1 if t == 0 else beam_size kwargs[cfg.PARAM.WT] = wt kwargs[cfg.PARAM.STATE] = state word_logprob, state = self.get_logprobs_state(**kwargs) word_logprob = word_logprob.view(batch_size, cur_beam_size, -1) candidate_logprob = seq_logprob + word_logprob # Mask sequence if it reaches EOS if t > 0: mask = (selected_words.view(batch_size, cur_beam_size) != 0).float().unsqueeze(-1) seq_mask = seq_mask * mask word_logprob = word_logprob * seq_mask.expand_as(word_logprob) old_seq_logprob = seq_logprob.expand_as( candidate_logprob).contiguous() old_seq_logprob[:, :, 1:] = -999 candidate_logprob = seq_mask * candidate_logprob + old_seq_logprob * ( 1 - seq_mask) selected_idx, selected_logprob = self.select( batch_size, beam_size, t, candidate_logprob) # selected_beam = selected_idx / candidate_logprob.shape[-1] selected_beam = torch.floor_divide(selected_idx, candidate_logprob.shape[-1]) selected_words = selected_idx - selected_beam * candidate_logprob.shape[ -1] for s in range(len(state)): state[s] = self._expand_state(batch_size, beam_size, cur_beam_size, state[s], selected_beam) seq_logprob = selected_logprob.unsqueeze(-1) seq_mask = torch.gather(seq_mask, 1, selected_beam.unsqueeze(-1)) outputs = list( torch.gather(o, 1, selected_beam.unsqueeze(-1)) for o in outputs) outputs.append(selected_words.unsqueeze(-1)) this_word_logprob = torch.gather( word_logprob, 1, selected_beam.unsqueeze(-1).expand(batch_size, beam_size, word_logprob.shape[-1])) this_word_logprob = torch.gather(this_word_logprob, 2, selected_words.unsqueeze(-1)) log_probs = list( torch.gather( o, 1, selected_beam.unsqueeze(-1).expand(batch_size, beam_size, 1)) for o in log_probs) log_probs.append(this_word_logprob) selected_words = selected_words.view(-1, 1) wt = selected_words.squeeze(-1) if t == 0: att_feats = utils.expand_tensor(att_feats, beam_size) gv_feat = utils.expand_tensor(gv_feat, beam_size) att_mask = utils.expand_tensor(att_mask, beam_size) p_att_feats = utils.expand_tensor(p_att_feats, beam_size) kwargs[cfg.PARAM.ATT_FEATS] = att_feats kwargs[cfg.PARAM.GLOBAL_FEAT] = gv_feat kwargs[cfg.PARAM.ATT_FEATS_MASK] = att_mask kwargs[cfg.PARAM.P_ATT_FEATS] = p_att_feats seq_logprob, sort_idxs = torch.sort(seq_logprob, 1, descending=True) outputs = torch.cat(outputs, -1) outputs = torch.gather( outputs, 1, sort_idxs.expand(batch_size, beam_size, cfg.MODEL.SEQ_LEN)) log_probs = torch.cat(log_probs, -1) log_probs = torch.gather( log_probs, 1, sort_idxs.expand(batch_size, beam_size, cfg.MODEL.SEQ_LEN)) outputs = outputs.contiguous()[:, 0] log_probs = log_probs.contiguous()[:, 0] return outputs, log_probs
def decode_beam(self, **kwargs): #print('decode beam!') gv_feat, att_feats, att_mask, p_att_feats = self.preprocess(**kwargs) att_mask0 = att_mask beam_size = kwargs['BEAM_SIZE'] output_attention = kwargs['output_attention'] batch_size = att_feats.size(0) seq_logprob = torch.zeros((batch_size, 1, 1)).cuda() log_probs, attention_scores = [], [] selected_words = None seq_mask = torch.ones((batch_size, beam_size, 1)).cuda() state = self.init_hidden(batch_size) wt = Variable(torch.zeros(batch_size, dtype=torch.long).cuda()) kwargs[cfg.PARAM.ATT_FEATS] = att_feats kwargs[cfg.PARAM.GLOBAL_FEAT] = gv_feat kwargs[cfg.PARAM.P_ATT_FEATS] = p_att_feats outputs = [] for t in range(cfg.MODEL.SEQ_LEN): cur_beam_size = 1 if t == 0 else beam_size #??? kwargs[cfg.PARAM.WT] = wt kwargs[cfg.PARAM.STATE] = state logprobs_state_output = self.get_logprobs_state( **kwargs) # word_logprob B, #V #print(len(logprobs_state_output)) #print(logprobs_state_output[0].shape, logprobs_state_output[1].shape) word_logprob, state = logprobs_state_output[ 0], logprobs_state_output[1] if output_attention: attention_score = logprobs_state_output[2][ -1] #[(36,8,67)] the last layer attention_score = attention_score.view( batch_size, -1, attention_score.shape[1], attention_score.shape[2]) word_logprob = word_logprob.view(batch_size, cur_beam_size, -1) candidate_logprob = seq_logprob + word_logprob # B,1,1 + B,1,#V # Mask sequence if it reaches EOS if t > 0: mask = (selected_words.view(batch_size, cur_beam_size) != 0).float().unsqueeze(-1) #B, cur_beam_size !=0 unended True ended False -> B, cur_beam_size, 1 seq_mask = seq_mask * mask #B, beam_size, 1 update seq_mask word_logprob = word_logprob * seq_mask.expand_as(word_logprob) #word_logprob B, beam_size #V #seq_mask B, beam_size, 1 old_seq_logprob = seq_logprob.expand_as( candidate_logprob).contiguous() #seq_logprob B,1,1 -> B, beam_size, #V old_seq_logprob[:, :, 1:] = -999 #B, beam_size, candidate_logprob = seq_mask * candidate_logprob + old_seq_logprob * ( 1 - seq_mask) selected_idx, selected_logprob = self.select( batch_size, beam_size, t, candidate_logprob) selected_beam = selected_idx / candidate_logprob.shape[-1] #B, bs #print('selected_beam', selected_beam.shape, selected_beam[0]) selected_words = selected_idx - selected_beam * candidate_logprob.shape[ -1] #print('selected_words', selected_words.shape, selected_words[0]) #B, bs #input() for s in range(len(state)): #print('state shape before expand {}'.format(state[s].shape)) state[s] = self._expand_state(batch_size, beam_size, cur_beam_size, state[s], selected_beam) # print('state shape after expand {}'.format(state[s].shape)) # input() #for a in range(len(attention_score)): if output_attention: selected_beam_ex = selected_beam.unsqueeze(-1).unsqueeze( -1).expand(batch_size, beam_size, attention_score.shape[-2], attention_score.shape[-1]) this_word_attention_score = torch.gather( attention_score, 1, selected_beam_ex) #print('this word attention score shape {}'.format(this_word_attention_score.shape)) attention_scores = list( torch.gather(a, 1, selected_beam_ex) for a in attention_scores) attention_scores.append(this_word_attention_score) #B, bs, H, # def debug(attention_scores, selected_beam): # b = 0 # print('selected_beam {}'.format(selected_beam[b])) # print('attention_scores \n {}'.format( \ # [a[b, :, 0, :3] \ # for a in attention_scores])) # debug(attention_scores, selected_beam) # input() seq_logprob = selected_logprob.unsqueeze(-1) # B,beam_size,1 seq_mask = torch.gather( seq_mask, 1, selected_beam.unsqueeze(-1)) #B, beam_size, 1 outputs = list( torch.gather(o, 1, selected_beam.unsqueeze(-1)) for o in outputs) outputs.append(selected_words.unsqueeze(-1)) #B, beam_size, 1 this_word_logprob = torch.gather( word_logprob, 1, selected_beam.unsqueeze(-1).expand(batch_size, beam_size, word_logprob.shape[-1])) this_word_logprob = torch.gather(this_word_logprob, 2, selected_words.unsqueeze(-1)) log_probs = list( torch.gather( o, 1, selected_beam.unsqueeze(-1).expand(batch_size, beam_size, 1)) for o in log_probs) log_probs.append(this_word_logprob) selected_words = selected_words.view(-1, 1) #B*beam_size 1 wt = selected_words.squeeze(-1) #B*beam_size if t == 0: att_feats = utils.expand_tensor(att_feats, beam_size) gv_feat = utils.expand_tensor(gv_feat, beam_size) att_mask = utils.expand_tensor(att_mask, beam_size) p_att_feats = utils.expand_tensor(p_att_feats, beam_size) kwargs[cfg.PARAM.ATT_FEATS] = att_feats kwargs[cfg.PARAM.GLOBAL_FEAT] = gv_feat kwargs[cfg.PARAM.ATT_FEATS_MASK] = att_mask kwargs[cfg.PARAM.P_ATT_FEATS] = p_att_feats seq_logprob, sort_idxs = torch.sort(seq_logprob, 1, descending=True) outputs = torch.cat(outputs, -1) outputs = torch.gather( outputs, 1, sort_idxs.expand(batch_size, beam_size, cfg.MODEL.SEQ_LEN)) log_probs = torch.cat(log_probs, -1) log_probs = torch.gather( log_probs, 1, sort_idxs.expand(batch_size, beam_size, cfg.MODEL.SEQ_LEN)) outputs = outputs.contiguous()[:, 0] log_probs = log_probs.contiguous()[:, 0] if output_attention: sort_idx_ex = sort_idxs.unsqueeze(-1).unsqueeze( -1) #B, bs, 1 -> B, bs, 1, 1, 1 attention_scores = torch.stack(attention_scores, -1) #B,bs,H,N,T sort_idx_ex = sort_idx_ex.expand(batch_size, beam_size, attention_scores.shape[-3], attention_scores.shape[-2], cfg.MODEL.SEQ_LEN) #print(attention_scores.shape, sort_idx_ex.shape) attention_scores = torch.gather(attention_scores, 1, sort_idx_ex) #B, bs, H,N,T attention_scores = attention_scores.contiguous()[:, 0] # B,H,N,T #print(attention_scores.shape, att_mask0.shape) #input() return outputs, log_probs, attention_scores, att_mask0 else: return outputs, log_probs