def _generate_one_batch(self, generator: 'models.GeneratorModel', batcher: Optional[batchers.Batcher] = None, src_batch: batchers.Batch = None, max_src_len: Optional[int] = None, fp: TextIO = None): """ Generate outputs for a single batch and write them to the output file. """ batch_size = len(src_batch) src_batches = batcher.pack(src_batch, None) src_batch = src_batches[0] src_len = src_batch.sent_len() if max_src_len is not None and src_len > max_src_len: output_txt = "\n".join([NO_DECODING_ATTEMPTED] * batch_size) fp.write(f"{output_txt}\n") else: with utils.ReportOnException({ "src": src_batch, "graph": utils.print_cg_conditional }): dy.renew_cg(immediate_compute=settings.IMMEDIATE_COMPUTE, check_validity=settings.CHECK_VALIDITY) outputs = self.generate_one(generator, src_batch) if self.reporter: self._create_sent_report() for i in range(len(outputs)): output_txt = outputs[i].sent_str( custom_output_procs=self.post_processor) fp.write(f"{output_txt}\n")
def _forced_decode_one_batch(self, generator: 'models.GeneratorModel', batcher: Optional[batchers.Batcher] = None, src_batch: batchers.Batch = None, ref_batch: batchers.Batch = None, assert_scores: batchers.Batch = None, max_src_len: Optional[int] = None): """ Performs forced decoding for a single batch. """ batch_size = len(src_batch) src_batches, ref_batches = batcher.pack(src_batch, ref_batch) src_batch = src_batches[0] src_len = src_batch.sent_len() if max_src_len is None or src_len <= max_src_len is not None and src_len > max_src_len: with utils.ReportOnException({ "src": src_batch, "graph": utils.print_cg_conditional }): tt.reset_graph() outputs = self.generate_one(generator, src_batch) if self.reporter: self._create_sent_report() for i in range(len(outputs)): if assert_scores is not None: # If debugging forced decoding, make sure it matches assert batch_size == len( outputs ), "debug forced decoding not supported with nbest inference" if (abs(outputs[i].score - assert_scores[i]) / abs(assert_scores[i])) > 1e-5: raise ValueError( f'Forced decoding score {outputs[i].score} and loss {assert_scores[i]} do not match at ' f'sentence {i}')
def _generate_one_batch(self, generator: 'models.GeneratorModel', batcher: Optional[batchers.Batcher] = None, src_batch: batchers.Batch = None, max_src_len: Optional[int] = None, fp: TextIO = None): """ Generate outputs for a single batch and write them to the output file. """ batch_size = len(src_batch) src_batches = batcher.pack(src_batch, None) src_batch = src_batches[0] src_len = src_batch.sent_len() if max_src_len is not None and src_len > max_src_len: output_txt = "\n".join([NO_DECODING_ATTEMPTED] * batch_size) fp.write(f"{output_txt}\n") else: with utils.ReportOnException({ "src": src_batch, "graph": utils.print_cg_conditional }): tt.reset_graph() with torch.no_grad( ) if xnmt.backend_torch else utils.dummy_context_mgr(): outputs = self.generate_one(generator, src_batch) if self.reporter: self._create_sent_report() for i in range(len(outputs)): output_txt = outputs[i].sent_str( custom_output_procs=self.post_processor) fp.write(f"{output_txt}\n")
def _cut_or_pad_targets(self, seq_len: numbers.Integral, trg: batchers.Batch) -> batchers.Batch: old_mask = trg.mask if trg.sent_len() > seq_len: trunc_len = trg.sent_len() - seq_len trg = batchers.mark_as_batch([ trg_sent.create_truncated_sent(trunc_len=trunc_len) for trg_sent in trg ]) if old_mask: trg.mask = batchers.Mask( np_arr=old_mask.np_arr[:, :-trunc_len]) else: pad_len = seq_len - trg.sent_len() trg = batchers.mark_as_batch([ trg_sent.create_padded_sent(pad_len=pad_len) for trg_sent in trg ]) if old_mask: trg.mask = np.pad(old_mask.np_arr, pad_width=((0, 0), (0, pad_len)), mode="constant", constant_values=1) return trg
def generate_search_output( self, src: batchers.Batch, search_strategy: search_strategies.SearchStrategy ) -> List[search_strategies.SearchOutput]: """ Takes in a batch of source sentences and outputs a list of search outputs. Args: src: The source sentences search_strategy: The strategy with which to perform the search Returns: A list of search outputs including scores, etc. """ if src.batch_size() != 1: raise NotImplementedError( "batched decoding not implemented for DefaultTranslator. " "Specify inference batcher with batch size 1.") if isinstance(src, batchers.CompoundBatch): src = src.batches[0] search_outputs = search_strategy.generate_output( self, self._initial_state(src), src_length=src.sent_len()) return search_outputs
def _generate_one_batch(self, generator: 'models.GeneratorModel', batcher: Optional[batchers.Batcher] = None, src_batch: batchers.Batch = None, ref_batch: Optional[batchers.Batch] = None, assert_scores: Optional[List[int]] = None, max_src_len: Optional[int] = None, fp: TextIO = None): """ Generate outputs for a single batch and write them to the output file. """ batch_size = len(src_batch) if ref_batch[0] is not None: src_batches, ref_batches = batcher.pack(src_batch, ref_batch) ref_batch = ref_batches[0] else: src_batches = batcher.pack(src_batch, None) ref_batch = None src_batch = src_batches[0] src_len = src_batch.sent_len() if max_src_len is not None and src_len > max_src_len: output_txt = "\n".join([NO_DECODING_ATTEMPTED] * batch_size) fp.write(f"{output_txt}\n") else: with utils.ReportOnException({"src": src_batch, "graph": utils.print_cg_conditional}): dy.renew_cg(immediate_compute=settings.IMMEDIATE_COMPUTE, check_validity=settings.CHECK_VALIDITY) outputs = self.generate_one(generator, src_batch, ref_batch) if self.reporter: self._create_sent_report() for i in range(len(outputs)): if assert_scores[0] is not None: # If debugging forced decoding, make sure it matches assert batch_size == len(outputs), "debug forced decoding not supported with nbest inference" if (abs(outputs[i].score - assert_scores[i]) / abs(assert_scores[i])) > 1e-5: raise ValueError( f'Forced decoding score {outputs[i].score} and loss {assert_scores[i]} do not match at ' f'sentence {i}') output_txt = outputs[i].sent_str(custom_output_procs=self.post_processor) fp.write(f"{output_txt}\n")