def generate(self, src, forced_trg_ids=None, search_strategy=None): event_trigger.start_sent(src) if not batchers.is_batched(src): src = batchers.mark_as_batch([src]) outputs = [] trg = sent.SimpleSentence([0]) if not batchers.is_batched(trg): trg = batchers.mark_as_batch([trg]) output_actions = [] score = 0. # TODO Fix this with generate_one_step and use the appropriate search_strategy self.max_len = 100 # This is a temporary hack for _ in range(self.max_len): dy.renew_cg(immediate_compute=settings.IMMEDIATE_COMPUTE, check_validity=settings.CHECK_VALIDITY) log_prob_tail = self.calc_loss(src, trg, loss_cal=None, infer_prediction=True) ys = np.argmax(log_prob_tail.npvalue(), axis=0).astype('i') if ys == Vocab.ES: output_actions.append(ys) break output_actions.append(ys) trg = sent.SimpleSentence(words=output_actions + [0]) if not batchers.is_batched(trg): trg = batchers.mark_as_batch([trg]) # Append output to the outputs if hasattr(self, "trg_vocab") and self.trg_vocab is not None: outputs.append(sent.SimpleSentence(words=output_actions, vocab=self.trg_vocab)) else: outputs.append((output_actions, score)) return outputs
def read_sent(self, line: str, idx: numbers.Integral) -> sent.SimpleSentence: words = line.strip().split() if not self.train: return sent.SimpleSentence( idx=idx, words=[self.vocab.convert(word) for word in words] + [vocabs.Vocab.ES], vocab=self.vocab, output_procs=self.output_procs) word_ids = np.array([self.vocab.convert(word) for word in words]) length = len(word_ids) logits = np.arange(length) * (-1) * self.tau logits = np.exp(logits - np.max(logits)) probs = logits / np.sum(logits) num_words = np.random.choice(length, p=probs) corrupt_pos = np.random.binomial(1, p=num_words / length, size=(length, )) num_words_to_sample = np.sum(corrupt_pos) sampled_words = np.random.choice(np.arange(2, len(self.vocab)), size=(num_words_to_sample, )) word_ids[np.where(corrupt_pos == 1)[0].tolist()] = sampled_words return sent.SimpleSentence(idx=idx, words=word_ids.tolist() + [vocabs.Vocab.ES], vocab=self.vocab, output_procs=self.output_procs)
def test_batch_src(self): src_sents = [sent.SimpleSentence([0] * i, pad_token=1) for i in range(1,7)] trg_sents = [sent.SimpleSentence([0] * ((i+3)%6 + 1), pad_token=2) for i in range(1,7)] my_batcher = batchers.SrcBatcher(batch_size=3) src, trg = my_batcher.pack(src_sents, trg_sents) self.assertEqual([[0, 0, 1], [0, 1, 1], [0, 0, 0]], [x.words for x in src[0]]) self.assertEqual([[0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 2], [0, 2, 2, 2, 2, 2]], [x.words for x in trg[0]]) self.assertEqual([[0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 1], [0, 0, 0, 0, 1, 1]], [x.words for x in src[1]]) self.assertEqual([[0, 0, 0, 0], [0, 0, 0, 2], [0, 0, 2, 2]], [x.words for x in trg[1]])
def generate(self, src, forced_trg_ids): assert not forced_trg_ids assert batchers.is_batched(src) and src.batch_size()==1, "batched generation not fully implemented" src = src[0] # Generating outputs outputs = [] event_trigger.start_sent(src) embeddings = self.src_embedder.embed_sent(src) encodings = self.encoder.transduce(embeddings) if self.mode in ["avg_mlp", "final_mlp"]: if self.generate_per_step: assert self.mode == "avg_mlp", "final_mlp not supported with generate_per_step=True" scores = [dy.logistic(self.output_layer.transform(enc_i)) for enc_i in encodings] else: if self.mode == "avg_mlp": encoding_fixed_size = dy.sum_dim(encodings.as_tensor(), [1]) * (1.0 / encodings.dim()[0][1]) elif self.mode == "final_mlp": encoding_fixed_size = self.encoder.get_final_states()[-1].main_expr() scores = dy.logistic(self.output_layer.transform(encoding_fixed_size)) elif self.mode == "lin_sum_sig": enc_lin = [] for step_i, enc_i in enumerate(encodings): step_linear = self.output_layer.transform(enc_i) if encodings.mask and np.sum(encodings.mask.np_arr[:, step_i]) > 0: step_linear = dy.cmult(step_linear, dy.inputTensor(1.0 - encodings.mask.np_arr[:, step_i], batched=True)) enc_lin.append(step_linear) if self.generate_per_step: scores = [dy.logistic(enc_i) for enc_i in enc_lin] else: if encodings.mask: encoding_fixed_size = dy.cdiv(dy.esum(enc_lin), dy.inputTensor(np.sum(1.0 - encodings.mask.np_arr, axis=1), batched=True)) else: encoding_fixed_size = dy.esum(enc_lin) / encodings.dim()[0][1] scores = dy.logistic(encoding_fixed_size) else: raise ValueError(f"unknown mode '{self.mode}'") if self.generate_per_step: output_actions = [np.argmax(score_i.npvalue()) for score_i in scores] score = np.sum([np.max(score_i.npvalue()) for score_i in scores]) outputs.append(sent.SimpleSentence(words=output_actions, idx=src.idx, vocab=getattr(self.trg_reader, "vocab", None), score=score, output_procs=self.trg_reader.output_procs)) else: scores_arr = scores.npvalue() output_actions = list(np.nonzero(scores_arr > 0.5)[0]) score = np.sum(scores_arr[scores_arr > 0.5]) outputs.append(sent.SimpleSentence(words=output_actions, idx=src.idx, vocab=getattr(self.trg_reader, "vocab", None), score=score, output_procs=self.trg_reader.output_procs)) return outputs
def test_batch_random_no_ties(self): src_sents = [sent.SimpleSentence([0] * i, pad_token=1) for i in range(1,7)] trg_sents = [sent.SimpleSentence([0] * ((i+3)%6 + 1), pad_token=2) for i in range(1,7)] my_batcher = batchers.SrcBatcher(batch_size=3) _, trg = my_batcher.pack(src_sents, trg_sents) l0 = trg[0].sent_len() for _ in range(10): _, trg = my_batcher.pack(src_sents, trg_sents) l = trg[0].sent_len() self.assertTrue(l==l0)
def assert_single_loss_equals_batch_loss(self, model, pad_src_to_multiple=1): """ Tests whether single loss equals batch loss. Here we don't truncate the target side and use masking. """ batch_size = 5 src_sents = self.src_data[:batch_size] src_min = min([x.sent_len() for x in src_sents]) src_sents_trunc = [s.words[:src_min] for s in src_sents] for single_sent in src_sents_trunc: single_sent[src_min - 1] = Vocab.ES while len(single_sent) % pad_src_to_multiple != 0: single_sent.append(Vocab.ES) trg_sents = sorted(self.trg_data[:batch_size], key=lambda x: x.sent_len(), reverse=True) trg_max = max([x.sent_len() for x in trg_sents]) np_arr = np.zeros([batch_size, trg_max]) for i in range(batch_size): for j in range(trg_sents[i].sent_len(), trg_max): np_arr[i, j] = 1.0 trg_masks = Mask(np_arr) trg_sents_padded = [[w for w in s] + [Vocab.ES] * (trg_max - s.sent_len()) for s in trg_sents] src_sents_trunc = [ sent.SimpleSentence(words=s) for s in src_sents_trunc ] trg_sents_padded = [ sent.SimpleSentence(words=s) for s in trg_sents_padded ] single_loss = 0.0 for sent_id in range(batch_size): dy.renew_cg() train_loss, _ = MLELoss().calc_loss( model=model, src=src_sents_trunc[sent_id], trg=trg_sents[sent_id]).compute() single_loss += train_loss.value() dy.renew_cg() batched_loss, _ = MLELoss().calc_loss( model=model, src=mark_as_batch(src_sents_trunc), trg=mark_as_batch(trg_sents_padded, trg_masks)).compute() self.assertAlmostEqual(single_loss, np.sum(batched_loss.value()), places=4)
def test_batch_word_src(self): src_sents = [sent.SimpleSentence([0] * i, pad_token=1) for i in range(1,7)] trg_sents = [sent.SimpleSentence([0] * ((i+3)%6 + 1), pad_token=2) for i in range(1,7)] my_batcher = batchers.WordSrcBatcher(words_per_batch=12) src, trg = my_batcher.pack(src_sents, trg_sents) self.assertEqual([[0]], [x.words for x in src[0]]) self.assertEqual([[0, 0, 0, 0, 0]], [x.words for x in trg[0]]) self.assertEqual([[0, 0]], [x.words for x in src[1]]) self.assertEqual([[0, 0, 0, 0, 0, 0]], [x.words for x in trg[1]]) self.assertEqual([[0, 0, 0, 0], [0, 0, 0, 1]], [x.words for x in src[2]]) self.assertEqual([[0, 0], [0, 2]], [x.words for x in trg[2]]) self.assertEqual([[0, 0, 0, 0, 0]], [x.words for x in src[3]]) self.assertEqual([[0, 0, 0]], [x.words for x in trg[3]]) self.assertEqual([[0, 0, 0, 0, 0, 0]], [x.words for x in src[4]]) self.assertEqual([[0, 0, 0, 0]], [x.words for x in trg[4]])
def _emit_translation(self, src, output_actions, score): return sent.SimpleSentence(idx=src[0].idx, words=output_actions, vocab=getattr(self.trg_reader, "vocab", None), output_procs=self.trg_reader.output_procs, score=score)
def generate( self, src: batchers.Batch, forced_trg_ids: Sequence[numbers.Integral] = None, normalize_scores: bool = False) -> Sequence[sent.ReadableSentence]: if not batchers.is_batched(src): src = batchers.mark_as_batch([src]) if forced_trg_ids: forced_trg_ids = batchers.mark_as_batch([forced_trg_ids]) assert src.batch_size() == 1, "batch size > 1 not properly tested" batch_size, encodings, outputs, seq_len = self._encode_src(src) score_expr = self.scorer.calc_log_softmax( outputs) if normalize_scores else self.scorer.calc_scores(outputs) scores = score_expr.npvalue() # vocab_size x seq_len if forced_trg_ids: output_actions = forced_trg_ids else: output_actions = [np.argmax(scores[:, j]) for j in range(seq_len)] score = np.sum([scores[output_actions[j], j] for j in range(seq_len)]) outputs = [ sent.SimpleSentence( words=output_actions, idx=src[0].idx, vocab=self.trg_vocab if hasattr(self, "trg_vocab") else None, output_procs=self.trg_reader.output_procs, score=score) ] return outputs
def generate( self, src: batchers.Batch, normalize_scores: bool = False) -> Sequence[sent.ReadableSentence]: if not batchers.is_batched(src): src = batchers.mark_as_batch([src]) assert src.batch_size() == 1, "batch size > 1 not properly tested" batch_size, encodings, outputs, seq_len = self._encode_src(src) best_words, best_scores = self.scorer.best_k( outputs, k=1, normalize_scores=normalize_scores) best_words = best_words[0, :] score = np.sum(best_scores, axis=1) outputs = [ sent.SimpleSentence( words=best_words, idx=src[0].idx, vocab=self.trg_vocab if hasattr(self, "trg_vocab") else None, output_procs=self.trg_reader.output_procs, score=score) ] return outputs
def read_sent(self, line: str, idx: numbers.Integral) -> sent.SimpleSentence: return sent.SimpleSentence( idx=idx, words=[self.vocab.convert(word) for word in line.strip().split()] + [vocabs.Vocab.ES], vocab=self.vocab, output_procs=self.output_procs)
def assert_single_loss_equals_batch_loss(self, model, pad_src_to_multiple=1): """ Tests whether single loss equals batch loss. Truncating src / trg sents to same length so no masking is necessary """ batch_size = 5 src_sents = self.src_data[:batch_size] src_min = min([x.sent_len() for x in src_sents]) src_sents_trunc = [s.words[:src_min] for s in src_sents] for single_sent in src_sents_trunc: single_sent[src_min - 1] = Vocab.ES while len(single_sent) % pad_src_to_multiple != 0: single_sent.append(Vocab.ES) trg_sents = self.trg_data[:batch_size] trg_min = min([x.sent_len() for x in trg_sents]) trg_sents_trunc = [s.words[:trg_min] for s in trg_sents] for single_sent in trg_sents_trunc: single_sent[trg_min - 1] = Vocab.ES src_sents_trunc = [ sent.SimpleSentence(words=s) for s in src_sents_trunc ] trg_sents_trunc = [ sent.SimpleSentence(words=s) for s in trg_sents_trunc ] single_loss = 0.0 for sent_id in range(batch_size): dy.renew_cg() train_loss, _ = MLELoss().calc_loss( model=model, src=src_sents_trunc[sent_id], trg=trg_sents_trunc[sent_id]).compute() single_loss += train_loss.value() dy.renew_cg() batched_loss, _ = MLELoss().calc_loss( model=model, src=mark_as_batch(src_sents_trunc), trg=mark_as_batch(trg_sents_trunc)).compute() self.assertAlmostEqual(single_loss, np.sum(batched_loss.value()), places=4)
def read_sent(self, line: str, idx: numbers.Integral) -> sent.SimpleSentence: if self.sample_train and self.train: words = self.subword_model.SampleEncodeAsPieces(line.strip(), self.l, self.alpha) else: words = self.subword_model.EncodeAsPieces(line.strip()) #words = [w.decode('utf-8') for w in words] return sent.SimpleSentence(idx=idx, words=[self.vocab.convert(word) for word in words] + [self.vocab.convert(vocabs.Vocab.ES_STR)], vocab=self.vocab, output_procs=self.output_procs)
def read_sent(self, line: str, idx: numbers.Integral) -> sent.Sentence: if self.vocab: convert_fct = self.vocab.convert else: convert_fct = convert_int if self.read_sent_len: return sent.ScalarSentence(idx=idx, value=len(line.strip().split())) else: return sent.SimpleSentence(idx=idx, words=[convert_fct(word) for word in line.strip().split()] + [vocabs.Vocab.ES], vocab=self.vocab, output_procs=self.output_procs)
def generate( self, src: batchers.Batch, search_strategy: search_strategies.SearchStrategy, forced_trg_ids: batchers.Batch = None) -> Sequence[sent.Sentence]: """ Takes in a batch of source sentences and outputs a list of search outputs. Args: src: The source sentences search_strategy: The strategy with which to perform the search forced_trg_ids: The target IDs to generate if performing forced decoding Returns: A list of search outputs including scores, etc. """ assert src.batch_size() == 1 search_outputs = self.generate_search_output(src, search_strategy, forced_trg_ids) if isinstance(src, batchers.CompoundBatch): src = src.batches[0] sorted_outputs = sorted(search_outputs, key=lambda x: x.score[0], reverse=True) assert len(sorted_outputs) >= 1 outputs = [] for curr_output in sorted_outputs: output_actions = [x for x in curr_output.word_ids[0]] attentions = [x for x in curr_output.attentions[0]] score = curr_output.score[0] out_sent = sent.SimpleSentence( idx=src[0].idx, words=output_actions, vocab=getattr(self.trg_reader, "vocab", None), output_procs=self.trg_reader.output_procs, score=score) if len(sorted_outputs) == 1: outputs.append(out_sent) else: outputs.append( sent.NbestSentence(base_sent=out_sent, nbest_id=src[0].idx)) if self.is_reporting(): attentions = np.concatenate([x.npvalue() for x in attentions], axis=1) self.report_sent_info({ "attentions": attentions, "src": src[0], "output": outputs[0] }) return outputs
def transduce(self, x): # some preparations output_states = [] current_state = self._encode_src(x, apply_emb=False) if self.mode_transduce == "split": first_state = SymmetricDecoderState( rnn_state=current_state.rnn_state, context=current_state.context) batch_size = x.dim()[1] done = [False] * batch_size out_mask = batchers.Mask(np_arr=np.zeros((batch_size, self.max_dec_len))) out_mask.np_arr.flags.writeable = True # teacher / split mode: unfold guided by reference targets # -> feed everything up unto (except) the last token back into the LSTM # other modes: unfold until EOS is output or max len is reached max_dec_len = self.cur_src.batches[1].sent_len( ) if self.mode_transduce in ["teacher", "split"] else self.max_dec_len atts_list = [] generated_word_ids = [] for pos in range(max_dec_len): if self.train and self.mode_transduce in ["teacher", "split"]: # unroll RNN guided by reference prev_ref_action, ref_action = None, None if pos > 0: prev_ref_action = self._batch_ref_action(pos - 1) if self.transducer_loss: ref_action = self._batch_ref_action(pos) step_loss = self.calc_loss_one_step( dec_state=current_state, batch_size=batch_size, mode=self.mode_transduce, ref_action=ref_action, prev_ref_action=prev_ref_action) self.transducer_losses.append(step_loss) else: # inference # unroll RNN guided by model predictions if self.mode_transduce in ["teacher", "split"]: prev_ref_action = self._batch_max_action( batch_size, current_state, pos) else: prev_ref_action = None out_scores = self.generate_one_step( dec_state=current_state, mask=out_mask, cur_step=pos, batch_size=batch_size, mode=self.mode_transduce, prev_ref_action=prev_ref_action) word_id = np.argmax(out_scores.npvalue(), axis=0) word_id = word_id.reshape((word_id.size, )) generated_word_ids.append(word_id[0]) for batch_i in range(batch_size): if self._terminate_rnn(batch_i=batch_i, pos=pos, batched_word_id=word_id): done[batch_i] = True out_mask.np_arr[batch_i, pos + 1:] = 1.0 if pos > 0 and all(done): atts_list.append(self.attender.get_last_attention()) output_states.append(current_state.rnn_state.h()[-1]) break output_states.append(current_state.rnn_state.h()[-1]) atts_list.append(self.attender.get_last_attention()) if self.mode_transduce == "split": # split mode: use attentions to compute context, then run RNNs over these context inputs if self.split_regularizer: assert len(atts_list) == len( self._chosen_rnn_inputs ), f"{len(atts_list)} != {len(self._chosen_rnn_inputs)}" split_output_states = [] split_rnn_state = first_state.rnn_state for pos, att in enumerate(atts_list): lstm_input_context = self.attender.curr_sent.as_tensor( ) * att # TODO: better reuse the already computed context vecs lstm_input_context = dy.reshape( lstm_input_context, (lstm_input_context.dim()[0][0], ), batch_size=batch_size) if self.split_dual: lstm_input_label = self._chosen_rnn_inputs[pos] if self.split_dual[0] > 0.0 and self.train: lstm_input_context = dy.dropout_batch( lstm_input_context, self.split_dual[0]) if self.split_dual[1] > 0.0 and self.train: lstm_input_label = dy.dropout_batch( lstm_input_label, self.split_dual[1]) if self.split_context_transform: lstm_input_context = self.split_context_transform.transform( lstm_input_context) lstm_input_context = self.split_dual_proj.transform( dy.concatenate([lstm_input_context, lstm_input_label])) if self.split_regularizer and pos < len( self._chosen_rnn_inputs): # _chosen_rnn_inputs does not contain first (empty) input, so this is in fact like comparing to pos-1: penalty = dy.squared_norm(lstm_input_context - self._chosen_rnn_inputs[pos]) if self.split_regularizer != 1: penalty = self.split_regularizer * penalty self.split_reg_penalty_expr = penalty split_rnn_state = split_rnn_state.add_input(lstm_input_context) split_output_states.append(split_rnn_state.h()[-1]) assert len(output_states) == len(split_output_states) output_states = split_output_states out_mask.np_arr = out_mask.np_arr[:, :len(output_states)] self._final_states = [] if self.compute_report: # for symmetric reporter (this can only be run at inference time) assert batch_size == 1 atts_matrix = np.asarray([att.npvalue() for att in atts_list ]).reshape(len(atts_list), atts_list[0].dim()[0][0]).T self.report_sent_info({ "symm_att": atts_matrix, "symm_out": sent.SimpleSentence( words=generated_word_ids, idx=self.cur_src.batches[0][0].idx, vocab=self.cur_src.batches[1][0].vocab, output_procs=self.cur_src.batches[1][0].output_procs), "symm_ref": self.cur_src.batches[1][0] if isinstance( self.cur_src, batchers.CompoundBatch) else None }) # prepare final outputs for layer_i in range(len(current_state.rnn_state.h())): self._final_states.append( transducers.FinalTransducerState( main_expr=current_state.rnn_state.h()[layer_i], cell_expr=current_state.rnn_state._c[layer_i])) out_mask.np_arr.flags.writeable = False return expression_seqs.ExpressionSequence(expr_list=output_states, mask=out_mask)
def generate(self, src, forced_trg_ids=None, **kwargs): event_trigger.start_sent(src) if isinstance(src, batchers.CompoundBatch): src = src.batches[0] outputs = [] batch_size = src.batch_size() score = batchers.ListBatch([[] for _ in range(batch_size)]) words = batchers.ListBatch([[] for _ in range(batch_size)]) done = [False] * batch_size initial_state = self._encode_src(src) current_state = initial_state attentions = [] for pos in range(self.max_dec_len): prev_ref_action = None if pos > 0 and self.mode_translate != "context": if forced_trg_ids is not None: prev_ref_action = batchers.mark_as_batch([ forced_trg_ids[batch_i][pos - 1] for batch_i in range(batch_size) ]) elif batch_size > 1: prev_ref_action = batchers.mark_as_batch( np.argmax(current_state.out_prob.npvalue(), axis=0)) else: prev_ref_action = batchers.mark_as_batch( [np.argmax(current_state.out_prob.npvalue(), axis=0)]) logsoft = self.generate_one_step(dec_state=current_state, batch_size=batch_size, mode=self.mode_translate, cur_step=pos, prev_ref_action=prev_ref_action) attentions.append(self.attender.get_last_attention().npvalue()) logsoft = logsoft.npvalue() logsoft = logsoft.reshape(logsoft.shape[0], batch_size) if forced_trg_ids is None: batched_word_id = np.argmax(logsoft, axis=0) batched_word_id = batched_word_id.reshape( (batched_word_id.size, )) else: batched_word_id = [ forced_trg_batch_elem[pos] for forced_trg_batch_elem in forced_trg_ids ] for batch_i in range(batch_size): if done[batch_i]: batch_word = vocabs.Vocab.ES batch_score = 0.0 else: batch_word = batched_word_id[batch_i] batch_score = logsoft[batch_word, batch_i] if self._terminate_rnn(batch_i=batch_i, pos=pos, batched_word_id=batched_word_id): done[batch_i] = True score[batch_i].append(batch_score) words[batch_i].append(batch_word) if all(done): break for batch_i in range(batch_size): batch_elem_score = sum(score[batch_i]) outputs.append( sent.SimpleSentence(words=words[batch_i], idx=src[batch_i].idx, vocab=getattr(self.trg_reader, "vocab", None), score=batch_elem_score, output_procs=self.trg_reader.output_procs)) if self.compute_report: if batch_size > 1: cur_attentions = [x[:, :, batch_i] for x in attentions] else: cur_attentions = attentions attention = np.concatenate(cur_attentions, axis=1) self.report_sent_info({ "attentions": attention, "src": src[batch_i], "output": outputs[-1] }) return outputs