Exemplo n.º 1
0
    def generate(
            self,
            src: batchers.Batch,
            normalize_scores: bool = False) -> Sequence[sent.ReadableSentence]:
        if not batchers.is_batched(src):
            src = batchers.mark_as_batch([src])
        assert src.batch_size() == 1, "batch size > 1 not properly tested"

        batch_size, encodings, outputs, seq_len = self._encode_src(src)

        best_words, best_scores = self.scorer.best_k(
            outputs, k=1, normalize_scores=normalize_scores)
        best_words = best_words[0, :]
        score = np.sum(best_scores, axis=1)

        outputs = [
            sent.SimpleSentence(
                words=best_words,
                idx=src[0].idx,
                vocab=self.trg_vocab if hasattr(self, "trg_vocab") else None,
                output_procs=self.trg_reader.output_procs,
                score=score)
        ]

        return outputs
Exemplo n.º 2
0
    def generate(
            self,
            src: batchers.Batch,
            forced_trg_ids: Sequence[numbers.Integral] = None,
            normalize_scores: bool = False) -> Sequence[sent.ReadableSentence]:
        if not batchers.is_batched(src):
            src = batchers.mark_as_batch([src])
            if forced_trg_ids:
                forced_trg_ids = batchers.mark_as_batch([forced_trg_ids])
        assert src.batch_size() == 1, "batch size > 1 not properly tested"

        batch_size, encodings, outputs, seq_len = self._encode_src(src)
        score_expr = self.scorer.calc_log_softmax(
            outputs) if normalize_scores else self.scorer.calc_scores(outputs)
        scores = score_expr.npvalue()  # vocab_size x seq_len

        if forced_trg_ids:
            output_actions = forced_trg_ids
        else:
            output_actions = [np.argmax(scores[:, j]) for j in range(seq_len)]
        score = np.sum([scores[output_actions[j], j] for j in range(seq_len)])

        outputs = [
            sent.SimpleSentence(
                words=output_actions,
                idx=src[0].idx,
                vocab=self.trg_vocab if hasattr(self, "trg_vocab") else None,
                output_procs=self.trg_reader.output_procs,
                score=score)
        ]

        return outputs
Exemplo n.º 3
0
 def _generate_one_batch(self,
                         generator: 'models.GeneratorModel',
                         batcher: Optional[batchers.Batcher] = None,
                         src_batch: batchers.Batch = None,
                         max_src_len: Optional[int] = None,
                         fp: TextIO = None):
     """
 Generate outputs for a single batch and write them to the output file.
 """
     batch_size = len(src_batch)
     src_batches = batcher.pack(src_batch, None)
     src_batch = src_batches[0]
     src_len = src_batch.sent_len()
     if max_src_len is not None and src_len > max_src_len:
         output_txt = "\n".join([NO_DECODING_ATTEMPTED] * batch_size)
         fp.write(f"{output_txt}\n")
     else:
         with utils.ReportOnException({
                 "src": src_batch,
                 "graph": utils.print_cg_conditional
         }):
             dy.renew_cg(immediate_compute=settings.IMMEDIATE_COMPUTE,
                         check_validity=settings.CHECK_VALIDITY)
             outputs = self.generate_one(generator, src_batch)
             if self.reporter: self._create_sent_report()
             for i in range(len(outputs)):
                 output_txt = outputs[i].sent_str(
                     custom_output_procs=self.post_processor)
                 fp.write(f"{output_txt}\n")
Exemplo n.º 4
0
    def _forced_decode_one_batch(self,
                                 generator: 'models.GeneratorModel',
                                 batcher: Optional[batchers.Batcher] = None,
                                 src_batch: batchers.Batch = None,
                                 ref_batch: batchers.Batch = None,
                                 assert_scores: batchers.Batch = None,
                                 max_src_len: Optional[int] = None):
        """
    Performs forced decoding for a single batch.
    """
        batch_size = len(src_batch)
        src_batches, ref_batches = batcher.pack(src_batch, ref_batch)
        src_batch = src_batches[0]
        src_len = src_batch.sent_len()

        if max_src_len is None or src_len <= max_src_len is not None and src_len > max_src_len:
            with utils.ReportOnException({
                    "src": src_batch,
                    "graph": utils.print_cg_conditional
            }):
                tt.reset_graph()
                outputs = self.generate_one(generator, src_batch)
                if self.reporter: self._create_sent_report()
                for i in range(len(outputs)):
                    if assert_scores is not None:
                        # If debugging forced decoding, make sure it matches
                        assert batch_size == len(
                            outputs
                        ), "debug forced decoding not supported with nbest inference"
                        if (abs(outputs[i].score - assert_scores[i]) /
                                abs(assert_scores[i])) > 1e-5:
                            raise ValueError(
                                f'Forced decoding score {outputs[i].score} and loss {assert_scores[i]} do not match at '
                                f'sentence {i}')
