def embed_sent(self, x: sent.Sentence) -> expression_seqs.ExpressionSequence: # TODO refactor: seems a bit too many special cases that need to be distinguished batched = batchers.is_batched(x) first_sent = x[0] if batched else x if hasattr(first_sent, "get_array"): if not batched: return expression_seqs.LazyNumpyExpressionSequence( lazy_data=x.get_array()) else: return expression_seqs.LazyNumpyExpressionSequence( lazy_data=batchers.mark_as_batch([s for s in x]), mask=x.mask) else: if not batched: embeddings = [self.embed(word) for word in x] else: embeddings = [] for word_i in range(x.sent_len()): embeddings.append( self.embed( batchers.mark_as_batch( [single_sent[word_i] for single_sent in x]))) return expression_seqs.ExpressionSequence(expr_list=embeddings, mask=x.mask)
def generate( self, src: batchers.Batch, forced_trg_ids: Sequence[numbers.Integral] = None, normalize_scores: bool = False) -> Sequence[sent.ReadableSentence]: if not batchers.is_batched(src): src = batchers.mark_as_batch([src]) if forced_trg_ids: forced_trg_ids = batchers.mark_as_batch([forced_trg_ids]) assert src.batch_size() == 1, "batch size > 1 not properly tested" batch_size, encodings, outputs, seq_len = self._encode_src(src) score_expr = self.scorer.calc_log_softmax( outputs) if normalize_scores else self.scorer.calc_scores(outputs) scores = score_expr.npvalue() # vocab_size x seq_len if forced_trg_ids: output_actions = forced_trg_ids else: output_actions = [np.argmax(scores[:, j]) for j in range(seq_len)] score = np.sum([scores[output_actions[j], j] for j in range(seq_len)]) outputs = [ sent.SimpleSentence( words=output_actions, idx=src[0].idx, vocab=self.trg_vocab if hasattr(self, "trg_vocab") else None, output_procs=self.trg_reader.output_procs, score=score) ] return outputs
def generate(self, src: Union[batchers.Batch, sent.Sentence], forced_trg_ids: Optional[Sequence[numbers.Integral]] = None, normalize_scores: bool = False): if not batchers.is_batched(src): src = batchers.mark_as_batch([src]) if forced_trg_ids: forced_trg_ids = batchers.mark_as_batch([forced_trg_ids]) h = self._encode_src(src) scores = self.scorer.calc_log_probs( h) if normalize_scores else self.scorer.calc_scores(h) np_scores = scores.npvalue() if forced_trg_ids: output_action = forced_trg_ids else: output_action = np.argmax(np_scores, axis=0) outputs = [] for batch_i in range(src.batch_size()): if src.batch_size() > 1: my_action = output_action[batch_i] score = np_scores[:, batch_i][my_action] else: my_action = output_action score = np_scores[my_action] outputs.append(sent.ScalarSentence(value=my_action, score=score)) return outputs
def assert_forced_decoding(self, sent_id): dy.renew_cg() outputs = self.model.generate( batchers.mark_as_batch([self.src_data[sent_id]]), BeamSearch(), forced_trg_ids=batchers.mark_as_batch([self.trg_data[sent_id]])) self.assertItemsEqual(self.trg_data[sent_id].words, outputs[0].words)
def generate(self, src, forced_trg_ids=None, search_strategy=None): event_trigger.start_sent(src) if not batchers.is_batched(src): src = batchers.mark_as_batch([src]) outputs = [] trg = sent.SimpleSentence([0]) if not batchers.is_batched(trg): trg = batchers.mark_as_batch([trg]) output_actions = [] score = 0. # TODO Fix this with generate_one_step and use the appropriate search_strategy self.max_len = 100 # This is a temporary hack for _ in range(self.max_len): dy.renew_cg(immediate_compute=settings.IMMEDIATE_COMPUTE, check_validity=settings.CHECK_VALIDITY) log_prob_tail = self.calc_loss(src, trg, loss_cal=None, infer_prediction=True) ys = np.argmax(log_prob_tail.npvalue(), axis=0).astype('i') if ys == Vocab.ES: output_actions.append(ys) break output_actions.append(ys) trg = sent.SimpleSentence(words=output_actions + [0]) if not batchers.is_batched(trg): trg = batchers.mark_as_batch([trg]) # Append output to the outputs if hasattr(self, "trg_vocab") and self.trg_vocab is not None: outputs.append(sent.SimpleSentence(words=output_actions, vocab=self.trg_vocab)) else: outputs.append((output_actions, score)) return outputs
def _batch_max_action(self, batch_size, current_state, pos): if pos == 0: return None elif batch_size > 1: return batchers.mark_as_batch( np.argmax(current_state.out_prob.npvalue(), axis=0)) else: return batchers.mark_as_batch( [np.argmax(current_state.out_prob.npvalue(), axis=0)])
def test_single(self): dy.renew_cg() train_loss = self.model.calc_nll(src=self.src_data[0], trg=self.trg_data[0]).value() dy.renew_cg() outputs = self.model.generate( batchers.mark_as_batch([self.src_data[0]]), BeamSearch(beam_size=1), forced_trg_ids=batchers.mark_as_batch([self.trg_data[0]])) self.assertAlmostEqual(-outputs[0].score, train_loss, places=4)
def calc_loss(self, src, trg, infer_prediction=False): event_trigger.start_sent(src) if not batchers.is_batched(src): src = batchers.mark_as_batch([src]) if not batchers.is_batched(trg): trg = batchers.mark_as_batch([trg]) src_words = np.array([[vocabs.Vocab.SS] + x.words for x in src]) batch_size, src_len = src_words.shape if isinstance(src.mask, type(None)): src_mask = np.zeros((batch_size, src_len), dtype=np.int) else: src_mask = np.concatenate([ np.zeros((batch_size, 1), dtype=np.int), src.mask.np_arr.astype(np.int) ], axis=1) src_embeddings = self.sentence_block_embed( self.src_embedder.embeddings, src_words, src_mask) src_embeddings = self.make_input_embedding(src_embeddings, src_len) trg_words = np.array( list(map(lambda x: [vocabs.Vocab.SS] + x.words[:-1], trg))) batch_size, trg_len = trg_words.shape if isinstance(trg.mask, type(None)): trg_mask = np.zeros((batch_size, trg_len), dtype=np.int) else: trg_mask = trg.mask.np_arr.astype(np.int) trg_embeddings = self.sentence_block_embed( self.trg_embedder.embeddings, trg_words, trg_mask) trg_embeddings = self.make_input_embedding(trg_embeddings, trg_len) xx_mask = self.make_attention_mask(src_mask, src_mask) xy_mask = self.make_attention_mask(trg_mask, src_mask) yy_mask = self.make_attention_mask(trg_mask, trg_mask) yy_mask *= self.make_history_mask(trg_mask) z_blocks = self.encoder.transduce(src_embeddings, xx_mask) h_block = self.decoder(trg_embeddings, z_blocks, xy_mask, yy_mask) if infer_prediction: y_len = h_block.dim()[0][1] last_col = dy.pick(h_block, dim=1, index=y_len - 1) logits = self.decoder.output(last_col) return logits ref_list = list( itertools.chain.from_iterable(map(lambda x: x.words, trg))) concat_t_block = (1 - trg_mask.ravel()).reshape(-1) * np.array(ref_list) loss = self.decoder.output_and_loss(h_block, concat_t_block) return losses.FactoredLossExpr({"mle": loss})
def calc_loss( self, model: 'model_base.ConditionedModel', src: Union[sent.Sentence, 'batchers.Batch'], trg: Union[sent.Sentence, 'batchers.Batch']) -> losses.FactoredLossExpr: if not batchers.is_batched(src): src = batchers.mark_as_batch([src]) if not batchers.is_batched(trg): trg = batchers.mark_as_batch([trg]) event_trigger.start_sent(src) return self._perform_calc_loss(model, src, trg)
def assert_single_loss_equals_batch_loss(self, model, pad_src_to_multiple=1): """ Tests whether single loss equals batch loss. Here we don't truncate the target side and use masking. """ batch_size = 5 src_sents = self.src_data[:batch_size] src_min = min([x.sent_len() for x in src_sents]) src_sents_trunc = [s.words[:src_min] for s in src_sents] for single_sent in src_sents_trunc: single_sent[src_min - 1] = Vocab.ES while len(single_sent) % pad_src_to_multiple != 0: single_sent.append(Vocab.ES) trg_sents = sorted(self.trg_data[:batch_size], key=lambda x: x.sent_len(), reverse=True) trg_max = max([x.sent_len() for x in trg_sents]) np_arr = np.zeros([batch_size, trg_max]) for i in range(batch_size): for j in range(trg_sents[i].sent_len(), trg_max): np_arr[i, j] = 1.0 trg_masks = Mask(np_arr) trg_sents_padded = [[w for w in s] + [Vocab.ES] * (trg_max - s.sent_len()) for s in trg_sents] src_sents_trunc = [ sent.SimpleSentence(words=s) for s in src_sents_trunc ] trg_sents_padded = [ sent.SimpleSentence(words=s) for s in trg_sents_padded ] single_loss = 0.0 for sent_id in range(batch_size): dy.renew_cg() train_loss, _ = MLELoss().calc_loss( model=model, src=src_sents_trunc[sent_id], trg=trg_sents[sent_id]).compute() single_loss += train_loss.value() dy.renew_cg() batched_loss, _ = MLELoss().calc_loss( model=model, src=mark_as_batch(src_sents_trunc), trg=mark_as_batch(trg_sents_padded, trg_masks)).compute() self.assertAlmostEqual(single_loss, np.sum(batched_loss.value()), places=4)
def _cut_or_pad_targets(self, seq_len, trg): old_mask = trg.mask if len(trg[0]) > seq_len: trunc_len = len(trg[0]) - seq_len trg = batchers.mark_as_batch([trg_sent.get_truncated_sent(trunc_len=trunc_len) for trg_sent in trg]) if old_mask: trg.mask = batchers.Mask(np_arr=old_mask.np_arr[:, :-trunc_len]) else: pad_len = seq_len - len(trg[0]) trg = batchers.mark_as_batch([trg_sent.create_padded_sent(pad_len=pad_len) for trg_sent in trg]) if old_mask: trg.mask = np.pad(old_mask.np_arr, pad_width=((0, 0), (0, pad_len)), mode="constant", constant_values=1) return trg
def test_greedy_vs_beam(self): dy.renew_cg() outputs = self.model.generate( batchers.mark_as_batch([self.src_data[0]]), BeamSearch(beam_size=1)) output_score1 = outputs[0].score dy.renew_cg() outputs = self.model.generate( batchers.mark_as_batch([self.src_data[0]]), GreedySearch()) output_score2 = outputs[0].score self.assertAlmostEqual(output_score1, output_score2)
def embed_sent(self, x) -> expression_seqs.ExpressionSequence: """Embed a full sentence worth of words. By default, just do a for loop. Args: x: This will generally be a list of word IDs, but could also be a list of strings or some other format. It could also be batched, in which case it will be a (possibly masked) :class:`xnmt.batcher.Batch` object Returns: An expression sequence representing vectors of each word in the input. """ # single mode if not batchers.is_batched(x): embeddings = [self.embed(word) for word in x] # minibatch mode else: embeddings = [] seq_len = x.sent_len() for single_sent in x: assert single_sent.sent_len() == seq_len for word_i in range(seq_len): batch = batchers.mark_as_batch( [single_sent[word_i] for single_sent in x]) embeddings.append(self.embed(batch)) return expression_seqs.ExpressionSequence( expr_list=embeddings, mask=x.mask if batchers.is_batched(x) else None)
def calc_nll(self, src: Union[batchers.Batch, sent.Sentence], trg: Union[batchers.Batch, sent.Sentence]) \ -> dy.Expression: assert batchers.is_batched(src) and batchers.is_batched(trg) batch_size, encodings, outputs, seq_len = self._encode_src(src) if trg.sent_len() != seq_len: if self.auto_cut_pad: trg = self._cut_or_pad_targets(seq_len, trg) else: raise ValueError( f"src/trg length do not match: {seq_len} != {len(trg[0])}") ref_action = np.asarray([trg_sent.words for trg_sent in trg]).reshape( (seq_len * batch_size, )) loss_expr_perstep = self.scorer.calc_loss( outputs, batchers.mark_as_batch(ref_action)) # loss_expr_perstep = dy.pickneglogsoftmax_batch(outputs, ref_action) loss_expr_perstep = dy.reshape(loss_expr_perstep, (seq_len, ), batch_size=batch_size) if trg.mask: loss_expr_perstep = dy.cmult( loss_expr_perstep, dy.inputTensor(1.0 - trg.mask.np_arr.T, batched=True)) loss_expr = dy.sum_elems(loss_expr_perstep) return loss_expr
def calc_nll(self, src: Union[batchers.Batch, sent.Sentence], trg: Union[batchers.Batch, sent.Sentence]) \ -> tt.Tensor: if not batchers.is_batched(src): src = batchers.ListBatch([src]) src_inputs = batchers.ListBatch( [s[:-1] for s in src], mask=batchers.Mask(src.mask.np_arr[:, :-1]) if src.mask else None) src_targets = batchers.ListBatch( [s[1:] for s in src], mask=batchers.Mask(src.mask.np_arr[:, 1:]) if src.mask else None) event_trigger.start_sent(src) embeddings = self.src_embedder.embed_sent(src_inputs) encodings = self.rnn.transduce(embeddings) encodings_tensor = encodings.as_tensor() encoding_reshaped = tt.merge_time_batch_dims(encodings_tensor) seq_len = tt.sent_len(encodings_tensor) batch_size = tt.batch_size(encodings_tensor) outputs = self.transform.transform(encoding_reshaped) ref_action = np.asarray([sent.words for sent in src_targets]).reshape( (seq_len * batch_size, )) loss_expr_perstep = self.scorer.calc_loss( outputs, batchers.mark_as_batch(ref_action)) loss_expr_perstep = tt.unmerge_time_batch_dims(loss_expr_perstep, batch_size) loss = tt.aggregate_masked_loss(loss_expr_perstep, src_targets.mask) return loss
def generate( self, src: batchers.Batch, normalize_scores: bool = False) -> Sequence[sent.ReadableSentence]: if not batchers.is_batched(src): src = batchers.mark_as_batch([src]) assert src.batch_size() == 1, "batch size > 1 not properly tested" batch_size, encodings, outputs, seq_len = self._encode_src(src) best_words, best_scores = self.scorer.best_k( outputs, k=1, normalize_scores=normalize_scores) best_words = best_words[0, :] score = np.sum(best_scores, axis=1) outputs = [ sent.SimpleSentence( words=best_words, idx=src[0].idx, vocab=self.trg_vocab if hasattr(self, "trg_vocab") else None, output_procs=self.trg_reader.output_procs, score=score) ] return outputs
def generate_search_output(self, src: batchers.Batch, search_strategy: search_strategies.SearchStrategy, forced_trg_ids: batchers.Batch=None) -> List[search_strategies.SearchOutput]: """ Takes in a batch of source sentences and outputs a list of search outputs. Args: src: The source sentences search_strategy: The strategy with which to perform the search forced_trg_ids: The target IDs to generate if performing forced decoding Returns: A list of search outputs including scores, etc. """ if src.batch_size()!=1: raise NotImplementedError("batched decoding not implemented for DefaultTranslator. " "Specify inference batcher with batch size 1.") event_trigger.start_sent(src) all_src = src if isinstance(src, batchers.CompoundBatch): src = src.batches[0] # Generating outputs cur_forced_trg = None src_sent = src[0]#checkme sent_mask = None if src.mask: sent_mask = batchers.Mask(np_arr=src.mask.np_arr[0:1]) sent_batch = batchers.mark_as_batch([sent], mask=sent_mask) # Encode the sentence initial_state = self._encode_src(all_src) if forced_trg_ids is not None: cur_forced_trg = forced_trg_ids[0] search_outputs = search_strategy.generate_output(self, initial_state, src_length=[src_sent.sent_len()], forced_trg_ids=cur_forced_trg) return search_outputs
def calc_nll(self, src, trg): if not batchers.is_batched(src): src = batchers.ListBatch([src]) src_inputs = batchers.ListBatch( [s[:-1] for s in src], mask=batchers.Mask(src.mask.np_arr[:, :-1]) if src.mask else None) src_targets = batchers.ListBatch( [s[1:] for s in src], mask=batchers.Mask(src.mask.np_arr[:, 1:]) if src.mask else None) event_trigger.start_sent(src) embeddings = self.src_embedder.embed_sent(src_inputs) encodings = self.rnn.transduce(embeddings) encodings_tensor = encodings.as_tensor() ((hidden_dim, seq_len), batch_size) = encodings.dim() encoding_reshaped = dy.reshape(encodings_tensor, (hidden_dim, ), batch_size=batch_size * seq_len) outputs = self.transform.transform(encoding_reshaped) ref_action = np.asarray([sent.words for sent in src_targets]).reshape( (seq_len * batch_size, )) loss_expr_perstep = self.scorer.calc_loss( outputs, batchers.mark_as_batch(ref_action)) loss_expr_perstep = dy.reshape(loss_expr_perstep, (seq_len, ), batch_size=batch_size) if src_targets.mask: loss_expr_perstep = dy.cmult( loss_expr_perstep, dy.inputTensor(1.0 - src_targets.mask.np_arr.T, batched=True)) loss = dy.sum_elems(loss_expr_perstep) return loss
def test_composite(self): event_trigger.set_train(True) composite_loss = loss_calculators.CompositeLoss([loss_calculators.MLELoss(), loss_calculators.PolicyMLELoss()]) composite_loss.calc_loss(self.model, self.src[0], self.trg[0]) event_trigger.set_train(False) self.model.generate(batchers.mark_as_batch([self.src_data[0]]), GreedySearch())
def _embed_word(self, word: sent.SegmentedWord, is_batched: bool = False): char_embeds = self.embeddings.batch(batchers.mark_as_batch(word.chars)) char_embeds = [ dy.pick_batch_elem(char_embeds, i) for i in range(len(word.chars)) ] return self.composer.compose(char_embeds)
def test_train_mle_only(self): self.model.policy_network = None event_trigger.set_train(True) mle_loss = loss_calculators.MLELoss() mle_loss.calc_loss(self.model, self.src[0], self.trg[0]) event_trigger.set_train(False) self.model.generate(batchers.mark_as_batch([self.src_data[0]]), GreedySearch())
def _encode_src(self, src: Union[batchers.Batch, sent.Sentence]): embeddings = self.src_embedder.embed_sent(src) encoding = self.encoder.transduce(embeddings) final_state = self.encoder.get_final_states() self.attender.init_sent(encoding) ss = batchers.mark_as_batch([Vocab.SS] * src.batch_size()) if batchers.is_batched(src) else Vocab.SS initial_state = self.decoder.initial_state(final_state, self.trg_embedder.embed(ss)) return initial_state
def assert_single_loss_equals_batch_loss(self, model, pad_src_to_multiple=1): """ Tests whether single loss equals batch loss. Truncating src / trg sents to same length so no masking is necessary """ batch_size = 5 src_sents = self.src_data[:batch_size] src_min = min([x.sent_len() for x in src_sents]) src_sents_trunc = [s.words[:src_min] for s in src_sents] for single_sent in src_sents_trunc: single_sent[src_min - 1] = Vocab.ES while len(single_sent) % pad_src_to_multiple != 0: single_sent.append(Vocab.ES) trg_sents = self.trg_data[:batch_size] trg_min = min([x.sent_len() for x in trg_sents]) trg_sents_trunc = [s.words[:trg_min] for s in trg_sents] for single_sent in trg_sents_trunc: single_sent[trg_min - 1] = Vocab.ES src_sents_trunc = [ sent.SimpleSentence(words=s) for s in src_sents_trunc ] trg_sents_trunc = [ sent.SimpleSentence(words=s) for s in trg_sents_trunc ] single_loss = 0.0 for sent_id in range(batch_size): dy.renew_cg() train_loss, _ = MLELoss().calc_loss( model=model, src=src_sents_trunc[sent_id], trg=trg_sents_trunc[sent_id]).compute() single_loss += train_loss.value() dy.renew_cg() batched_loss, _ = MLELoss().calc_loss( model=model, src=mark_as_batch(src_sents_trunc), trg=mark_as_batch(trg_sents_trunc)).compute() self.assertAlmostEqual(single_loss, np.sum(batched_loss.value()), places=4)
def _select_ref_words(sent, index, truncate_masked = False): if truncate_masked: mask = sent.mask if batchers.is_batched(sent) else None if not batchers.is_batched(sent): return sent[index] else: ret = [] found_masked = False for (j, single_trg) in enumerate(sent): if mask is None or mask.np_arr[j, index] == 0 or np.sum(mask.np_arr[:, index]) == mask.np_arr.shape[0]: assert not found_masked, "sentences must be sorted by decreasing target length" ret.append(single_trg[index]) else: found_masked = True return batchers.mark_as_batch(ret) else: if not batchers.is_batched(sent): return sent[index] else: return batchers.mark_as_batch([single_trg[index] for single_trg in sent])
def _batch_ref_action(self, pos): ref_action = [] for src_sent in self.cur_src.batches[1]: if src_sent[pos] is None: ref_action.append(vocabs.Vocab.ES) else: ref_action.append(src_sent[pos]) ref_action = batchers.mark_as_batch(ref_action) return ref_action
def test_train_nll(self): event_trigger.set_train(True) mle_loss = loss_calculators.MLELoss() mle_loss.calc_loss(self.model, self.src[0], self.trg[0]) pol_loss = loss_calculators.PolicyMLELoss() pol_loss._perform_calc_loss(self.model, self.src[0], self.trg[0]) event_trigger.set_train(False) self.model.generate(batchers.mark_as_batch([self.src_data[0]]), GreedySearch())
def test_single(self): dy.renew_cg() outputs = self.model.generate(batchers.mark_as_batch([self.src_data[0]]), GreedySearch()) output_score = outputs[0].score dy.renew_cg() train_loss = self.model.calc_nll(src=self.src_data[0], trg=outputs[0]).value() self.assertAlmostEqual(-output_score, train_loss, places=3)
def test_single(self): tt.reset_graph() outputs = self.model.generate( batchers.mark_as_batch([self.src_data[0]]), GreedySearch()) output_score = outputs[0].score tt.reset_graph() train_loss = tt.npvalue( self.model.calc_nll(src=self.src_data[0], trg=outputs[0])) self.assertAlmostEqual(-output_score, train_loss[0], places=3)
def calc_loss(self, dec_state, ref_action): state = self._calc_transform(dec_state) action_batch = batchers.mark_as_batch( [x.action_type.value for x in ref_action]) action_type = ref_action[0].action_type loss = self.action_scorer.calc_loss(state, action_batch) # Aux Losses based on action content if action_type == sent.RNNGAction.Type.NT: nt_batch = batchers.mark_as_batch( [x.action_content for x in ref_action]) loss += self.nt_scorer.calc_loss(state, nt_batch) elif action_type == sent.RNNGAction.Type.GEN: term_batch = batchers.mark_as_batch( [x.action_content for x in ref_action]) loss += self.term_scorer.calc_loss(state, term_batch) elif action_type == sent.RNNGAction.Type.REDUCE_LEFT or \ action_type == sent.RNNGAction.Type.REDUCE_RIGHT: edge_batch = batchers.mark_as_batch( [x.action_content for x in ref_action]) loss += self.edge_scorer.calc_loss(state, edge_batch) # Total Loss return loss
def generate_one_step(self, current_word: Any, current_state: AutoRegressiveDecoderState) -> TranslatorOutput: if current_word is not None: if type(current_word) == int: current_word = [current_word] if type(current_word) == list or type(current_word) == np.ndarray: current_word = batchers.mark_as_batch(current_word) current_word_embed = self.trg_embedder.embed(current_word) next_state = self.decoder.add_input(current_state, current_word_embed) else: next_state = current_state next_state.context = self.attender.calc_context(next_state.rnn_state.output()) next_logsoftmax = self.decoder.calc_log_probs(next_state) return TranslatorOutput(next_state, next_logsoftmax, self.attender.get_last_attention())