def _forward(self, img, seq, ppls, gt_boxes, mask_boxes, num): seq = seq.view(-1, seq.size(2), seq.size(3)) seq_update = seq.data.clone() batch_size = img.size(0) seq_batch_size = seq.size(0) rois_num = ppls.size(1) # constructing the mask. pnt_mask = mask_boxes.data.new(batch_size, rois_num + 1).byte().fill_(1) for i in range(batch_size): pnt_mask[i, :num.data[i, 1] + 1] = 0 pnt_mask = Variable(pnt_mask) state = self.init_hidden(seq_batch_size) rnn_output = [] roi_labels = [] det_output = [] if self.finetune_cnn: conv_feats, fc_feats = self.cnn(img) else: # with torch.no_grad(): conv_feats, fc_feats = self.cnn(Variable(img.data, volatile=True)) conv_feats = Variable(conv_feats.data) fc_feats = Variable(fc_feats.data) # pooling the conv_feats rois = ppls.data.new(batch_size, rois_num, 5) rois[:, :, 1:] = ppls.data[:, :, :4] for i in range(batch_size): rois[i, :, 0] = i pool_feats = self.roi_align(conv_feats, Variable(rois.view(-1, 5))) pool_feats = pool_feats.view(batch_size, rois_num, self.att_feat_size) loc_input = ppls.data.new(batch_size, rois_num, 5) loc_input[:, :, :4] = ppls.data[:, :, :4] / self.image_crop_size loc_input[:, :, 4] = ppls.data[:, :, 5] loc_feats = self.loc_fc(Variable(loc_input)) label_input = seq.data.new(batch_size, rois_num) label_input[:, :] = ppls.data[:, :, 4] label_feat = self.det_fc(Variable(label_input)) # pool_feats = pool_feats + label_feat pool_feats = torch.cat((pool_feats, loc_feats, label_feat), 2) # transpose the conv_feats conv_feats = conv_feats.view(batch_size, self.att_feat_size, -1).transpose(1, 2).contiguous() # replicate the feature to map the seq size. fc_feats = fc_feats.view(batch_size, 1, self.fc_feat_size)\ .expand(batch_size, self.seq_per_img, self.fc_feat_size)\ .contiguous().view(-1, self.fc_feat_size) conv_feats = conv_feats.view(batch_size, 1, self.att_size*self.att_size, self.att_feat_size)\ .expand(batch_size, self.seq_per_img, self.att_size*self.att_size, self.att_feat_size)\ .contiguous().view(-1, self.att_size*self.att_size, self.att_feat_size) pool_feats = pool_feats.view(batch_size, 1, rois_num, self.pool_feat_size)\ .expand(batch_size, self.seq_per_img, rois_num, self.pool_feat_size)\ .contiguous().view(-1, rois_num, self.pool_feat_size) pnt_mask = pnt_mask.view(batch_size, 1, rois_num+1).expand(batch_size, self.seq_per_img, rois_num+1)\ .contiguous().view(-1, rois_num+1) # embed fc and att feats fc_feats = self.fc_embed(fc_feats) conv_feats = self.att_embed(conv_feats) pool_feats = self.pool_embed(pool_feats) # Project the attention feats first to reduce memory and computation comsumptions. p_conv_feats = self.ctx2att(conv_feats) p_pool_feats = self.ctx2pool(pool_feats) # calculate the overlaps between the rois and gt_bbox. overlaps = utils.bbox_overlaps(ppls.data, gt_boxes.data) for i in range(self.seq_length): if self.training and i >= 1 and self.ss_prob > 0.0: # otherwiste no need to sample sample_prob = fc_feats.data.new(batch_size).uniform_(0, 1) sample_mask = sample_prob < self.ss_prob if sample_mask.sum() == 0: it = seq[:, i, 0].clone() else: sample_ind = sample_mask.nonzero().view(-1) it = seq[:, i, 0].data.clone() #prob_prev = torch.exp(outputs[-1].data.index_select(0, sample_ind)) # fetch prev distribution: shape Nx(M+1) #it.index_copy_(0, sample_ind, torch.multinomial(prob_prev, 1).view(-1)) prob_prev = torch.exp( outputs[-1].data ) # fetch prev distribution: shape Nx(M+1) it.index_copy_( 0, sample_ind, torch.multinomial(prob_prev, 1).view(-1).index_select( 0, sample_ind)) it = Variable(it) else: it = seq[:, i, 0].clone() # break if all the sequences end if i >= 1 and seq[:, i, 0].data.sum() == 0: break roi_label = utils.bbox_target( ppls, mask_boxes[:, :, :, i + 1], overlaps, seq[:, i + 1], seq_update[:, i + 1], self.vocab_size) # roi_label if for the target seq # pdb.set_trace() roi_labels.append(roi_label.view(seq_batch_size, -1)) xt = self.embed(it) output, det_prob, state = self.core(xt, fc_feats, conv_feats, p_conv_feats, pool_feats, p_pool_feats, pnt_mask, pnt_mask, state) rnn_output.append(output) det_output.append(det_prob) seq_cnt = len(det_output) rnn_output = torch.cat([_.unsqueeze(1) for _ in rnn_output], 1) det_output = torch.cat([_.unsqueeze(1) for _ in det_output], 1) roi_labels = torch.cat([_.unsqueeze(1) for _ in roi_labels], 1) det_output = F.log_softmax(det_output, dim=2) decoded = F.log_softmax(self.beta * self.logit(rnn_output), dim=2) lambda_v = det_output[:, :, 0].contiguous() prob_det = det_output[:, :, 1:].contiguous() decoded = decoded + lambda_v.view(seq_batch_size, seq_cnt, 1).expand_as(decoded) decoded = decoded.view((seq_cnt) * seq_batch_size, -1) roi_labels = Variable(roi_labels) prob_det = prob_det * roi_labels roi_cnt = roi_labels.sum(2) roi_cnt[roi_cnt == 0] = 1 vis_prob = prob_det.sum(2) / roi_cnt vis_prob = vis_prob.view(-1, 1) seq_update = Variable(seq_update) lm_loss = self.critLM(decoded, vis_prob, seq_update[:, 1:seq_cnt + 1, 0].clone()) # do the cascade object recognition. vis_idx = seq_update[:, 1:seq_cnt + 1, 0].data - self.vocab_size vis_idx[vis_idx < 0] = 0 vis_idx = Variable(vis_idx.view(-1)) # roi_labels = Variable(roi_labels) bn_logprob, fg_logprob = self.ccr_core(vis_idx, pool_feats, rnn_output, roi_labels, seq_batch_size, seq_cnt) bn_loss = self.critBN(bn_logprob, seq_update[:, 1:seq_cnt + 1, 1].clone()) fg_loss = self.critFG(fg_logprob, seq_update[:, 1:seq_cnt + 1, 2].clone()) return lm_loss, bn_loss, fg_loss
def _forward(self, segs_feat, input_seq, gt_seq, ppls, gt_boxes, mask_boxes, num, ppls_feat, frm_mask, sample_idx, pnt_mask, eval_obj_ground=False): seq = gt_seq[:, :self.seq_per_img, :].clone().view( -1, gt_seq.size(2)) # choose the first seq_per_img seq = torch.cat((Variable(seq.data.new(seq.size(0), 1).fill_(0)), seq), 1) input_seq = input_seq.view( -1, input_seq.size(2), input_seq.size(3)) # B*self.seq_per_img, self.seq_length+1, 5 input_seq_update = input_seq.data.clone() batch_size = segs_feat.size(0) # B seq_batch_size = seq.size(0) # B*self.seq_per_img rois_num = ppls.size(1) # max_num_proposal of the batch state = self.init_hidden( seq_batch_size ) # self.num_layers, B*self.seq_per_img, self.rnn_size rnn_output = [] roi_labels = [] # store which proposal match the gt box att2_weights = [] h_att_output = [] max_grd_output = [] frm_mask_output = [] conv_feats = segs_feat sample_idx_mask = conv_feats.new(batch_size, conv_feats.size(1), 1).fill_(1).byte() for i in range(batch_size): sample_idx_mask[i, sample_idx[i, 0]:sample_idx[i, 1]] = 0 fc_feats = torch.mean(segs_feat, dim=1) fc_feats = torch.cat((F.layer_norm(fc_feats, [self.fc_feat_size-self.seg_info_size]), \ F.layer_norm(self.seg_info_embed(num[:, 3:7].float()), [self.seg_info_size])), dim=-1) # pooling the conv_feats pool_feats = ppls_feat pool_feats = self.ctx2pool_grd(pool_feats) g_pool_feats = pool_feats # calculate the overlaps between the rois/rois and rois/gt_bbox. # apply both frame mask and proposal mask overlaps = utils.bbox_overlaps(ppls.data, gt_boxes.data, \ (frm_mask | pnt_mask[:, 1:].unsqueeze(-1)).data) # visual words embedding vis_word = Variable( torch.Tensor(range(0, self.detect_size + 1)).type(input_seq.type())) vis_word_embed = self.vis_embed(vis_word) assert (vis_word_embed.size(0) == self.detect_size + 1) p_vis_word_embed = vis_word_embed.view(1, self.detect_size+1, self.vis_encoding_size) \ .expand(batch_size, self.detect_size+1, self.vis_encoding_size).contiguous() if hasattr(self, 'vis_classifiers_bias'): bias = self.vis_classifiers_bias.type(p_vis_word_embed.type()) \ .view(1,-1,1).expand(p_vis_word_embed.size(0), \ p_vis_word_embed.size(1), g_pool_feats.size(1)) else: bias = None # region-class similarity matrix sim_mat_static = self._grounder(p_vis_word_embed, g_pool_feats, pnt_mask[:, 1:], bias) sim_mat_static_update = sim_mat_static.view(batch_size, 1, self.detect_size+1, rois_num) \ .expand(batch_size, self.seq_per_img, self.detect_size+1, rois_num).contiguous() \ .view(seq_batch_size, self.detect_size+1, rois_num) sim_mat_static = F.softmax(sim_mat_static, dim=1) if self.test_mode: cls_pred = 0 else: sim_target = utils.sim_mat_target( overlaps, gt_boxes[:, :, 5].data) # B, num_box, num_rois sim_mask = (sim_target > 0) if not eval_obj_ground: masked_sim = torch.gather(sim_mat_static, 1, sim_target) masked_sim = torch.masked_select(masked_sim, sim_mask) cls_loss = F.binary_cross_entropy( masked_sim, masked_sim.new(masked_sim.size()).fill_(1)) else: # region classification accuracy sim_target_masked = torch.masked_select(sim_target, sim_mask) sim_mat_masked = torch.masked_select( torch.max(sim_mat_static, dim=1)[1].unsqueeze(1).expand_as(sim_target), sim_mask) cls_pred = torch.stack((sim_target_masked, sim_mat_masked), dim=1).data if not self.enable_BUTD: loc_input = ppls.data.new(batch_size, rois_num, 5) loc_input[:, :, :4] = ppls.data[:, :, :4] / 720. loc_input[:, :, 4] = ppls.data[:, :, 4] * 1. / self.num_sampled_frm loc_feats = self.loc_fc( Variable(loc_input)) # encode the locations label_feat = sim_mat_static.permute(0, 2, 1).contiguous() pool_feats = torch.cat((F.layer_norm(pool_feats, [pool_feats.size(-1)]), \ F.layer_norm(loc_feats, [loc_feats.size(-1)]), F.layer_norm(label_feat, [label_feat.size(-1)])), 2) # replicate the feature to map the seq size. fc_feats = fc_feats.view(batch_size, 1, self.fc_feat_size)\ .expand(batch_size, self.seq_per_img, self.fc_feat_size)\ .contiguous().view(-1, self.fc_feat_size) pool_feats = pool_feats.view(batch_size, 1, rois_num, self.pool_feat_size)\ .expand(batch_size, self.seq_per_img, rois_num, self.pool_feat_size)\ .contiguous().view(-1, rois_num, self.pool_feat_size) g_pool_feats = g_pool_feats.view(batch_size, 1, rois_num, self.vis_encoding_size) \ .expand(batch_size, self.seq_per_img, rois_num, self.vis_encoding_size) \ .contiguous().view(-1, rois_num, self.vis_encoding_size) pnt_mask = pnt_mask.view(batch_size, 1, rois_num+1).expand(batch_size, self.seq_per_img, rois_num+1)\ .contiguous().view(-1, rois_num+1) overlaps = overlaps.view(batch_size, 1, rois_num, overlaps.size(2)) \ .expand(batch_size, self.seq_per_img, rois_num, overlaps.size(2)) \ .contiguous().view(-1, rois_num, overlaps.size(2)) # embed fc and att feats fc_feats = self.fc_embed(fc_feats) pool_feats = self.pool_embed(pool_feats) # object region interactions if hasattr(self, 'obj_interact'): pool_feats = self.obj_interact(pool_feats) # Project the attention feats first to reduce memory and computation comsumptions. p_pool_feats = self.ctx2pool(pool_feats) # same here if self.att_input_mode in ('both', 'featmap'): conv_feats_splits = torch.split(conv_feats, 2048, 2) conv_feats = torch.cat( [m(c) for (m, c) in zip(self.att_embed, conv_feats_splits)], dim=2) conv_feats = conv_feats.permute(0, 2, 1).contiguous( ) # inconsistency between Torch TempConv and PyTorch Conv1d conv_feats = self.att_embed_aux(conv_feats) conv_feats = conv_feats.permute(0, 2, 1).contiguous( ) # inconsistency between Torch TempConv and PyTorch Conv1d conv_feats = self.context_enc(conv_feats)[0] conv_feats = conv_feats.masked_fill(sample_idx_mask, 0) conv_feats = conv_feats.view(batch_size, 1, self.t_attn_size, self.rnn_size)\ .expand(batch_size, self.seq_per_img, self.t_attn_size, self.rnn_size)\ .contiguous().view(-1, self.t_attn_size, self.rnn_size) p_conv_feats = self.ctx2att( conv_feats) # self.rnn_size (1024) -> self.att_hid_size (512) else: # dummy conv_feats = pool_feats.new(1, 1).fill_(0) p_conv_feats = pool_feats.new(1, 1).fill_(0) if self.att_model == 'transformer': # Masked Transformer does not support box supervision yet if self.att_input_mode == 'both': lm_loss = self.cap_model([conv_feats, pool_feats], seq) elif self.att_input_mode == 'featmap': lm_loss = self.cap_model([conv_feats, conv_feats], seq) elif self.att_input_mode == 'region': lm_loss = self.cap_model([pool_feats, pool_feats], seq) return lm_loss.unsqueeze(0), lm_loss.new(1).fill_(0), lm_loss.new(1).fill_(0), \ lm_loss.new(1).fill_(0), lm_loss.new(1).fill_(0), lm_loss.new(1).fill_(0) elif self.att_model == 'topdown': for i in range(self.seq_length): it = seq[:, i].clone() # break if all the sequences end if i >= 1 and seq[:, i].data.sum() == 0: break xt = self.embed(it) if not eval_obj_ground: roi_label = utils.bbox_target(mask_boxes[:,:,:,i+1], overlaps, input_seq[:,i+1], \ input_seq_update[:,i+1], self.vocab_size) # roi_label if for the target seq roi_labels.append(roi_label.view(seq_batch_size, -1)) # use frame mask during training box_mask = mask_boxes[:, 0, :, i + 1].contiguous().unsqueeze(1).expand( (batch_size, rois_num, mask_boxes.size(2))) frm_mask_on_prop = (torch.sum( (~(box_mask | frm_mask)), dim=2) <= 0) frm_mask_on_prop = torch.cat((frm_mask_on_prop.new(batch_size, 1).fill_(0.), \ frm_mask_on_prop), dim=1) | pnt_mask output, state, att2_weight, att_h, max_grd_val, grd_val = self.core(xt, fc_feats, \ conv_feats, p_conv_feats, pool_feats, p_pool_feats, pnt_mask, frm_mask_on_prop, \ state, sim_mat_static_update) frm_mask_output.append(frm_mask_on_prop) else: output, state, att2_weight, att_h, max_grd_val, grd_val = self.core(xt, fc_feats, \ conv_feats, p_conv_feats, pool_feats, p_pool_feats, pnt_mask, pnt_mask, \ state, sim_mat_static_update) att2_weights.append(att2_weight) h_att_output.append( att_h) # the hidden state of attention LSTM rnn_output.append(output) max_grd_output.append(max_grd_val) seq_cnt = len(rnn_output) rnn_output = torch.cat([_.unsqueeze(1) for _ in rnn_output], 1) # seq_batch_size, seq_cnt, vocab h_att_output = torch.cat([_.unsqueeze(1) for _ in h_att_output], 1) att2_weights = torch.cat([_.unsqueeze(1) for _ in att2_weights], 1) # seq_batch_size, seq_cnt, att_size max_grd_output = torch.cat( [_.unsqueeze(1) for _ in max_grd_output], 1) if not eval_obj_ground: frm_mask_output = torch.cat( [_.unsqueeze(1) for _ in frm_mask_output], 1) roi_labels = torch.cat([_.unsqueeze(1) for _ in roi_labels], 1) decoded = F.log_softmax(self.beta * self.logit(rnn_output), dim=2) # text word prob decoded = decoded.view((seq_cnt) * seq_batch_size, -1) # object grounding h_att_all = h_att_output # hidden states from the Attention LSTM xt_clamp = torch.clamp(input_seq[:, 1:seq_cnt + 1, 0].clone() - self.vocab_size, min=0) xt_all = self.vis_embed(xt_clamp) if hasattr(self, 'vis_classifiers_bias'): bias = self.vis_classifiers_bias[xt_clamp].type(xt_all.type()) \ .unsqueeze(2).expand(seq_batch_size, seq_cnt, rois_num) else: bias = 0 if not eval_obj_ground: # att2_weights/ground_weights with both proposal mask and frame mask ground_weights = self._grounder(xt_all, g_pool_feats, frm_mask_output[:, :, 1:], bias + att2_weights) lm_loss, att2_loss, ground_loss = self.critLM(decoded, att2_weights, ground_weights, \ seq[:, 1:seq_cnt+1].clone(), roi_labels[:, :seq_cnt, :].clone(), input_seq[:, 1:seq_cnt+1, 0].clone()) return lm_loss.unsqueeze(0), att2_loss.unsqueeze( 0), ground_loss.unsqueeze(0), cls_loss.unsqueeze(0) else: # att2_weights/ground_weights with proposal mask only ground_weights = self._grounder(xt_all, g_pool_feats, pnt_mask[:, 1:], bias + att2_weights) return cls_pred, torch.max(att2_weights.view(seq_batch_size, seq_cnt, self.num_sampled_frm, \ self.num_prop_per_frm), dim=-1)[1], torch.max(ground_weights.view(seq_batch_size, \ seq_cnt, self.num_sampled_frm, self.num_prop_per_frm), dim=-1)[1]
def _forward_3_loops(self, segs_feat, input_seq, proposals, gt_caption, num, mask_boxes, gt_boxes, region_feats, frm_mask, sample_idx, pnt_mask): """ :param segs_feat: :param proposals: (batch_size, max num of proposals allowed (200), 6), 6-dim: (4 points for the box, proposal index to be used for self.itod to get detection, proposal score) :param seq: (batch_size, # of captions for this img, max caption length, 4), 4-dim: (category class, binary class, fine-grain class, from self.wtoi) :param num: (batch_size, 3), 3-dim: (# of captions for this img, # of proposals, # of GT bboxs) :param mask_boxes: (batch_size, # of captions for this img, max num of GT bbox allowed (100), max caption length) :param gt_boxes: (batch_size, max num of GT bbox allowed (100), 5), 5-dim: (4 points for the box, bbox index) :return: """ batch_size = segs_feat.size(0) num_rois = proposals.size(1) gt_caption = gt_caption[:, :self.seq_per_img, :].clone().view( -1, gt_caption.size(2)) # choose the first seq_per_img gt_caption = torch.cat( (gt_caption.new(gt_caption.size(0), 1).fill_(0), gt_caption), 1) # B*self.seq_per_img, self.seq_length+1, 5 input_seq = input_seq.view(-1, input_seq.size(2), input_seq.size(3)) input_seq_update = input_seq.data.clone() gt_caption_batch_size = gt_caption.size(0) decoder_state = self.init_hidden(gt_caption_batch_size, self.decoder_num_layers) # calculate the overlaps between the rois and gt_bbox. # overlaps, overlaps_no_replicate = utils.bbox_overlaps(proposals.data, gt_boxes.data) # calculate the overlaps between the rois/rois and rois/gt_bbox. # apply both frame mask and proposal mask overlaps = utils.bbox_overlaps(proposals.data, gt_boxes.data, (frm_mask | pnt_mask[:, 1:].unsqueeze(-1)).data) fc_feats, conv_feats, p_conv_feats, pool_feats, \ p_pool_feats, g_pool_feats, pnt_mask, overlaps_expanded, cls_pred, cls_loss = self.roi_feat_extractor( segs_feat, proposals, num, mask_boxes, region_feats, gt_boxes, overlaps, sample_idx) # ==================================================================================== # regular decoder # ==================================================================================== lang_outputs = [] decoder_roi_attn, decoder_masked_attn, roi_labels = [], [], [] frm_mask_output = [] # unfold the GT caption and localize the ROIs word by word for word_idx in range(self.seq_length): word = gt_caption[:, word_idx].clone() embedded_word = self.embed(word) roi_label = utils.bbox_target( mask_boxes[:, :, :, word_idx + 1], overlaps, input_seq[:, word_idx + 1], input_seq_update[:, word_idx + 1], self.vocab_size) # roi_label for the target seq roi_labels.append(roi_label.view(gt_caption_batch_size, -1)) # use frame mask during training box_mask = mask_boxes[:, 0, :, word_idx + 1].contiguous().unsqueeze(1).expand( (batch_size, num_rois, mask_boxes.size(2))) # frm_mask_on_prop = ( # torch.sum((1 - (box_mask | frm_mask)), dim=2) <= 0) frm_mask_on_prop = (torch.sum( (~(box_mask | frm_mask)), dim=2) <= 0) frm_mask_on_prop = torch.cat((frm_mask_on_prop.new( batch_size, 1).fill_(0.), frm_mask_on_prop), dim=1) | pnt_mask.bool() frm_mask_output.append(frm_mask_on_prop) output, decoder_state, roi_attn, frame_masked_attn, weighted_pool_feat = self.decoder_core( embedded_word, fc_feats, conv_feats, p_conv_feats, pool_feats, p_pool_feats, pnt_mask[:, 1:], decoder_state, proposal_frame_mask=frm_mask_on_prop[:, 1:], with_sentinel=False) output = F.log_softmax(self.logit(output), dim=1) decoder_roi_attn.append(roi_attn) decoder_masked_attn.append(frame_masked_attn) lang_outputs.append(output) # (batch_size, seq_length - 1, max # of ROIs) att2_weights = torch.stack(decoder_masked_attn, dim=1) # decoder output for captioning task lang_outputs = torch.cat([_.unsqueeze(1) for _ in lang_outputs], dim=1) roi_labels = torch.cat([_.unsqueeze(1) for _ in roi_labels], dim=1) frm_mask_output = torch.cat([_.unsqueeze(1) for _ in frm_mask_output], 1) # object grounding xt_clamp = torch.clamp(input_seq[:, 1:self.seq_length + 1, 0].clone() - self.vocab_size, min=0) xt_all = self.roi_feat_extractor.vis_embed(xt_clamp) if hasattr(self.roi_feat_extractor, 'vis_classifiers_bias'): bias = self.roi_feat_extractor.vis_classifiers_bias[xt_clamp].type(xt_all.type()) \ .unsqueeze(2).expand(gt_caption_batch_size, self.seq_length, num_rois) else: bias = 0 # att2_weights/ground_weights with both proposal mask and frame mask ground_weights = self._grounder(xt_all, g_pool_feats, frm_mask_output[:, :, 1:], bias + att2_weights) # if we are only training for the decoder, directly return language output if self.opts.train_decoder_only: lang_outputs = lang_outputs.view(-1, lang_outputs.size(2)) lm_loss, att2_loss, ground_loss = self.critLM( lang_outputs, att2_weights, ground_weights, gt_caption[:, 1:self.seq_length + 1].clone(), roi_labels[:, :self.seq_length, :].clone(), input_seq[:, 1:self.seq_length + 1, 0].clone()) return lm_loss.unsqueeze(0), att2_loss.unsqueeze( 0), ground_loss.unsqueeze(0), cls_loss.unsqueeze(0) # ==================================================================================== # greedy select words with highest probability and use them for localization # ==================================================================================== # this will turn the output_seq to be the leaf on the computational graph _, output_seq = lang_outputs.max(2) localizer_state = self.init_hidden(gt_caption_batch_size, self.localizer_num_layers) localized_weighted_feats = [] localized_conv_feats = [] localized_roi_probs = [] for word_idx in range(self.seq_length): word = output_seq[:, word_idx].clone() embedded_word = self.embed(word) localized_weighted_feat, localized_conv_feat, localized_roi_prob, localizer_state = self.localizer_core( embedded_word, fc_feats, conv_feats, p_conv_feats, pool_feats, p_pool_feats, pnt_mask[:, 1:], localizer_state, None, proposal_frame_mask=frm_mask_output[:, word_idx, 1:], with_sentinel=False) # if self.opts.localizer_only_groundable: # visually_groundable_mask = np.in1d(np.array(word.cpu()), self.i_to_visually_groundable).astype('float') # visually_groundable_mask = torch.from_numpy(visually_groundable_mask).float().to(localized_weighted_feat.device) # localized_weighted_feat = visually_groundable_mask.unsqueeze(1) * localized_weighted_feat # localized_conv_feat = visually_groundable_mask.unsqueeze(1) * localized_conv_feat # localized_roi_prob = visually_groundable_mask.unsqueeze(1) * localized_roi_prob localized_weighted_feats.append(localized_weighted_feat) localized_conv_feats.append(localized_conv_feat) localized_roi_probs.append(localized_roi_prob) # localized_roi_probs = torch.stack(localized_roi_probs, dim=1) # (batch_size, seq_length - 1, max # of ROIs) # ==================================================================================== # decoder use the localized ROIs to generate caption again # ==================================================================================== consistent_outputs = [] consistent_decoder_state = self.init_hidden(gt_caption_batch_size, self.decoder_num_layers) for word_idx in range(self.seq_length): word = gt_caption[:, word_idx].clone() embedded_word = self.embed(word) localized_weighted_feat = localized_weighted_feats[word_idx] localized_conv_feat = localized_conv_feats[word_idx] output, consistent_decoder_state = self.attended_roi_decoder_core( embedded_word, fc_feats, localized_weighted_feat, localized_conv_feat, consistent_decoder_state, with_sentinel=False) output = F.log_softmax(self.logit(output), dim=1) consistent_outputs.append(output) consistent_outputs = torch.cat( [_.unsqueeze(1) for _ in consistent_outputs], dim=1) lang_outputs = lang_outputs.view(-1, lang_outputs.size(2)) lm_loss, att2_loss, ground_loss = self.critLM( lang_outputs, att2_weights, ground_weights, gt_caption[:, 1:self.seq_length + 1].clone(), roi_labels[:, :self.seq_length, :].clone(), input_seq[:, 1:self.seq_length + 1, 0].clone()) # compute another loss for the reconstructor consistent_outputs = consistent_outputs.view( -1, consistent_outputs.size(2)) lm_recon_loss = self.xe_criterion( consistent_outputs, gt_caption[:, 1:self.seq_length + 1].clone()) return lm_loss.unsqueeze(0), att2_loss.unsqueeze(0), ground_loss.unsqueeze(0), \ cls_loss.unsqueeze(0), lm_recon_loss.unsqueeze(0)