Exemplo n.º 5
0
 def _generate_one_batch(self,
                         generator: 'models.GeneratorModel',
                         batcher: Optional[batchers.Batcher] = None,
                         src_batch: batchers.Batch = None,
                         max_src_len: Optional[int] = None,
                         fp: TextIO = None):
     """
 Generate outputs for a single batch and write them to the output file.
 """
     batch_size = len(src_batch)
     src_batches = batcher.pack(src_batch, None)
     src_batch = src_batches[0]
     src_len = src_batch.sent_len()
     if max_src_len is not None and src_len > max_src_len:
         output_txt = "\n".join([NO_DECODING_ATTEMPTED] * batch_size)
         fp.write(f"{output_txt}\n")
     else:
         with utils.ReportOnException({
                 "src": src_batch,
                 "graph": utils.print_cg_conditional
         }):
             tt.reset_graph()
             with torch.no_grad(
             ) if xnmt.backend_torch else utils.dummy_context_mgr():
                 outputs = self.generate_one(generator, src_batch)
             if self.reporter: self._create_sent_report()
             for i in range(len(outputs)):
                 output_txt = outputs[i].sent_str(
                     custom_output_procs=self.post_processor)
                 fp.write(f"{output_txt}\n")
Exemplo n.º 6
0
  def generate_search_output(self,
                             src: batchers.Batch,
                             search_strategy: search_strategies.SearchStrategy,
                             forced_trg_ids: batchers.Batch=None) -> List[search_strategies.SearchOutput]:
    """
    Takes in a batch of source sentences and outputs a list of search outputs.
    Args:
      src: The source sentences
      search_strategy: The strategy with which to perform the search
      forced_trg_ids: The target IDs to generate if performing forced decoding
    Returns:
      A list of search outputs including scores, etc.
    """
    if src.batch_size()!=1:
      raise NotImplementedError("batched decoding not implemented for DefaultTranslator. "
                                "Specify inference batcher with batch size 1.")
    event_trigger.start_sent(src)
    all_src = src
    if isinstance(src, batchers.CompoundBatch): src = src.batches[0]
    # Generating outputs
    cur_forced_trg = None
    src_sent = src[0]#checkme
    sent_mask = None
    if src.mask: sent_mask = batchers.Mask(np_arr=src.mask.np_arr[0:1])
    sent_batch = batchers.mark_as_batch([sent], mask=sent_mask)

    # Encode the sentence
    initial_state = self._encode_src(all_src)

    if forced_trg_ids is  not None: cur_forced_trg = forced_trg_ids[0]
    search_outputs = search_strategy.generate_output(self, initial_state,
                                                     src_length=[src_sent.sent_len()],
                                                     forced_trg_ids=cur_forced_trg)
    return search_outputs
Exemplo n.º 7
0
 def generate_search_output(
     self, src: batchers.Batch,
     search_strategy: search_strategies.SearchStrategy
 ) -> List[search_strategies.SearchOutput]:
     """
 Takes in a batch of source sentences and outputs a list of search outputs.
 Args:
   src: The source sentences
   search_strategy: The strategy with which to perform the search
 Returns:
   A list of search outputs including scores, etc.
 """
     if src.batch_size() != 1:
         raise NotImplementedError(
             "batched decoding not implemented for DefaultTranslator. "
             "Specify inference batcher with batch size 1.")
     if isinstance(src, batchers.CompoundBatch):
         src = src.batches[0]
     search_outputs = search_strategy.generate_output(
         self, self._initial_state(src), src_length=src.sent_len())
     return search_outputs
Exemplo n.º 8
0
    def generate(
            self,
            src: batchers.Batch,
            search_strategy: search_strategies.SearchStrategy,
            forced_trg_ids: batchers.Batch = None) -> Sequence[sent.Sentence]:
        """
    Takes in a batch of source sentences and outputs a list of search outputs.
    Args:
      src: The source sentences
      search_strategy: The strategy with which to perform the search
      forced_trg_ids: The target IDs to generate if performing forced decoding
    Returns:
      A list of search outputs including scores, etc.
    """
        assert src.batch_size() == 1
        search_outputs = self.generate_search_output(src, search_strategy,
                                                     forced_trg_ids)
        if isinstance(src, batchers.CompoundBatch): src = src.batches[0]
        sorted_outputs = sorted(search_outputs,
                                key=lambda x: x.score[0],
                                reverse=True)
        assert len(sorted_outputs) >= 1
        outputs = []
        for curr_output in sorted_outputs:
            output_actions = [x for x in curr_output.word_ids[0]]
            attentions = [x for x in curr_output.attentions[0]]
            score = curr_output.score[0]
            out_sent = sent.SimpleSentence(
                idx=src[0].idx,
                words=output_actions,
                vocab=getattr(self.trg_reader, "vocab", None),
                output_procs=self.trg_reader.output_procs,
                score=score)
            if len(sorted_outputs) == 1:
                outputs.append(out_sent)
            else:
                outputs.append(
                    sent.NbestSentence(base_sent=out_sent,
                                       nbest_id=src[0].idx))

        if self.is_reporting():
            attentions = np.concatenate([x.npvalue() for x in attentions],
                                        axis=1)
            self.report_sent_info({
                "attentions": attentions,
                "src": src[0],
                "output": outputs[0]
            })

        return outputs
