def _forward(self, ofc_feats, oatt_feats,densecap, seq, att_masks=None,personality=None): batch_size = self.batch_size seq_per_img = seq.shape[0] // batch_size outputs = torch.zeros(batch_size*seq_per_img, seq.size(1) - 1, self.vocab_size+1,dtype=torch.float).cuda() # Prepare the features rp_fc_feats, rp_att_feats, rpp_att_feats, rp_att_masks = self._prepare_feature(ofc_feats, oatt_feats,att_masks) # pp_att_feats is used for attention, we cache it in advance to reduce computation cost encodestate = self.enc_init_hidden(batch_size*5) encoder_cells =[] for k in range(densecap.size(-1)): w = densecap[:,:,k].clone() embedw = self.embed(w) embedw = embedw.contiguous().view(-1,embedw.size(-1)).contiguous() encodestate= self.encoder(embedw, (encodestate[0],encodestate[1])) encoder_cells.append(encodestate[1].contiguous().view(batch_size,5,encodestate[1].size(-1))) hstate, cstate = encodestate att_feats = torch.stack(encoder_cells).cuda() att_feats = att_feats.contiguous().permute(1,2,0,3) fc_feats = hstate.contiguous().view(batch_size,5,encodestate[0].size(-1)) fc_feats = fc_feats.contiguous().view(batch_size,-1) p_att_feats = self.ctx2att_t(att_feats) decodestate = self.init_hidden(batch_size*seq_per_img) if seq_per_img > 1: fc_feats, att_feats, p_att_feats, att_masks = utils.repeat_tensors(seq_per_img, [fc_feats, att_feats, p_att_feats, att_masks]) rp_fc_feats, rp_att_feats, rpp_att_feats, rp_att_masks = utils.repeat_tensors(seq_per_img,[rp_fc_feats, rp_att_feats, rpp_att_feats, rp_att_masks]) for i in range(seq.size(1) - 1): if self.training and i >= 1 and self.ss_prob > 0.0: # otherwiste no need to sample sample_prob = fc_feats.new(batch_size*seq_per_img).uniform_(0, 1) sample_mask = sample_prob < self.ss_prob if sample_mask.sum() == 0: it = seq[:, i].clone() else: sample_ind = sample_mask.nonzero().view(-1) it = seq[:, i].data.clone() #prob_prev = torch.exp(outputs[-1].data.index_select(0, sample_ind)) # fetch prev distribution: shape Nx(M+1) #it.index_copy_(0, sample_ind, torch.multinomial(prob_prev, 1).view(-1)) # prob_prev = torch.exp(outputs[-1].data) # fetch prev distribution: shape Nx(M+1) prob_prev = torch.exp(outputs[:, i-1].detach()) # fetch prev distribution: shape Nx(M+1) it.index_copy_(0, sample_ind, torch.multinomial(prob_prev, 1).view(-1).index_select(0, sample_ind)) else: it = seq[:, i].clone() # break if all the sequences end if i >= 1 and seq[:, i].sum() == 0: break output, decodestate = self.get_logprobs_state(it,personality, fc_feats, att_feats, p_att_feats, att_masks,rp_fc_feats, rp_att_feats, rpp_att_feats, rp_att_masks, decodestate) outputs[:, i] = output return outputs
def _prepare_feature_forward(self, att_feats, att_masks=None, seq=None): att_feats, att_masks = self.clip_att(att_feats, att_masks) att_feats = pack_wrapper(self.att_embed, att_feats, att_masks) if att_masks is None: att_masks = att_feats.new_ones(att_feats.shape[:2], dtype=torch.long) att_masks = att_masks.unsqueeze(-2) if seq is not None: # crop the last one seq = seq[:, :-1] seq_mask = (seq.data > 0) seq_mask[:, 0] = 1 # bos seq_mask = seq_mask.unsqueeze(-2) seq_mask = seq_mask & subsequent_mask(seq.size(-1)).to(seq_mask) seq_per_img = seq.shape[0] // att_feats.shape[0] if seq_per_img > 1: att_feats, att_masks = utils.repeat_tensors( seq_per_img, [att_feats, att_masks]) else: seq_mask = None return att_feats, seq, att_masks, seq_mask
def _forward(self, fc_feats, att_feats, seq, att_masks=None): batch_size = fc_feats.size(0) seq_per_img = seq.shape[0] // batch_size state = self.init_hidden(batch_size * seq_per_img) outputs = fc_feats.new_zeros(batch_size * seq_per_img, seq.size(1) - 1, self.vocab_size + 1) # Prepare the features p_fc_feats, p_att_feats, pp_att_feats, p_att_masks = self._prepare_feature( fc_feats, att_feats, att_masks) # pp_att_feats is used for attention, we cache it in advance to reduce computation cost if seq_per_img > 1: p_fc_feats, p_att_feats, pp_att_feats, p_att_masks = utils.repeat_tensors( seq_per_img, [p_fc_feats, p_att_feats, pp_att_feats, p_att_masks]) for i in range(seq.size(1) - 1): # break if all the sequences end if i >= 1 and seq[:, i].sum() == 0: break output, state = self.get_logprobs_state(seq[:, i], p_fc_feats, p_att_feats, pp_att_feats, p_att_masks, state) outputs[:, i] = output return outputs
def _sample_beam(self, semmantic_feat, semantic1_feat, att_feats, att1_feat, box_feat, box1_feat, opt={}): beam_size = opt.get('beam_size', 10) group_size = opt.get('group_size', 1) sample_n = opt.get('sample_n', 10) # when sample_n == beam_size then each beam is a sample. assert sample_n == 1 or sample_n == beam_size // group_size, 'when beam search, sample_n == 1 or beam search' batch_size = att_feats.size(0) new_semantic_feat, new_semantic1_feat = self.att_feat( semmantic_feat, semantic1_feat, att_feats, att1_feat, box_feat, box1_feat) assert beam_size <= self.vocab_size + 1, 'lets assume this for now, otherwise this corner case causes a few headaches down the road. can be dealt with in future if needed' seq = att_feats.new_zeros((batch_size * sample_n, self.seq_length), dtype=torch.long) seqLogprobs = att_feats.new_zeros(batch_size * sample_n, self.seq_length, self.vocab_size + 1) # lets process every image independently for now, for simplicity self.done_beams = [[] for _ in range(batch_size)] state = self.init_hidden(batch_size) # first step, feed bos it = att_feats.new_zeros([batch_size], dtype=torch.long) logprobs, state = self.get_logprobs_state(it, new_semantic_feat, new_semantic1_feat, state) new_semantic_feat, new_semantic1_feat = utils.repeat_tensors( beam_size, [new_semantic_feat, new_semantic1_feat]) self.done_beams = self.beam_search(state, logprobs, new_semantic_feat, new_semantic1_feat, opt=opt) for k in range(batch_size): if sample_n == beam_size: for _n in range(sample_n): seq_len = self.done_beams[k][_n]['seq'].shape[0] seq[k * sample_n + _n, :seq_len] = self.done_beams[k][_n]['seq'] seqLogprobs[k * sample_n + _n, :seq_len] = self.done_beams[k][_n]['logps'] else: seq_len = self.done_beams[k][0]['seq'].shape[0] seq[k, :seq_len] = self.done_beams[k][0][ 'seq'] # the first beam has highest cumulative score seqLogprobs[k, :seq_len] = self.done_beams[k][0]['logps'] # return the samples and their log likelihoods return seq, seqLogprobs
def _old_sample_beam(self, fc_feats, att_feats, att_masks=None, opt={}): beam_size = opt.get('beam_size', 10) group_size = opt.get('group_size', 1) sample_n = opt.get('sample_n', 10) # when sample_n == beam_size then each beam is a sample. assert sample_n == 1 or sample_n == beam_size // group_size, 'when beam search, sample_n == 1 or beam search' batch_size = fc_feats.size(0) p_fc_feats, p_att_feats, pp_att_feats, p_att_masks = self._prepare_feature( fc_feats, att_feats, att_masks) assert beam_size <= self.vocab_size + 1, 'lets assume this for now, otherwise this corner case causes a few headaches down the road. can be dealt with in future if needed' seq = fc_feats.new_zeros((batch_size * sample_n, self.seq_length), dtype=torch.long) seqLogprobs = fc_feats.new_zeros(batch_size * sample_n, self.seq_length, self.vocab_size + 1) # lets process every image independently for now, for simplicity self.done_beams = [[] for _ in range(batch_size)] for k in range(batch_size): state = self.init_hidden(beam_size) tmp_fc_feats, tmp_att_feats, tmp_p_att_feats, tmp_att_masks = utils.repeat_tensors( beam_size, [ p_fc_feats[k:k + 1], p_att_feats[k:k + 1], pp_att_feats[k:k + 1], p_att_masks[k:k + 1] if att_masks is not None else None ]) for t in range(1): if t == 0: # input <bos> it = fc_feats.new_zeros([beam_size], dtype=torch.long) logprobs, state = self.get_logprobs_state( it, tmp_fc_feats, tmp_att_feats, tmp_p_att_feats, tmp_att_masks, state) self.done_beams[k] = self.old_beam_search(state, logprobs, tmp_fc_feats, tmp_att_feats, tmp_p_att_feats, tmp_att_masks, opt=opt) if sample_n == beam_size: for _n in range(sample_n): seq[k * sample_n + _n, :] = self.done_beams[k][_n]['seq'] seqLogprobs[k * sample_n + _n, :] = self.done_beams[k][_n]['logps'] else: seq[k, :] = self.done_beams[k][0][ 'seq'] # the first beam has highest cumulative score seqLogprobs[k, :] = self.done_beams[k][0]['logps'] # return the samples and their log likelihoods return seq, seqLogprobs
def get_logprobs_state(self, it,personality, fc_feats, att_feats, p_att_feats, att_masks,rp_fc_feats,rp_att_feats, rpp_att_feats, rp_att_masks, state): # 'it' contains a word index batch_size = personality.size(0) xt = self.embed(it)# 500*100 seq_per_img = xt.size(0)//batch_size if personality is not None: pers_encoded = self.personality_encoder(personality.nonzero(as_tuple=True)[1]) pers_encoded = utils.repeat_tensors(seq_per_img,pers_encoded) xt=torch.cat((xt,pers_encoded),1) output, state = self.core(xt, fc_feats, att_feats, p_att_feats, state, att_masks,rp_fc_feats,rp_att_feats, rpp_att_feats, rp_att_masks) logitoutput = self.logit(output) finallogprobs = F.log_softmax(logitoutput, dim=1) return finallogprobs, state
def _forward(self, fc_feats, att_feats, seq, att_masks=None): batch_size = fc_feats.size(0) if seq.ndim == 3: # B * seq_per_img * seq_len seq = seq.reshape(-1, seq.shape[2]) seq_per_img = seq.shape[0] // batch_size state = self.init_hidden(batch_size * seq_per_img) outputs = fc_feats.new_zeros(batch_size * seq_per_img, seq.size(1) - 1, self.vocab_size + 1) # Prepare the features p_fc_feats, p_att_feats, pp_att_feats, p_att_masks = self._prepare_feature( fc_feats, att_feats, att_masks) # pp_att_feats is used for attention, we cache it in advance to reduce computation cost if seq_per_img > 1: p_fc_feats, p_att_feats, pp_att_feats, p_att_masks = utils.repeat_tensors( seq_per_img, [p_fc_feats, p_att_feats, pp_att_feats, p_att_masks]) for i in range(seq.size(1) - 1): if self.training and i >= 1 and self.ss_prob > 0.0: # otherwiste no need to sample sample_prob = fc_feats.new(batch_size * seq_per_img).uniform_( 0, 1) sample_mask = sample_prob < self.ss_prob if sample_mask.sum() == 0: it = seq[:, i].clone() else: sample_ind = sample_mask.nonzero().view(-1) it = seq[:, i].data.clone() prob_prev = torch.exp(outputs[:, i - 1].detach( )) # fetch prev distribution: shape Nx(M+1) it.index_copy_( 0, sample_ind, torch.multinomial(prob_prev, 1).view(-1).index_select( 0, sample_ind)) else: it = seq[:, i].clone() # break if all the sequences end if i >= 1 and seq[:, i].sum() == 0: break output, state = self.get_logprobs_state(it, p_fc_feats, p_att_feats, pp_att_feats, p_att_masks, state) outputs[:, i] = output return outputs
def _sample(self, fc_feats, att_feats, att_masks=None, opt={}): sample_method = opt.get('sample_method', 'greedy') beam_size = opt.get('beam_size', 1) sample_n = int(opt.get('sample_n', 1)) if beam_size > 1: return self._sample_beam(fc_feats, att_feats, att_masks, opt) batch_size = fc_feats.size(0) state = self.init_hidden(batch_size * sample_n) p_fc_feats, p_att_feats, pp_att_feats, p_att_masks = self._prepare_feature( fc_feats, att_feats, att_masks) if sample_n > 1: p_fc_feats, p_att_feats, pp_att_feats, p_att_masks = utils.repeat_tensors( sample_n, [p_fc_feats, p_att_feats, pp_att_feats, p_att_masks]) seq = fc_feats.new_zeros((batch_size * sample_n, self.seq_length), dtype=torch.long) seqLogprobs = fc_feats.new_zeros(batch_size * sample_n, self.seq_length, self.vocab_size + 1) for t in range(self.seq_length): if t == 0: # input <bos> it = fc_feats.new_zeros(batch_size * sample_n, dtype=torch.long) logprobs, state = self.get_logprobs_state(it, p_fc_feats, p_att_feats, pp_att_feats, p_att_masks, state) # sample the next word it, _ = self.sample_next_word(logprobs, sample_method) # stop when all finished if t == 0: unfinished = it > 0 else: unfinished = unfinished * (it > 0) it = it * unfinished.type_as(it) seq[:, t] = it seqLogprobs[:, t] = logprobs # quit loop if all sequences have finished if unfinished.sum() == 0: break return seq.detach(), seqLogprobs
def _forward(self, semantic_feat, semantic1_feat, att_feats, att1_feats, box_feat, box1_feat, seq): batch_size = att_feats.size(0) if seq.ndim == 3: seq = seq.reshape(-1, seq.shape[2]) seq_per_img = seq.shape[0] // batch_size state = self.init_hidden(batch_size * seq_per_img) outputs = att_feats.new_zeros(batch_size * seq_per_img, seq.size(1) - 1, self.vocab_size + 1) # att_feat new_semantic_feat, new_semantic1_feat = self.att_feat( semantic_feat, semantic1_feat, att_feats, att1_feats, box_feat, box1_feat) if seq_per_img > 1: new_semantic_feat, new_semantic1_feat = utils.repeat_tensors( seq_per_img, [new_semantic_feat, new_semantic1_feat]) for i in range(seq.size(1) - 1): if self.training and i >= 1 and self.ss_prob > 0.0: # otherwiste no need to sample sample_prob = att_feats.new(batch_size * seq_per_img).uniform_( 0, 1) sample_mask = sample_prob < self.ss_prob if sample_mask.sum() == 0: it = seq[:, i].clone() else: sample_ind = sample_mask.nonzero().view(-1) it = seq[:, i].data.clone() prob_prev = torch.exp(outputs[:, i - 1].detach( )) # fetch prev distribution: shape Nx(M+1) it.index_copy_( 0, sample_ind, torch.multinomial(prob_prev, 1).view(-1).index_select( 0, sample_ind)) else: it = seq[:, i].clone() # break if all the sequences end if i >= 1 and seq[:, i].sum() == 0: break output, state = self.get_logprobs_state(it, new_semantic_feat, new_semantic1_feat, state) outputs[:, i] = output return outputs
def _forward(self, fc_feats, att_feats, seq, att_masks=None): batch_size = fc_feats.size(0) seq_per_img = seq.shape[0] // batch_size state = self.init_hidden(batch_size * seq_per_img) outputs = [] if seq_per_img > 1: fc_feats = utils.repeat_tensors(seq_per_img, fc_feats) for i in range(seq.size(1)): if i == 0: xt = self.img_embed(fc_feats) else: if self.training and i >= 2 and self.ss_prob > 0.0: # otherwiste no need to sample sample_prob = fc_feats.data.new( batch_size * seq_per_img).uniform_(0, 1) sample_mask = sample_prob < self.ss_prob if sample_mask.sum() == 0: it = seq[:, i - 1].clone() else: sample_ind = sample_mask.nonzero().view(-1) it = seq[:, i - 1].data.clone() #prob_prev = torch.exp(outputs[-1].data.index_select(0, sample_ind)) # fetch prev distribution: shape Nx(M+1) #it.index_copy_(0, sample_ind, torch.multinomial(prob_prev, 1).view(-1)) prob_prev = torch.exp( outputs[-1].data ) # fetch prev distribution: shape Nx(M+1) it.index_copy_( 0, sample_ind, torch.multinomial(prob_prev, 1).view(-1).index_select( 0, sample_ind)) else: it = seq[:, i - 1].clone() # break if all the sequences end if i >= 2 and seq[:, i - 1].data.sum() == 0: break xt = self.embed(it) output, state = self.core(xt.unsqueeze(0), state) output = F.log_softmax(self.logit(self.dropout(output.squeeze(0))), dim=1) outputs.append(output) return torch.cat([_.unsqueeze(1) for _ in outputs[1:]], 1).contiguous()
def add_diversity(beam_seq_table, logprobs, t, divm, diversity_lambda, bdash): local_time = t - divm unaug_logprobs = logprobs.clone() batch_size = beam_seq_table[0].shape[0] if divm > 0: change = logprobs.new_zeros(batch_size, logprobs.shape[-1]) for prev_choice in range(divm): prev_decisions = beam_seq_table[prev_choice][:, :, local_time] # Nxb for prev_labels in range(bdash): change.scatter_add_(1, prev_decisions[:, prev_labels].unsqueeze(-1), change.new_ones(batch_size, 1)) if local_time == 0: logprobs = logprobs - change * diversity_lambda else: logprobs = logprobs - utils.repeat_tensors(bdash, change) * diversity_lambda return logprobs, unaug_logprobs
def _sample_beam(self, fc_feats, att_feats, att_masks=None, opt={}): beam_size = opt.get('beam_size', 10) batch_size = fc_feats.size(0) p_fc_feats, p_att_feats, pp_att_feats, p_att_masks = self._prepare_feature( fc_feats, att_feats, att_masks) # let's assume this for now assert beam_size <= self.vocab_size + 1 seq = fc_feats.new_zeros((batch_size, self.seq_length), dtype=torch.long) seqLogprobs = fc_feats.new_zeros(batch_size, self.seq_length, self.vocab_size + 1) # let's process every image independently for now, for simplicity state = self.init_hidden(batch_size) # first step, feed bos it = fc_feats.new_zeros([batch_size], dtype=torch.long) # logprobs shape is batch_size x (vocab_size + 1) logprobs, state = self.get_logprobs_state(it, p_fc_feats, p_att_feats, pp_att_feats, p_att_masks, state) p_fc_feats, p_att_feats, pp_att_feats, p_att_masks = utils.repeat_tensors( beam_size, [p_fc_feats, p_att_feats, pp_att_feats, p_att_masks]) done_beams = self.beam_search(state, logprobs, p_fc_feats, p_att_feats, pp_att_feats, p_att_masks, opt=opt) for k in range(batch_size): seq_len = done_beams[k][0]['seq'].shape[0] seq[k, :seq_len] = done_beams[k][0][ 'seq'] # the first beam has the highest cumulative score seqLogprobs[k, :seq_len] = done_beams[k][0]['logps'] return seq, seqLogprobs
def _sample(self, fc_feats, att_feats, topic_vec, att_masks=None, opt={}): sample_method = opt.get('sample_method', 'greedy') beam_size = opt.get('beam_size', 1) temperature = opt.get('temperature', 1.0) sample_n = int(opt.get('sample_n', 1)) group_size = opt.get('group_size', 1) output_logsoftmax = opt.get('output_logsoftmax', 1) decoding_constraint = opt.get('decoding_constraint', 0) block_trigrams = opt.get('block_trigrams', 0) remove_bad_endings = opt.get('remove_bad_endings', 0) if beam_size > 1: return self._sample_beam(fc_feats, att_feats, att_masks, opt) if group_size > 1: return self._diverse_sample(fc_feats, att_feats, att_masks, opt) batch_size = fc_feats.size(0) state = self.init_hidden(batch_size*sample_n) p_fc_feats, p_att_feats, pp_att_feats, p_att_masks = self._prepare_feature(fc_feats, att_feats, att_masks) if sample_n > 1: p_fc_feats, p_att_feats, pp_att_feats, p_att_masks = utils.repeat_tensors(sample_n, [p_fc_feats, p_att_feats, pp_att_feats, p_att_masks] ) trigrams = [] # will be a list of batch_size dictionaries seq = fc_feats.new_zeros((batch_size*sample_n, self.seq_length), dtype=torch.long) seqLogprobs = fc_feats.new_zeros(batch_size*sample_n, self.seq_length, self.vocab_size + 1) decoder_states = [] for t in range(self.seq_length + 1): if t == 0: # input <bos> it = fc_feats.new_zeros(batch_size*sample_n, dtype=torch.long) logprobs, state, output = self.get_logprobs_state(it, p_fc_feats, p_att_feats, pp_att_feats, p_att_masks, state, topic_vec, output_logsoftmax=output_logsoftmax) if decoding_constraint and t > 0: tmp = logprobs.new_zeros(logprobs.size()) tmp.scatter_(1, seq[:,t-1].data.unsqueeze(1), float('-inf')) logprobs = logprobs + tmp if remove_bad_endings and t > 0: tmp = logprobs.new_zeros(logprobs.size()) prev_bad = np.isin(seq[:,t-1].data.cpu().numpy(), self.bad_endings_ix) # Make it impossible to generate bad_endings tmp[torch.from_numpy(prev_bad.astype('uint8')), 0] = float('-inf') logprobs = logprobs + tmp # Mess with trigrams # Copy from https://github.com/lukemelas/image-paragraph-captioning if block_trigrams and t >= 3: # Store trigram generated at last step prev_two_batch = seq[:,t-3:t-1] for i in range(batch_size): # = seq.size(0) prev_two = (prev_two_batch[i][0].item(), prev_two_batch[i][1].item()) current = seq[i][t-1] if t == 3: # initialize trigrams.append({prev_two: [current]}) # {LongTensor: list containing 1 int} elif t > 3: if prev_two in trigrams[i]: # add to list trigrams[i][prev_two].append(current) else: # create list trigrams[i][prev_two] = [current] # Block used trigrams at next step prev_two_batch = seq[:,t-2:t] mask = torch.zeros(logprobs.size(), requires_grad=False).cuda() # batch_size x vocab_size for i in range(batch_size): prev_two = (prev_two_batch[i][0].item(), prev_two_batch[i][1].item()) if prev_two in trigrams[i]: for j in trigrams[i][prev_two]: mask[i,j] += 1 # Apply mask to log probs #logprobs = logprobs - (mask * 1e9) alpha = 2.0 # = 4 logprobs = logprobs + (mask * -0.693 * alpha) # ln(1/2) * alpha (alpha -> infty works best) # sample the next word if t == self.seq_length: # skip if we achieve maximum length break it, sampleLogprobs = self.sample_next_word(logprobs, sample_method, temperature) # stop when all finished if t == 0: unfinished = it > 0 else: unfinished = unfinished * (it > 0) it = it * unfinished.type_as(it) seq[:,t] = it seqLogprobs[:,t] = logprobs # quit loop if all sequences have finished if unfinished.sum() == 0: break decoder_states.append(output) return seq, seqLogprobs, decoder_states
def _samplen(self, ofc_feats, oatt_feats,densecap, att_masks=None,personality=None, opt={}): sample_method = opt.get('sample_method', 'greedy') beam_size = opt.get('beam_size', 1) temperature = opt.get('temperature', 1.0) opt['block_trigrams'] =0 opt['remove_bad_endings'] =1 opt['decoding_constraint'] =1 decoding_constraint = opt.get('decoding_constraint', 1) block_trigrams = opt.get('block_trigrams', 1) remove_bad_endings = opt.get('remove_bad_endings', 1) sample_n = int(opt.get('sample_n', 3)) no_unk=1 if beam_size > 1: return self._sample_beam(ofc_feats, oatt_feats,densecap, att_masks,personality, opt) batch_size = densecap.size(0) # Prepare the features rp_fc_feats, rp_att_feats, rpp_att_feats, rp_att_masks = self._prepare_feature(ofc_feats, oatt_feats,att_masks) # pp_att_feats is used for attention, we cache it in advance to reduce computation cost if sample_n > 1: personality, densecap, rp_fc_feats, rp_att_feats, rpp_att_feats, rp_att_masks = utils.repeat_tensors(sample_n, [personality,densecap,rp_fc_feats, rp_att_feats, rpp_att_feats, rp_att_masks] ) encodestate = self.enc_init_hidden(batch_size*5*sample_n) encoder_cells =[] for k in range(densecap.size(-1)): w = densecap[:,:,k].clone() embedw = self.embed(w) embedw = embedw.contiguous().view(-1,embedw.size(-1)).contiguous() encodestate= self.encoder(embedw, (encodestate[0],encodestate[1])) encoder_cells.append(encodestate[1].contiguous().view(batch_size*sample_n,5,encodestate[1].size(-1))) hstate, cstate = encodestate att_feats = torch.stack(encoder_cells).cuda() p_att_feats = att_feats.contiguous().permute(1,2,0,3) fc_feats = hstate.contiguous().view(batch_size*sample_n,5,encodestate[0].size(-1)) p_fc_feats = fc_feats.contiguous().view(batch_size*sample_n,-1) pp_att_feats = self.ctx2att_t(p_att_feats) p_att_masks = att_masks decodestate = self.init_hidden(batch_size*sample_n) #p_fc_feats, p_att_feats, pp_att_feats, p_att_masks = self._prepare_feature(fc_feats, att_feats, att_masks) trigrams = [] # will be a list of batch_size dictionaries alogprobs1 = torch.zeros(batch_size*sample_n,self.seq_length+1, self.vocab_size+1).cuda() alogprobs= torch.zeros(batch_size*sample_n,self.seq_length+1, self.vocab_size+1).cuda() for bk in range(alogprobs.size(0)): alogprobs[bk]=nn.LogSoftmax(dim=1)(alogprobs1[bk]) seq = fc_feats.new_full((batch_size*sample_n, self.seq_length), self.pad_idx, dtype=torch.long) seqLogprobs = torch.zeros(batch_size*sample_n, self.seq_length,dtype=torch.float).cuda() for t in range(self.seq_length + 1): if t == 0: # input <bos> it = fc_feats.new_zeros(batch_size*sample_n, dtype=torch.long) logprobs, decodestate = self.get_logprobs_state(it,personality, p_fc_feats, p_att_feats, pp_att_feats, p_att_masks,rp_fc_feats,rp_att_feats, rpp_att_feats, rp_att_masks, decodestate) if decoding_constraint and t > 0: tmp = logprobs.new_zeros(logprobs.size()) tmp.scatter_(1, seq[:,t-1].data.unsqueeze(1), float('-10e20')) logprobs = logprobs + tmp if remove_bad_endings and t > 0: tmp = logprobs.new_zeros(logprobs.size()) prev_bad = np.isin(seq[:,t-1].data.cpu().numpy(), self.bad_endings_ix) # Impossible to generate remove_bad_endings tmp[torch.from_numpy(prev_bad.astype(np.bool_)), 0] = float('-10e20') logprobs = logprobs + tmp # Mess with trigrams if block_trigrams and t >= 3: # Store trigram generated at last step prev_two_batch = seq[:,t-3:t-1] for i in range(batch_size): # = seq.size(0) prev_two = (prev_two_batch[i][0].item(), prev_two_batch[i][1].item()) current = seq[i][t-1] if t == 3: # initialize trigrams.append({prev_two: [current]}) # {LongTensor: list containing 1 int} elif t > 3: if prev_two in trigrams[i]: # add to list trigrams[i][prev_two].append(current) else: # create list trigrams[i][prev_two] = [current] # Block used trigrams at next step prev_two_batch = seq[:,t-2:t] mask = torch.zeros(logprobs.size(), requires_grad=False).cuda() # batch_size x vocab_size for i in range(batch_size): prev_two = (prev_two_batch[i][0].item(), prev_two_batch[i][1].item()) if prev_two in trigrams[i]: for j in trigrams[i][prev_two]: mask[i,j] += 1 # Apply mask to log probs #logprobs = logprobs - (mask * 1e9) alpha = 10e20 # = 4 logprobs = logprobs + (mask * -0.693 * alpha) # ln(1/2) * alpha (alpha -> infty works best) if no_unk==1: mask2 = torch.zeros(logprobs.size(), requires_grad=False).cuda() mask2[:,mask2.size(1)-1] =-10e20 logprobs= logprobs+ mask2 logprobs = F.log_softmax(logprobs,dim=-1) # sample the next word if t == self.seq_length: # skip if we achieve maximum length break it, sampleLogprobs = self.sample_next_word(logprobs, sample_method, temperature) # stop when all finished if t == 0: unfinished = it > 0 else: unfinished = unfinished * (it > 0) it = it * unfinished.type_as(it) seq[:,t] = it seqLogprobs[:,t] = sampleLogprobs.view(-1) # quit loop if all sequences have finished alogprobs[:, t] = logprobs if unfinished.sum() == 0: break return seq, seqLogprobs, alogprobs