示例#1
0
    def generate(
            self,
            src: batchers.Batch,
            forced_trg_ids: Sequence[numbers.Integral] = None,
            normalize_scores: bool = False) -> Sequence[sent.ReadableSentence]:
        if not batchers.is_batched(src):
            src = batchers.mark_as_batch([src])
            if forced_trg_ids:
                forced_trg_ids = batchers.mark_as_batch([forced_trg_ids])
        assert src.batch_size() == 1, "batch size > 1 not properly tested"

        batch_size, encodings, outputs, seq_len = self._encode_src(src)
        score_expr = self.scorer.calc_log_softmax(
            outputs) if normalize_scores else self.scorer.calc_scores(outputs)
        scores = score_expr.npvalue()  # vocab_size x seq_len

        if forced_trg_ids:
            output_actions = forced_trg_ids
        else:
            output_actions = [np.argmax(scores[:, j]) for j in range(seq_len)]
        score = np.sum([scores[output_actions[j], j] for j in range(seq_len)])

        outputs = [
            sent.SimpleSentence(
                words=output_actions,
                idx=src[0].idx,
                vocab=self.trg_vocab if hasattr(self, "trg_vocab") else None,
                output_procs=self.trg_reader.output_procs,
                score=score)
        ]

        return outputs
示例#2
0
    def generate(
            self,
            src: batchers.Batch,
            normalize_scores: bool = False) -> Sequence[sent.ReadableSentence]:
        if not batchers.is_batched(src):
            src = batchers.mark_as_batch([src])
        assert src.batch_size() == 1, "batch size > 1 not properly tested"

        batch_size, encodings, outputs, seq_len = self._encode_src(src)

        best_words, best_scores = self.scorer.best_k(
            outputs, k=1, normalize_scores=normalize_scores)
        best_words = best_words[0, :]
        score = np.sum(best_scores, axis=1)

        outputs = [
            sent.SimpleSentence(
                words=best_words,
                idx=src[0].idx,
                vocab=self.trg_vocab if hasattr(self, "trg_vocab") else None,
                output_procs=self.trg_reader.output_procs,
                score=score)
        ]

        return outputs
示例#3
0
  def generate_search_output(self,
                             src: batchers.Batch,
                             search_strategy: search_strategies.SearchStrategy,
                             forced_trg_ids: batchers.Batch=None) -> List[search_strategies.SearchOutput]:
    """
    Takes in a batch of source sentences and outputs a list of search outputs.
    Args:
      src: The source sentences
      search_strategy: The strategy with which to perform the search
      forced_trg_ids: The target IDs to generate if performing forced decoding
    Returns:
      A list of search outputs including scores, etc.
    """
    if src.batch_size()!=1:
      raise NotImplementedError("batched decoding not implemented for DefaultTranslator. "
                                "Specify inference batcher with batch size 1.")
    event_trigger.start_sent(src)
    all_src = src
    if isinstance(src, batchers.CompoundBatch): src = src.batches[0]
    # Generating outputs
    cur_forced_trg = None
    src_sent = src[0]#checkme
    sent_mask = None
    if src.mask: sent_mask = batchers.Mask(np_arr=src.mask.np_arr[0:1])
    sent_batch = batchers.mark_as_batch([sent], mask=sent_mask)

    # Encode the sentence
    initial_state = self._encode_src(all_src)

    if forced_trg_ids is  not None: cur_forced_trg = forced_trg_ids[0]
    search_outputs = search_strategy.generate_output(self, initial_state,
                                                     src_length=[src_sent.sent_len()],
                                                     forced_trg_ids=cur_forced_trg)
    return search_outputs
示例#4
0
    def generate(
            self,
            src: batchers.Batch,
            search_strategy: search_strategies.SearchStrategy,
            forced_trg_ids: batchers.Batch = None) -> Sequence[sent.Sentence]:
        """
    Takes in a batch of source sentences and outputs a list of search outputs.
    Args:
      src: The source sentences
      search_strategy: The strategy with which to perform the search
      forced_trg_ids: The target IDs to generate if performing forced decoding
    Returns:
      A list of search outputs including scores, etc.
    """
        assert src.batch_size() == 1
        search_outputs = self.generate_search_output(src, search_strategy,
                                                     forced_trg_ids)
        if isinstance(src, batchers.CompoundBatch): src = src.batches[0]
        sorted_outputs = sorted(search_outputs,
                                key=lambda x: x.score[0],
                                reverse=True)
        assert len(sorted_outputs) >= 1
        outputs = []
        for curr_output in sorted_outputs:
            output_actions = [x for x in curr_output.word_ids[0]]
            attentions = [x for x in curr_output.attentions[0]]
            score = curr_output.score[0]
            out_sent = sent.SimpleSentence(
                idx=src[0].idx,
                words=output_actions,
                vocab=getattr(self.trg_reader, "vocab", None),
                output_procs=self.trg_reader.output_procs,
                score=score)
            if len(sorted_outputs) == 1:
                outputs.append(out_sent)
            else:
                outputs.append(
                    sent.NbestSentence(base_sent=out_sent,
                                       nbest_id=src[0].idx))

        if self.is_reporting():
            attentions = np.concatenate([x.npvalue() for x in attentions],
                                        axis=1)
            self.report_sent_info({
                "attentions": attentions,
                "src": src[0],
                "output": outputs[0]
            })

        return outputs
示例#5
0
    def generate(
        self, src: batchers.Batch,
        search_strategy: search_strategies.SearchStrategy
    ) -> Sequence[sent.Sentence]:
        """
    Takes in a batch of source sentences and outputs a list of search outputs.
    Args:
      src: The source sentences
      search_strategy: The strategy with which to perform the search
    Returns:
      A list of search outputs including scores, etc.
    """
        assert src.batch_size() == 1
        event_trigger.start_sent(src)
        search_outputs = self.generate_search_output(src, search_strategy)
        if isinstance(src, batchers.CompoundBatch): src = src.batches[0]
        sorted_outputs = sorted(search_outputs,
                                key=lambda x: x.score[0],
                                reverse=True)
        assert len(sorted_outputs) >= 1
        outputs = []
        for curr_output in sorted_outputs:
            output_actions = [x for x in curr_output.word_ids[0]]
            attentions = [x for x in curr_output.attentions[0]]
            score = curr_output.score[0]
            out_sent = self._emit_translation(src, output_actions, score)
            if len(sorted_outputs) == 1:
                outputs.append(out_sent)
            else:
                outputs.append(
                    sent.NbestSentence(base_sent=out_sent,
                                       nbest_id=src[0].idx))

        if self.is_reporting():
            attentions = np.concatenate([x.npvalue() for x in attentions],
                                        axis=1)
            self.report_sent_info({
                "attentions": attentions,
                "src": src[0],
                "output": outputs[0]
            })

        return outputs
示例#6
0
 def generate_search_output(
     self, src: batchers.Batch,
     search_strategy: search_strategies.SearchStrategy
 ) -> List[search_strategies.SearchOutput]:
     """
 Takes in a batch of source sentences and outputs a list of search outputs.
 Args:
   src: The source sentences
   search_strategy: The strategy with which to perform the search
 Returns:
   A list of search outputs including scores, etc.
 """
     if src.batch_size() != 1:
         raise NotImplementedError(
             "batched decoding not implemented for DefaultTranslator. "
             "Specify inference batcher with batch size 1.")
     if isinstance(src, batchers.CompoundBatch):
         src = src.batches[0]
     search_outputs = search_strategy.generate_output(
         self, self._initial_state(src), src_length=src.sent_len())
     return search_outputs