Exemplo n.º 9
0
 def _cut_or_pad_targets(self, seq_len: numbers.Integral,
                         trg: batchers.Batch) -> batchers.Batch:
     old_mask = trg.mask
     if len(trg[0]) > seq_len:
         trunc_len = len(trg[0]) - seq_len
         trg = batchers.mark_as_batch([
             trg_sent.get_truncated_sent(trunc_len=trunc_len)
             for trg_sent in trg
         ])
         if old_mask:
             trg.mask = batchers.Mask(
                 np_arr=old_mask.np_arr[:, :-trunc_len])
     else:
         pad_len = seq_len - len(trg[0])
         trg = batchers.mark_as_batch([
             trg_sent.create_padded_sent(pad_len=pad_len)
             for trg_sent in trg
         ])
         if old_mask:
             trg.mask = np.pad(old_mask.np_arr,
                               pad_width=((0, 0), (0, pad_len)),
                               mode="constant",
                               constant_values=1)
     return trg
Exemplo n.º 10
0
    def generate(
        self, src: batchers.Batch,
        search_strategy: search_strategies.SearchStrategy
    ) -> Sequence[sent.Sentence]:
        """
    Takes in a batch of source sentences and outputs a list of search outputs.
    Args:
      src: The source sentences
      search_strategy: The strategy with which to perform the search
    Returns:
      A list of search outputs including scores, etc.
    """
        assert src.batch_size() == 1
        event_trigger.start_sent(src)
        search_outputs = self.generate_search_output(src, search_strategy)
        if isinstance(src, batchers.CompoundBatch): src = src.batches[0]
        sorted_outputs = sorted(search_outputs,
                                key=lambda x: x.score[0],
                                reverse=True)
        assert len(sorted_outputs) >= 1
        outputs = []
        for curr_output in sorted_outputs:
            output_actions = [x for x in curr_output.word_ids[0]]
            attentions = [x for x in curr_output.attentions[0]]
            score = curr_output.score[0]
            out_sent = self._emit_translation(src, output_actions, score)
            if len(sorted_outputs) == 1:
                outputs.append(out_sent)
            else:
                outputs.append(
                    sent.NbestSentence(base_sent=out_sent,
                                       nbest_id=src[0].idx))

        if self.is_reporting():
            attentions = np.concatenate([x.npvalue() for x in attentions],
                                        axis=1)
            self.report_sent_info({
                "attentions": attentions,
                "src": src[0],
                "output": outputs[0]
            })

        return outputs
Exemplo n.º 11
0
 def _generate_one_batch(self, generator: 'models.GeneratorModel',
                               batcher: Optional[batchers.Batcher] = None,
                               src_batch: batchers.Batch = None,
                               ref_batch: Optional[batchers.Batch] = None,
                               assert_scores: Optional[List[int]] = None,
                               max_src_len: Optional[int] = None,
                               fp: TextIO = None):
   """
   Generate outputs for a single batch and write them to the output file.
   """
   batch_size = len(src_batch)
   if ref_batch[0] is not None:
     src_batches, ref_batches = batcher.pack(src_batch, ref_batch)
     ref_batch = ref_batches[0]
   else:
     src_batches = batcher.pack(src_batch, None)
     ref_batch = None
   src_batch = src_batches[0]
   src_len = src_batch.sent_len()
   if max_src_len is not None and src_len > max_src_len:
     output_txt = "\n".join([NO_DECODING_ATTEMPTED] * batch_size)
     fp.write(f"{output_txt}\n")
   else:
     with utils.ReportOnException({"src": src_batch, "graph": utils.print_cg_conditional}):
       dy.renew_cg(immediate_compute=settings.IMMEDIATE_COMPUTE, check_validity=settings.CHECK_VALIDITY)
       outputs = self.generate_one(generator, src_batch, ref_batch)
       if self.reporter: self._create_sent_report()
       for i in range(len(outputs)):
         if assert_scores[0] is not None:
           # If debugging forced decoding, make sure it matches
           assert batch_size == len(outputs), "debug forced decoding not supported with nbest inference"
           if (abs(outputs[i].score - assert_scores[i]) / abs(assert_scores[i])) > 1e-5:
             raise ValueError(
               f'Forced decoding score {outputs[i].score} and loss {assert_scores[i]} do not match at '
               f'sentence {i}')
         output_txt = outputs[i].sent_str(custom_output_procs=self.post_processor)
         fp.write(f"{output_txt}\n")