def _forward(self, sg_data, fc_feats, att_feats, seq, weak_rela, att_masks=None): core_args = self.prepare_core_args(sg_data, fc_feats, att_feats, att_masks) # make seq_per_img copies of the encoded inputs: shape: (B, ...) => (B*seq_per_image, ...) core_args = expand_feats(core_args, self.seq_per_img) weak_rela = expand_feats([weak_rela], self.seq_per_img)[0] batch_size = fc_feats.size(0) * self.seq_per_img state = self.init_hidden(batch_size, weak_rela) outputs = fc_feats.new_zeros(batch_size, seq.size(1) - 1, self.vocab_size + 1) outputs_tag = fc_feats.new_zeros(batch_size, seq.size(1) - 1, 3) # teacher forcing for i in range(seq.size(1) - 1): # scheduled sampling if self.training and i >= 1 and self.ss_prob > 0.0: sample_prob = fc_feats.new(batch_size).uniform_(0, 1) sample_mask = sample_prob < self.ss_prob if sample_mask.sum() == 0: it = seq[:, i].clone() else: sample_ind = sample_mask.nonzero().view(-1) it = seq[:, i].data.clone() prob_prev = torch.exp(outputs[:, i - 1].detach( )) # fetch prev distribution: shape Nx(M+1) it.index_copy_( 0, sample_ind, torch.multinomial(prob_prev, 1).view(-1).index_select( 0, sample_ind)) else: it = seq[:, i].clone() # break if all the sequences end if i >= 1 and seq[:, i].sum() == 0: break # output, state = self.get_logprobs_state(it, state, core_args) output, output_tag, state = self.get_logprobs_state( it, state, core_args) outputs[:, i] = output outputs_tag[:, i] = output_tag return outputs, outputs_tag
def _sample_beam(self, sg_data, fc_feats, att_feats, att_masks=None, opt={}): beam_size = opt.get('beam_size', 10) batch_size = fc_feats.size(0) core_args = self.prepare_core_args(sg_data, fc_feats, att_feats, att_masks) assert beam_size <= self.vocab_size + 1, 'lets assume this for now, otherwise this corner case causes a few headaches down the road. can be dealt with in future if needed' seq = torch.LongTensor(self.seq_length, batch_size).zero_() seqLogprobs = torch.FloatTensor(self.seq_length, batch_size) # lets process every image independently for now, for simplicity self.done_beams = [[] for _ in range(batch_size)] for k in range(batch_size): state = self.init_hidden(beam_size) sample_core_args = [] for item in core_args: if type(item) is list or item is None: sample_core_args.append(item) continue else: sample_core_args.append(item[k:k+1]) sample_core_args = expand_feats(sample_core_args, beam_size) for t in range(1): if t == 0: # input <bos> it = fc_feats.new_zeros([beam_size], dtype=torch.long) logprobs, state = self.get_logprobs_state(it, state, sample_core_args) self.done_beams[k] = self.beam_search(state, logprobs, sample_core_args, opt=opt) seq[:, k] = self.done_beams[k][0]['seq'] # the first beam has highest cumulative score seqLogprobs[:, k] = self.done_beams[k][0]['logps'] # return the samples and their log likelihoods return seq.transpose(0, 1), seqLogprobs.transpose(0, 1)
def get_self_critical_reward(model, core_args, sg_data, fc_feats, att_feats, weak_relas, att_masks, data, gen_result, opt): batch_size = gen_result.size(0) seq_per_img = batch_size // len(data['gts']) # get greedy decoding baseline model.eval() with torch.no_grad(): greedy_res, _ = model(sg_data, fc_feats, att_feats, weak_relas, att_masks=att_masks, _core_args=core_args, opt={'expand_features': False}, mode='sample') model.train() greedy_res = expand_feats([greedy_res], seq_per_img)[0] res = OrderedDict() gen_result = gen_result.data.cpu().numpy() greedy_res = greedy_res.data.cpu().numpy() for i in range(batch_size): res[i] = [array_to_str(gen_result[i])] for i in range(batch_size): res[batch_size + i] = [array_to_str(greedy_res[i])] gts = OrderedDict() for i in range(len(data['gts'])): gts[i] = [ array_to_str(data['gts'][i][j]) for j in range(len(data['gts'][i])) ] res_ = [{'image_id': i, 'caption': res[i]} for i in range(2 * batch_size)] res__ = {i: res[i] for i in range(2 * batch_size)} gts = { i: gts[i % batch_size // seq_per_img] for i in range(2 * batch_size) } if opt.cider_reward_weight > 0: _, cider_scores = CiderD_scorer.compute_score(gts, res_) print('Cider scores:', _) else: cider_scores = 0 if opt.bleu_reward_weight > 0: _, bleu_scores = Bleu_scorer.compute_score(gts, res__) bleu_scores = np.array(bleu_scores[3]) print('Bleu scores:', _[3]) else: bleu_scores = 0 scores = opt.cider_reward_weight * cider_scores + opt.bleu_reward_weight * bleu_scores scores = scores[:batch_size] - scores[batch_size:] # batch_size * seq_length rewards = np.repeat(scores[:, np.newaxis], gen_result.shape[1], 1) return rewards
def _sample(self, sg_data, fc_feats, att_feats, att_masks=None, opt={}, _core_args=None): sample_max = opt.get('sample_max', 1) beam_size = opt.get('beam_size', 1) temperature = opt.get('temperature', 1.0) decoding_constraint = opt.get('decoding_constraint', 0) return_core_args = opt.get('return_core_args', False) expand_features = opt.get('expand_features', True) if beam_size > 1: return self._sample_beam(sg_data, fc_feats, att_feats, att_masks, opt) if _core_args is not None: # reuse the core_args calculated during generating sampled captions # when generating greedy captions for SCST, core_args = _core_args else: core_args = self.prepare_core_args(sg_data, fc_feats, att_feats, att_masks) # make seq_per_img copies of the encoded inputs: shape: (B, ...) => (B*seq_per_image, ...) # should be True when training (xe or scst), False when evaluation if expand_features: if return_core_args: _core_args = core_args core_args = expand_feats(core_args, self.seq_per_img) batch_size = fc_feats.size(0)*self.opt.seq_per_img else: batch_size = fc_feats.size(0) state = self.init_hidden(batch_size) seq = fc_feats.new_zeros((batch_size, self.seq_length), dtype=torch.long) seqLogprobs = fc_feats.new_zeros(batch_size, self.seq_length) for t in range(self.seq_length + 1): if t == 0: # input <bos> it = fc_feats.new_zeros(batch_size, dtype=torch.long) logprobs, state = self.get_logprobs_state(it, state, core_args) if decoding_constraint and t > 0: tmp = logprobs.new_zeros(logprobs.size()) tmp.scatter_(1, seq[:,t-1].data.unsqueeze(1), float('-inf')) logprobs = logprobs + tmp # sample the next word if t == self.seq_length: # skip if we achieve maximum length break if sample_max: sampleLogprobs, it = torch.max(logprobs.data, 1) it = it.view(-1).long() else: if temperature == 1.0: prob_prev = torch.exp(logprobs.data) # fetch prev distribution: shape Nx(M+1) else: # scale logprobs by temperature prob_prev = torch.exp(torch.div(logprobs.data, temperature)) it = torch.multinomial(prob_prev, 1) sampleLogprobs = logprobs.gather(1, it) # gather the logprobs at sampled positions it = it.view(-1).long() # and flatten indices for downstream processing # stop when all finished if t == 0: unfinished = it > 0 else: unfinished = unfinished * (it > 0) it = it * unfinished.type_as(it) seq[:,t] = it seqLogprobs[:,t] = sampleLogprobs.view(-1) # quit loop if all sequences have finished if unfinished.sum() == 0: break returns = [seq, seqLogprobs] if return_core_args: returns.append(_core_args) return returns