def generate(self, src, idx, forced_trg_ids=None, normalize_scores=False): if not batcher.is_batched(src): src = batcher.mark_as_batch([src]) if forced_trg_ids: forced_trg_ids = batcher.mark_as_batch([forced_trg_ids]) assert src.batch_size() == 1, "batch size > 1 not properly tested" batch_size, encodings, outputs, seq_len = self._encode_src(src) score_expr = self.scorer.calc_log_softmax( outputs) if normalize_scores else self.scorer.calc_scores(outputs) scores = score_expr.npvalue() # vocab_size x seq_len if forced_trg_ids: output_actions = forced_trg_ids else: output_actions = [np.argmax(scores[:, j]) for j in range(seq_len)] score = np.sum([scores[output_actions[j], j] for j in range(seq_len)]) outputs = [ output.TextOutput( actions=output_actions, vocab=self.trg_vocab if hasattr(self, "trg_vocab") else None, score=score) ] return outputs
def assert_single_loss_equals_batch_loss(self, model, batch_size=5): """ Tests whether single loss equals batch loss. Here we don't truncate the target side and use masking. """ batch_size = 5 src_sents = self.training_corpus.train_src_data[:batch_size] src_min = min([len(x) for x in src_sents]) src_sents_trunc = [s[:src_min] for s in src_sents] for single_sent in src_sents_trunc: single_sent[src_min - 1] = Vocab.ES trg_sents = self.training_corpus.train_trg_data[:batch_size] trg_max = max([len(x) for x in trg_sents]) trg_masks = Mask(np.zeros([batch_size, trg_max])) for i in range(batch_size): for j in range(len(trg_sents[i]), trg_max): trg_masks.np_arr[i, j] = 1.0 trg_sents_padded = [[w for w in s] + [Vocab.ES] * (trg_max - len(s)) for s in trg_sents] single_loss = 0.0 for sent_id in range(batch_size): dy.renew_cg() train_loss = model.calc_loss(src=src_sents_trunc[sent_id], trg=trg_sents[sent_id]).value() single_loss += train_loss dy.renew_cg() batched_loss = model.calc_loss(src=mark_as_batch(src_sents_trunc), trg=mark_as_batch( trg_sents_padded, trg_masks)).value() self.assertAlmostEqual(single_loss, sum(batched_loss), places=4)
def assert_single_loss_equals_batch_loss(self, model, pad_src_to_multiple=1): """ Tests whether single loss equals batch loss. Truncating src / trg sents to same length so no masking is necessary """ batch_size=5 src_sents = self.src_data[:batch_size] src_min = min([x.sent_len() for x in src_sents]) src_sents_trunc = [s.words[:src_min] for s in src_sents] for single_sent in src_sents_trunc: single_sent[src_min-1] = Vocab.ES while len(single_sent)%pad_src_to_multiple != 0: single_sent.append(Vocab.ES) trg_sents = self.trg_data[:batch_size] trg_min = min([x.sent_len() for x in trg_sents]) trg_sents_trunc = [s.words[:trg_min] for s in trg_sents] for single_sent in trg_sents_trunc: single_sent[trg_min-1] = Vocab.ES src_sents_trunc = [SimpleSentenceInput(s) for s in src_sents_trunc] trg_sents_trunc = [SimpleSentenceInput(s) for s in trg_sents_trunc] single_loss = 0.0 for sent_id in range(batch_size): dy.renew_cg() train_loss = model.calc_loss(src=src_sents_trunc[sent_id], trg=trg_sents_trunc[sent_id], loss_calculator=AutoRegressiveMLELoss()).value() single_loss += train_loss dy.renew_cg() batched_loss = model.calc_loss(src=mark_as_batch(src_sents_trunc), trg=mark_as_batch(trg_sents_trunc), loss_calculator=AutoRegressiveMLELoss()).value() self.assertAlmostEqual(single_loss, np.sum(batched_loss), places=4)
def assert_single_loss_equals_batch_loss(self, model, batch_size=5): """ Tests whether single loss equals batch loss. Truncating src / trg sents to same length so no masking is necessary """ batch_size = 5 src_sents = self.training_corpus.train_src_data[:batch_size] src_min = min([len(x) for x in src_sents]) src_sents_trunc = [s[:src_min] for s in src_sents] for single_sent in src_sents_trunc: single_sent[src_min - 1] = Vocab.ES trg_sents = self.training_corpus.train_trg_data[:batch_size] trg_min = min([len(x) for x in trg_sents]) trg_sents_trunc = [s[:trg_min] for s in trg_sents] for single_sent in trg_sents_trunc: single_sent[trg_min - 1] = Vocab.ES single_loss = 0.0 for sent_id in range(batch_size): dy.renew_cg() train_loss = model.calc_loss(src=src_sents_trunc[sent_id], trg=trg_sents_trunc[sent_id]).value() single_loss += train_loss dy.renew_cg() batched_loss = model.calc_loss( src=mark_as_batch(src_sents_trunc), trg=mark_as_batch(trg_sents_trunc)).value() self.assertAlmostEqual(single_loss, sum(batched_loss), places=4)
def calc_loss(self, src, trg, loss_calculator): if not batcher.is_batched(src): src = batcher.ListBatch([src]) src_inputs = batcher.ListBatch([s[:-1] for s in src], mask=batcher.Mask(src.mask.np_arr[:,:-1]) if src.mask else None) src_targets = batcher.ListBatch([s[1:] for s in src], mask=batcher.Mask(src.mask.np_arr[:,1:]) if src.mask else None) self.start_sent(src) embeddings = self.src_embedder.embed_sent(src_inputs) encodings = self.rnn.transduce(embeddings) encodings_tensor = encodings.as_tensor() ((hidden_dim, seq_len), batch_size) = encodings.dim() encoding_reshaped = dy.reshape(encodings_tensor, (hidden_dim,), batch_size=batch_size * seq_len) outputs = self.transform(encoding_reshaped) ref_action = np.asarray([sent.words for sent in src_targets]).reshape((seq_len * batch_size,)) loss_expr_perstep = self.scorer.calc_loss(outputs, batcher.mark_as_batch(ref_action)) loss_expr_perstep = dy.reshape(loss_expr_perstep, (seq_len,), batch_size=batch_size) if src_targets.mask: loss_expr_perstep = dy.cmult(loss_expr_perstep, dy.inputTensor(1.0-src_targets.mask.np_arr.T, batched=True)) loss_expr = dy.sum_elems(loss_expr_perstep) model_loss = loss.FactoredLossExpr() model_loss.add_loss("mle", loss_expr) return model_loss
def generate(self, src, idx, src_mask=None, forced_trg_ids=None): if not xnmt.batcher.is_batched(src): src = xnmt.batcher.mark_as_batch([src]) else: assert src_mask is not None outputs = [] for sents in src: self.start_sent(src) embeddings = self.src_embedder.embed_sent(src) encodings = self.encoder(embeddings) self.attender.init_sent(encodings) ss = mark_as_batch([Vocab.SS] * len(src)) if is_batched(src) else Vocab.SS dec_state = self.decoder.initial_state(self.encoder.get_final_states(), self.trg_embedder.embed(ss)) output_actions, score = self.search_strategy.generate_output(self.decoder, self.attender, self.trg_embedder, dec_state, src_length=len(sents), forced_trg_ids=forced_trg_ids) # In case of reporting if self.report_path is not None: src_words = [self.reporting_src_vocab[w] for w in sents] trg_words = [self.trg_vocab[w] for w in output_actions[1:]] attentions = self.attender.attention_vecs self.set_report_input(idx, src_words, trg_words, attentions) self.set_report_resource("src_words", src_words) self.set_report_path('{}.{}'.format(self.report_path, str(idx))) self.generate_report(self.report_type) # Append output to the outputs outputs.append(TextOutput(actions=output_actions, vocab=self.trg_vocab if hasattr(self, "trg_vocab") else None, score=score)) return outputs
def calc_loss(self, src, trg, loss_calculator): self.start_sent(src) embeddings = self.src_embedder.embed_sent(src) encodings = self.encoder(embeddings) self.attender.init_sent(encodings) # Initialize the hidden state from the encoder ss = mark_as_batch([Vocab.SS] * len(src)) if is_batched(src) else Vocab.SS dec_state = self.decoder.initial_state(self.encoder.get_final_states(), self.trg_embedder.embed(ss)) # Compose losses model_loss = LossBuilder() model_loss.add_loss("mle", loss_calculator(self, dec_state, src, trg)) if self.calc_global_fertility or self.calc_attention_entropy: # philip30: I assume that attention_vecs is already masked src wisely. # Now applying the mask to the target masked_attn = self.attender.attention_vecs if trg.mask is not None: trg_mask = trg.mask.get_active_one_mask().transpose() masked_attn = [ dy.cmult(attn, dy.inputTensor(mask, batched=True)) for attn, mask in zip(masked_attn, trg_mask) ] if self.calc_global_fertility: model_loss.add_loss("fertility", self.global_fertility(masked_attn)) if self.calc_attention_entropy: model_loss.add_loss("H(attn)", self.attention_entropy(masked_attn)) return model_loss
def calc_loss(self, src, trg, loss_calculator): assert batcher.is_batched(src) and batcher.is_batched(trg) batch_size, encodings, outputs, seq_len = self._encode_src(src) if trg.sent_len() != seq_len: if self.auto_cut_pad: trg = self._cut_or_pad_targets(seq_len, trg) else: raise ValueError( f"src/trg length do not match: {seq_len} != {len(trg[0])}") ref_action = np.asarray([sent.words for sent in trg]).reshape( (seq_len * batch_size, )) loss_expr_perstep = self.scorer.calc_loss( outputs, batcher.mark_as_batch(ref_action)) # loss_expr_perstep = dy.pickneglogsoftmax_batch(outputs, ref_action) loss_expr_perstep = dy.reshape(loss_expr_perstep, (seq_len, ), batch_size=batch_size) if trg.mask: loss_expr_perstep = dy.cmult( loss_expr_perstep, dy.inputTensor(1.0 - trg.mask.np_arr.T, batched=True)) loss_expr = dy.sum_elems(loss_expr_perstep) model_loss = loss.FactoredLossExpr() model_loss.add_loss("mle", loss_expr) return model_loss
def generate(self, src, idx, forced_trg_ids=None, normalize_scores=False): if not batcher.is_batched(src): src = batcher.mark_as_batch([src]) if forced_trg_ids: forced_trg_ids = batcher.mark_as_batch([forced_trg_ids]) h = self._encode_src(src) scores = self.scorer.calc_log_probs(h) if normalize_scores else self.scorer.calc_scores(h) np_scores = scores.npvalue() if forced_trg_ids: output_action = forced_trg_ids else: output_action = np.argmax(np_scores, axis=0) outputs = [] for batch_i in range(src.batch_size()): score = np_scores[:, batch_i][output_action[batch_i]] outputs.append(output.ScalarOutput(actions=[output_action], vocab=None, score=score)) return outputs
def assert_single_loss_equals_batch_loss(self, model, pad_src_to_multiple=1): """ Tests whether single loss equals batch loss. Here we don't truncate the target side and use masking. """ batch_size = 5 src_sents = self.src_data[:batch_size] src_min = min([x.sent_len() for x in src_sents]) src_sents_trunc = [s.words[:src_min] for s in src_sents] for single_sent in src_sents_trunc: single_sent[src_min-1] = Vocab.ES while len(single_sent)%pad_src_to_multiple != 0: single_sent.append(Vocab.ES) trg_sents = sorted(self.trg_data[:batch_size], key=lambda x: x.sent_len(), reverse=True) trg_max = max([x.sent_len() for x in trg_sents]) np_arr = np.zeros([batch_size, trg_max]) for i in range(batch_size): for j in range(trg_sents[i].sent_len(), trg_max): np_arr[i,j] = 1.0 trg_masks = Mask(np_arr) trg_sents_padded = [[w for w in s] + [Vocab.ES]*(trg_max-s.sent_len()) for s in trg_sents] src_sents_trunc = [SimpleSentenceInput(s) for s in src_sents_trunc] trg_sents_padded = [SimpleSentenceInput(s) for s in trg_sents_padded] single_loss = 0.0 for sent_id in range(batch_size): dy.renew_cg() train_loss = model.calc_loss(src=src_sents_trunc[sent_id], trg=trg_sents[sent_id], loss_calculator=AutoRegressiveMLELoss()).value() single_loss += train_loss dy.renew_cg() batched_loss = model.calc_loss(src=mark_as_batch(src_sents_trunc), trg=mark_as_batch(trg_sents_padded, trg_masks), loss_calculator=AutoRegressiveMLELoss()).value() self.assertAlmostEqual(single_loss, np.sum(batched_loss), places=4)
def _cut_or_pad_targets(self, seq_len, trg): old_mask = trg.mask if len(trg[0]) > seq_len: trunc_len = len(trg[0]) - seq_len trg = batcher.mark_as_batch([ trg_sent.get_truncated_sent(trunc_len=trunc_len) for trg_sent in trg ]) if old_mask: trg.mask = batcher.Mask(np_arr=old_mask.np_arr[:, :-trunc_len]) else: pad_len = seq_len - len(trg[0]) trg = batcher.mark_as_batch([ trg_sent.get_padded_sent(token=vocab.Vocab.ES, pad_len=pad_len) for trg_sent in trg ]) if old_mask: trg.mask = np.pad(old_mask.np_arr, pad_width=((0, 0), (0, pad_len)), mode="constant", constant_values=1) return trg
def calc_loss(self, src, trg, loss_calculator): """ :param src: source sequence (unbatched, or batched + padded) :param trg: target sequence (unbatched, or batched + padded); losses will be accumulated only if trg_mask[batch,pos]==0, or no mask is set :param loss_calculator: :returns: (possibly batched) loss expression """ self.start_sent(src) embeddings = self.src_embedder.embed_sent(src) encodings = self.encoder(embeddings) self.attender.init_sent(encodings) # Initialize the hidden state from the encoder ss = mark_as_batch([Vocab.SS] * len(src)) if is_batched(src) else Vocab.SS dec_state = self.decoder.initial_state(self.encoder.get_final_states(), self.trg_embedder.embed(ss)) return loss_calculator(self, dec_state, src, trg)
def generate(self, src, idx, search_strategy, src_mask=None, forced_trg_ids=None): if not xnmt.batcher.is_batched(src): src = xnmt.batcher.mark_as_batch([src]) else: assert src_mask is not None # Generating outputs outputs = [] for sents in src: self.start_sent(src) embeddings = self.src_embedder.embed_sent(src) encodings = self.encoder(embeddings) self.attender.init_sent(encodings) ss = mark_as_batch([Vocab.SS] * len(src)) if is_batched(src) else Vocab.SS initial_state = self.decoder.initial_state(self.encoder.get_final_states(), self.trg_embedder.embed(ss)) search_outputs = search_strategy.generate_output(self, initial_state, src_length=[len(sents)], forced_trg_ids=forced_trg_ids) best_output = sorted(search_outputs, key=lambda x: x.score[0], reverse=True)[0] output_actions = [x for x in best_output.word_ids[0]] attentions = [x for x in best_output.attentions[0]] score = best_output.score[0] # In case of reporting if self.report_path is not None: if self.reporting_src_vocab: src_words = [self.reporting_src_vocab[w] for w in sents] else: src_words = ['' for w in sents] trg_words = [self.trg_vocab[w] for w in output_actions] # Attentions attentions = np.concatenate([x.npvalue() for x in attentions], axis=1) # Segmentation segment = self.get_report_resource("segmentation") if segment is not None: segment = [int(x[0]) for x in segment] src_inp = [x[0] for x in self.encoder.apply_segmentation(src_words, segment)] else: src_inp = src_words # Other Resources self.set_report_input(idx, src_inp, trg_words, attentions) self.set_report_resource("src_words", src_words) self.set_report_path('{}.{}'.format(self.report_path, str(idx))) self.generate_report(self.report_type) # Append output to the outputs outputs.append(TextOutput(actions=output_actions, vocab=self.trg_vocab if hasattr(self, "trg_vocab") else None, score=score)) self.outputs = outputs return outputs
def generate(self, src, idx, src_mask=None, forced_trg_ids=None): if not xnmt.batcher.is_batched(src): src = xnmt.batcher.mark_as_batch([src]) else: assert src_mask is not None outputs = [] for sents in src: self.start_sent(src) embeddings = self.src_embedder.embed_sent(src) encodings = self.encoder(embeddings) self.attender.init_sent(encodings) ss = mark_as_batch([Vocab.SS] * len(src)) if is_batched(src) else Vocab.SS dec_state = self.decoder.initial_state(self.encoder.get_final_states(), self.trg_embedder.embed(ss)) output_actions, score = self.search_strategy.generate_output(self.decoder, self.attender, self.trg_embedder, dec_state, src_length=len(sents), forced_trg_ids=forced_trg_ids) # In case of reporting if self.report_path is not None: if self.reporting_src_vocab: src_words = [self.reporting_src_vocab[w] for w in sents] else: src_words = ['' for w in sents] trg_words = [self.trg_vocab[w] for w in output_actions.word_ids] # Attentions attentions = output_actions.attentions if type(attentions) == dy.Expression: attentions = attentions.npvalue() elif type(attentions) == list: attentions = np.concatenate([x.npvalue() for x in attentions], axis=1) elif type(attentions) != np.ndarray: raise RuntimeError("Illegal type for attentions in translator report: {}".format(type(attentions))) # Segmentation segment = self.get_report_resource("segmentation") if segment is not None: segment = [int(x[0]) for x in segment] src_inp = [x[0] for x in self.encoder.apply_segmentation(src_words, segment)] else: src_inp = src_words # Other Resources self.set_report_input(idx, src_inp, trg_words, attentions) self.set_report_resource("src_words", src_words) self.set_report_path('{}.{}'.format(self.report_path, str(idx))) self.generate_report(self.report_type) # Append output to the outputs outputs.append(TextOutput(actions=output_actions.word_ids, vocab=self.trg_vocab if hasattr(self, "trg_vocab") else None, score=score)) self.outputs = outputs return outputs
def _encode_src(self, src): embeddings = self.src_embedder.embed_sent(src) # We assume that the encoder can generate multiple possible encodings encodings = self.encoder.transduce(embeddings) # Most cases, it falls here where the encoder just generate 1 encodings if type(encodings) != CompoundSeqExpression: encodings = CompoundSeqExpression([encodings]) final_states = [self.encoder.get_final_states()] else: final_states = self.encoder.get_final_states() initial_states = [] for encoding, final_state in zip(encodings, final_states): self.attender.init_sent(encoding) ss = mark_as_batch( [Vocab.SS] * src.batch_size()) if is_batched(src) else Vocab.SS initial_states.append( self.decoder.initial_state(final_state, self.trg_embedder.embed(ss))) return CompoundSeqExpression(initial_states)
def calc_loss(self, src, trg, loss_calculator): self.start_sent(src) tokens = [x[0] for x in src] transitions = [x[1] for x in src] print("Current Batch: " + str(len(tokens)) + " pairs.\n") is_batched = xnmt.batcher.is_batched(src) tokens = xnmt.batcher.mark_as_batch(tokens) embeddings = self.src_embedder.embed_sent(tokens) encodings = self.encoder(embeddings, transitions) self.attender.init_sent(encodings) #import pdb;pdb.set_trace() # Initialize the hidden state from the encoder ss = mark_as_batch( [Vocab.SS] * len(tokens)) if xnmt.batcher.is_batched(src) else Vocab.SS dec_state = self.decoder.initial_state(self.encoder._final_states, self.trg_embedder.embed(ss)) # Compose losses model_loss = LossBuilder() loss, wer = loss_calculator(self, dec_state, src, trg) model_loss.add_loss("mle", loss) print("wer_b:" + str(wer)) if self.calc_global_fertility or self.calc_attention_entropy: # philip30: I assume that attention_vecs is already masked src wisely. # Now applying the mask to the target masked_attn = self.attender.attention_vecs if trg.mask is not None: trg_mask = trg.mask.get_active_one_mask().transpose() masked_attn = [ dy.cmult(attn, dy.inputTensor(mask, batched=True)) for attn, mask in zip(masked_attn, trg_mask) ] if self.calc_global_fertility: model_loss.add_loss("fertility", self.global_fertility(masked_attn)) if self.calc_attention_entropy: model_loss.add_loss("H(attn)", self.attention_entropy(masked_attn)) return model_loss
def generate(self, src: Batch, idx: Sequence[int], search_strategy: SearchStrategy, forced_trg_ids: Batch = None): if src.batch_size() != 1: raise NotImplementedError( "batched decoding not implemented for DefaultTranslator. " "Specify inference batcher with batch size 1.") assert src.batch_size() == len( idx), f"src: {src.batch_size()}, idx: {len(idx)}" # Generating outputs self.start_sent(src) outputs = [] cur_forced_trg = None sent = src[0] sent_mask = None if src.mask: sent_mask = Mask(np_arr=src.mask.np_arr[0:1]) sent_batch = mark_as_batch([sent], mask=sent_mask) # TODO MBR can be implemented here. It takes only the first result from the encoder # To further implement MBR, we need to handle the generation considering multiple encoder output. initial_state = self._encode_src(sent_batch)[0] if forced_trg_ids is not None: cur_forced_trg = forced_trg_ids[0] search_outputs = search_strategy.generate_output( self, initial_state, src_length=[sent.sent_len()], forced_trg_ids=cur_forced_trg) sorted_outputs = sorted(search_outputs, key=lambda x: x.score[0], reverse=True) assert len(sorted_outputs) >= 1 for curr_output in sorted_outputs: output_actions = [x for x in curr_output.word_ids[0]] attentions = [x for x in curr_output.attentions[0]] score = curr_output.score[0] if len(sorted_outputs) == 1: outputs.append( TextOutput(actions=output_actions, vocab=getattr(self.trg_reader, "vocab", None), score=score)) else: outputs.append( NbestOutput(TextOutput(actions=output_actions, vocab=getattr( self.trg_reader, "vocab", None), score=score), nbest_id=idx[0])) if self.compute_report: attentions = np.concatenate([x.npvalue() for x in attentions], axis=1) self.add_sent_for_report({ "idx": idx[0], "attentions": attentions, "src": sent, "src_vocab": getattr(self.src_reader, "vocab", None), "trg_vocab": getattr(self.trg_reader, "vocab", None), "output": outputs[0] }) return outputs