def generate( self, src: batchers.Batch, forced_trg_ids: Sequence[numbers.Integral] = None, normalize_scores: bool = False) -> Sequence[sent.ReadableSentence]: if not batchers.is_batched(src): src = batchers.mark_as_batch([src]) if forced_trg_ids: forced_trg_ids = batchers.mark_as_batch([forced_trg_ids]) assert src.batch_size() == 1, "batch size > 1 not properly tested" batch_size, encodings, outputs, seq_len = self._encode_src(src) score_expr = self.scorer.calc_log_softmax( outputs) if normalize_scores else self.scorer.calc_scores(outputs) scores = score_expr.npvalue() # vocab_size x seq_len if forced_trg_ids: output_actions = forced_trg_ids else: output_actions = [np.argmax(scores[:, j]) for j in range(seq_len)] score = np.sum([scores[output_actions[j], j] for j in range(seq_len)]) outputs = [ sent.SimpleSentence( words=output_actions, idx=src[0].idx, vocab=self.trg_vocab if hasattr(self, "trg_vocab") else None, output_procs=self.trg_reader.output_procs, score=score) ] return outputs
def generate( self, src: batchers.Batch, normalize_scores: bool = False) -> Sequence[sent.ReadableSentence]: if not batchers.is_batched(src): src = batchers.mark_as_batch([src]) assert src.batch_size() == 1, "batch size > 1 not properly tested" batch_size, encodings, outputs, seq_len = self._encode_src(src) best_words, best_scores = self.scorer.best_k( outputs, k=1, normalize_scores=normalize_scores) best_words = best_words[0, :] score = np.sum(best_scores, axis=1) outputs = [ sent.SimpleSentence( words=best_words, idx=src[0].idx, vocab=self.trg_vocab if hasattr(self, "trg_vocab") else None, output_procs=self.trg_reader.output_procs, score=score) ] return outputs
def generate_search_output(self, src: batchers.Batch, search_strategy: search_strategies.SearchStrategy, forced_trg_ids: batchers.Batch=None) -> List[search_strategies.SearchOutput]: """ Takes in a batch of source sentences and outputs a list of search outputs. Args: src: The source sentences search_strategy: The strategy with which to perform the search forced_trg_ids: The target IDs to generate if performing forced decoding Returns: A list of search outputs including scores, etc. """ if src.batch_size()!=1: raise NotImplementedError("batched decoding not implemented for DefaultTranslator. " "Specify inference batcher with batch size 1.") event_trigger.start_sent(src) all_src = src if isinstance(src, batchers.CompoundBatch): src = src.batches[0] # Generating outputs cur_forced_trg = None src_sent = src[0]#checkme sent_mask = None if src.mask: sent_mask = batchers.Mask(np_arr=src.mask.np_arr[0:1]) sent_batch = batchers.mark_as_batch([sent], mask=sent_mask) # Encode the sentence initial_state = self._encode_src(all_src) if forced_trg_ids is not None: cur_forced_trg = forced_trg_ids[0] search_outputs = search_strategy.generate_output(self, initial_state, src_length=[src_sent.sent_len()], forced_trg_ids=cur_forced_trg) return search_outputs
def generate( self, src: batchers.Batch, search_strategy: search_strategies.SearchStrategy, forced_trg_ids: batchers.Batch = None) -> Sequence[sent.Sentence]: """ Takes in a batch of source sentences and outputs a list of search outputs. Args: src: The source sentences search_strategy: The strategy with which to perform the search forced_trg_ids: The target IDs to generate if performing forced decoding Returns: A list of search outputs including scores, etc. """ assert src.batch_size() == 1 search_outputs = self.generate_search_output(src, search_strategy, forced_trg_ids) if isinstance(src, batchers.CompoundBatch): src = src.batches[0] sorted_outputs = sorted(search_outputs, key=lambda x: x.score[0], reverse=True) assert len(sorted_outputs) >= 1 outputs = [] for curr_output in sorted_outputs: output_actions = [x for x in curr_output.word_ids[0]] attentions = [x for x in curr_output.attentions[0]] score = curr_output.score[0] out_sent = sent.SimpleSentence( idx=src[0].idx, words=output_actions, vocab=getattr(self.trg_reader, "vocab", None), output_procs=self.trg_reader.output_procs, score=score) if len(sorted_outputs) == 1: outputs.append(out_sent) else: outputs.append( sent.NbestSentence(base_sent=out_sent, nbest_id=src[0].idx)) if self.is_reporting(): attentions = np.concatenate([x.npvalue() for x in attentions], axis=1) self.report_sent_info({ "attentions": attentions, "src": src[0], "output": outputs[0] }) return outputs
def generate( self, src: batchers.Batch, search_strategy: search_strategies.SearchStrategy ) -> Sequence[sent.Sentence]: """ Takes in a batch of source sentences and outputs a list of search outputs. Args: src: The source sentences search_strategy: The strategy with which to perform the search Returns: A list of search outputs including scores, etc. """ assert src.batch_size() == 1 event_trigger.start_sent(src) search_outputs = self.generate_search_output(src, search_strategy) if isinstance(src, batchers.CompoundBatch): src = src.batches[0] sorted_outputs = sorted(search_outputs, key=lambda x: x.score[0], reverse=True) assert len(sorted_outputs) >= 1 outputs = [] for curr_output in sorted_outputs: output_actions = [x for x in curr_output.word_ids[0]] attentions = [x for x in curr_output.attentions[0]] score = curr_output.score[0] out_sent = self._emit_translation(src, output_actions, score) if len(sorted_outputs) == 1: outputs.append(out_sent) else: outputs.append( sent.NbestSentence(base_sent=out_sent, nbest_id=src[0].idx)) if self.is_reporting(): attentions = np.concatenate([x.npvalue() for x in attentions], axis=1) self.report_sent_info({ "attentions": attentions, "src": src[0], "output": outputs[0] }) return outputs
def generate_search_output( self, src: batchers.Batch, search_strategy: search_strategies.SearchStrategy ) -> List[search_strategies.SearchOutput]: """ Takes in a batch of source sentences and outputs a list of search outputs. Args: src: The source sentences search_strategy: The strategy with which to perform the search Returns: A list of search outputs including scores, etc. """ if src.batch_size() != 1: raise NotImplementedError( "batched decoding not implemented for DefaultTranslator. " "Specify inference batcher with batch size 1.") if isinstance(src, batchers.CompoundBatch): src = src.batches[0] search_outputs = search_strategy.generate_output( self, self._initial_state(src), src_length=src.sent_len()) return search_outputs