def initialize_generator(self, **kwargs): if kwargs.get("len_norm_type", None) is None: len_norm = xnmt.length_normalization.NoNormalization() else: len_norm = xnmt.serializer.YamlSerializer().initialize_if_needed(kwargs["len_norm_type"]) search_args = {} if kwargs.get("max_len", None) is not None: search_args["max_len"] = kwargs["max_len"] if kwargs.get("beam", None) is None: self.search_strategy = GreedySearch(**search_args) else: search_args["beam_size"] = kwargs.get("beam", 1) search_args["len_norm"] = len_norm self.search_strategy = BeamSearch(**search_args) self.report_path = kwargs.get("report_path", None) self.report_type = kwargs.get("report_type", None)
def assert_forced_decoding(self, sent_id): dy.renew_cg() outputs = self.model.generate_output( self.src_data[sent_id], sent_id, BeamSearch(), forced_trg_ids=self.trg_data[sent_id]) self.assertItemsEqual(self.trg_data[sent_id], outputs[0].actions)
def assert_forced_decoding(self, sent_id): dy.renew_cg() outputs = self.model.generate( xnmt.batcher.mark_as_batch([self.src_data[sent_id]]), [sent_id], BeamSearch(), forced_trg_ids=xnmt.batcher.mark_as_batch([self.trg_data[sent_id] ])) self.assertItemsEqual(self.trg_data[sent_id].words, outputs[0].actions)
def test_single(self): dy.renew_cg() train_loss = self.model.calc_loss(src=self.src_data[0], trg=self.trg_data[0], loss_calculator=MLELoss()).value() dy.renew_cg() self.model.initialize_generator() outputs = self.model.generate_output(self.src_data[0], 0, BeamSearch(beam_size=1), forced_trg_ids=self.trg_data[0]) self.assertAlmostEqual(-outputs[0].score, train_loss, places=4)
def test_single(self): dy.renew_cg() train_loss = self.model.calc_loss( src=self.src_data[0], trg=self.trg_data[0], loss_calculator=AutoRegressiveMLELoss()).value() dy.renew_cg() outputs = self.model.generate( xnmt.batcher.mark_as_batch([self.src_data[0]]), [0], BeamSearch(beam_size=1), forced_trg_ids=xnmt.batcher.mark_as_batch([self.trg_data[0]])) self.assertAlmostEqual(-outputs[0].score, train_loss, places=4)
def test_greedy_vs_beam(self): dy.renew_cg() outputs = self.model.generate( xnmt.batcher.mark_as_batch([self.src_data[0]]), [0], BeamSearch(beam_size=1), forced_trg_ids=xnmt.batcher.mark_as_batch([self.trg_data[0]])) output_score1 = outputs[0].score dy.renew_cg() outputs = self.model.generate( xnmt.batcher.mark_as_batch([self.src_data[0]]), [0], GreedySearch(), forced_trg_ids=xnmt.batcher.mark_as_batch([self.trg_data[0]])) output_score2 = outputs[0].score self.assertAlmostEqual(output_score1, output_score2)
def test_greedy_vs_beam(self): dy.renew_cg() self.model.initialize_generator() outputs = self.model.generate_output(self.src_data[0], 0, BeamSearch(beam_size=1), forced_trg_ids=self.trg_data[0]) output_score1 = outputs[0].score dy.renew_cg() self.model.initialize_generator() outputs = self.model.generate_output(self.src_data[0], 0, GreedySearch(), forced_trg_ids=self.trg_data[0]) output_score2 = outputs[0].score self.assertAlmostEqual(output_score1, output_score2)
class DefaultTranslator(Translator, Serializable, Reportable): ''' A default translator based on attentional sequence-to-sequence models. Args: src_reader (InputReader): A reader for the source side. trg_reader (InputReader): A reader for the target side. src_embedder (Embedder): A word embedder for the input language encoder (Transducer): An encoder to generate encoded inputs attender (Attender): An attention module trg_embedder (Embedder): A word embedder for the output language decoder (Decoder): A decoder inference (SimpleInference): The default inference strategy used for this model calc_global_fertility (bool): calc_attention_entropy (bool): ''' yaml_tag = '!DefaultTranslator' @register_xnmt_handler @serializable_init def __init__(self, src_reader, trg_reader, src_embedder=bare(SimpleWordEmbedder), encoder=bare(BiLSTMSeqTransducer), attender=bare(MlpAttender), trg_embedder=bare(SimpleWordEmbedder), decoder=bare(MlpSoftmaxDecoder), inference=bare(SimpleInference), calc_global_fertility=False, calc_attention_entropy=False): self.src_reader = src_reader self.trg_reader = trg_reader self.src_embedder = src_embedder self.encoder = encoder self.attender = attender self.trg_embedder = trg_embedder self.decoder = decoder self.calc_global_fertility = calc_global_fertility self.calc_attention_entropy = calc_attention_entropy self.inference = inference def shared_params(self): return [ set([".src_embedder.emb_dim", ".encoder.input_dim"]), set([ ".encoder.hidden_dim", ".attender.input_dim", ".decoder.input_dim" ]), set([".attender.state_dim", ".decoder.rnn_layer.hidden_dim"]), set([".trg_embedder.emb_dim", ".decoder.trg_embed_dim"]) ] def initialize_generator(self, **kwargs): # TODO: refactor? if kwargs.get("len_norm_type", None) is None: len_norm = xnmt.length_normalization.NoNormalization() else: len_norm = initialize_if_needed(kwargs["len_norm_type"]) search_args = {} if kwargs.get("max_len", None) is not None: search_args["max_len"] = kwargs["max_len"] if kwargs.get("beam", None) is None: self.search_strategy = GreedySearch(**search_args) else: search_args["beam_size"] = kwargs.get("beam", 1) search_args["len_norm"] = len_norm self.search_strategy = BeamSearch(**search_args) self.report_path = kwargs.get("report_path", None) self.report_type = kwargs.get("report_type", None) def calc_loss(self, src, trg, loss_calculator): self.start_sent(src) embeddings = self.src_embedder.embed_sent(src) encodings = self.encoder(embeddings) self.attender.init_sent(encodings) # Initialize the hidden state from the encoder ss = mark_as_batch([Vocab.SS] * len(src)) if is_batched(src) else Vocab.SS dec_state = self.decoder.initial_state(self.encoder.get_final_states(), self.trg_embedder.embed(ss)) # Compose losses model_loss = LossBuilder() model_loss.add_loss("mle", loss_calculator(self, dec_state, src, trg)) if self.calc_global_fertility or self.calc_attention_entropy: # philip30: I assume that attention_vecs is already masked src wisely. # Now applying the mask to the target masked_attn = self.attender.attention_vecs if trg.mask is not None: trg_mask = trg.mask.get_active_one_mask().transpose() masked_attn = [ dy.cmult(attn, dy.inputTensor(mask, batched=True)) for attn, mask in zip(masked_attn, trg_mask) ] if self.calc_global_fertility: model_loss.add_loss("fertility", self.global_fertility(masked_attn)) if self.calc_attention_entropy: model_loss.add_loss("H(attn)", self.attention_entropy(masked_attn)) return model_loss def generate(self, src, idx, src_mask=None, forced_trg_ids=None): if not xnmt.batcher.is_batched(src): src = xnmt.batcher.mark_as_batch([src]) else: assert src_mask is not None outputs = [] for sents in src: self.start_sent(src) embeddings = self.src_embedder.embed_sent(src) encodings = self.encoder(embeddings) self.attender.init_sent(encodings) ss = mark_as_batch([Vocab.SS] * len(src)) if is_batched(src) else Vocab.SS dec_state = self.decoder.initial_state( self.encoder.get_final_states(), self.trg_embedder.embed(ss)) output_actions, score = self.search_strategy.generate_output( self.decoder, self.attender, self.trg_embedder, dec_state, src_length=len(sents), forced_trg_ids=forced_trg_ids) # In case of reporting if self.report_path is not None: if self.reporting_src_vocab: src_words = [self.reporting_src_vocab[w] for w in sents] else: src_words = ['' for w in sents] trg_words = [ self.trg_vocab[w] for w in output_actions.word_ids ] # Attentions attentions = output_actions.attentions if type(attentions) == dy.Expression: attentions = attentions.npvalue() elif type(attentions) == list: attentions = np.concatenate( [x.npvalue() for x in attentions], axis=1) elif type(attentions) != np.ndarray: raise RuntimeError( "Illegal type for attentions in translator report: {}". format(type(attentions))) # Segmentation segment = self.get_report_resource("segmentation") if segment is not None: segment = [int(x[0]) for x in segment] src_inp = [ x[0] for x in self.encoder.apply_segmentation( src_words, segment) ] else: src_inp = src_words # Other Resources self.set_report_input(idx, src_inp, trg_words, attentions) self.set_report_resource("src_words", src_words) self.set_report_path('{}.{}'.format(self.report_path, str(idx))) self.generate_report(self.report_type) # Append output to the outputs outputs.append( TextOutput(actions=output_actions.word_ids, vocab=self.trg_vocab if hasattr(self, "trg_vocab") else None, score=score)) self.outputs = outputs return outputs def global_fertility(self, a): return dy.sum_elems(dy.square(1 - dy.esum(a))) def attention_entropy(self, a): EPS = 1e-10 entropy = [] for a_i in a: a_i += EPS entropy.append(dy.cmult(a_i, dy.log(a_i))) return -dy.sum_elems(dy.esum(entropy)) def set_reporting_src_vocab(self, src_vocab): """ Sets source vocab for reporting purposes. Args: src_vocab (Vocab): """ self.reporting_src_vocab = src_vocab @register_xnmt_event_assign def html_report(self, context=None): assert (context is None) idx, src, trg, att = self.get_report_input() path_to_report = self.get_report_path() html = etree.Element('html') head = etree.SubElement(html, 'head') title = etree.SubElement(head, 'title') body = etree.SubElement(html, 'body') report = etree.SubElement(body, 'h1') if idx is not None: title.text = report.text = 'Translation Report for Sentence %d' % ( idx) else: title.text = report.text = 'Translation Report' main_content = etree.SubElement(body, 'div', name='main_content') # Generating main content captions = ["Source Words", "Target Words"] inputs = [src, trg] for caption, inp in zip(captions, inputs): if inp is None: continue sent = ' '.join(inp) p = etree.SubElement(main_content, 'p') p.text = f"{caption}: {sent}" # Generating attention if not any([src is None, trg is None, att is None]): attention = etree.SubElement(main_content, 'p') att_text = etree.SubElement(attention, 'b') att_text.text = "Attention:" etree.SubElement(attention, 'br') attention_file = f"{path_to_report}.attention.png" xnmt.plot.plot_attention(src, trg, att, file_name=attention_file) # return the parent context to be used as child context return html @handle_xnmt_event def on_file_report(self): idx, src, trg, attn = self.get_report_input() assert attn.shape == (len(src), len(trg)) col_length = [] for word in trg: col_length.append(max(len(word), 6)) col_length.append(max(len(x) for x in src)) with open(self.get_report_path() + ".attention.txt", encoding='utf-8', mode='w') as attn_file: for i in range(len(src) + 1): if i == 0: words = trg + [""] else: words = ["%.4f" % (f) for f in attn[i - 1]] + [src[i - 1]] str_format = "" for length in col_length: str_format += "{:%ds}" % (length + 2) print(str_format.format(*words), file=attn_file)
class DefaultTranslator(Translator, Serializable, Reportable): ''' A default translator based on attentional sequence-to-sequence models. ''' yaml_tag = u'!DefaultTranslator' def __init__(self, src_embedder, encoder, attender, trg_embedder, decoder): '''Constructor. :param src_embedder: A word embedder for the input language :param encoder: An encoder to generate encoded inputs :param attender: An attention module :param trg_embedder: A word embedder for the output language :param decoder: A decoder ''' register_handler(self) self.src_embedder = src_embedder self.encoder = encoder self.attender = attender self.trg_embedder = trg_embedder self.decoder = decoder def shared_params(self): return [ set(["src_embedder.emb_dim", "encoder.input_dim"]), set([ "encoder.hidden_dim", "attender.input_dim", "decoder.input_dim" ]), set(["attender.state_dim", "decoder.lstm_dim"]), set(["trg_embedder.emb_dim", "decoder.trg_embed_dim"]) ] def initialize_generator(self, **kwargs): if kwargs.get("len_norm_type", None) is None: len_norm = xnmt.length_normalization.NoNormalization() else: len_norm = xnmt.serializer.YamlSerializer().initialize_if_needed( kwargs["len_norm_type"]) search_args = {} if kwargs.get("max_len", None) is not None: search_args["max_len"] = kwargs["max_len"] if kwargs.get("beam", None) is None: self.search_strategy = GreedySearch(**search_args) else: search_args["beam_size"] = kwargs.get("beam", 1) search_args["len_norm"] = len_norm self.search_strategy = BeamSearch(**search_args) self.report_path = kwargs.get("report_path", None) self.report_type = kwargs.get("report_type", None) def initialize_training_strategy(self, training_strategy): self.loss_calculator = training_strategy def calc_loss(self, src, trg): """ :param src: source sequence (unbatched, or batched + padded) :param trg: target sequence (unbatched, or batched + padded); losses will be accumulated only if trg_mask[batch,pos]==0, or no mask is set :returns: (possibly batched) loss expression """ assert hasattr(self, "loss_calculator") self.start_sent(src) embeddings = self.src_embedder.embed_sent(src) encodings = self.encoder(embeddings) self.attender.init_sent(encodings) # Initialize the hidden state from the encoder ss = mark_as_batch([Vocab.SS] * len(src)) if is_batched(src) else Vocab.SS dec_state = self.decoder.initial_state(self.encoder.get_final_states(), self.trg_embedder.embed(ss)) return self.loss_calculator(self, dec_state, src, trg) def generate(self, src, idx, src_mask=None, forced_trg_ids=None): if not xnmt.batcher.is_batched(src): src = xnmt.batcher.mark_as_batch([src]) else: assert src_mask is not None outputs = [] for sents in src: self.start_sent(src) embeddings = self.src_embedder.embed_sent(src) encodings = self.encoder(embeddings) self.attender.init_sent(encodings) ss = mark_as_batch([Vocab.SS] * len(src)) if is_batched(src) else Vocab.SS dec_state = self.decoder.initial_state( self.encoder.get_final_states(), self.trg_embedder.embed(ss)) output_actions, score = self.search_strategy.generate_output( self.decoder, self.attender, self.trg_embedder, dec_state, src_length=len(sents), forced_trg_ids=forced_trg_ids) # In case of reporting if self.report_path is not None: src_words = [self.reporting_src_vocab[w] for w in sents] trg_words = [self.trg_vocab[w] for w in output_actions[1:]] attentions = self.attender.attention_vecs self.set_report_input(idx, src_words, trg_words, attentions) self.set_report_resource("src_words", src_words) self.set_report_path('{}.{}'.format(self.report_path, str(idx))) self.generate_report(self.report_type) # Append output to the outputs outputs.append( TextOutput(actions=output_actions, vocab=self.trg_vocab if hasattr(self, "trg_vocab") else None, score=score)) return outputs def set_reporting_src_vocab(self, src_vocab): """ Sets source vocab for reporting purposes. """ self.reporting_src_vocab = src_vocab @register_xnmt_event_assign def html_report(self, context=None): assert (context is None) idx, src, trg, att = self.get_report_input() path_to_report = self.get_report_path() html = etree.Element('html') head = etree.SubElement(html, 'head') title = etree.SubElement(head, 'title') body = etree.SubElement(html, 'body') report = etree.SubElement(body, 'h1') if idx is not None: title.text = report.text = 'Translation Report for Sentence %d' % ( idx) else: title.text = report.text = 'Translation Report' main_content = etree.SubElement(body, 'div', name='main_content') # Generating main content captions = [u"Source Words", u"Target Words"] inputs = [src, trg] for caption, inp in six.moves.zip(captions, inputs): if inp is None: continue sent = ' '.join(inp) p = etree.SubElement(main_content, 'p') p.text = u"{}: {}".format(caption, sent) # Generating attention if not any([src is None, trg is None, att is None]): attention = etree.SubElement(main_content, 'p') att_text = etree.SubElement(attention, 'b') att_text.text = "Attention:" etree.SubElement(attention, 'br') attention_file = u"{}.attention.png".format(path_to_report) if type(att) == dy.Expression: attentions = att.npvalue() elif type(att) == list: attentions = np.concatenate([x.npvalue() for x in att], axis=1) elif type(att) != np.ndarray: raise RuntimeError( "Illegal type for attentions in translator report: {}". format(type(attentions))) plot.plot_attention(src, trg, attentions, file_name=attention_file) # return the parent context to be used as child context return html