def greedy_search(self, examples, to_word=True): args = self.args if not isinstance(examples, list): examples = [examples] src_words = [e.src for e in examples] src_var = to_input_variable(src_words, self.vocab.src, cuda=args.cuda, batch_first=True) src_length = [len(c) for c in src_words] encoder_outputs, encoder_hidden = self.encode_var( src_var=src_var, src_length=src_length) decoder_output, decoder_hidden, ret_dict, _ = self.decoder.forward( encoder_hidden=encoder_hidden, encoder_outputs=encoder_outputs, teacher_forcing_ratio=0.0) result = torch.stack(ret_dict['sequence']).squeeze() final_result = [] example_nums = result.size(1) if to_word: for i in range(example_nums): hyp = result[:, i].data.tolist() res = id2word(hyp, self.vocab) seems = [[res], [len(res)]] final_result.append(seems) return final_result
def beam_search(self, src_sent, beam_size=5, dmts=None): if dmts is None: dmts = self.args.tgt_max_time_step src_var = to_input_variable(src_sent, self.vocab.src, cuda=self.args.cuda, training=False, append_boundary_sym=False, batch_first=True) src_length = [len(src_sent)] encoder_outputs, encoder_hidden = self.encode_var( src_var=src_var, src_length=src_length) meta_data = self.beam_decoder.beam_search( encoder_hidden=encoder_hidden, encoder_outputs=encoder_outputs, beam_size=beam_size, decode_max_time_step=dmts) topk_sequence = meta_data['sequence'] topk_score = meta_data['score'].squeeze() completed_hypotheses = torch.cat(topk_sequence, dim=-1) number_return = completed_hypotheses.size(0) final_result = [] final_scores = [] for i in range(number_return): hyp = completed_hypotheses[i, :].data.tolist() res = id2word(hyp, self.vocab.tgt) final_result.append(res) final_scores.append(topk_score[i].item()) return final_result, final_scores
def predict_syntax(self, hidden, predictor): result = predictor.predict(hidden) numbers = result.size(1) final_result = [] for i in range(numbers): hyp = result[:, i].data.tolist() res = id2word(hyp, self.vocab.tgt) seems = [[res], [len(res)]] final_result.append(seems) return final_result
def predict(self, examples, to_word=True): hidden = self.encode(examples) decoder_output, decoder_hidden, ret_dict, _ = self.decoder.forward( encoder_hidden=hidden, encoder_outputs=None, teacher_forcing_ratio=0.0) result = torch.stack(ret_dict['sequence']).squeeze() final_result = [] example_nums = result.size(1) if to_word: for i in range(example_nums): hyp = result[:, i].data.tolist() res = id2word(hyp, self.vocab.tgt) seems = [[res], [len(res)]] final_result.append(seems) return final_result
def decode_to_sentence(self, ret): sentence_decode_init = ret['decode_init'] sentence_decode_init = self.bridger.forward( input_tensor=sentence_decode_init) decoder_outputs, decoder_hidden, ret_dict, enc_states = self.decode( inputs=None, encoder_outputs=None, encoder_hidden=sentence_decode_init, ) result = torch.stack(ret_dict['sequence']).squeeze() final_result = [] if result.dim() < 2: result = result.unsqueeze(1) example_nums = result.size(1) for i in range(example_nums): hyp = result[:, i].data.tolist() res = id2word(hyp, self.vocab.src) seems = [[res], [len(res)]] final_result.append(seems) return final_result
def recovery(seqs, vocab, keep_origin=False): # if keep_origin: # return seqs.split() return id2word(word2id(seqs, vocab), vocab)