Exemplo n.º 1
0
    def generate(
            self,
            src: batchers.Batch,
            search_strategy: search_strategies.SearchStrategy,
            forced_trg_ids: batchers.Batch = None) -> Sequence[sent.Sentence]:
        """
    Takes in a batch of source sentences and outputs a list of search outputs.
    Args:
      src: The source sentences
      search_strategy: The strategy with which to perform the search
      forced_trg_ids: The target IDs to generate if performing forced decoding
    Returns:
      A list of search outputs including scores, etc.
    """
        assert src.batch_size() == 1
        search_outputs = self.generate_search_output(src, search_strategy,
                                                     forced_trg_ids)
        if isinstance(src, batchers.CompoundBatch): src = src.batches[0]
        sorted_outputs = sorted(search_outputs,
                                key=lambda x: x.score[0],
                                reverse=True)
        assert len(sorted_outputs) >= 1
        outputs = []
        for curr_output in sorted_outputs:
            output_actions = [x for x in curr_output.word_ids[0]]
            attentions = [x for x in curr_output.attentions[0]]
            score = curr_output.score[0]
            out_sent = sent.SimpleSentence(
                idx=src[0].idx,
                words=output_actions,
                vocab=getattr(self.trg_reader, "vocab", None),
                output_procs=self.trg_reader.output_procs,
                score=score)
            if len(sorted_outputs) == 1:
                outputs.append(out_sent)
            else:
                outputs.append(
                    sent.NbestSentence(base_sent=out_sent,
                                       nbest_id=src[0].idx))

        if self.is_reporting():
            attentions = np.concatenate([x.npvalue() for x in attentions],
                                        axis=1)
            self.report_sent_info({
                "attentions": attentions,
                "src": src[0],
                "output": outputs[0]
            })

        return outputs
Exemplo n.º 2
0
    def generate(
        self, src: batchers.Batch,
        search_strategy: search_strategies.SearchStrategy
    ) -> Sequence[sent.Sentence]:
        """
    Takes in a batch of source sentences and outputs a list of search outputs.
    Args:
      src: The source sentences
      search_strategy: The strategy with which to perform the search
    Returns:
      A list of search outputs including scores, etc.
    """
        assert src.batch_size() == 1
        event_trigger.start_sent(src)
        search_outputs = self.generate_search_output(src, search_strategy)
        if isinstance(src, batchers.CompoundBatch): src = src.batches[0]
        sorted_outputs = sorted(search_outputs,
                                key=lambda x: x.score[0],
                                reverse=True)
        assert len(sorted_outputs) >= 1
        outputs = []
        for curr_output in sorted_outputs:
            output_actions = [x for x in curr_output.word_ids[0]]
            attentions = [x for x in curr_output.attentions[0]]
            score = curr_output.score[0]
            out_sent = self._emit_translation(src, output_actions, score)
            if len(sorted_outputs) == 1:
                outputs.append(out_sent)
            else:
                outputs.append(
                    sent.NbestSentence(base_sent=out_sent,
                                       nbest_id=src[0].idx))

        if self.is_reporting():
            attentions = np.concatenate([x.npvalue() for x in attentions],
                                        axis=1)
            self.report_sent_info({
                "attentions": attentions,
                "src": src[0],
                "output": outputs[0]
            })

        return outputs