Beispiel #1
0
 def _generate_one_batch(self,
                         generator: 'models.GeneratorModel',
                         batcher: Optional[batchers.Batcher] = None,
                         src_batch: batchers.Batch = None,
                         max_src_len: Optional[int] = None,
                         fp: TextIO = None):
     """
 Generate outputs for a single batch and write them to the output file.
 """
     batch_size = len(src_batch)
     src_batches = batcher.pack(src_batch, None)
     src_batch = src_batches[0]
     src_len = src_batch.sent_len()
     if max_src_len is not None and src_len > max_src_len:
         output_txt = "\n".join([NO_DECODING_ATTEMPTED] * batch_size)
         fp.write(f"{output_txt}\n")
     else:
         with utils.ReportOnException({
                 "src": src_batch,
                 "graph": utils.print_cg_conditional
         }):
             dy.renew_cg(immediate_compute=settings.IMMEDIATE_COMPUTE,
                         check_validity=settings.CHECK_VALIDITY)
             outputs = self.generate_one(generator, src_batch)
             if self.reporter: self._create_sent_report()
             for i in range(len(outputs)):
                 output_txt = outputs[i].sent_str(
                     custom_output_procs=self.post_processor)
                 fp.write(f"{output_txt}\n")
Beispiel #2
0
    def _forced_decode_one_batch(self,
                                 generator: 'models.GeneratorModel',
                                 batcher: Optional[batchers.Batcher] = None,
                                 src_batch: batchers.Batch = None,
                                 ref_batch: batchers.Batch = None,
                                 assert_scores: batchers.Batch = None,
                                 max_src_len: Optional[int] = None):
        """
    Performs forced decoding for a single batch.
    """
        batch_size = len(src_batch)
        src_batches, ref_batches = batcher.pack(src_batch, ref_batch)
        src_batch = src_batches[0]
        src_len = src_batch.sent_len()

        if max_src_len is None or src_len <= max_src_len is not None and src_len > max_src_len:
            with utils.ReportOnException({
                    "src": src_batch,
                    "graph": utils.print_cg_conditional
            }):
                tt.reset_graph()
                outputs = self.generate_one(generator, src_batch)
                if self.reporter: self._create_sent_report()
                for i in range(len(outputs)):
                    if assert_scores is not None:
                        # If debugging forced decoding, make sure it matches
                        assert batch_size == len(
                            outputs
                        ), "debug forced decoding not supported with nbest inference"
                        if (abs(outputs[i].score - assert_scores[i]) /
                                abs(assert_scores[i])) > 1e-5:
                            raise ValueError(
                                f'Forced decoding score {outputs[i].score} and loss {assert_scores[i]} do not match at '
                                f'sentence {i}')
Beispiel #3
0
 def _generate_one_batch(self,
                         generator: 'models.GeneratorModel',
                         batcher: Optional[batchers.Batcher] = None,
                         src_batch: batchers.Batch = None,
                         max_src_len: Optional[int] = None,
                         fp: TextIO = None):
     """
 Generate outputs for a single batch and write them to the output file.
 """
     batch_size = len(src_batch)
     src_batches = batcher.pack(src_batch, None)
     src_batch = src_batches[0]
     src_len = src_batch.sent_len()
     if max_src_len is not None and src_len > max_src_len:
         output_txt = "\n".join([NO_DECODING_ATTEMPTED] * batch_size)
         fp.write(f"{output_txt}\n")
     else:
         with utils.ReportOnException({
                 "src": src_batch,
                 "graph": utils.print_cg_conditional
         }):
             tt.reset_graph()
             with torch.no_grad(
             ) if xnmt.backend_torch else utils.dummy_context_mgr():
                 outputs = self.generate_one(generator, src_batch)
             if self.reporter: self._create_sent_report()
             for i in range(len(outputs)):
                 output_txt = outputs[i].sent_str(
                     custom_output_procs=self.post_processor)
                 fp.write(f"{output_txt}\n")
Beispiel #4
0
 def _cut_or_pad_targets(self, seq_len: numbers.Integral,
                         trg: batchers.Batch) -> batchers.Batch:
     old_mask = trg.mask
     if trg.sent_len() > seq_len:
         trunc_len = trg.sent_len() - seq_len
         trg = batchers.mark_as_batch([
             trg_sent.create_truncated_sent(trunc_len=trunc_len)
             for trg_sent in trg
         ])
         if old_mask:
             trg.mask = batchers.Mask(
                 np_arr=old_mask.np_arr[:, :-trunc_len])
     else:
         pad_len = seq_len - trg.sent_len()
         trg = batchers.mark_as_batch([
             trg_sent.create_padded_sent(pad_len=pad_len)
             for trg_sent in trg
         ])
         if old_mask:
             trg.mask = np.pad(old_mask.np_arr,
                               pad_width=((0, 0), (0, pad_len)),
                               mode="constant",
                               constant_values=1)
     return trg
Beispiel #5
0
 def generate_search_output(
     self, src: batchers.Batch,
     search_strategy: search_strategies.SearchStrategy
 ) -> List[search_strategies.SearchOutput]:
     """
 Takes in a batch of source sentences and outputs a list of search outputs.
 Args:
   src: The source sentences
   search_strategy: The strategy with which to perform the search
 Returns:
   A list of search outputs including scores, etc.
 """
     if src.batch_size() != 1:
         raise NotImplementedError(
             "batched decoding not implemented for DefaultTranslator. "
             "Specify inference batcher with batch size 1.")
     if isinstance(src, batchers.CompoundBatch):
         src = src.batches[0]
     search_outputs = search_strategy.generate_output(
         self, self._initial_state(src), src_length=src.sent_len())
     return search_outputs
Beispiel #6
0
 def _generate_one_batch(self, generator: 'models.GeneratorModel',
                               batcher: Optional[batchers.Batcher] = None,
                               src_batch: batchers.Batch = None,
                               ref_batch: Optional[batchers.Batch] = None,
                               assert_scores: Optional[List[int]] = None,
                               max_src_len: Optional[int] = None,
                               fp: TextIO = None):
   """
   Generate outputs for a single batch and write them to the output file.
   """
   batch_size = len(src_batch)
   if ref_batch[0] is not None:
     src_batches, ref_batches = batcher.pack(src_batch, ref_batch)
     ref_batch = ref_batches[0]
   else:
     src_batches = batcher.pack(src_batch, None)
     ref_batch = None
   src_batch = src_batches[0]
   src_len = src_batch.sent_len()
   if max_src_len is not None and src_len > max_src_len:
     output_txt = "\n".join([NO_DECODING_ATTEMPTED] * batch_size)
     fp.write(f"{output_txt}\n")
   else:
     with utils.ReportOnException({"src": src_batch, "graph": utils.print_cg_conditional}):
       dy.renew_cg(immediate_compute=settings.IMMEDIATE_COMPUTE, check_validity=settings.CHECK_VALIDITY)
       outputs = self.generate_one(generator, src_batch, ref_batch)
       if self.reporter: self._create_sent_report()
       for i in range(len(outputs)):
         if assert_scores[0] is not None:
           # If debugging forced decoding, make sure it matches
           assert batch_size == len(outputs), "debug forced decoding not supported with nbest inference"
           if (abs(outputs[i].score - assert_scores[i]) / abs(assert_scores[i])) > 1e-5:
             raise ValueError(
               f'Forced decoding score {outputs[i].score} and loss {assert_scores[i]} do not match at '
               f'sentence {i}')
         output_txt = outputs[i].sent_str(custom_output_procs=self.post_processor)
         fp.write(f"{output_txt}\n")