def generate( self, src: batchers.Batch, normalize_scores: bool = False) -> Sequence[sent.ReadableSentence]: if not batchers.is_batched(src): src = batchers.mark_as_batch([src]) assert src.batch_size() == 1, "batch size > 1 not properly tested" batch_size, encodings, outputs, seq_len = self._encode_src(src) best_words, best_scores = self.scorer.best_k( outputs, k=1, normalize_scores=normalize_scores) best_words = best_words[0, :] score = np.sum(best_scores, axis=1) outputs = [ sent.SimpleSentence( words=best_words, idx=src[0].idx, vocab=self.trg_vocab if hasattr(self, "trg_vocab") else None, output_procs=self.trg_reader.output_procs, score=score) ] return outputs
def generate( self, src: batchers.Batch, forced_trg_ids: Sequence[numbers.Integral] = None, normalize_scores: bool = False) -> Sequence[sent.ReadableSentence]: if not batchers.is_batched(src): src = batchers.mark_as_batch([src]) if forced_trg_ids: forced_trg_ids = batchers.mark_as_batch([forced_trg_ids]) assert src.batch_size() == 1, "batch size > 1 not properly tested" batch_size, encodings, outputs, seq_len = self._encode_src(src) score_expr = self.scorer.calc_log_softmax( outputs) if normalize_scores else self.scorer.calc_scores(outputs) scores = score_expr.npvalue() # vocab_size x seq_len if forced_trg_ids: output_actions = forced_trg_ids else: output_actions = [np.argmax(scores[:, j]) for j in range(seq_len)] score = np.sum([scores[output_actions[j], j] for j in range(seq_len)]) outputs = [ sent.SimpleSentence( words=output_actions, idx=src[0].idx, vocab=self.trg_vocab if hasattr(self, "trg_vocab") else None, output_procs=self.trg_reader.output_procs, score=score) ] return outputs
def _generate_one_batch(self, generator: 'models.GeneratorModel', batcher: Optional[batchers.Batcher] = None, src_batch: batchers.Batch = None, max_src_len: Optional[int] = None, fp: TextIO = None): """ Generate outputs for a single batch and write them to the output file. """ batch_size = len(src_batch) src_batches = batcher.pack(src_batch, None) src_batch = src_batches[0] src_len = src_batch.sent_len() if max_src_len is not None and src_len > max_src_len: output_txt = "\n".join([NO_DECODING_ATTEMPTED] * batch_size) fp.write(f"{output_txt}\n") else: with utils.ReportOnException({ "src": src_batch, "graph": utils.print_cg_conditional }): dy.renew_cg(immediate_compute=settings.IMMEDIATE_COMPUTE, check_validity=settings.CHECK_VALIDITY) outputs = self.generate_one(generator, src_batch) if self.reporter: self._create_sent_report() for i in range(len(outputs)): output_txt = outputs[i].sent_str( custom_output_procs=self.post_processor) fp.write(f"{output_txt}\n")
def _forced_decode_one_batch(self, generator: 'models.GeneratorModel', batcher: Optional[batchers.Batcher] = None, src_batch: batchers.Batch = None, ref_batch: batchers.Batch = None, assert_scores: batchers.Batch = None, max_src_len: Optional[int] = None): """ Performs forced decoding for a single batch. """ batch_size = len(src_batch) src_batches, ref_batches = batcher.pack(src_batch, ref_batch) src_batch = src_batches[0] src_len = src_batch.sent_len() if max_src_len is None or src_len <= max_src_len is not None and src_len > max_src_len: with utils.ReportOnException({ "src": src_batch, "graph": utils.print_cg_conditional }): tt.reset_graph() outputs = self.generate_one(generator, src_batch) if self.reporter: self._create_sent_report() for i in range(len(outputs)): if assert_scores is not None: # If debugging forced decoding, make sure it matches assert batch_size == len( outputs ), "debug forced decoding not supported with nbest inference" if (abs(outputs[i].score - assert_scores[i]) / abs(assert_scores[i])) > 1e-5: raise ValueError( f'Forced decoding score {outputs[i].score} and loss {assert_scores[i]} do not match at ' f'sentence {i}')
def _generate_one_batch(self, generator: 'models.GeneratorModel', batcher: Optional[batchers.Batcher] = None, src_batch: batchers.Batch = None, max_src_len: Optional[int] = None, fp: TextIO = None): """ Generate outputs for a single batch and write them to the output file. """ batch_size = len(src_batch) src_batches = batcher.pack(src_batch, None) src_batch = src_batches[0] src_len = src_batch.sent_len() if max_src_len is not None and src_len > max_src_len: output_txt = "\n".join([NO_DECODING_ATTEMPTED] * batch_size) fp.write(f"{output_txt}\n") else: with utils.ReportOnException({ "src": src_batch, "graph": utils.print_cg_conditional }): tt.reset_graph() with torch.no_grad( ) if xnmt.backend_torch else utils.dummy_context_mgr(): outputs = self.generate_one(generator, src_batch) if self.reporter: self._create_sent_report() for i in range(len(outputs)): output_txt = outputs[i].sent_str( custom_output_procs=self.post_processor) fp.write(f"{output_txt}\n")
def generate_search_output(self, src: batchers.Batch, search_strategy: search_strategies.SearchStrategy, forced_trg_ids: batchers.Batch=None) -> List[search_strategies.SearchOutput]: """ Takes in a batch of source sentences and outputs a list of search outputs. Args: src: The source sentences search_strategy: The strategy with which to perform the search forced_trg_ids: The target IDs to generate if performing forced decoding Returns: A list of search outputs including scores, etc. """ if src.batch_size()!=1: raise NotImplementedError("batched decoding not implemented for DefaultTranslator. " "Specify inference batcher with batch size 1.") event_trigger.start_sent(src) all_src = src if isinstance(src, batchers.CompoundBatch): src = src.batches[0] # Generating outputs cur_forced_trg = None src_sent = src[0]#checkme sent_mask = None if src.mask: sent_mask = batchers.Mask(np_arr=src.mask.np_arr[0:1]) sent_batch = batchers.mark_as_batch([sent], mask=sent_mask) # Encode the sentence initial_state = self._encode_src(all_src) if forced_trg_ids is not None: cur_forced_trg = forced_trg_ids[0] search_outputs = search_strategy.generate_output(self, initial_state, src_length=[src_sent.sent_len()], forced_trg_ids=cur_forced_trg) return search_outputs
def generate_search_output( self, src: batchers.Batch, search_strategy: search_strategies.SearchStrategy ) -> List[search_strategies.SearchOutput]: """ Takes in a batch of source sentences and outputs a list of search outputs. Args: src: The source sentences search_strategy: The strategy with which to perform the search Returns: A list of search outputs including scores, etc. """ if src.batch_size() != 1: raise NotImplementedError( "batched decoding not implemented for DefaultTranslator. " "Specify inference batcher with batch size 1.") if isinstance(src, batchers.CompoundBatch): src = src.batches[0] search_outputs = search_strategy.generate_output( self, self._initial_state(src), src_length=src.sent_len()) return search_outputs
def generate( self, src: batchers.Batch, search_strategy: search_strategies.SearchStrategy, forced_trg_ids: batchers.Batch = None) -> Sequence[sent.Sentence]: """ Takes in a batch of source sentences and outputs a list of search outputs. Args: src: The source sentences search_strategy: The strategy with which to perform the search forced_trg_ids: The target IDs to generate if performing forced decoding Returns: A list of search outputs including scores, etc. """ assert src.batch_size() == 1 search_outputs = self.generate_search_output(src, search_strategy, forced_trg_ids) if isinstance(src, batchers.CompoundBatch): src = src.batches[0] sorted_outputs = sorted(search_outputs, key=lambda x: x.score[0], reverse=True) assert len(sorted_outputs) >= 1 outputs = [] for curr_output in sorted_outputs: output_actions = [x for x in curr_output.word_ids[0]] attentions = [x for x in curr_output.attentions[0]] score = curr_output.score[0] out_sent = sent.SimpleSentence( idx=src[0].idx, words=output_actions, vocab=getattr(self.trg_reader, "vocab", None), output_procs=self.trg_reader.output_procs, score=score) if len(sorted_outputs) == 1: outputs.append(out_sent) else: outputs.append( sent.NbestSentence(base_sent=out_sent, nbest_id=src[0].idx)) if self.is_reporting(): attentions = np.concatenate([x.npvalue() for x in attentions], axis=1) self.report_sent_info({ "attentions": attentions, "src": src[0], "output": outputs[0] }) return outputs
def _cut_or_pad_targets(self, seq_len: numbers.Integral, trg: batchers.Batch) -> batchers.Batch: old_mask = trg.mask if len(trg[0]) > seq_len: trunc_len = len(trg[0]) - seq_len trg = batchers.mark_as_batch([ trg_sent.get_truncated_sent(trunc_len=trunc_len) for trg_sent in trg ]) if old_mask: trg.mask = batchers.Mask( np_arr=old_mask.np_arr[:, :-trunc_len]) else: pad_len = seq_len - len(trg[0]) trg = batchers.mark_as_batch([ trg_sent.create_padded_sent(pad_len=pad_len) for trg_sent in trg ]) if old_mask: trg.mask = np.pad(old_mask.np_arr, pad_width=((0, 0), (0, pad_len)), mode="constant", constant_values=1) return trg
def generate( self, src: batchers.Batch, search_strategy: search_strategies.SearchStrategy ) -> Sequence[sent.Sentence]: """ Takes in a batch of source sentences and outputs a list of search outputs. Args: src: The source sentences search_strategy: The strategy with which to perform the search Returns: A list of search outputs including scores, etc. """ assert src.batch_size() == 1 event_trigger.start_sent(src) search_outputs = self.generate_search_output(src, search_strategy) if isinstance(src, batchers.CompoundBatch): src = src.batches[0] sorted_outputs = sorted(search_outputs, key=lambda x: x.score[0], reverse=True) assert len(sorted_outputs) >= 1 outputs = [] for curr_output in sorted_outputs: output_actions = [x for x in curr_output.word_ids[0]] attentions = [x for x in curr_output.attentions[0]] score = curr_output.score[0] out_sent = self._emit_translation(src, output_actions, score) if len(sorted_outputs) == 1: outputs.append(out_sent) else: outputs.append( sent.NbestSentence(base_sent=out_sent, nbest_id=src[0].idx)) if self.is_reporting(): attentions = np.concatenate([x.npvalue() for x in attentions], axis=1) self.report_sent_info({ "attentions": attentions, "src": src[0], "output": outputs[0] }) return outputs
def _generate_one_batch(self, generator: 'models.GeneratorModel', batcher: Optional[batchers.Batcher] = None, src_batch: batchers.Batch = None, ref_batch: Optional[batchers.Batch] = None, assert_scores: Optional[List[int]] = None, max_src_len: Optional[int] = None, fp: TextIO = None): """ Generate outputs for a single batch and write them to the output file. """ batch_size = len(src_batch) if ref_batch[0] is not None: src_batches, ref_batches = batcher.pack(src_batch, ref_batch) ref_batch = ref_batches[0] else: src_batches = batcher.pack(src_batch, None) ref_batch = None src_batch = src_batches[0] src_len = src_batch.sent_len() if max_src_len is not None and src_len > max_src_len: output_txt = "\n".join([NO_DECODING_ATTEMPTED] * batch_size) fp.write(f"{output_txt}\n") else: with utils.ReportOnException({"src": src_batch, "graph": utils.print_cg_conditional}): dy.renew_cg(immediate_compute=settings.IMMEDIATE_COMPUTE, check_validity=settings.CHECK_VALIDITY) outputs = self.generate_one(generator, src_batch, ref_batch) if self.reporter: self._create_sent_report() for i in range(len(outputs)): if assert_scores[0] is not None: # If debugging forced decoding, make sure it matches assert batch_size == len(outputs), "debug forced decoding not supported with nbest inference" if (abs(outputs[i].score - assert_scores[i]) / abs(assert_scores[i])) > 1e-5: raise ValueError( f'Forced decoding score {outputs[i].score} and loss {assert_scores[i]} do not match at ' f'sentence {i}') output_txt = outputs[i].sent_str(custom_output_procs=self.post_processor) fp.write(f"{output_txt}\n")