def _translate_batch( self, batch, src_vocabs, max_length, min_length=0, ratio=0., n_best=1, return_attention=False): # TODO: support these blacklisted features. assert not self.dump_beam # (0) Prep the components of the search. use_src_map = self.copy_attn beam_size = self.beam_size batch_size = batch.batch_size # (1) Run the encoder on the src. src, enc_states, memory_bank, src_lengths = self._run_encoder(batch) self.model.decoder.init_state(src, memory_bank, enc_states) results = { "predictions": None, "scores": None, "attention": None, "batch": batch, "gold_score": self._gold_score( batch, memory_bank, src_lengths, src_vocabs, use_src_map, enc_states, batch_size, src)} # (2) Repeat src objects `beam_size` times. # We use batch_size x beam_size src_map = (tile(batch.src_map, beam_size, dim=1) if use_src_map else None) self.model.decoder.map_state( lambda state, dim: tile(state, beam_size, dim=dim)) if isinstance(memory_bank, tuple): memory_bank = tuple(tile(x, beam_size, dim=1) for x in memory_bank) mb_device = memory_bank[0].device else: memory_bank = tile(memory_bank, beam_size, dim=1) mb_device = memory_bank.device memory_lengths = tile(src_lengths, beam_size) # (0) pt 2, prep the beam object beam = BeamSearch( beam_size, n_best=n_best, batch_size=batch_size, global_scorer=self.global_scorer, pad=self._tgt_pad_idx, eos=self._tgt_eos_idx, bos=self._tgt_bos_idx, min_length=min_length, ratio=ratio, max_length=max_length, mb_device=mb_device, return_attention=return_attention, stepwise_penalty=self.stepwise_penalty, block_ngram_repeat=self.block_ngram_repeat, exclusion_tokens=self._exclusion_idxs, memory_lengths=memory_lengths) for step in range(max_length): decoder_input = beam.current_predictions.view(1, -1, 1) log_probs, attn = self._decode_and_generate( decoder_input, memory_bank, batch, src_vocabs, memory_lengths=memory_lengths, src_map=src_map, step=step, batch_offset=beam._batch_offset) beam.advance(log_probs, attn) any_beam_is_finished = beam.is_finished.any() if any_beam_is_finished: beam.update_finished() if beam.done: break select_indices = beam.current_origin if any_beam_is_finished: # Reorder states. if isinstance(memory_bank, tuple): memory_bank = tuple(x.index_select(1, select_indices) for x in memory_bank) else: memory_bank = memory_bank.index_select(1, select_indices) memory_lengths = memory_lengths.index_select(0, select_indices) if src_map is not None: src_map = src_map.index_select(1, select_indices) self.model.decoder.map_state( lambda state, dim: state.index_select(dim, select_indices)) results["scores"] = beam.scores results["predictions"] = beam.predictions results["attention"] = beam.attention return results
def _translate_batch( self, batch, src_vocabs, max_length, min_length=0, ratio=0., n_best=1, return_attention=False): # TODO: support these blacklisted features. assert not self.dump_beam # (0) Prep the components of the search. use_src_map = self.copy_attn beam_size = self.beam_size #default 5 batch_size = batch.batch_size #default 30 # (1) Run the encoder on the src. src, enc_states, memory_bank, src_lengths = self._run_encoder(batch) #src.size() = torch.Size([59, 30, 1]) [src_len,batch_size,1] #enc_states[0/1].size() = [2,30,500] #memory_bank.size() =[59,30,500] #src_lengths.size() = [30] self.model.decoder.init_state(src, memory_bank, enc_states) results = { "predictions": None, "scores": None, "attention": None, "batch": batch, "gold_score": self._gold_score( batch, memory_bank, src_lengths, src_vocabs, use_src_map, enc_states, batch_size, src)} # (2) Repeat src objects `beam_size` times. # We use batch_size x beam_size src_map = (tile(batch.src_map, beam_size, dim=1) if use_src_map else None) self.model.decoder.map_state( lambda state, dim: tile(state, beam_size, dim=dim)) if isinstance(memory_bank, tuple): memory_bank = tuple(tile(x, beam_size, dim=1) for x in memory_bank) #把张量x在dim=1,重复beam_size次。beam_size=1是batch_size的维度。 mb_device = memory_bank[0].device else: memory_bank = tile(memory_bank, beam_size, dim=1) mb_device = memory_bank.device memory_lengths = tile(src_lengths, beam_size) # (0) pt 2, prep the beam object beam = BeamSearch( beam_size, n_best=n_best, batch_size=batch_size, global_scorer=self.global_scorer, pad=self._tgt_pad_idx, eos=self._tgt_eos_idx, bos=self._tgt_bos_idx, min_length=min_length, ratio=ratio, max_length=max_length, #Maximum prediction length. mb_device=mb_device, return_attention=return_attention, stepwise_penalty=self.stepwise_penalty, block_ngram_repeat=self.block_ngram_repeat, exclusion_tokens=self._exclusion_idxs, memory_lengths=memory_lengths) for step in range(max_length): #一共走这个多个step,每个step将beam_size * batch_size个分支加入 decoder_input = beam.current_predictions.view(1, -1, 1) #decoder_input.size() = torch.Size([1,150,1]) 150 = 30 * 5 = batch_size * beam_size # @property # def current_predictions(self): # return self.alive_seq[:, -1] log_probs, attn = self._decode_and_generate( decoder_input, memory_bank,#torch.Size([59, 150, 500]) batch, src_vocabs, memory_lengths=memory_lengths, src_map=src_map, step=step, batch_offset=beam._batch_offset) # print("log_probs = ",log_probs) #[150, 50004] 这个50004应该是词表的大小,词表中的单词应该是5万,多出来的4个应该是<s> </s> <unk> <pad> # print("attn = ",attn) #torch.Size([1, 150, 59]) 这个59应该是src中的最长的句子的长度 # print("decoder_input = ",decoder_input.size()) beam.advance(log_probs, attn)#这个里面完成的工作应该是将150再变回30, any_beam_is_finished = beam.is_finished.any() if any_beam_is_finished: beam.update_finished() if beam.done: break select_indices = beam.current_origin if any_beam_is_finished: # Reorder states. if isinstance(memory_bank, tuple): memory_bank = tuple(x.index_select(1, select_indices) for x in memory_bank) else: memory_bank = memory_bank.index_select(1, select_indices) memory_lengths = memory_lengths.index_select(0, select_indices) if src_map is not None: src_map = src_map.index_select(1, select_indices) self.model.decoder.map_state( lambda state, dim: state.index_select(dim, select_indices)) results["scores"] = beam.scores results["predictions"] = beam.predictions results["attention"] = beam.attention return results
def _translate_batch( self, batch, src_vocabs, max_length, min_length=0, ratio=0., n_best=1, return_attention=False,tags=[]): # TODO: support these blacklisted features. assert not self.dump_beam # (0) Prep the components of the search. use_src_map = self.copy_attn beam_size = self.beam_size batch_size = batch.batch_size tags = self.ctags if tags[-1] =="EN" or tags[-1]=="DE": lang = tags[-1] tags = tags[:-1] else: lang =None assert(False) if lang is not None: enc1 = self.model.decoder enc2 = self.model.decoder2 if lang =="EN": self.model.decoder=enc2 # (1) Run the encoder on the src. allstuff = self._run_encoder(batch,tags=tags) #print (len(allstuff)) if len(allstuff) == 3: src, enc_states, memory_bank, = allstuff elif len(allstuff) == 4: src, enc_states, memory_bank, src_lengths =allstuff thing = (enc_states.data.cpu().numpy()) lengths = (src_lengths.data.cpu().numpy()) maxvecs = [] self.model.decoder.init_state(src, memory_bank, enc_states) results = { "predictions": None, "scores": None, "attention": None, "batch": batch, "gold_score": self._gold_score( batch, memory_bank, src_lengths, src_vocabs, use_src_map, enc_states, batch_size, src)} # (2) Repeat src objects `beam_size` times. # We use batch_size x beam_size src_map = (tile(batch.src_map, beam_size, dim=1) if use_src_map else None) self.model.decoder.map_state( lambda state, dim: tile(state, beam_size, dim=dim)) if isinstance(memory_bank, tuple): memory_bank = tuple(tile(x, beam_size, dim=1) for x in memory_bank) mb_device = memory_bank[0].device else: memory_bank = tile(memory_bank, beam_size, dim=1) mb_device = memory_bank.device memory_lengths = tile(src_lengths, beam_size) # (0) pt 2, prep the beam object if hasattr(batch,"tgt"): print ("tgt") tgt = (batch.tgt.data.cpu().numpy()).T else: print ("no tgt") tgt = [[] for _ in range(100000)] # beam = BeamSearch( beam_size, n_best=n_best, batch_size=batch_size, global_scorer=self.global_scorer, pad=self._tgt_pad_idx, eos=self._tgt_eos_idx, bos=self._tgt_bos_idx, min_length=min_length, ratio=ratio, max_length=max_length, mb_device=mb_device, return_attention=return_attention, stepwise_penalty=self.stepwise_penalty, block_ngram_repeat=self.block_ngram_repeat, exclusion_tokens=self._exclusion_idxs, memory_lengths=memory_lengths,i2w=self._tgt_vocab.itos,batch=batch) if hasattr(batch,"tgt"): tgt = (batch.tgt.data.cpu().numpy()).T.squeeze() else: tgt = [[] for _ in range(100000)] # for step in range(max_length): decoder_input = beam.current_predictions.view(1, -1, 1) log_probs, attn = self._decode_and_generate( decoder_input, memory_bank, batch, src_vocabs, memory_lengths=memory_lengths, src_map=src_map, step=step, batch_offset=beam._batch_offset) beam.advance(log_probs, attn) any_beam_is_finished = beam.is_finished.any() if any_beam_is_finished: beam.update_finished() if beam.done: break select_indices = beam.current_origin if any_beam_is_finished: # Reorder states. if isinstance(memory_bank, tuple): memory_bank = tuple(x.index_select(1, select_indices) for x in memory_bank) else: memory_bank = memory_bank.index_select(1, select_indices) memory_lengths = memory_lengths.index_select(0, select_indices) if src_map is not None: src_map = src_map.index_select(1, select_indices) print ("NOW") print (src_map) if lang is not None and lang=="EN" and False: self.model.decoder2.map_state( lambda state, dim: state.index_select(dim, select_indices)) else: self.model.decoder.map_state( lambda state, dim: state.index_select(dim, select_indices)) #print (len(beam.scores)) results["scores"] = beam.scores results["predictions"] = beam.predictions results["attention"] = beam.attention results["maxvecs"] = [] if lang is not None: self.decoder=enc1 return results
def _translate_batch( self, batch, src_vocabs, max_length, min_length=0, ratio=0., n_best=1, return_attention=False): # TODO: support these blacklisted features. assert not self.dump_beam # (0) Prep the components of the search. use_src_map = self.copy_attn beam_size = self.beam_size batch_size = batch.batch_size # (1) Run the encoder on the src. src, enc_states, memory_bank, src_lengths = self._run_encoder(batch) self.model.decoder.init_state(src, memory_bank, enc_states) results = { "predictions": None, "scores": None, "attention": None, "batch": batch, "gold_score": self._gold_score( batch, memory_bank, src_lengths, src_vocabs, use_src_map, enc_states, batch_size, src)} # (2) Repeat src objects `beam_size` times. # We use batch_size x beam_size src_map = (tile(batch.src_map, beam_size, dim=1) if use_src_map else None) self.model.decoder.map_state( lambda state, dim: tile(state, beam_size, dim=dim)) if isinstance(memory_bank, tuple): memory_bank = tuple(tile(x, beam_size, dim=1) for x in memory_bank) mb_device = memory_bank[0].device else: memory_bank = tile(memory_bank, beam_size, dim=1) mb_device = memory_bank.device memory_lengths = tile(src_lengths, beam_size) # (0) pt 2, prep the beam object beam = BeamSearch( beam_size, n_best=n_best, batch_size=batch_size, global_scorer=self.global_scorer, pad=self._tgt_pad_idx, eos=self._tgt_eos_idx, bos=self._tgt_bos_idx, min_length=min_length, ratio=ratio, max_length=max_length, mb_device=mb_device, return_attention=return_attention, stepwise_penalty=self.stepwise_penalty, block_ngram_repeat=self.block_ngram_repeat, exclusion_tokens=self._exclusion_idxs, memory_lengths=memory_lengths) states = None #see = 23 see = 0 if self.constraint: itos = self.fields["tgt"].base_field.vocab.itos stoi = self.fields["tgt"].base_field.vocab.stoi states = BB_sequence_state( itos, stoi, mb_device, batch_size, beam_size, eos=self._tgt_eos_idx) for step in range(max_length): decoder_input = beam.current_predictions.view(1, -1, 1) #print("================= step",step) #print(decoder_input[0][see*10:(see+1)*10,0]) log_probs, attn = self._decode_and_generate( decoder_input, states, memory_bank, batch, src_vocabs, memory_lengths=memory_lengths, src_map=src_map, step=step, batch_offset=beam._batch_offset) beam.advance(log_probs, attn) lastest_action = beam.current_predictions.data.tolist() lastest_score = beam.current_scores.view(-1).data.tolist() select_indices = beam.current_origin #print(select_indices[see*10:see*10+10] % 10) #print(beam.current_scores.view(-1)[see*10:see*10+10] % 10) #print(lastest_action[see*10:(see+1)*10]) #print(lastest_score[see*10:(see+1)*10]) #print(lastest_action) #print(lastest_score) #for act in lastest_action[see*10:(see+1)*10]: # if act < len(self.fields["tgt"].base_field.vocab.itos): # print(act, self.fields["tgt"].base_field.vocab.itos[act], end=" | ") # else: # print(act, "copy", end=" | ") #print() #for act in lastest_action[::10]: # if act < len(self.fields["tgt"].base_field.vocab.itos): # print(self.fields["tgt"].base_field.vocab.itos[act], end=" ") # else: # print("copy", end=" ") #print() #print(lastest_score) if states is not None: states.update_beam(lastest_action, select_indices.data.tolist(), lastest_score) #print(select_indices[see*10:see*10+10] % 10) #for i in range(10): # states.states[see*10+i].print() #print() any_beam_is_finished = beam.is_finished.any() if any_beam_is_finished: # print("any_beam_is_finished") finished_batch = beam.update_finished() cnt = 0 for bidx in finished_batch: if bidx < see: cnt += 1 see -= cnt #exit() if beam.done: break select_indices = beam.current_origin #print("REDUCE",select_indices.size()) if any_beam_is_finished: # Reorder states. if states is not None: states.index_select(select_indices) if isinstance(memory_bank, tuple): memory_bank = tuple(x.index_select(1, select_indices) for x in memory_bank) else: memory_bank = memory_bank.index_select(1, select_indices) memory_lengths = memory_lengths.index_select(0, select_indices) if src_map is not None: src_map = src_map.index_select(1, select_indices) self.model.decoder.map_state( lambda state, dim: state.index_select(dim, select_indices)) results["scores"] = beam.scores results["predictions"] = beam.predictions results["attention"] = beam.attention return results