def translate_batch(self, src_seq, src_pos): ''' Translation work in one batch ''' def get_inst_idx_to_tensor_position_map(inst_idx_list): ''' Indicate the position of an instance in a tensor. ''' return { inst_idx: tensor_position for tensor_position, inst_idx in enumerate(inst_idx_list) } def collect_active_part(beamed_tensor, curr_active_inst_idx, n_prev_active_inst, n_bm): ''' Collect tensor parts associated to active instances. ''' _, *d_hs = beamed_tensor.size() n_curr_active_inst = len(curr_active_inst_idx) new_shape = (n_curr_active_inst * n_bm, *d_hs) beamed_tensor = beamed_tensor.view(n_prev_active_inst, -1) beamed_tensor = beamed_tensor.index_select(0, curr_active_inst_idx) beamed_tensor = beamed_tensor.view(*new_shape) return beamed_tensor def collate_active_info(src_seq, src_enc, inst_idx_to_position_map, active_inst_idx_list): # Sentences which are still active are collected, # so the decoder will not run on completed sentences. n_prev_active_inst = len(inst_idx_to_position_map) active_inst_idx = [ inst_idx_to_position_map[k] for k in active_inst_idx_list ] active_inst_idx = torch.LongTensor(active_inst_idx).to(self.device) active_src_seq = collect_active_part(src_seq, active_inst_idx, n_prev_active_inst, n_bm) active_src_enc = collect_active_part(src_enc, active_inst_idx, n_prev_active_inst, n_bm) active_inst_idx_to_position_map = get_inst_idx_to_tensor_position_map( active_inst_idx_list) return active_src_seq, active_src_enc, active_inst_idx_to_position_map def beam_decode_step(inst_dec_beams, len_dec_seq, src_seq, enc_output, inst_idx_to_position_map, n_bm): ''' Decode and update beam status, and then return active beam idx ''' def prepare_beam_dec_seq(inst_dec_beams, len_dec_seq): dec_partial_seq = [ b.get_current_state() for b in inst_dec_beams if not b.done ] dec_partial_seq = torch.stack(dec_partial_seq).to(self.device) dec_partial_seq = dec_partial_seq.view(-1, len_dec_seq) return dec_partial_seq def prepare_beam_dec_pos(len_dec_seq, n_active_inst, n_bm): dec_partial_pos = torch.arange(1, len_dec_seq + 1, dtype=torch.long, device=self.device) dec_partial_pos = dec_partial_pos.unsqueeze(0).repeat( n_active_inst * n_bm, 1) return dec_partial_pos def predict_word(dec_seq, dec_pos, src_seq, enc_output, n_active_inst, n_bm): dec_output, *_ = self.model.decoder(dec_seq, dec_pos, src_seq, enc_output) dec_output = dec_output[:, -1, :] # Pick the last step: (bh * bm) * d_h word_prob = F.log_softmax(self.model.tgt_word_prj(dec_output), dim=1) word_prob = word_prob.view(n_active_inst, n_bm, -1) return word_prob def collect_active_inst_idx_list(inst_beams, word_prob, inst_idx_to_position_map): active_inst_idx_list = [] for inst_idx, inst_position in inst_idx_to_position_map.items( ): is_inst_complete = inst_beams[inst_idx].advance( word_prob[inst_position]) if not is_inst_complete: active_inst_idx_list += [inst_idx] return active_inst_idx_list n_active_inst = len(inst_idx_to_position_map) dec_seq = prepare_beam_dec_seq(inst_dec_beams, len_dec_seq) dec_pos = prepare_beam_dec_pos(len_dec_seq, n_active_inst, n_bm) word_prob = predict_word(dec_seq, dec_pos, src_seq, enc_output, n_active_inst, n_bm) # Update the beam with predicted word prob information and collect incomplete instances active_inst_idx_list = collect_active_inst_idx_list( inst_dec_beams, word_prob, inst_idx_to_position_map) return active_inst_idx_list def collect_hypothesis_and_scores(inst_dec_beams, n_best): all_hyp, all_scores = [], [] for inst_idx in range(len(inst_dec_beams)): scores, tail_idxs = inst_dec_beams[inst_idx].sort_scores() all_scores += [scores[:n_best]] hyps = [ inst_dec_beams[inst_idx].get_hypothesis(i) for i in tail_idxs[:n_best] ] all_hyp += [hyps] return all_hyp, all_scores with torch.no_grad(): #-- Encode src_seq, src_pos = src_seq.to(self.device), src_pos.to(self.device) src_enc, *_ = self.model.encoder(src_seq, src_pos) #-- Repeat data for beam search n_bm = self.opt.beam_size n_inst, len_s, d_h = src_enc.size() src_seq = src_seq.repeat(1, n_bm).view(n_inst * n_bm, len_s) src_enc = src_enc.repeat(1, n_bm, 1).view(n_inst * n_bm, len_s, d_h) #-- Prepare beams inst_dec_beams = [ Beam(n_bm, device=self.device) for _ in range(n_inst) ] #-- Bookkeeping for active or not active_inst_idx_list = list(range(n_inst)) inst_idx_to_position_map = get_inst_idx_to_tensor_position_map( active_inst_idx_list) #-- Decode for len_dec_seq in range(1, self.model_opt.max_token_seq_len + 1): active_inst_idx_list = beam_decode_step( inst_dec_beams, len_dec_seq, src_seq, src_enc, inst_idx_to_position_map, n_bm) if not active_inst_idx_list: break # all instances have finished their path to <EOS> src_seq, src_enc, inst_idx_to_position_map = collate_active_info( src_seq, src_enc, inst_idx_to_position_map, active_inst_idx_list) batch_hyp, batch_scores = collect_hypothesis_and_scores( inst_dec_beams, self.opt.n_best) return batch_hyp, batch_scores
def translate_batch(self, src_batch): ''' Translation work in one batch ''' # Batch size is in different location depending on data. src_seq, src_pos = src_batch batch_size = src_seq.size(0) beam_size = self.opt.beam_size #- Enocde enc_output, *_ = self.model.encoder(src_seq, src_pos) #--- Repeat data for beam src_seq = Variable( src_seq.data.repeat(1, beam_size).view( src_seq.size(0) * beam_size, src_seq.size(1))) enc_output = Variable( enc_output.data.repeat(1, beam_size, 1).view( enc_output.size(0) * beam_size, enc_output.size(1), enc_output.size(2))) #--- Prepare beams beams = [Beam(beam_size, self.opt.cuda) for _ in range(batch_size)] beam_inst_idx_map = { beam_idx: inst_idx for inst_idx, beam_idx in enumerate(range(batch_size))} n_remaining_sents = batch_size #- Decode for i in range(self.model_opt.max_token_seq_len): len_dec_seq = i + 1 # -- Preparing decoded data seq -- # # size: batch x beam x seq dec_partial_seq = torch.stack([ b.get_current_state() for b in beams if not b.done]) # size: (batch * beam) x seq dec_partial_seq = dec_partial_seq.view(-1, len_dec_seq) # wrap into a Variable dec_partial_seq = Variable(dec_partial_seq, volatile=True) # -- Preparing decoded pos seq -- # # size: 1 x seq dec_partial_pos = torch.arange(1, len_dec_seq + 1).unsqueeze(0) # size: (batch * beam) x seq dec_partial_pos = dec_partial_pos.repeat(n_remaining_sents * beam_size, 1) # wrap into a Variable dec_partial_pos = Variable(dec_partial_pos.type(torch.LongTensor), volatile=True) if self.opt.cuda: dec_partial_seq = dec_partial_seq.cuda() dec_partial_pos = dec_partial_pos.cuda() # -- Decoding -- # dec_output, *_ = self.model.decoder( dec_partial_seq, dec_partial_pos, src_seq, enc_output) dec_output = dec_output[:, -1, :] # (batch * beam) * d_model dec_output = self.model.tgt_word_proj(dec_output) out = self.model.prob_projection(dec_output) # batch x beam x n_words word_lk = out.view(n_remaining_sents, beam_size, -1).contiguous() active_beam_idx_list = [] for beam_idx in range(batch_size): if beams[beam_idx].done: continue inst_idx = beam_inst_idx_map[beam_idx] if not beams[beam_idx].advance(word_lk.data[inst_idx]): active_beam_idx_list += [beam_idx] if not active_beam_idx_list: # all instances have finished their path to <EOS> break # in this section, the sentences that are still active are # compacted so that the decoder is not run on completed sentences active_inst_idxs = self.tt.LongTensor( [beam_inst_idx_map[k] for k in active_beam_idx_list]) # update the idx mapping beam_inst_idx_map = { beam_idx: inst_idx for inst_idx, beam_idx in enumerate(active_beam_idx_list)} def update_active_seq(seq_var, active_inst_idxs): ''' Remove the src sequence of finished instances in one batch. ''' inst_idx_dim_size, *rest_dim_sizes = seq_var.size() inst_idx_dim_size = inst_idx_dim_size * len(active_inst_idxs) // n_remaining_sents new_size = (inst_idx_dim_size, *rest_dim_sizes) # select the active instances in batch original_seq_data = seq_var.data.view(n_remaining_sents, -1) active_seq_data = original_seq_data.index_select(0, active_inst_idxs) active_seq_data = active_seq_data.view(*new_size) return Variable(active_seq_data, volatile=True) def update_active_enc_info(enc_info_var, active_inst_idxs): ''' Remove the encoder outputs of finished instances in one batch. ''' inst_idx_dim_size, *rest_dim_sizes = enc_info_var.size() inst_idx_dim_size = inst_idx_dim_size * len(active_inst_idxs) // n_remaining_sents new_size = (inst_idx_dim_size, *rest_dim_sizes) # select the active instances in batch original_enc_info_data = enc_info_var.data.view( n_remaining_sents, -1, self.model_opt.d_model) active_enc_info_data = original_enc_info_data.index_select(0, active_inst_idxs) active_enc_info_data = active_enc_info_data.view(*new_size) return Variable(active_enc_info_data, volatile=True) src_seq = update_active_seq(src_seq, active_inst_idxs) enc_output = update_active_enc_info(enc_output, active_inst_idxs) #- update the remaining size n_remaining_sents = len(active_inst_idxs) #- Return useful information all_hyp, all_scores = [], [] n_best = self.opt.n_best for beam_idx in range(batch_size): scores, tail_idxs = beams[beam_idx].sort_scores() all_scores += [scores[:n_best]] hyps = [beams[beam_idx].get_hypothesis(i) for i in tail_idxs[:n_best]] all_hyp += [hyps] return all_hyp, all_scores
def translate_batch(self, raw_src_seq, raw_src_pos, block_list=[]): ''' Translation work in one batch ''' def get_inst_idx_to_tensor_position_map(inst_idx_list): ''' Indicate the position of an instance in a tensor. ''' return { inst_idx: tensor_position for tensor_position, inst_idx in enumerate(inst_idx_list) } def collect_active_part(beamed_tensor, curr_active_inst_idx, n_prev_active_inst, n_bm): ''' Collect tensor parts associated to active instances. ''' _, *d_hs = beamed_tensor.size() n_curr_active_inst = len(curr_active_inst_idx) new_shape = (n_curr_active_inst * n_bm, *d_hs) beamed_tensor = beamed_tensor.view(n_prev_active_inst, -1) beamed_tensor = beamed_tensor.index_select(0, curr_active_inst_idx) beamed_tensor = beamed_tensor.view(*new_shape) return beamed_tensor def collate_active_info(src_seq, src_enc, inst_idx_to_position_map, active_inst_idx_list): # Sentences which are still active are collected, # so the decoder will not run on completed sentences. n_prev_active_inst = len(inst_idx_to_position_map) active_inst_idx = [ inst_idx_to_position_map[k] for k in active_inst_idx_list ] active_inst_idx = torch.LongTensor(active_inst_idx).to(self.device) active_src_seq = collect_active_part(src_seq, active_inst_idx, n_prev_active_inst, n_bm) active_src_enc = collect_active_part(src_enc, active_inst_idx, n_prev_active_inst, n_bm) # print(active_src_enc.shape) if hasattr(self.model.encoder, 'ntm'): self.model.encoder.ntm.previous_state = list( self.model.encoder.ntm.previous_state) self.model.encoder.ntm.previous_state[1] = list( self.model.encoder.ntm.previous_state[1]) memory = self.model.encoder.ntm.memory self.model.encoder.ntm.memory.memory = collect_active_part( memory.memory.view(n_prev_active_inst * n_bm, -1), active_inst_idx, n_prev_active_inst, n_bm).view(-1, memory.N, memory.M) self.model.encoder.ntm.memory.batch_size = self.model.encoder.ntm.memory.memory.shape[ 0] # print(self.model.encoder.ntm.memory.memory.shape, self.model.encoder.ntm.memory.batch_size) for i in range(len(self.model.encoder.ntm.previous_state)): for j, tensor in enumerate( self.model.encoder.ntm.previous_state[i]): # print(i, j, tensor.shape) squeezed = False if len(tensor.shape) == 3: dim0, dim1, dim2 = tensor.shape # dim1 = n_prev_active_inst*n_bm tensor = torch.transpose(tensor, 0, 1).contiguous().view( dim1, -1) # tensor = tensor.squeeze(0) squeezed = True new_tensor = collect_active_part( tensor, active_inst_idx, n_prev_active_inst, n_bm) if squeezed: new_tensor = torch.transpose( new_tensor.contiguous().view(-1, dim0, dim2), 0, 1).contiguous() # print(new_tensor.shape) self.model.encoder.ntm.previous_state[i][ j] = new_tensor # active_src_enc.register_hook(print_grad('active src enc')) active_src_enc[torch.isnan(active_src_enc)] = 0 active_inst_idx_to_position_map = get_inst_idx_to_tensor_position_map( active_inst_idx_list) return active_src_seq, active_src_enc, active_inst_idx_to_position_map def prepare_beam_dec_seq(inst_dec_beams, len_dec_seq): dec_partial_seq = [ b.get_current_state() for b in inst_dec_beams if not b.done ] dec_partial_seq = torch.stack(dec_partial_seq).to(self.device) dec_partial_seq = dec_partial_seq.view(-1, len_dec_seq) return dec_partial_seq def prepare_beam_dec_pos(len_dec_seq, n_active_inst, n_bm): dec_partial_pos = torch.arange(1, len_dec_seq + 1, dtype=torch.long, device=self.device) dec_partial_pos = dec_partial_pos.unsqueeze(0).repeat( n_active_inst * n_bm, 1) return dec_partial_pos def predict_word(decoder, tgt_word_prj, dec_seq, dec_pos, src_seq, enc_output, n_active_inst, n_bm): """decoder is added as an argument, compared to the original version""" # sometimes the output is only [0] the pad token dec_output, *_ = decoder(dec_seq, dec_pos, src_seq, enc_output) dec_output_last = dec_output[:, -1, :] # Pick the last step: (bh * bm) * d_h # dec_output[torch.isnan(dec_output)] = 0 # dec_output.register_hook(print_grad('dec_output in predict_word {} data'.format(dec_seq))) # print(dec_output) # gcl decoder # k = enc_output.clone() # k = F.max_pool1d(k.permute(0, 2, 1), k.shape[-2]).squeeze() # x = dec_output.clone() # x = F.max_pool1d(x.permute(0, 2, 1), x.shape[-2]).squeeze() # # gcl_output = self.model.gcl(k.unsqueeze(0).detach(), x.unsqueeze(0), bidirectional=False, save_attn=False) # # dec_output_last = torch.cat([dec_output_last, gcl_output.squeeze()], -1) word_prob = F.log_softmax(tgt_word_prj(dec_output_last), dim=1) # word_prob.register_hook(print_grad('word prob in predict_word')) word_prob = word_prob.view(n_active_inst, n_bm, -1) if block_list != []: for block_tok in block_list: word_prob[:, :, block_tok] = -1000. return word_prob def collect_active_inst_idx_list(inst_beams, word_prob, inst_idx_to_position_map): """get indexes of instances that have not been fully translated yet""" active_inst_idx_list = [] for inst_idx, inst_position in inst_idx_to_position_map.items(): is_inst_complete = inst_beams[inst_idx].advance( word_prob[inst_position]) if not is_inst_complete: active_inst_idx_list += [inst_idx] return active_inst_idx_list def beam_decode_step(decoder, tgt_word_prj, inst_dec_beams, len_dec_seq, src_seq, enc_output, inst_idx_to_position_map, n_bm): ''' Decode and update beam status, and then return active beam idx decoder is added as an argument, compared to the original version ''' # enc_output.register_hook(print_grad('enc output')) n_active_inst = len(inst_idx_to_position_map) dec_seq = prepare_beam_dec_seq(inst_dec_beams, len_dec_seq) dec_pos = prepare_beam_dec_pos(len_dec_seq, n_active_inst, n_bm) word_prob = predict_word(decoder, tgt_word_prj, dec_seq, dec_pos, src_seq, enc_output, n_active_inst, n_bm) # word_prob.register_hook(print_grad('word prob in beam decode')) # grad ok # Update the beam with predicted word prob information and collect incomplete instances active_inst_idx_list = collect_active_inst_idx_list( inst_dec_beams, word_prob, inst_idx_to_position_map) return active_inst_idx_list def collect_hypothesis_and_scores(inst_dec_beams, n_best): all_hyp, all_scores = [], [] for inst_idx in range(len(inst_dec_beams)): scores, tail_idxs = inst_dec_beams[inst_idx].sort_scores() # scores.register_hook(print_grad('scores from collect hypothses')) all_scores = all_scores + [scores[:n_best]] hyps = [ inst_dec_beams[inst_idx].get_hypothesis(i) for i in tail_idxs[:n_best] ] all_hyp += [hyps] return all_hyp, all_scores if self.opt.bi: decoders = [self.model.decoder_lr, self.model.decoder_rl] tgt_word_prjs = [ self.model.tgt_word_prj_lr, self.model.tgt_word_prj_rl ] else: decoders = [self.model.decoder] tgt_word_prjs = [self.model.tgt_word_prj] batch_hyp_list = [] # list of results from each decoder batch_scores_list = [] n_bm = self.opt.beam_size # -- Decode for decoder, tgt_word_prj in zip( decoders, tgt_word_prjs): # two decoders for bidirectional model src_seq = copy.copy(raw_src_seq) src_pos = copy.copy(raw_src_pos) # -- Encode src_seq, src_pos = src_seq.to(self.device), src_pos.to(self.device) if hasattr(self.model.encoder, 'ntm'): n_inst, len_s = src_seq.size() src_seq = src_seq.repeat(1, n_bm).view(n_inst * n_bm, len_s) src_enc, *_ = self.model.encoder(src_seq, src_pos) # print(self.model.encoder.ntm.memory.N, self.model.encoder.ntm.memory.memory.shape) # print(type(self.model.encoder.ntm.previous_state[0])) # print(type(self.model.encoder.ntm.previous_state[1])) # print(type(self.model.encoder.ntm.previous_state[2])) #-- Repeat data for beam search if len(src_enc.size()) == 3: n_inst, len_s, d_h = src_enc.size() src_seq = src_seq.repeat(1, n_bm).view(n_inst * n_bm, len_s) src_enc = src_enc.unsqueeze(1).expand(-1, n_bm, -1, -1).contiguous().view( n_inst * n_bm, len_s, d_h) # src_enc.register_hook(print_grad('{}, src_enc'.format(src_enc.size()))) #-- Prepare beams inst_dec_beams = [ Beam(n_bm, device=self.device) for _ in range(n_inst) ] # -- Bookkeeping for active or not active_inst_idx_list = list(range(n_inst)) inst_idx_to_position_map = get_inst_idx_to_tensor_position_map( active_inst_idx_list) for len_dec_seq in range(1, self.model_opt.max_token_seq_len + 1): # if len_dec_seq > 30: # abnormally long seq # print(len_dec_seq) active_inst_idx_list = beam_decode_step( decoder, tgt_word_prj, inst_dec_beams, len_dec_seq, src_seq, src_enc, inst_idx_to_position_map, n_bm) # src_enc[torch.isnan(src_enc)] = 0 if not active_inst_idx_list: break # all instances have finished their path to <EOS> src_seq, src_enc, inst_idx_to_position_map = collate_active_info( src_seq, src_enc, inst_idx_to_position_map, active_inst_idx_list) # src_enc.register_hook(print_grad('active src enc')) # GRAD OK HERE # batch_hyp is a nested list of [batches [n_best seqs] ] batch_hyp, batch_scores = collect_hypothesis_and_scores( inst_dec_beams, self.opt.n_best) # print('\n') # print(batch_scores) batch_hyp_list.append(batch_hyp) batch_scores_list.append(batch_scores) return batch_hyp_list, batch_scores_list
def translate_batch(self, src_batch): ''' Translation work in one batch ''' # Batch size is in different location depending on data. src_seq, src_pos = src_batch batch_size = src_seq.size(0) beam_size = self.opt.beam_size #- Enocde enc_outputs, enc_slf_attns = self.model.encoder(src_seq, src_pos) #--- Repeat data for beam src_seq = Variable(src_seq.data.repeat(beam_size, 1)) # ERROR` enc_outputs = [ Variable(enc_output.data.repeat(beam_size, 1, 1)) for enc_output in enc_outputs ] #--- Prepare beams beam = [Beam(beam_size, self.opt.cuda) for k in range(batch_size)] batch_idx = list(range(batch_size)) n_remaining_sents = batch_size #- Decode for i in range(self.model_opt.max_token_seq_len): len_dec_seq = i + 1 # -- Preparing decode data seq -- # input_data = torch.stack([ b.get_current_state() for b in beam if not b.done ]) # size: mb x bm x sq input_data = input_data.view(-1, len_dec_seq) # size: (mb*bm) x sq input_data = Variable(input_data, volatile=True) # -- Preparing decode pos seq -- # # size: 1 x seq input_pos = torch.arange(1, len_dec_seq + 1).unsqueeze(0) # size: (batch * beam) x seq input_pos = input_pos.repeat(n_remaining_sents * beam_size, 1) input_pos = Variable(input_pos.type(torch.LongTensor), volatile=True) if self.opt.cuda: input_pos = input_pos.cuda() input_data = input_data.cuda() # -- Decoding -- # dec_outputs, dec_slf_attns, dec_enc_attns = self.model.decoder( input_data, input_pos, src_seq, enc_outputs) dec_output = dec_outputs[-1][:, -1, :] # (batch * beam) * d_model dec_output = self.model.tgt_word_proj(dec_output) out = self.model.prob_projection(dec_output) # batch x beam x n_words word_lk = out.view(n_remaining_sents, beam_size, -1).contiguous() active = [] for b in range(batch_size): if beam[b].done: continue idx = batch_idx[b] if not beam[b].advance(word_lk.data[idx]): active += [b] if not active: break # in this section, the sentences that are still active are # compacted so that the decoder is not run on completed sentences active_idx = self.tt.LongTensor([batch_idx[k] for k in active]) batch_idx = {beam: idx for idx, beam in enumerate(active)} def update_active_enc_info(tensor_var, active_idx): ''' Remove the encoder outputs of finished instances in one batch. ''' tensor_data = tensor_var.data.view(n_remaining_sents, -1, self.model_opt.d_model) new_size = list(tensor_var.size()) new_size[0] = new_size[0] * len( active_idx) // n_remaining_sents # select the active index in batch return Variable(tensor_data.index_select( 0, active_idx).view(*new_size), volatile=True) def update_active_seq(seq, active_idx): ''' Remove the src sequence of finished instances in one batch. ''' view = seq.data.view(n_remaining_sents, -1) new_size = list(seq.size()) new_size[0] = new_size[0] * len( active_idx) // n_remaining_sents # trim on batch dim # select the active index in batch return Variable(view.index_select(0, active_idx).view(*new_size), volatile=True) src_seq = update_active_seq(src_seq, active_idx) enc_outputs = [ update_active_enc_info(enc_output, active_idx) for enc_output in enc_outputs ] n_remaining_sents = len(active) #- Return useful information all_hyp, all_scores = [], [] n_best = self.opt.n_best for b in range(batch_size): scores, ks = beam[b].sort_scores() all_scores += [scores[:n_best]] hyps = [beam[b].get_hypothesis(k) for k in ks[:n_best]] all_hyp += [hyps] return all_hyp, all_scores
def translate_batch(self, src_batch): ''' Translation work in one batch ''' # Batch size is in different location depending on data. if self.model_opt.use_ctx: (src_seq, src_pos), (ctx_seq, ctx_pos) = src_batch else: src_seq, src_pos = src_batch batch_size = src_seq.size(0) beam_size = self.opt.beam_size #- Encode enc_outputs, enc_slf_attns = self.model.encoder(src_seq, src_pos) enc_output = enc_outputs[-1] #--- Repeat data for beam src_seq = Variable(src_seq.data.repeat(beam_size, 1)) enc_output = Variable(enc_output.data.repeat(beam_size, 1, 1)) if self.model_opt.use_ctx: #- Encode ctx_outputs, ctx_slf_attns = self.model.encoder_ctx( ctx_seq, ctx_pos) ctx_output = ctx_outputs[-1] #--- Repeat data for beam ctx_seq = Variable(ctx_seq.data.repeat(beam_size, 1)) ctx_output = Variable(ctx_output.data.repeat(beam_size, 1, 1)) #--- Prepare beams beam = [Beam(beam_size, self.opt.cuda) for k in range(batch_size)] batch_idx = list(range(batch_size)) n_remaining_sents = batch_size #- Decode for i in range(self.model_opt.max_token_seq_len): len_dec_seq = i + 1 # -- Preparing decode data seq -- # input_data = torch.stack([ b.get_current_state() for b in beam if not b.done ]) # size: mb x bm x sq input_data = input_data.permute(1, 0, 2).contiguous() input_data = input_data.view(-1, len_dec_seq) # size: (mb*bm) x sq input_data = Variable(input_data, volatile=True) # -- Preparing decode pos seq -- # # size: 1 x seq input_pos = torch.arange(1, len_dec_seq + 1).unsqueeze(0) # size: (batch * beam) x seq input_pos = input_pos.repeat(n_remaining_sents * beam_size, 1) input_pos = Variable(input_pos.type(torch.LongTensor), volatile=True) if self.opt.cuda: input_pos = input_pos.cuda() input_data = input_data.cuda() # -- Decoding -- # if self.model_opt.use_ctx: dec_outputs, dec_slf_attns, dec_enc_attns, dec_ctx_attns = self.model.decoder( input_data, input_pos, src_seq, enc_output, ctx_seq, ctx_output) else: dec_outputs, dec_slf_attns, dec_enc_attns = self.model.decoder( input_data, input_pos, src_seq, enc_output) dec_output = dec_outputs[-1][:, -1, :] # (batch * beam) * d_model dec_output = self.model.tgt_word_proj(dec_output) out = self.model.prob_projection(dec_output) # batch x beam x n_words word_lk = out.view(beam_size, n_remaining_sents, -1).contiguous() active = [] for b in range(batch_size): if beam[b].done: continue idx = batch_idx[b] if not beam[b].advance(word_lk.data[:, idx]): active += [b] if not active: break # in this section, the sentences that are still active are # compacted so that the decoder is not run on completed sentences active_idx = self.tt.LongTensor([batch_idx[k] for k in active]) batch_idx = {beam: idx for idx, beam in enumerate(active)} def update_active_enc_info(tensor_var, active_idx): ''' Remove the encoder outputs of finished instances in one batch. ''' batch = tensor_var.data[:n_remaining_sents] selected = batch.index_select(0, active_idx) data = selected.repeat(beam_size, 1, 1) return Variable(data, volatile=True) def update_active_seq(seq, active_idx): ''' Remove the src sequence of finished instances in one batch. ''' batch = seq.data[:n_remaining_sents] selected = batch.index_select(0, active_idx) data = selected.repeat(beam_size, 1) return Variable(data, volatile=True) src_seq = update_active_seq(src_seq, active_idx) enc_output = update_active_enc_info(enc_output, active_idx) if self.model_opt.use_ctx: ctx_seq = update_active_seq(ctx_seq, active_idx) ctx_output = update_active_enc_info(ctx_output, active_idx) n_remaining_sents = len(active) #- Return useful information all_hyp, all_scores = [], [] n_best = self.opt.n_best for b in range(batch_size): scores = self.tt.FloatTensor( beam_size + len(beam[b].finish_early_scores)).zero_() scores[:beam_size] = beam[b].scores for i in range(beam_size, beam_size + len(beam[b].finish_early_scores)): scores[i] = beam[b].finish_early_scores[i - beam_size][2] beam[b].scores = scores scores, ks = beam[b].sort_scores() all_scores += [scores[:n_best]] hyps = [ beam[b].get_hypothesis(k) if k < beam_size else beam[b].get_early_hypothesis( beam[b].finish_early_scores[k - beam_size][0], beam[b].finish_early_scores[k - beam_size][1]) for k in ks[:n_best] ] all_hyp += [hyps] return all_hyp, all_scores
def generate_question_batch(self, src1_seq, src1_pos, src1_emo, src1_bio, src1_bio_pos): ''' :param src_seq: :param src_pos: :return: ''' ''' Generate question batach by batch''' def get_inst_idx_to_tensor_position_map(inst_idx_list): return { inst_idx: tensor_position for tensor_position, inst_idx in enumerate(inst_idx_list) } def collect_active_part(beamed_tensor, curr_active_inst_idx, n_prev_active_inst, n_bm): _, *d_hs = beamed_tensor.size() n_curr_active_inst = len(curr_active_inst_idx) new_shape = (n_curr_active_inst * n_bm, *d_hs) beamed_tensor = beamed_tensor.view(n_prev_active_inst, -1) beamed_tensor = beamed_tensor.index_select(0, curr_active_inst_idx) beamed_tensor = beamed_tensor.view(*new_shape) return beamed_tensor def collate_active_info(src_seq, src_enc, inst_idx_to_position_map, active_inst_idx_list): ''' :param src_seq: :param src_enc: :param inst_idx_to_position_map: :param active_inst_idx_list: :return: ''' n_prev_active_inst = len(inst_idx_to_position_map) active_inst_idx = [ inst_idx_to_position_map[k] for k in active_inst_idx_list ] active_inst_idx = torch.LongTensor(active_inst_idx).to(self.device) active_src_seq = collect_active_part(src_seq, active_inst_idx, n_prev_active_inst, n_bm) active_src_enc = collect_active_part(src_enc, active_inst_idx, n_prev_active_inst, n_bm) active_inst_idx_to_position_map = get_inst_idx_to_tensor_position_map( active_inst_idx_list) return active_src_seq, active_src_enc, active_inst_idx_to_position_map def beam_decode_step(inst_dec_beams, len_dec_seq, src_seq, enc_output, inst_idx_to_position_map, n_bm): ''' :param inst_dec_beams: :param len_dec_seq: :param src_seq: :param enc_output: :param inst_idx_to_position_map: :param n_bm: :return: ''' def prepare_beam_dec_seq(inst_dec_beams, len_dec_seq): dec_partial_seq = [ b.get_current_state() for b in inst_dec_beams if not b.done ] dec_partial_seq = torch.stack(dec_partial_seq).to(self.device) dec_partial_seq = dec_partial_seq.view(-1, len_dec_seq) return dec_partial_seq def prepare_beam_dec_pos(len_dec_seq, n_active_inst, n_bm): dec_partial_pos = torch.arange(1, len_dec_seq + 1, dtype=torch.long, device=self.device) dec_partial_pos = dec_partial_pos.unsqueeze(0).repeat( n_active_inst * n_bm, 1) return dec_partial_pos def predict_word(dec_seq, dec_pos, src_seq, enc_output, n_active_inst, n_bm): dec_output, *_ = self.model.decoder(dec_seq, dec_pos, src_seq, enc_output) dec_output = dec_output[:, -1, :] # Pick the last step: (bh * bm) * d_h word_prob = F.log_softmax(self.model.tgt_word_prj( self.model.fc(dec_output)), dim=1) word_prob = word_prob.view(n_active_inst, n_bm, -1) return word_prob def collect_active_inst_idx_list(inst_beams, word_prob, inst_idx_to_position_map): active_inst_idx_list = [] for inst_idx, inst_position in inst_idx_to_position_map.items( ): is_inst_complete = inst_beams[inst_idx].advance( word_prob[inst_position]) if not is_inst_complete: active_inst_idx_list += [inst_idx] return active_inst_idx_list n_active_inst = len(inst_idx_to_position_map) dec_seq = prepare_beam_dec_seq(inst_dec_beams, len_dec_seq) dec_pos = prepare_beam_dec_pos(len_dec_seq, n_active_inst, n_bm) word_prob = predict_word(dec_seq, dec_pos, src_seq, enc_output, n_active_inst, n_bm) active_inst_idx_list = collect_active_inst_idx_list( inst_dec_beams, word_prob, inst_idx_to_position_map) return active_inst_idx_list def collect_hypothesis_and_scores(inst_dec_beams, n_best): all_hyp, all_scores = [], [] for inst_idx in range(len(inst_dec_beams)): scores, tail_idxs = inst_dec_beams[inst_idx].sort_scores() all_scores += [scores[:n_best]] hyps = [ inst_dec_beams[inst_idx].get_hypothesis(i) for i in tail_idxs[:n_best] ] all_hyp += [hyps] return all_hyp, all_scores with torch.no_grad(): src1_seq = src1_seq.to(self.device) src1_bio = src1_bio.to(self.device) src1_emo = src1_emo.to(self.device) # print('src1_seq:',src1_seq.shape) # print('src1_emo:',src1_emo.shape) # src1_enc, *_ = self.model.encoder1(src1_seq, src1_pos, src1_emo) # src2_enc, *_ = self.model.encoder2(src2_seq, src2_pos, src2_emo, src1_enc) # src3_enc, *_ = self.model.encoder3(src3_seq, src3_pos, src3_emo, src2_enc) src_enc1 = self.model.gcn(src1_seq, src1_bio, src1_emo) # (batch, 20, 300) src_enc2, _ = self.model.encoder1(src1_seq, src1_bio) # print("src_enc2.sahpe:",src_enc2.shape) # print("_.shape:",_.shape) # src_enc2 = self.model.layer3(src_enc2) src_enc = torch.cat((src_enc1, src_enc2), 2) src_enc = self.model.layer1(src_enc) # src_enc = (0.5*src_enc2 + 0.5*src_enc1) n_bm = self.opt.beam_size # 5 n_inst, len_s, d_h = src_enc.size() # (batch, 20, 300) src_seq = src1_seq.repeat(1, n_bm).view(n_inst * n_bm, len_s) # (batch * 5, 20) # (batch, 20*5, 300) --> (batch * 5, 20, 300) src_enc = src_enc.repeat(1, n_bm, 1).view(n_inst * n_bm, len_s, d_h) inst_dec_beams = [ Beam(n_bm, device=self.device) for _ in range(n_inst) ] active_inst_idx_list = list( range(n_inst)) # [0, 1, 2, ..., batch-1] inst_idx_to_position_map = get_inst_idx_to_tensor_position_map( active_inst_idx_list) for len_dec_seq in range(1, 30): active_inst_idx_list = beam_decode_step( inst_dec_beams, len_dec_seq, src_seq, src_enc, inst_idx_to_position_map, n_bm) if not active_inst_idx_list: break src_seq, src_enc, inst_idx_to_position_map = collate_active_info( src_seq, src_enc, inst_idx_to_position_map, active_inst_idx_list) batch_hyp, batch_scores = collect_hypothesis_and_scores( inst_dec_beams, self.opt.n_best) return batch_hyp, batch_scores
def translate_batch(model, batch, opt, model_options): model.eval() # prepare data #key = [triples[0] for triples in batch] src = [triples[1] for triples in batch] tgt = [triples[2] for triples in batch] src_seq, src_pad_mask = instances_handler.pad_to_longest(src) tgt_seq, tgt_pad_mask = instances_handler.pad_to_longest(tgt) src_seq = Variable(torch.FloatTensor( src_seq)) #batch * max length in batch * padded feature dim src_pad_mask = Variable(torch.LongTensor( src_pad_mask)) #batch * maxlength in batch * bool mask dim tgt_seq = Variable(torch.LongTensor( tgt_seq)) #batch * max length in batch * padded index dim tgt_pad_mask = Variable(torch.LongTensor( tgt_pad_mask)) #batch * maxlength in batch * bool mask dim if opt.use_gpu: src_seq = src_seq.cuda() src_pad_mask = src_pad_mask.cuda() tgt_seq = tgt_seq.cuda() tgt_pad_mask = tgt_pad_mask.cuda() goal = tgt_seq[:, 1:] tgt_seq = tgt_seq[:, :-1] tgt_pad_mask = tgt_pad_mask[:, :-1] beam_size = opt.beam_size batch_size = src_seq.size(0) #--------------------------------------------------------------------------------------- #- Enocde enc_output, *_ = model.encoder(src_seq, src_pad_mask) #--- Repeat data for beam src_seq = Variable( src_seq.data.repeat(1, beam_size, 1).view( src_seq.size(0) * beam_size, src_seq.size(1), src_seq.size(2))) src_pad_mask = Variable( src_pad_mask.data.repeat(1, beam_size).view( src_pad_mask.size(0) * beam_size, src_pad_mask.size(1))) enc_output = Variable( enc_output.data.repeat(1, beam_size, 1).view( enc_output.size(0) * beam_size, enc_output.size(1), enc_output.size(2))) #--- Prepare beams beams = [Beam(beam_size, opt.use_gpu) for _ in range(batch_size)] beam_inst_idx_map = { beam_idx: inst_idx for inst_idx, beam_idx in enumerate(range(batch_size)) } n_remaining_sents = batch_size #- Decode for i in range(opt.max_token_seq_len): len_dec_seq = i + 1 # -- Preparing decoded data seq -- # # size: batch x beam x seq dec_partial_seq = torch.stack( [b.get_current_state() for b in beams if not b.done]) dec_partial_seq = dec_partial_seq.view(-1, len_dec_seq).cpu().numpy() dec_partial_seq, dec_partial_seq_mask = instances_handler.pad_to_longest( dec_partial_seq) # size: (batch * beam) x seq dec_partial_seq = Variable(torch.LongTensor(dec_partial_seq)) dec_partial_seq_mask = Variable(torch.LongTensor(dec_partial_seq_mask)) if opt.use_gpu: dec_partial_seq = dec_partial_seq.cuda() dec_partial_seq_mask = dec_partial_seq_mask.cuda() # -- Decoding -- # dec_output, *_ = model.decoder(dec_partial_seq, dec_partial_seq_mask, src_pad_mask, enc_output) dec_output = dec_output[:, -1, :] # (batch * beam) * d_model out = model.prob_projection(dec_output) # batch x beam x n_words word_lk = out.view(n_remaining_sents, beam_size, -1).contiguous() active_beam_idx_list = [] for beam_idx in range(batch_size): if beams[beam_idx].done: continue inst_idx = beam_inst_idx_map[beam_idx] if not beams[beam_idx].advance(word_lk.data[inst_idx]): active_beam_idx_list += [beam_idx] if not active_beam_idx_list: # all instances have finished their path to <EOS> break # in this section, the sentences that are still active are # compacted so that the decoder is not run on completed sentences active_inst_idxs = torch.LongTensor( [beam_inst_idx_map[k] for k in active_beam_idx_list]) if opt.use_gpu: active_inst_idxs = active_inst_idxs.cuda() # update the idx mapping beam_inst_idx_map = { beam_idx: inst_idx for inst_idx, beam_idx in enumerate(active_beam_idx_list) } def update_active_seq(seq_var, active_inst_idxs): ''' Remove the src sequence of finished instances in one batch. ''' inst_idx_dim_size, *rest_dim_sizes = seq_var.size() inst_idx_dim_size = inst_idx_dim_size * len( active_inst_idxs) // n_remaining_sents new_size = (inst_idx_dim_size, *rest_dim_sizes) # select the active instances in batch original_seq_data = seq_var.data.view(n_remaining_sents, -1) active_seq_data = original_seq_data.index_select( 0, active_inst_idxs) active_seq_data = active_seq_data.view(*new_size) return Variable(active_seq_data, volatile=True) def update_active_enc_info(enc_info_var, active_inst_idxs): ''' Remove the encoder outputs of finished instances in one batch. ''' inst_idx_dim_size, *rest_dim_sizes = enc_info_var.size() inst_idx_dim_size = inst_idx_dim_size * len( active_inst_idxs) // n_remaining_sents new_size = (inst_idx_dim_size, *rest_dim_sizes) # select the active instances in batch original_enc_info_data = enc_info_var.data.view( n_remaining_sents, -1, model_options.d_model) active_enc_info_data = original_enc_info_data.index_select( 0, active_inst_idxs) active_enc_info_data = active_enc_info_data.view(*new_size) return Variable(active_enc_info_data, volatile=True) src_pad_mask = update_active_seq(src_pad_mask, active_inst_idxs) enc_output = update_active_enc_info(enc_output, active_inst_idxs) #- update the remaining size n_remaining_sents = len(active_inst_idxs) #- Return useful information all_hyp, all_scores = [], [] for beam_idx in range(batch_size): scores, tail_idxs = beams[beam_idx].sort_scores() all_scores += [scores[:opt.n_best]] hyps = [ beams[beam_idx].get_hypothesis(i) for i in tail_idxs[:opt.n_best] ] all_hyp += [hyps] return all_hyp, all_scores
def translate_batch(self, src_seq, src_pos, src_sen_pos): ''' Translation work in one batch ''' def get_inst_idx_to_tensor_position_map(inst_idx_list): ''' Indicate the position of an instance in a tensor. ''' return {inst_idx: tensor_position for tensor_position, inst_idx in enumerate(inst_idx_list)} def collect_active_part(beamed_tensor, curr_active_inst_idx, n_prev_active_inst, n_bm): ''' Collect tensor parts associated to active instances. ''' _, *d_hs = beamed_tensor.size() n_curr_active_inst = len(curr_active_inst_idx) new_shape = (n_curr_active_inst * n_bm, *d_hs) beamed_tensor = beamed_tensor.view(n_prev_active_inst, -1) beamed_tensor = beamed_tensor.index_select(0, curr_active_inst_idx) beamed_tensor = beamed_tensor.view(*new_shape) return beamed_tensor def collate_active_info( src_seq, src_enc, inst_idx_to_position_map, active_inst_idx_list): # Sentences which are still active are collected, # so the decoder will not run on completed sentences. n_prev_active_inst = len(inst_idx_to_position_map) active_inst_idx = [inst_idx_to_position_map[k] for k in active_inst_idx_list] active_inst_idx = torch.LongTensor(active_inst_idx).to(self.device) active_src_seq = collect_active_part(src_seq, active_inst_idx, n_prev_active_inst, n_bm) active_src_enc = collect_active_part(src_enc, active_inst_idx, n_prev_active_inst, n_bm) active_inst_idx_to_position_map = get_inst_idx_to_tensor_position_map(active_inst_idx_list) return active_src_seq, active_src_enc, active_inst_idx_to_position_map def beam_decode_step( inst_dec_beams, len_dec_seq, src_seq, enc_output, inst_idx_to_position_map, n_bm): ''' Decode and update beam status, and then return active beam idx ''' def prepare_beam_dec_seq(inst_dec_beams, len_dec_seq): dec_partial_seq = [b.get_current_seq_state() for b in inst_dec_beams if not b.done] dec_partial_seq = torch.stack(dec_partial_seq).to(self.device) dec_partial_seq = dec_partial_seq.view(-1, len_dec_seq) return dec_partial_seq def prepare_beam_dec_pos(inst_dec_beams, len_dec_seq): dec_partial_seq = [b.get_current_pos_state() for b in inst_dec_beams if not b.done] dec_partial_seq = torch.stack(dec_partial_seq).to(self.device) dec_partial_seq = dec_partial_seq.view(-1, len_dec_seq) return dec_partial_seq def prepare_beam_dec_sen_pos(inst_dec_beams, len_dec_seq): #print('inst_dec_beams',inst_dec_beams) #print('inst_dec_beams', type(inst_dec_beams)) #print('inst_dec_beams',inst_dec_beams.size) #print('inst_dec_beams[0]', inst_dec_beams[0].get_current_sen_pos_state()) dec_partial_seq = [b.get_current_sen_pos_state() for b in inst_dec_beams if not b.done] #print('dec_partial_seq',dec_partial_seq) dec_partial_seq = torch.stack(dec_partial_seq).to(self.device) dec_partial_seq = dec_partial_seq.view(-1, len_dec_seq) return dec_partial_seq def predict_word(dec_seq, dec_pos, dec_sen_pos, src_seq, enc_output, n_active_inst, n_bm): #print('dec_seq:',dec_seq.shape) #print('dec_pos:', dec_pos.shape) #print('dec_sen_pos:', dec_sen_pos.shape) #print('src_seq:', src_seq.shape) #print('enc_output', enc_output.shape) #print('n_active_inst', n_active_inst) #print('n_bm:', n_bm) dec_output, *_ = self.model.decoder(dec_seq, dec_pos, dec_sen_pos, src_seq, enc_output) dec_output = dec_output[:, -1, :] # Pick the last step: (bh * bm) * d_h logits = self.model.tgt_word_prj(dec_output) ## word_prob actually logits = F.log_softmax(logits, dim=1) # UNK mask logits[:, Constants.UNK] = -1e19 #rm_set = set(Constants.BOSs+[13]) rm_set = set(Constants.BOSs+[19]) # 19 => "." #logits[:, 19] -= logits[:, 19].abs()* + 1e-10 for i, (ins, pos, sen_pos) in enumerate(zip(dec_seq, dec_pos, dec_sen_pos)): current_sen_pos = sen_pos[-1] for token, s_pos in zip(ins.flip(0), sen_pos.flip(0)): length_norm = len(ins) if token.item() not in rm_set and s_pos == current_sen_pos: logits[i, token] -= 5+1e-19 if s_pos != current_sen_pos: logits[i, token] -= 20/length_norm - 1e-19 #break #word_prob = F.log_softmax(logits, dim=1) #word_prob = word_prob.view(n_active_inst, n_bm, -1) word_prob = logits.view(n_active_inst, n_bm, -1) return word_prob def collect_active_inst_idx_list(inst_beams, word_prob, inst_idx_to_position_map): active_inst_idx_list = [] for inst_idx, inst_position in inst_idx_to_position_map.items(): is_inst_complete = inst_beams[inst_idx].advance(word_prob[inst_position]) if not is_inst_complete: active_inst_idx_list += [inst_idx] return active_inst_idx_list n_active_inst = len(inst_idx_to_position_map) dec_seq = prepare_beam_dec_seq(inst_dec_beams, len_dec_seq) dec_pos = prepare_beam_dec_pos(inst_dec_beams, len_dec_seq) dec_sen_pos = prepare_beam_dec_sen_pos(inst_dec_beams, len_dec_seq) #dec_pos = prepare_beam_dec_pos(len_dec_seq, n_active_inst, n_bm) #dec_pos, dec_sen_pos = prepare_beam_dec_pos(dec_seq) word_prob = predict_word(dec_seq, dec_pos, dec_sen_pos, src_seq, enc_output, n_active_inst, n_bm) # Update the beam with predicted word prob information and collect incomplete instances active_inst_idx_list = collect_active_inst_idx_list( inst_dec_beams, word_prob, inst_idx_to_position_map) return active_inst_idx_list def collect_hypothesis_and_scores(inst_dec_beams, n_best): all_hyp, all_scores = [], [] for inst_idx in range(len(inst_dec_beams)): scores, tail_idxs = inst_dec_beams[inst_idx].sort_scores() all_scores += [scores[:n_best]] hyps = [inst_dec_beams[inst_idx].get_seq_hypothesis(i) for i in tail_idxs[:n_best]] all_hyp += [hyps] return all_hyp, all_scores with torch.no_grad(): #-- Encode src_seq, src_pos, src_sen_pos = src_seq.to(self.device), src_pos.to(self.device), src_sen_pos.to(self.device) src_enc, *_ = self.model.encoder(src_seq, src_pos, src_sen_pos) #-- Repeat data for beam search n_bm = self.opt.beam_size n_inst, len_s, d_h = src_enc.size() #print('n_inst',n_inst) #print('len_s',len_s) #print('d_h',d_h) #print('n_bm',n_bm) src_seq = src_seq.repeat(1, n_bm).view(n_inst * n_bm, len_s) src_enc = src_enc.repeat(1, n_bm, 1).view(n_inst * n_bm, len_s, d_h) #-- Prepare beams inst_dec_beams = [Beam(n_bm, device=self.device) for _ in range(n_inst)] #print('inst_dec_beams.shape',len(inst_dec_beams)) #print('inst_dec_beams.shape',inst_dec_beams[0].size) #print('inst_dec_beams[0].getcurrent', inst_dec_beams[0].get_current_sen_pos_state()) #-- Bookkeeping for active or not active_inst_idx_list = list(range(n_inst)) inst_idx_to_position_map = get_inst_idx_to_tensor_position_map(active_inst_idx_list) #-- Decode #for len_dec_seq in range(1, self.model_opt.max_token_seq_len + 1): for len_dec_seq in range(1, 200): active_inst_idx_list = beam_decode_step( inst_dec_beams, len_dec_seq, src_seq, src_enc, inst_idx_to_position_map, n_bm) if not active_inst_idx_list or len_dec_seq > 50: break # all instances have finished their path to <EOS> src_seq, src_enc, inst_idx_to_position_map = collate_active_info( src_seq, src_enc, inst_idx_to_position_map, active_inst_idx_list) batch_hyp, batch_scores = collect_hypothesis_and_scores(inst_dec_beams, self.opt.n_best) return batch_hyp, batch_scores
def sample(self, src_seq, enc_outputs): """Samples captions for given image features (Greedy search).""" beam_size = 1 batch_size = src_seq.size(0) self.softmax = nn.LogSoftmax() self.tt = torch.cuda if torch.cuda.is_available() else torch # Repeat Data src_seq = Variable( src_seq.data.repeat(1, beam_size).view( src_seq.size(0) * beam_size, src_seq.size(1))) for i in range(len(enc_outputs)): enc_output = enc_outputs[i] enc_outputs[i] = Variable( enc_output.data.repeat(1, beam_size, 1).view( enc_output.size(0) * beam_size, enc_output.size(1), enc_output.size(2))) # --- Prepare beams beams = [ Beam(beam_size, torch.cuda.is_available()) for _ in range(batch_size) ] beam_inst_idx_map = { beam_idx: inst_idx for inst_idx, beam_idx in enumerate(range(batch_size)) } n_remaining_sents = batch_size # - Decode for i in range(20): len_dec_seq = i + 1 # -- Preparing decoded data seq -- # # size: batch x beam x seq dec_partial_seq = torch.stack( [b.get_current_state() for b in beams if not b.done]) # size: (batch * beam) x seq dec_partial_seq = dec_partial_seq.view(-1, len_dec_seq) # wrap into a Variable dec_partial_seq = Variable(dec_partial_seq, volatile=True) if torch.cuda.is_available(): dec_partial_seq = dec_partial_seq.cuda() # -- Decoding -- # dec_output = self(src_seq, dec_partial_seq, enc_outputs, [len_dec_seq] * n_remaining_sents * beam_size, False) dec_output = dec_output[:, -1, :] # (batch * beam) * d_model out = self.softmax(dec_output) # batch x beam x n_words word_lk = out.view(n_remaining_sents, beam_size, -1).contiguous() active_beam_idx_list = [] for beam_idx in range(batch_size): if beams[beam_idx].done: continue inst_idx = beam_inst_idx_map[beam_idx] if not beams[beam_idx].advance(word_lk.data[inst_idx]): active_beam_idx_list += [beam_idx] if not active_beam_idx_list: # all instances have finished their path to <EOS> break # in this section, the sentences that are still active are # compacted so that the decoder is not run on completed sentences active_inst_idxs = self.tt.LongTensor( [beam_inst_idx_map[k] for k in active_beam_idx_list]) # update the idx mapping beam_inst_idx_map = { beam_idx: inst_idx for inst_idx, beam_idx in enumerate(active_beam_idx_list) } def update_active_seq(seq_var, active_inst_idxs): ''' Remove the src sequence of finished instances in one batch. ''' inst_idx_dim_size, b = seq_var.size() inst_idx_dim_size = inst_idx_dim_size * len( active_inst_idxs) // n_remaining_sents new_size = inst_idx_dim_size, b # select the active instances in batch original_seq_data = seq_var.data.view(n_remaining_sents, -1) active_seq_data = original_seq_data.index_select( 0, active_inst_idxs) active_seq_data = active_seq_data.view(*new_size) return Variable(active_seq_data, volatile=True) def update_active_enc_info(enc_info_var, active_inst_idxs): ''' Remove the encoder outputs of finished instances in one batch. ''' inst_idx_dim_size, b, c = enc_info_var.size() inst_idx_dim_size = inst_idx_dim_size * len( active_inst_idxs) // n_remaining_sents new_size = inst_idx_dim_size, b, c # select the active instances in batch original_enc_info_data = enc_info_var.data.view( n_remaining_sents, -1, self.d_model) active_enc_info_data = original_enc_info_data.index_select( 0, active_inst_idxs) active_enc_info_data = active_enc_info_data.view(*new_size) return Variable(active_enc_info_data, volatile=True) src_seq = update_active_seq(src_seq, active_inst_idxs) for j in range(len(enc_outputs)): enc_outputs[j] = update_active_enc_info( enc_outputs[j], active_inst_idxs) # - update the remaining size n_remaining_sents = len(active_inst_idxs) # - Return useful information all_hyp, all_scores = [], [] n_best = 1 for beam_idx in range(batch_size): scores, tail_idxs = beams[beam_idx].sort_scores() all_scores += [scores[:n_best]] hyps = [ beams[beam_idx].get_hypothesis(i) for i in tail_idxs[:n_best] ] all_hyp += [hyps] return all_hyp
def decode_batch(self, src_batch): ''' Translation work in one batch ''' # Batch size is in different location depending on data. src_seq = src_batch batch_size = src_seq.size(0) beam_size = self.opt.beam_size # - Enocde enc_output, src_mask = self.model.encoder(src_seq) # print('enc_output.size', enc_output.size()) #(batch, length, d_model) # print('src_mask.size()', src_mask.size()) # (batch, 1, length) # (batch * beam_size, length, d_model) enc_output = Variable( enc_output.data.repeat(1, beam_size, 1).view( enc_output.size(0) * beam_size, enc_output.size(1), enc_output.size(2))) # (batch * beam_size, 1, d_model) src_mask = src_mask.repeat(1, beam_size, 1).view( src_mask.size(0) * beam_size, src_mask.size(1), src_mask.size(2)) # --- Prepare beams beams = [Beam(beam_size, self.device) for _ in range(batch_size)] # print('beams:',beams) beam_inst_idx_map = { beam_idx: inst_idx for inst_idx, beam_idx in enumerate(range(batch_size)) } # print('beam_inst_idx_map:',beam_inst_idx_map) n_remaining_sents = batch_size # - Decode for i in range(self.model_opt.label_max_len): # print('-'*20) len_dec_seq = i + 1 # print(len_dec_seq) # -- Preparing decoded data seq -- # # size: (batch , beam , len_dec_seq) dec_partial_seq = torch.stack( [b.get_current_state() for b in beams if not b.done]) # print('dec_partial_seq 1',dec_partial_seq.size()) # size: (batch * beam , len_dec_seq) dec_partial_seq = dec_partial_seq.view(-1, len_dec_seq) # print('dec_partial_seq 2:\n', dec_partial_seq) dec_partial_seq = dec_partial_seq.to(self.device) # -- Decoding -- # # (batch * beam, len_dec_seq, d_model) dec_output = self.model.decoder(dec_partial_seq, enc_output, src_mask) # print('dec_output:',dec_output.size()) # (batch * beam, d_model) dec_output = dec_output[:, -1, :] # (batch * beam, vocab_size) dec_output = self.model.final_proj(dec_output) # print('decoder output shape:', dec_output.size()) # (batch * beam, vocab_size) logSoftmax out = self.model.log_softmax(dec_output) # (batch , beam , vocab_size) word_lk = out.view(n_remaining_sents, beam_size, -1).contiguous() active_beam_idx_list = [] for beam_idx in range(batch_size): # current case in batch, is predicted EOS. if beams[beam_idx].done: # print('continue','\n'*100) continue inst_idx = beam_inst_idx_map[beam_idx] # print('word_lk.data[%d]'%(inst_idx),word_lk.data[inst_idx]) # word_lk.data[inst_idx] (beam_size, vocab_size) current inst of batch if not beams[beam_idx].advance(word_lk.data[inst_idx]): active_beam_idx_list += [beam_idx] if not active_beam_idx_list: # all instances have finished their path to <EOS> break # in this section, the sentences that are still active are # compacted so that the decoder is not run on completed sentences active_inst_idxs = torch.LongTensor( [beam_inst_idx_map[k] for k in active_beam_idx_list]) # update the idx mapping beam_inst_idx_map = { beam_idx: inst_idx for inst_idx, beam_idx in enumerate(active_beam_idx_list) } # print('beam_inst_idx_map2:\n',beam_inst_idx_map) def update_active_seq(seq_var, active_inst_idxs): ''' Remove the src sequence of finished instances in one batch. ''' inst_idx_dim_size, *rest_dim_sizes = seq_var.size() inst_idx_dim_size = inst_idx_dim_size * \ len(active_inst_idxs) // n_remaining_sents new_size = (inst_idx_dim_size, *rest_dim_sizes) # select the active instances in batch original_seq_data = seq_var.data.view(n_remaining_sents, -1) active_seq_data = original_seq_data.index_select( 0, active_inst_idxs) active_seq_data = active_seq_data.view(*new_size) with torch.no_grad(): return Variable(active_seq_data) def update_active_enc_info(enc_info_var, active_inst_idxs): ''' Remove the encoder outputs of finished instances in one batch. ''' # (batch * beam, length, d_model) inst_idx_dim_size, *rest_dim_sizes = enc_info_var.size() inst_idx_dim_size = inst_idx_dim_size * len( active_inst_idxs) // n_remaining_sents new_size = (inst_idx_dim_size, *rest_dim_sizes) # print('new_size:\n',new_size) # print(n_remaining_sents) # select the active instances in batch # (batch, beam, d_model) original_enc_info_data = enc_info_var.data.view( n_remaining_sents, -1, enc_info_var.size(2)) # select instance of batch (new_batch, beam, d_model) active_enc_info_data = original_enc_info_data.index_select( 0, active_inst_idxs) active_enc_info_data = active_enc_info_data.view(*new_size) with torch.no_grad(): return Variable(active_enc_info_data) enc_output = update_active_enc_info( enc_output, active_inst_idxs.to(self.device)) src_mask = update_active_enc_info(src_mask, active_inst_idxs.to(self.device)) # - update the remaining size n_remaining_sents = len(active_inst_idxs) # - Return useful information all_hyp, all_scores = [], [] n_best = self.opt.n_best for beam_idx in range(batch_size): scores, tail_idxs = beams[beam_idx].sort_scores() all_scores += [scores[:n_best]] # hyps1 = [beams[beam_idx].get_hypothesis( # i) for i in tail_idxs[:n_best]] # print(torch.LongTensor(hyps1)) hyps = torch.LongTensor(beams[beam_idx].bestpath)[:n_best, 1:] # print(hyps) # assert torch.LongTensor(hyps1).equal(hyps) all_hyp += [hyps] return all_hyp, all_scores
def translate_batch(self, src_seq, src_pos): """ Translation work in one batch """ def get_inst_idx_to_tensor_position_map(inst_idx_list): """ Indicate the position of an instance in a tensor. """ return { inst_idx: tensor_position for tensor_position, inst_idx in enumerate(inst_idx_list) } def collect_active_part(beamed_tensor, curr_active_inst_idx, n_prev_active_inst, n_bm): """ Collect tensor parts associated to active instances. """ _, *d_hs = beamed_tensor.size() n_curr_active_inst = len(curr_active_inst_idx) new_shape = (n_curr_active_inst * n_bm, *d_hs) beamed_tensor = beamed_tensor.view(n_prev_active_inst, -1) beamed_tensor = beamed_tensor.index_select(0, curr_active_inst_idx) beamed_tensor = beamed_tensor.view(*new_shape) return beamed_tensor def collate_active_info(src_seq, src_enc, inst_idx_to_position_map, active_inst_idx_list): # Sentences which are still active are collected, # so the decoder will not run on completed sentences. n_prev_active_inst = len(inst_idx_to_position_map) active_inst_idx = [ inst_idx_to_position_map[k] for k in active_inst_idx_list ] active_inst_idx = torch.LongTensor(active_inst_idx).to(self.device) active_src_seq = collect_active_part(src_seq, active_inst_idx, n_prev_active_inst, n_bm) active_src_enc = collect_active_part(src_enc, active_inst_idx, n_prev_active_inst, n_bm) active_inst_idx_to_position_map = \ get_inst_idx_to_tensor_position_map(active_inst_idx_list) return active_src_seq, active_src_enc, active_inst_idx_to_position_map def beam_decode_step(inst_dec_beams, len_dec_seq, src_seq, enc_output, inst_idx_to_position_map, n_bm): """ Decode and update beam status, and then return active beam idx """ def prepare_beam_dec_seq(inst_dec_beams, len_dec_seq): dec_partial_seq = [ b.get_current_state() for b in inst_dec_beams if not b.done ] dec_partial_seq = torch.stack(dec_partial_seq).to(self.device) dec_partial_seq = dec_partial_seq.view(-1, len_dec_seq) return dec_partial_seq def prepare_beam_dec_pos(len_dec_seq, n_active_inst, n_bm): dec_partial_pos = torch.arange(1, len_dec_seq + 1, dtype=torch.long, device=self.device) dec_partial_pos = dec_partial_pos.unsqueeze(0).repeat( n_active_inst * n_bm, 1) return dec_partial_pos def predict_word(dec_seq, dec_pos, src_seq, enc_output, n_active_inst, n_bm): ############################################################## # Make Mask # Find The Token Appeared in src sequence uniques = torch.unique(src_seq, dim=1).cpu().detach().numpy() # print(uniques.shape) p_gen_mask_shape = (src_seq.shape[0], self.n_voca) p_gen_mask = torch.tensor(np.zeros(p_gen_mask_shape), dtype=torch.float) batch_size = src_seq.shape[0] p_gen_mask[np.arange(batch_size)[:, None], uniques] = 1 p_gen_mask = p_gen_mask.to(self.device) # print(p_gen_mask.shape) ############################################################ dec_output, *_ = self.model.decoder(dec_seq, dec_pos, src_seq, enc_output) print("dec_output.shape | before reshape", dec_output.shape) p_gen, *_ = self.model.p_generator(dec_seq, dec_pos, src_seq, enc_output) p_gen = p_gen[:, -1, :] # to get just last one.. why? # print("p_gen", p_gen.shape) p_gen = self.model.p_gen_linear(p_gen) p_gen = self.model.p_gen_sig(p_gen) dec_output = dec_output[:, -1, :] print("dec_output.shape | after reshape", dec_output.shape) seq_logit = self.model.tgt_word_prj(dec_output) # print("seq_logit.shape | wo rd_prj", seq_logit.shape) seq_max_len = 1 print("seq_max_len", seq_max_len) p_gen_mask = p_gen_mask[:, None, :] p_gen_mask = p_gen_mask[:, -1, :] print("p_gen_mask", p_gen_mask.shape) p_gen_mask = torch.repeat_interleave(p_gen_mask, seq_max_len, dim=1) masked_seq_logit = seq_logit * p_gen_mask print("masked_seq_logit.shape", masked_seq_logit.shape) ########################################################### softmax = torch.nn.Softmax(dim=1) prb_gen = softmax(seq_logit) prb_cp = softmax(masked_seq_logit) exclusive_copy_or_gen = True if exclusive_copy_or_gen: p_gen = p_gen > 1 / 2 p_gen = p_gen.to(torch.float) prb = prb_gen * p_gen + prb_cp * (1 - p_gen) word_prob = prb.log() # Pick the last step: (bh * bm) * d_h # word_prob = F.log_softmax(self.model.tgt_word_prj(dec_output), # dim=1) word_prob = word_prob.view(n_active_inst, n_bm, -1) return word_prob def collect_active_inst_idx_list(inst_beams, word_prob, inst_idx_to_position_map): active_inst_idx_list = [] for inst_idx, inst_position in inst_idx_to_position_map.items( ): is_inst_complete = inst_beams[inst_idx].advance( word_prob[inst_position]) if not is_inst_complete: active_inst_idx_list += [inst_idx] return active_inst_idx_list n_active_inst = len(inst_idx_to_position_map) # int dec_seq = prepare_beam_dec_seq(inst_dec_beams, len_dec_seq) dec_pos = prepare_beam_dec_pos(len_dec_seq, n_active_inst, n_bm) word_prob = predict_word( dec_seq, dec_pos, src_seq, enc_output, n_active_inst, # what is n_active_inst? n_bm) # what is n_bm? # Update the beam with predicted word prob information and collect incomplete instances active_inst_idx_list = collect_active_inst_idx_list( inst_dec_beams, word_prob, inst_idx_to_position_map) return active_inst_idx_list def collect_hypothesis_and_scores(inst_dec_beams, n_best): all_hyp, all_scores = [], [] for inst_idx in range(len(inst_dec_beams)): scores, tail_idxs = inst_dec_beams[inst_idx].sort_scores() all_scores += [scores[:n_best]] hyps = [ inst_dec_beams[inst_idx].get_hypothesis(i) for i in tail_idxs[:n_best] ] all_hyp += [hyps] return all_hyp, all_scores with torch.no_grad(): # -- Encode src_seq, src_pos = src_seq.to(self.device), src_pos.to(self.device) src_enc, *_ = self.model.encoder(src_seq, src_pos) # -- Repeat data for beam search n_bm = self.opt.beam_size n_inst, len_s, d_h = src_enc.size() src_seq = src_seq.repeat(1, n_bm).view(n_inst * n_bm, len_s) src_enc = src_enc.repeat(1, n_bm, 1).view(n_inst * n_bm, len_s, d_h) # -- Prepare beams inst_dec_beams = [ Beam(n_bm, device=self.device) for _ in range(n_inst) ] # -- Bookkeeping for active or not active_inst_idx_list = list(range(n_inst)) inst_idx_to_position_map = get_inst_idx_to_tensor_position_map( active_inst_idx_list) # -- Decode # seq_logit_group = [] for len_dec_seq in range(1, self.model_opt.max_token_seq_len + 1): active_inst_idx_list = beam_decode_step( inst_dec_beams, len_dec_seq, src_seq, src_enc, inst_idx_to_position_map, n_bm) # seq_logit_group.append(seq_logit) if not active_inst_idx_list: break # all instances have finished their path to <EOS> src_seq, src_enc, inst_idx_to_position_map = \ collate_active_info(src_seq, src_enc, inst_idx_to_position_map, active_inst_idx_list) batch_hyp, batch_scores = collect_hypothesis_and_scores( inst_dec_beams, self.opt.n_best) return batch_hyp, batch_scores #, seq_logit_group