def sample_G(self, mbsize, z, c, sample_mode='categorical', temp=1.0, gumbel_temp=1.0, prepend_start_idx=True, prevent_empty=False, min_length=1, beam_size=5, n_best=3): """ This function samples a minibatch of mbsize from the decoder, given a (z,c) input. sample_mode determines hard sampling (categorical / greedy / gumbel_max) vs soft (gumbel_soft, gumbel_ST, XX_softmax) prepend_start_idx will prepend dummy <start> token, matches dataloader format. prevent_empty will modify the probabilities before hard sampling from them. min_length will not modify sampling, but just have at least this length output even if it's just all padding. """ sample_mode_soft = sample_mode in [ 'gumbel_soft', 'gumbel_ST', 'greedy_softmax', 'categorical_softmax', 'none_softmax' ] assert not ( sample_mode_soft and prevent_empty ), 'cant prevent_empty when soft sampling, we dont wanna modify softmax in place before feeding back into next timestep' assert beam_size >= n_best, 'Can\'t return more than max hypothesis' assert mbsize == z.size(0) == c.size( 0), 'oops sizes dont match {} {} {}'.format( mbsize, z.size(0), c.size(0)) assert ( not self.use_flow) or z.flowed, 'BUG: flow>0 but z.flowed=False' # Collecting sampled sequences - Note: does not work for beam search seqIx = [] seqSoftIx = [] # to mask out after EOS finished = torch.zeros(mbsize, dtype=torch.bool).to(self.device) if sample_mode == 'beam': def unbottle(m): return m.view(beam_size, mbsize, -1) # Repeat inputs beam_size times z = z.repeat(beam_size, 1) c = c.repeat(beam_size, 1) # Initialize Beams beam = [ Beam(beam_size, n_best=n_best, device=self.device, pad=PAD_IDX, bos=START_IDX, eos=EOS_IDX, min_length=min_length) for ___ in range(mbsize) ] # Start: first beam BOS, rest PAD. sampleIx = torch.stack([b.get_current_state() for b in beam]) \ .t().contiguous().view(-1) else: # Start: all BOS. sampleIx = torch.LongTensor(mbsize).to( self.device).fill_(START_IDX) sampleSoftIx = None # RNN state h = self.decoder.init_hidden(z, c) # [mbsize x z,c] h = h.unsqueeze(0) # prepend 1 = num_layers * num_directions # seqLogProbs = [] # unused for now # collecting sampled logprobs would be basis for all policy gradient algos (seqGAN etc) # include start_idx in the output if prepend_start_idx: seqIx.append(sampleIx) if sample_mode_soft: seqSoftIx.append(onehot_embed(sampleIx, self.n_vocab).detach()) for i in range(self.MAX_SEQ_LEN): ### 1) FORWARD PASS THIS TIMESTEP logits, h = self.decoder.forward_sample(sampleSoftIx, sampleIx, z, c, h) # END TODO use forward_decoder() if prevent_empty and i == 0: # kinda hacky: force first char to be real character by masking out the logits corresponding to pad/start/eos. large_neg = -2 * torch.abs( logits.min() ) # dont wanna throw off downstream softmaxes by just putting -inf for maskix in [PAD_IDX, START_IDX, EOS_IDX]: logits[:, maskix] = large_neg ### 2) GIVEN LOGITS, SAMPLE -> sampleIx, sampleLogProbs, sampleSoftIx if sample_mode == 'categorical': sampleIx = torch.distributions.Categorical(logits=logits / temp).sample() elif sample_mode == 'greedy': sampleIx = torch.argmax(logits, 1) elif sample_mode == 'gumbel_max': tmp = """hard decision, same as Categorical sampling.""" elif sample_mode == 'beam': logits = unbottle(logits) # Update the beams for j, b in enumerate(beam): if not b.done(): logprobs = F.log_softmax(logits[:, j], dim=1) b.advance(logprobs) # Update corresponding hidden states # NOTE if not advanced, the hidden will be reset and sampleIx will remain. self._update_hidden(h, j, b.get_current_origin(), beam_size) # Get the current predictions sampleIx = torch.stack([b.get_current_state() for b in beam]) \ .t().contiguous().view(-1) # ABOVE: HARD SAMPLING, BELOW: SOFT SAMPLING elif sample_mode == 'gumbel_soft': tmp = """keep the softmax as seqSoftIx, not straight through.""" elif sample_mode == 'gumbel_ST': tmp = """sampleSoftIx are straight-through onehot(argmax(gumbel_softmax)) which will pass through biased gradients""" # below: sampleIx none/greedy/categorical. softmax for softIx. Return seqIx, seqSoftIx. # The hard sample mode matters for when we'll run into EOS and mask out all subsequent softmaxes. elif sample_mode == 'none_softmax': sampleSoftIx = F.softmax(logits / temp, dim=1) elif sample_mode == 'greedy_softmax': sampleIx = torch.argmax(logits, 1) sampleSoftIx = F.softmax(logits / temp, dim=1) elif sample_mode == 'categorical_softmax': sampleIx = torch.distributions.Categorical(logits=logits / temp).sample() sampleSoftIx = F.softmax(logits / temp, dim=1) else: raise Exception( 'Sample mode {} not implemented.'.format(sample_mode)) ### 3) FINISHED SENTENCES: MASK OUT sampleIx, sampleLogProbs, sampleSoftIx # Not in beam-search: implemented inside of Beam.py if not sample_mode == "beam": sampleIx.masked_fill_(finished, PAD_IDX) #(mask, value) finished[ sampleIx == EOS_IDX] = True # new EOS reached, mask out in the future. seqIx.append(sampleIx) if sample_mode_soft: sampleSoftIx = sampleSoftIx.masked_fill( finished.unsqueeze(1).clone(), 0) # set "one-hots" to 0, will embed to 0 vector. Note not exactly the same as sampleIx=0 which will map to embedweight[0,:] seqSoftIx.append(sampleSoftIx) ### 4) UPDATE MASK FOR NEXT ITERATION; BREAK (if all done) if finished.sum() == mbsize and len(seqIx) >= min_length: break # everyone is done if sample_mode == "beam": if all((b.done() for b in beam)): break if sample_mode == "beam": seqIx = [] for b in beam: scores, ks = b.sort_finished(minimum=n_best) hyps = [] for i, (times, k) in enumerate(ks[:n_best]): hyp = b.get_hyp(times, k) hyps.append(hyp) seqIx.append(hyps) return seqIx # End of loop. Assemble seqIx, seqSoftIx into tensor. seqIx = torch.stack(seqIx, dim=1) # bs x seqlen if sample_mode_soft: seqSoftIx = torch.stack( seqSoftIx, dim=1 ) # bs x seqlen x vocab. Note seqlen dim is inserted in the middle. assert seqIx.size(1) == seqSoftIx.size( 1), 'messup with prepending startIx?' return seqIx, seqSoftIx else: return seqIx # only hard sampling.
def translate(model, opt, src_batch, adj): ''' Translation work in one batch ''' tt = torch.cuda if opt.cuda else torch # Batch size is in different location depending on data. src_seq, src_pos = src_batch batch_size = src_seq.size(0) beam_size = opt.beam_size #- Enocde enc_output, *_ = model.encoder(src_seq, adj, src_pos) #--- Repeat data for beam src_seq = src_seq.data.repeat(1, beam_size).view( src_seq.size(0) * beam_size, src_seq.size(1)) enc_output = enc_output.data.repeat(1, beam_size, 1).view( enc_output.size(0) * beam_size, enc_output.size(1), enc_output.size(2)) #--- Prepare beams beams = [Beam(beam_size, opt.cuda) for _ in range(batch_size)] beam_inst_idx_map = { beam_idx: inst_idx for inst_idx, beam_idx in enumerate(range(batch_size)) } n_remaining_sents = batch_size if opt.decoder == 'rnn_m': decoder_hidden = enc_output.mean(1) #- Decode for i in range(opt.max_token_seq_len_d): len_dec_seq = i + 1 # -- Preparing decoded data seq -- # # size: batch x beam x seq dec_partial_seq = torch.stack( [b.get_current_state() for b in beams if not b.done]) # size: (batch * beam) x seq dec_partial_seq = dec_partial_seq.view(-1, len_dec_seq) dec_partial_seq = dec_partial_seq # -- Preparing decoded pos seq -- # # size: 1 x seq # dec_partial_pos = torch.arange(1, len_dec_seq + 1).unsqueeze(0) # size: (batch * beam) x seq # dec_partial_pos = dec_partial_pos.repeat(n_remaining_sents * beam_size, 1) # dec_partial_pos = dec_partial_pos.type(torch.LongTensor) if opt.cuda: dec_partial_seq = dec_partial_seq.cuda() # dec_partial_pos = dec_partial_pos.cuda() # -- Decoding -- # # print(dec_partial_seq) if opt.decoder == 'rnn_m': dec_enc_attn_pad_mask = get_attn_padding_mask(dec_partial_seq, src_seq, unsqueeze=False) dec_output, decoder_hidden, _ = model.decoder.forward_step( dec_partial_seq[:, -1].unsqueeze(1), decoder_hidden.squeeze(), enc_output, dec_enc_attn_pad_mask=dec_enc_attn_pad_mask) dec_output = dec_output[-1, :, :] else: # dec_output, *_ = model.decoder(dec_partial_seq, dec_partial_pos, src_seq, enc_output) dec_output, *_ = model.decoder(dec_partial_seq, src_seq, enc_output) dec_output = dec_output[:, -1, :] # (batch * beam) * d_model dec_output = model.tgt_word_proj(dec_output) # dec_output += model.U(enc_output.mean(1))#.unsqueeze(1) # Mask previously predicted labels for J in range(dec_output.size(0)): dec_output.data[J].index_fill_(0, dec_partial_seq.data[J], -float('inf')) out = F.log_softmax(dec_output, dim=1) # batch x beam x n_words word_lk = out.view(n_remaining_sents, beam_size, -1).contiguous() active_beam_idx_list = [] for beam_idx in range(batch_size): if beams[beam_idx].done: continue inst_idx = beam_inst_idx_map[beam_idx] if not beams[beam_idx].advance(word_lk.data[inst_idx]): active_beam_idx_list += [beam_idx] if not active_beam_idx_list: # all instances have finished their path to <EOS> break # in this section, the sentences that are still active are # compacted so that the decoder is not run on completed sentences active_inst_idxs = tt.LongTensor( [beam_inst_idx_map[k] for k in active_beam_idx_list]) # update the idx mapping beam_inst_idx_map = { beam_idx: inst_idx for inst_idx, beam_idx in enumerate(active_beam_idx_list) } def update_active_seq(seq_var, active_inst_idxs): ''' Remove the src sequence of finished instances in one batch. ''' inst_idx_dim_size, *rest_dim_sizes = seq_var.size() inst_idx_dim_size = inst_idx_dim_size * len( active_inst_idxs) // n_remaining_sents new_size = (inst_idx_dim_size, *rest_dim_sizes) # select the active instances in batch original_seq_data = seq_var.data.view(n_remaining_sents, -1) active_seq_data = original_seq_data.index_select( 0, active_inst_idxs) active_seq_data = active_seq_data.view(*new_size) return active_seq_data def update_active_enc_info(enc_info_var, active_inst_idxs): ''' Remove the encoder outputs of finished instances in one batch. ''' inst_idx_dim_size, *rest_dim_sizes = enc_info_var.size() inst_idx_dim_size = inst_idx_dim_size * len( active_inst_idxs) // n_remaining_sents new_size = (inst_idx_dim_size, *rest_dim_sizes) # select the active instances in batch original_enc_info_data = enc_info_var.data.view( n_remaining_sents, -1, opt.d_model) active_enc_info_data = original_enc_info_data.index_select( 0, active_inst_idxs) active_enc_info_data = active_enc_info_data.view(*new_size) return active_enc_info_data src_seq = update_active_seq(src_seq, active_inst_idxs) enc_output = update_active_enc_info(enc_output, active_inst_idxs) if opt.decoder == 'rnn_m': decoder_hidden = update_active_enc_info( decoder_hidden.transpose(0, 1), active_inst_idxs) decoder_hidden = decoder_hidden.transpose(0, 1) #- update the remaining size n_remaining_sents = len(active_inst_idxs) #- Return useful information all_hyp, all_hyp_scores, all_scores = [], [], [] n_best = opt.n_best # for i in range(batch_size): print(len(beams[i].all_scores)) for beam_idx in range(batch_size): scores, tail_idxs = beams[beam_idx].sort_scores() all_scores += [scores[:n_best]] hyps = [beams[beam_idx].get_hypothesis(i) for i in tail_idxs[:n_best]] all_hyp += [hyps] # stop() all_hyp_scores += [[ torch.exp(i)[0] for i in beams[beam_idx].all_scores ]] # if beam_idx == 2: # stop() return all_hyp, all_hyp_scores #,all_scores
def decode_beam_search(self, hidden_states, memory, src_mask, vocab, copy_tokens=None, beam_size=5, n_best=1, alpha=0.6, length_pen='avg'): """ Beam search decoding """ results = {"scores": [], "predictions": []} # Construct beams, we donot use stepwise coverage penalty nor ngrams block remaining_sents = memory.size(0) global_scorer = GNMTGlobalScorer(alpha, length_pen) beam = [ Beam(beam_size, vocab, global_scorer=global_scorer, device=memory.device) for _ in range(remaining_sents) ] # repeat beam_size times memory, src_mask, copy_tokens = tile([memory, src_mask, copy_tokens], beam_size, dim=0) hidden_states = tile(hidden_states, beam_size, dim=1) h_c = type(hidden_states) in [list, tuple] batch_idx = list(range(remaining_sents)) for i in range(MAX_DECODE_LENGTH): # (a) construct beamsize * remaining_sents next words ys = torch.stack([ b.get_current_state() for b in beam if not b.done() ]).contiguous().view(-1, 1) # (b) pass through the decoder network out, hidden_states = self.decode_one_step(ys, hidden_states, memory, src_mask, copy_tokens) out = out.contiguous().view(remaining_sents, beam_size, -1) # (c) advance each beam active, select_indices_array = [], [] # Loop over the remaining_batch number of beam for b in range(remaining_sents): idx = batch_idx[ b] # idx represent the original order in minibatch_size beam[idx].advance(out[b]) if not beam[idx].done(): active.append((idx, b)) select_indices_array.append(beam[idx].get_current_origin() + b * beam_size) # (d) update hidden_states history select_indices_array = torch.cat(select_indices_array, dim=0) if h_c: hidden_states = (hidden_states[0].index_select( 1, select_indices_array), hidden_states[1].index_select( 1, select_indices_array)) else: hidden_states = hidden_states.index_select( 1, select_indices_array) if not active: break # (e) reserve un-finished batches active_idx = torch.tensor( [item[1] for item in active], dtype=torch.long, device=memory.device) # original order in remaining batch batch_idx = {idx: item[0] for idx, item in enumerate(active) } # order for next remaining batch def update_active(t): if t is None: return t t_reshape = t.contiguous().view(remaining_sents, beam_size, -1) new_size = list(t.size()) new_size[0] = -1 return t_reshape.index_select(0, active_idx).view(*new_size) if h_c: hidden_states = (update_active(hidden_states[0].transpose( 0, 1)).transpose(0, 1).contiguous(), update_active(hidden_states[1].transpose( 0, 1)).transpose(0, 1).contiguous()) else: hidden_states = update_active(hidden_states.transpose( 0, 1)).transpose(0, 1).contiguous() memory = update_active(memory) src_mask = update_active(src_mask) copy_tokens = update_active(copy_tokens) remaining_sents = len(active) for b in beam: scores, ks = b.sort_finished(minimum=n_best) hyps = [] for i, (times, k) in enumerate(ks[:n_best]): hyp = b.get_hyp(times, k) hyps.append( hyp.tolist()) # hyp contains </s> but does not contain <s> results["predictions"].append( hyps) # batch list of variable_tgt_len results["scores"].append(torch.stack( scores)[:n_best]) # list of [n_best], torch.FloatTensor results["scores"] = torch.stack(results["scores"]) return results
def translate_batch_ENSEMBLE(self, enc_output, enc_hidden, category): ''' Translation work in one batch ''' def beam_decode_step(inst_dec_beams, enc_output, enc_hidden, inst_idx_to_position_map, n_bm, category): ''' Decode and update beam status, and then return active beam idx ''' def prepare_beam_dec_seq(inst_dec_beams): dec_partial_seq = [ b.get_lastest_state() for b in inst_dec_beams if not b.done ] dec_partial_seq = torch.stack(dec_partial_seq).to(self.device) dec_partial_seq = dec_partial_seq.view(-1) #print(dec_partial_seq) return dec_partial_seq def predict_word(dec_seq, enc_output, enc_hidden, n_active_inst, n_bm, category): word_prob = [] for i in range(len(enc_output)): res = self.model[i].decoder(it=dec_seq, encoder_outputs=enc_output[i], category=category, decoder_hidden=enc_hidden[i]) dec_output, enc_hidden[i] = res['dec_outputs'], res[ 'dec_hidden'] tmp = F.log_softmax(self.model[i].tgt_word_prj(dec_output), dim=1) tmp = tmp.view(n_active_inst, n_bm, -1) word_prob.append(tmp) word_prob = torch.stack(word_prob, dim=0).mean(0) return word_prob, enc_hidden def collect_active_hidden_single(inst_beams, inst_idx_to_position_map, enc_hidden, n_bm): if isinstance(enc_hidden, tuple): tmp1, tmp2 = enc_hidden _, *d_hs = tmp1.size() n_curr_active_inst = len(inst_idx_to_position_map) new_shape = (n_curr_active_inst * n_bm, *d_hs) tmp1 = tmp1.view(n_curr_active_inst, n_bm, -1) tmp2 = tmp2.view(n_curr_active_inst, n_bm, -1) #print('hidden:', tmp1) for inst_idx, inst_position in inst_idx_to_position_map.items( ): _prev_ks = inst_beams[inst_idx].get_current_origin() tmp1[inst_position] = tmp1[inst_position].index_select( 0, _prev_ks) tmp2[inst_position] = tmp2[inst_position].index_select( 0, _prev_ks) #print("PREV_KS:", _prev_ks) #print('after h:', tmp1) tmp1 = tmp1.view(*new_shape) tmp2 = tmp2.view(*new_shape) enc_hidden = (tmp1, tmp2) else: _, *d_hs = enc_hidden.size() n_curr_active_inst = len(inst_idx_to_position_map) new_shape = (n_curr_active_inst * n_bm, *d_hs) enc_hidden = enc_hidden.view(n_curr_active_inst, n_bm, -1) for inst_idx, inst_position in inst_idx_to_position_map.items( ): _prev_ks = inst_beams[inst_idx].get_current_origin() enc_hidden[inst_position] = enc_hidden[ inst_position].index_select(0, _prev_ks) enc_hidden = enc_hidden.view(*new_shape) return enc_hidden def collect_active_hidden(inst_beams, inst_idx_to_position_map, enc_hidden, n_bm): if enc_hidden is None: return None if isinstance(enc_hidden, list): hidden = [] for item in enc_hidden: hidden.append( collect_active_hidden_single( inst_beams, inst_idx_to_position_map, item, n_bm)) else: hidden = collect_active_hidden_single( inst_beams, inst_idx_to_position_map, enc_hidden, n_bm) return hidden n_active_inst = len(inst_idx_to_position_map) dec_seq = prepare_beam_dec_seq(inst_dec_beams) #print(dec_seq) #print('before:', enc_hidden[0]) word_prob, enc_hidden = predict_word(dec_seq, enc_output, enc_hidden, n_active_inst, n_bm, category) #print('after:', enc_hidden[0]) # Update the beam with predicted word prob information and collect incomplete instances active_inst_idx_list = self.collect_active_inst_idx_list( inst_dec_beams, word_prob, inst_idx_to_position_map) #print(type(enc_hidden)) #print(type(enc_hidden[0])) #print(type(enc_hidden[0][0])) #print(type(enc_hidden[0][0][0])) enc_hidden = [ collect_active_hidden(inst_dec_beams, inst_idx_to_position_map, item, n_bm) for item in enc_hidden ] return active_inst_idx_list, enc_hidden with torch.no_grad(): assert isinstance(enc_output, list) assert isinstance(enc_hidden, list) assert len(enc_output) == len(self.model) assert len(enc_output) == len(enc_hidden) for i in range(len(enc_output)): if not isinstance(enc_output[i], list): enc_output[i] = [enc_output[i]] n_bm = self.opt["beam_size"] n_inst, len_s, d_h = enc_output[0][0].size() #-- Repeat data for beam search category = category.unsqueeze(1).repeat(1, n_bm, 1).view( n_inst * n_bm, self.opt['num_category']) enc_output = [[ tmp.repeat(1, n_bm, 1).view(n_inst * n_bm, len_s, d_h) for tmp in item ] for item in enc_output] for i in range(len(enc_hidden)): if isinstance(enc_hidden[i], tuple): enc_hidden[i] = (enc_hidden[i][0].unsqueeze(1).repeat( 1, n_bm, 1).view(n_inst * n_bm, d_h), enc_hidden[i][1].unsqueeze(1).repeat( 1, n_bm, 1).view(n_inst * n_bm, d_h)) else: enc_hidden[i] = enc_hidden[i].unsqueeze(1).repeat( 1, n_bm, 1).view(n_inst * n_bm, d_h) #-- initialize hidden state for i in range(len(enc_output)): enc_hidden[i] = self.model[i].decoder.init_hidden( enc_hidden[i]) #-- Prepare beams inst_dec_beams = [ Beam(n_bm, self.opt["max_len"], device=self.device) for _ in range(n_inst) ] #-- Bookkeeping for active or not active_inst_idx_list = list(range(n_inst)) inst_idx_to_position_map = self.get_inst_idx_to_tensor_position_map( active_inst_idx_list) #-- Decode for t in range(1, self.opt["max_len"]): active_inst_idx_list, enc_hidden = beam_decode_step( inst_dec_beams, enc_output, enc_hidden, inst_idx_to_position_map, n_bm, category) if not active_inst_idx_list: break # all instances have finished their path to <EOS> enc_output, enc_hidden, category, inst_idx_to_position_map = self.collate_active_info( enc_output, inst_idx_to_position_map, active_inst_idx_list, category, n_bm, enc_hidden=enc_hidden) batch_hyp, batch_scores = self.collect_hypothesis_and_scores( inst_dec_beams, self.opt.get("topk", 1)) return batch_hyp, batch_scores
def translate_batch_LSTM(self, encoder_outputs, category): ''' Translation work in one batch ''' def beam_decode_step(inst_dec_beams, enc_output, enc_hidden, inst_idx_to_position_map, n_bm, category, tag): ''' Decode and update beam status, and then return active beam idx ''' def prepare_beam_dec_seq(inst_dec_beams): dec_partial_seq = [ b.get_lastest_state() for b in inst_dec_beams if not b.done ] dec_partial_seq = torch.stack(dec_partial_seq).to(self.device) dec_partial_seq = dec_partial_seq.view(-1) #print(dec_partial_seq) return dec_partial_seq def predict_word(dec_seq, enc_output, enc_hidden, n_active_inst, n_bm, category, tag): res = self.model.decoder(it=dec_seq, encoder_outputs=enc_output, category=category, decoder_hidden=enc_hidden, tag=tag) dec_output, enc_hidden, tag = res['dec_outputs'], res[ 'dec_hidden'], res.get('pred_tag', None) word_prob = F.log_softmax(self.model.tgt_word_prj(dec_output), dim=1) word_prob = word_prob.view(n_active_inst, n_bm, -1) return word_prob, enc_hidden, tag.argmax( 1) if tag is not None else None def collect_active_hidden_single(inst_beams, inst_idx_to_position_map, enc_hidden, n_bm): if isinstance(enc_hidden, tuple): tmp1, tmp2 = enc_hidden _, *d_hs = tmp1.size() n_curr_active_inst = len(inst_idx_to_position_map) new_shape = (n_curr_active_inst * n_bm, *d_hs) tmp1 = tmp1.view(n_curr_active_inst, n_bm, -1) tmp2 = tmp2.view(n_curr_active_inst, n_bm, -1) #print('hidden:', tmp1) for inst_idx, inst_position in inst_idx_to_position_map.items( ): _prev_ks = inst_beams[inst_idx].get_current_origin() tmp1[inst_position] = tmp1[inst_position].index_select( 0, _prev_ks) tmp2[inst_position] = tmp2[inst_position].index_select( 0, _prev_ks) #print("PREV_KS:", _prev_ks) #print('after h:', tmp1) tmp1 = tmp1.view(*new_shape) tmp2 = tmp2.view(*new_shape) enc_hidden = (tmp1, tmp2) else: _, *d_hs = enc_hidden.size() n_curr_active_inst = len(inst_idx_to_position_map) new_shape = (n_curr_active_inst * n_bm, *d_hs) enc_hidden = enc_hidden.view(n_curr_active_inst, n_bm, -1) for inst_idx, inst_position in inst_idx_to_position_map.items( ): _prev_ks = inst_beams[inst_idx].get_current_origin() enc_hidden[inst_position] = enc_hidden[ inst_position].index_select(0, _prev_ks) enc_hidden = enc_hidden.view(*new_shape) return enc_hidden def collect_active_hidden(inst_beams, inst_idx_to_position_map, enc_hidden, n_bm): if enc_hidden is None: return None if isinstance(enc_hidden, list): hidden = [] for item in enc_hidden: hidden.append( collect_active_hidden_single( inst_beams, inst_idx_to_position_map, item, n_bm)) else: hidden = collect_active_hidden_single( inst_beams, inst_idx_to_position_map, enc_hidden, n_bm) return hidden ''' _, *d_hs = beamed_tensor.size() n_curr_active_inst = len(curr_active_inst_idx) new_shape = (n_curr_active_inst * n_bm, *d_hs) print('n_prev_active:', n_prev_active_inst) print('n_curr_active:', curr_active_inst_idx) beamed_tensor = beamed_tensor.view(n_prev_active_inst, -1) beamed_tensor = beamed_tensor.index_select(0, curr_active_inst_idx) beamed_tensor = beamed_tensor.view(*new_shape) return beamed_tensor ''' n_active_inst = len(inst_idx_to_position_map) dec_seq = prepare_beam_dec_seq(inst_dec_beams) #print(dec_seq) #print('before:', enc_hidden[0]) word_prob, enc_hidden, tag = predict_word(dec_seq, enc_output, enc_hidden, n_active_inst, n_bm, category, tag) #print('after:', enc_hidden[0]) # Update the beam with predicted word prob information and collect incomplete instances active_inst_idx_list = self.collect_active_inst_idx_list( inst_dec_beams, word_prob, inst_idx_to_position_map) enc_hidden = collect_active_hidden(inst_dec_beams, inst_idx_to_position_map, enc_hidden, n_bm) tag = collect_active_hidden(inst_dec_beams, inst_idx_to_position_map, tag, n_bm) return active_inst_idx_list, enc_hidden, tag with torch.no_grad(): enc_output, enc_hidden = encoder_outputs[ 'enc_output'], encoder_outputs['enc_hidden'] if not isinstance(enc_output, list): enc_output = [enc_output] n_bm = self.opt["beam_size"] n_inst, len_s, _ = enc_output[0].shape #-- Repeat data for beam search enc_output = [ item.repeat(1, n_bm, 1).view(n_inst * n_bm, len_s, -1) for item in enc_output ] if isinstance(enc_hidden, tuple): n_inst, d_h = enc_hidden[0].size() enc_hidden = (enc_hidden[0].unsqueeze(1).repeat( 1, n_bm, 1).view(n_inst * n_bm, d_h), enc_hidden[1].unsqueeze(1).repeat( 1, n_bm, 1).view(n_inst * n_bm, d_h)) elif isinstance(enc_hidden, list): n_inst, d_h = enc_hidden[0].size() enc_hidden = [ item.unsqueeze(1).repeat(1, n_bm, 1).view(n_inst * n_bm, d_h) for item in enc_hidden ] else: n_inst, d_h = enc_hidden.size() enc_hidden = enc_hidden.unsqueeze(1).repeat(1, n_bm, 1).view( n_inst * n_bm, d_h) enc_hidden = self.model.decoder.init_hidden(enc_hidden) if encoder_outputs.get('obj_emb', None) is not None: if self.opt['with_category']: category = torch.cat( [category, encoder_outputs['obj_emb']], dim=1) else: category = encoder_outputs['obj_emb'] category = category.unsqueeze(1).repeat(1, n_bm, 1).view(n_inst * n_bm, -1) #-- Prepare beams inst_dec_beams = [ Beam(n_bm, self.opt["max_len"], device=self.device) for _ in range(n_inst) ] if self.opt['use_tag']: tag = category.new(n_inst, n_bm).fill_(Constants.BOS).view( n_inst * n_bm).long() else: tag = None #-- Bookkeeping for active or not active_inst_idx_list = list(range(n_inst)) inst_idx_to_position_map = self.get_inst_idx_to_tensor_position_map( active_inst_idx_list) #-- Decode for t in range(1, self.opt["max_len"]): active_inst_idx_list, enc_hidden, tag = beam_decode_step( inst_dec_beams, enc_output, enc_hidden, inst_idx_to_position_map, n_bm, category, tag) if not active_inst_idx_list: break # all instances have finished their path to <EOS> enc_output, enc_hidden, category, inst_idx_to_position_map, tag = self.collate_active_info( enc_output, inst_idx_to_position_map, active_inst_idx_list, category, n_bm, enc_hidden=enc_hidden, tag=tag) batch_hyp, batch_scores = self.collect_hypothesis_and_scores( inst_dec_beams, self.opt.get("topk", 1)) return batch_hyp, batch_scores
def translate_batch_ARFormer(self, encoder_outputs, category): ''' Translation work in one batch ''' def beam_decode_step(inst_dec_beams, len_dec_seq, enc_output, inst_idx_to_position_map, n_bm, category, attribute): ''' Decode and update beam status, and then return active beam idx ''' def prepare_beam_dec_seq(inst_dec_beams, len_dec_seq): dec_partial_seq = [ b.get_current_state() for b in inst_dec_beams if not b.done ] dec_partial_seq = torch.stack(dec_partial_seq).to(self.device) dec_partial_seq = dec_partial_seq.view(-1, len_dec_seq) #print(dec_partial_seq) return dec_partial_seq def predict_word(dec_seq, enc_output, n_active_inst, n_bm, category, attribute): dec_output, *_ = self.model.decoder(dec_seq, enc_output, category, tags=attribute) if isinstance(dec_output, list): dec_output = dec_output[-1] dec_output = dec_output[:, -1, :] # Pick the last step: (bh * bm) * d_h word_prob = F.log_softmax(self.model.tgt_word_prj(dec_output), dim=1) word_prob = word_prob.view(n_active_inst, n_bm, -1) return word_prob n_active_inst = len(inst_idx_to_position_map) dec_seq = prepare_beam_dec_seq(inst_dec_beams, len_dec_seq) word_prob = predict_word(dec_seq, enc_output, n_active_inst, n_bm, category, attribute) # Update the beam with predicted word prob information and collect incomplete instances active_inst_idx_list = self.collect_active_inst_idx_list( inst_dec_beams, word_prob, inst_idx_to_position_map) return active_inst_idx_list ''' def collect_hypothesis_and_scores(inst_dec_beams, n_best): all_hyp, all_scores = [], [] for inst_idx in range(len(inst_dec_beams)): scores, tail_idxs = inst_dec_beams[inst_idx].sort_scores() print(113, scores, tail_idxs) all_scores += [scores[:n_best]] hyps = [inst_dec_beams[inst_idx].get_hypothesis(i) for i in tail_idxs[:n_best]] all_hyp += [hyps] return all_hyp, all_scores ''' with torch.no_grad(): enc_output = encoder_outputs['enc_output'] if isinstance(enc_output, list): assert len(enc_output) == 1 enc_output = enc_output[0] #-- Repeat data for beam search n_bm = self.opt["beam_size"] n_inst, len_s, d_h = enc_output.size() enc_output = enc_output.repeat(1, n_bm, 1).view(n_inst * n_bm, len_s, d_h) category = category.repeat(1, n_bm).view(n_inst * n_bm, 1) e = enc_output.clone() c = category.clone() attribute = encoder_outputs.get(Constants.mapping['attr'][0], None) if attribute is not None: attribute = attribute.unsqueeze(1).repeat(1, n_bm, 1).view( n_inst * n_bm, -1) #-- Prepare beams inst_dec_beams = [ Beam(n_bm, self.opt["max_len"], device=self.device, specific_nums_of_sents=self.opt.get('topk', 1)) for _ in range(n_inst) ] #-- Bookkeeping for active or not active_inst_idx_list = list(range(n_inst)) inst_idx_to_position_map = self.get_inst_idx_to_tensor_position_map( active_inst_idx_list) #-- Decode for len_dec_seq in range(1, self.opt["max_len"]): active_inst_idx_list = beam_decode_step( inst_dec_beams, len_dec_seq, enc_output, inst_idx_to_position_map, n_bm, category, attribute) if not active_inst_idx_list: break # all instances have finished their path to <EOS> enc_output, category, inst_idx_to_position_map, attribute = self.collate_active_info( enc_output, inst_idx_to_position_map, active_inst_idx_list, category, n_bm, tag=attribute) if self.opt.get('use_beam_decoder', False): batch_hyp, batch_scores = self.collect_hypothesis_and_scores_bd( inst_dec_beams, self.opt.get("topk", 1), e, c) else: batch_hyp, batch_scores = self.collect_hypothesis_and_scores( inst_dec_beams, self.opt.get("topk", 1)) return batch_hyp, batch_scores
def decode_beam_search(self, word_seqs, lengths, beam_size, tag2idx, extFeats=None, with_snt_classifier=False, masked_output=None): minibatch_size = len( lengths ) #word_seqs.size(0) if self.encoder.batch_first else word_seqs.size(1) max_length = max( lengths ) #word_seqs.size(1) if self.encoder.batch_first else word_seqs.size(0) # encoder embeds = self.get_token_embeddings(word_seqs, lengths) if type(extFeats) != type(None): concat_input = torch.cat((embeds, self.extFeats_linear(extFeats)), 2) else: concat_input = embeds concat_input = self.dropout_layer(concat_input) packed_word_embeds = rnn_utils.pack_padded_sequence(concat_input, lengths, batch_first=True) packed_word_lstm_out, (enc_h_t, enc_c_t) = self.encoder( packed_word_embeds) # bsize x seqlen x dim enc_word_lstm_out, unpacked_len = rnn_utils.pad_packed_sequence( packed_word_lstm_out, batch_first=True) # decoder if self.bidirectional: index_slices = [2 * i + 1 for i in range(self.num_layers) ] # generated from the reversed path index_slices = torch.tensor(index_slices, dtype=torch.long, device=self.device) h_t = torch.index_select(enc_h_t, 0, index_slices) c_t = torch.index_select(enc_c_t, 0, index_slices) else: h_t = enc_h_t c_t = enc_c_t h_t = h_t.repeat(1, beam_size, 1) c_t = c_t.repeat(1, beam_size, 1) word_lstm_out = enc_word_lstm_out.repeat(beam_size, 1, 1) beam = [ Beam(beam_size, tag2idx, device=self.device) for k in range(minibatch_size) ] batch_idx = list(range(minibatch_size)) remaining_sents = minibatch_size top_dec_h_t, top_dec_c_t = [0] * minibatch_size, [0] * minibatch_size for i in range(max_length): last_tags = torch.stack([ b.get_current_state() for b in beam if not b.done ]).t().contiguous().view(-1, 1) # after t() -> beam_size * batch_size last_tags = last_tags.to(self.device) tag_embeds = self.dropout_layer(self.tag_embeddings(last_tags)) decode_inputs = torch.cat( (self.dropout_layer(word_lstm_out[:, i:i + 1]), tag_embeds), 2) # (batch*beam) x 1 x insize tag_lstm_out, (dec_h_t, dec_c_t) = self.decoder( decode_inputs, (h_t, c_t)) # (batch*beam) x 1 x insize => (batch*beam) x 1 x hsize tag_lstm_out_reshape = tag_lstm_out.contiguous().view( tag_lstm_out.size(0) * tag_lstm_out.size(1), tag_lstm_out.size(2)) tag_space = self.hidden2tag( self.dropout_layer(tag_lstm_out_reshape)) out = F.log_softmax(tag_space) # (batch*beam) x outsize word_lk = out.view(beam_size, remaining_sents, -1).transpose(0, 1).contiguous() active = [] for b in range(minibatch_size): if beam[b].done: continue if lengths[b] == i + 1: beam[b].done = True top_dec_h_t[b] = dec_h_t[:, b:b + beam_size, :] top_dec_c_t[b] = dec_c_t[:, b:b + beam_size, :] idx = batch_idx[b] beam[b].advance(word_lk.data[idx]) if not beam[b].done: active.append(b) for dec_state in (dec_h_t, dec_c_t): # (layer*direction) x beam*sent x Hdim sent_states = dec_state.view(-1, beam_size, remaining_sents, dec_state.size(2))[:, :, idx] sent_states.data.copy_( sent_states.data.index_select( 1, beam[b].get_current_origin())) if not active: break active_idx = torch.tensor([batch_idx[k] for k in active], dtype=torch.long, device=self.device) batch_idx = {beam: idx for idx, beam in enumerate(active)} def update_active(t, hidden_dim): #t_reshape = t.data.view(-1, remaining_sents, hidden_dim) t_reshape = t.contiguous().view(-1, remaining_sents, hidden_dim) new_size = list(t.size()) new_size[-2] = new_size[-2] * len( active_idx) // remaining_sents # beam*len(active_idx) return t_reshape.index_select(1, active_idx).view(*new_size) h_t = update_active(dec_h_t, self.hidden_dim) c_t = update_active(dec_c_t, self.hidden_dim) word_lstm_out = update_active( word_lstm_out.transpose(0, 1), self.num_directions * self.hidden_dim).transpose(0, 1) remaining_sents = len(active) allHyp, allScores = [], [] n_best = 1 for b in range(minibatch_size): scores, ks = beam[b].sort_best() allScores += [scores[:n_best]] hyps = zip(*[beam[b].get_hyp(k) for k in ks[:n_best]]) allHyp += [hyps] top_dec_h_t[b] = top_dec_h_t[b].data.index_select(1, ks[:n_best]) top_dec_c_t[b] = top_dec_c_t[b].data.index_select(1, ks[:n_best]) top_dec_h_t = torch.cat(top_dec_h_t, 1) top_dec_c_t = torch.cat(top_dec_c_t, 1) allScores = torch.cat(allScores) if with_snt_classifier: return allScores, allHyp, ((enc_h_t, enc_c_t), enc_word_lstm_out, lengths) else: return allScores, allHyp
def translate_batch_ARFormer(self, encoder_outputs, category): ''' Translation work in one batch ''' def beam_decode_step(inst_dec_beams, len_dec_seq, inputs_for_decoder, inst_idx_to_position_map, n_bm): ''' Decode and update beam status, and then return active beam idx ''' def prepare_beam_dec_seq(inst_dec_beams, len_dec_seq): dec_partial_seq = [ b.get_current_state() for b in inst_dec_beams if not b.done ] dec_partial_seq = torch.stack(dec_partial_seq).to(self.device) dec_partial_seq = dec_partial_seq.view(-1, len_dec_seq) return dec_partial_seq def predict_word(dec_seq, inputs_for_decoder, n_active_inst, n_bm): dec_output, *_ = self.model.decoder(dec_seq, **inputs_for_decoder) if isinstance(dec_output, list): dec_output = dec_output[-1] dec_output = dec_output[:, -1, :] # Pick the last step: (bh * bm) * d_h word_prob = self.model.tgt_word_prj(dec_output) word_prob = F.log_softmax(word_prob, dim=1) #print(word_prob[0, :10]) word_prob = word_prob.view(n_active_inst, n_bm, -1) return word_prob n_active_inst = len(inst_idx_to_position_map) dec_seq = prepare_beam_dec_seq(inst_dec_beams, len_dec_seq) word_prob = predict_word(dec_seq, inputs_for_decoder, n_active_inst, n_bm) # Update the beam with predicted word prob information and collect incomplete instances active_inst_idx_list = self.collect_active_inst_idx_list( inst_dec_beams, word_prob, inst_idx_to_position_map) return active_inst_idx_list with torch.no_grad(): inputs_for_decoder = self.model.prepare_inputs_for_decoder( encoder_outputs, category) #-- Repeat data for beam search n_bm = self.opt["beam_size"] n_inst = inputs_for_decoder['enc_output'].size(0) for key in inputs_for_decoder: inputs_for_decoder[key] = auto_enlarge(inputs_for_decoder[key], n_bm) #-- Prepare beams inst_dec_beams = [ Beam(n_bm, self.opt["max_len"], device=self.device, specific_nums_of_sents=self.opt.get('topk', 1)) for _ in range(n_inst) ] #-- Bookkeeping for active or not active_inst_idx_list = list(range(n_inst)) inst_idx_to_position_map = self.get_inst_idx_to_tensor_position_map( active_inst_idx_list) #-- Decode for len_dec_seq in range(1, self.opt["max_len"]): active_inst_idx_list = beam_decode_step( inst_dec_beams, len_dec_seq, inputs_for_decoder, inst_idx_to_position_map, n_bm) if not active_inst_idx_list: break # all instances have finished their path to <EOS> inputs_for_decoder, inst_idx_to_position_map = self.collate_active_info( inputs_for_decoder, inst_idx_to_position_map, active_inst_idx_list, n_bm) batch_hyp, batch_scores = self.collect_hypothesis_and_scores( inst_dec_beams, self.opt.get("topk", 1)) return batch_hyp, batch_scores