Ejemplo n.º 1
0
  def generate(self, src, forced_trg_ids=None, search_strategy=None):
    event_trigger.start_sent(src)
    if not batchers.is_batched(src):
      src = batchers.mark_as_batch([src])
    outputs = []

    trg = sent.SimpleSentence([0])

    if not batchers.is_batched(trg):
      trg = batchers.mark_as_batch([trg])

    output_actions = []
    score = 0.

    # TODO Fix this with generate_one_step and use the appropriate search_strategy
    self.max_len = 100 # This is a temporary hack
    for _ in range(self.max_len):
      dy.renew_cg(immediate_compute=settings.IMMEDIATE_COMPUTE, check_validity=settings.CHECK_VALIDITY)
      log_prob_tail = self.calc_loss(src, trg, loss_cal=None, infer_prediction=True)
      ys = np.argmax(log_prob_tail.npvalue(), axis=0).astype('i')
      if ys == Vocab.ES:
        output_actions.append(ys)
        break
      output_actions.append(ys)
      trg = sent.SimpleSentence(words=output_actions + [0])
      if not batchers.is_batched(trg):
        trg = batchers.mark_as_batch([trg])

    # Append output to the outputs
    if hasattr(self, "trg_vocab") and self.trg_vocab is not None:
      outputs.append(sent.SimpleSentence(words=output_actions, vocab=self.trg_vocab))
    else:
      outputs.append((output_actions, score))

    return outputs
Ejemplo n.º 2
0
 def read_sent(self, line: str,
               idx: numbers.Integral) -> sent.SimpleSentence:
     words = line.strip().split()
     if not self.train:
         return sent.SimpleSentence(
             idx=idx,
             words=[self.vocab.convert(word)
                    for word in words] + [vocabs.Vocab.ES],
             vocab=self.vocab,
             output_procs=self.output_procs)
     word_ids = np.array([self.vocab.convert(word) for word in words])
     length = len(word_ids)
     logits = np.arange(length) * (-1) * self.tau
     logits = np.exp(logits - np.max(logits))
     probs = logits / np.sum(logits)
     num_words = np.random.choice(length, p=probs)
     corrupt_pos = np.random.binomial(1,
                                      p=num_words / length,
                                      size=(length, ))
     num_words_to_sample = np.sum(corrupt_pos)
     sampled_words = np.random.choice(np.arange(2, len(self.vocab)),
                                      size=(num_words_to_sample, ))
     word_ids[np.where(corrupt_pos == 1)[0].tolist()] = sampled_words
     return sent.SimpleSentence(idx=idx,
                                words=word_ids.tolist() + [vocabs.Vocab.ES],
                                vocab=self.vocab,
                                output_procs=self.output_procs)
Ejemplo n.º 3
0
 def test_batch_src(self):
   src_sents = [sent.SimpleSentence([0] * i, pad_token=1) for i in range(1,7)]
   trg_sents = [sent.SimpleSentence([0] * ((i+3)%6 + 1), pad_token=2) for i in range(1,7)]
   my_batcher = batchers.SrcBatcher(batch_size=3)
   src, trg = my_batcher.pack(src_sents, trg_sents)
   self.assertEqual([[0, 0, 1], [0, 1, 1], [0, 0, 0]], [x.words for x in src[0]])
   self.assertEqual([[0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 2], [0, 2, 2, 2, 2, 2]], [x.words for x in trg[0]])
   self.assertEqual([[0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 1], [0, 0, 0, 0, 1, 1]], [x.words for x in src[1]])
   self.assertEqual([[0, 0, 0, 0], [0, 0, 0, 2], [0, 0, 2, 2]], [x.words for x in trg[1]])
Ejemplo n.º 4
0
  def generate(self, src, forced_trg_ids):
    assert not forced_trg_ids
    assert batchers.is_batched(src) and src.batch_size()==1, "batched generation not fully implemented"
    src = src[0]
    # Generating outputs
    outputs = []
    event_trigger.start_sent(src)
    embeddings = self.src_embedder.embed_sent(src)
    encodings = self.encoder.transduce(embeddings)
    if self.mode in ["avg_mlp", "final_mlp"]:
      if self.generate_per_step:
        assert self.mode == "avg_mlp", "final_mlp not supported with generate_per_step=True"
        scores = [dy.logistic(self.output_layer.transform(enc_i)) for enc_i in encodings]
      else:
        if self.mode == "avg_mlp":
          encoding_fixed_size = dy.sum_dim(encodings.as_tensor(), [1]) * (1.0 / encodings.dim()[0][1])
        elif self.mode == "final_mlp":
          encoding_fixed_size = self.encoder.get_final_states()[-1].main_expr()
        scores = dy.logistic(self.output_layer.transform(encoding_fixed_size))
    elif self.mode == "lin_sum_sig":
      enc_lin = []
      for step_i, enc_i in enumerate(encodings):
        step_linear = self.output_layer.transform(enc_i)
        if encodings.mask and np.sum(encodings.mask.np_arr[:, step_i]) > 0:
          step_linear = dy.cmult(step_linear, dy.inputTensor(1.0 - encodings.mask.np_arr[:, step_i], batched=True))
        enc_lin.append(step_linear)
      if self.generate_per_step:
        scores = [dy.logistic(enc_i) for enc_i in enc_lin]
      else:
        if encodings.mask:
          encoding_fixed_size = dy.cdiv(dy.esum(enc_lin),
                                        dy.inputTensor(np.sum(1.0 - encodings.mask.np_arr, axis=1), batched=True))
        else:
          encoding_fixed_size = dy.esum(enc_lin) / encodings.dim()[0][1]
        scores = dy.logistic(encoding_fixed_size)
    else:
      raise ValueError(f"unknown mode '{self.mode}'")

    if self.generate_per_step:
      output_actions = [np.argmax(score_i.npvalue()) for score_i in scores]
      score = np.sum([np.max(score_i.npvalue()) for score_i in scores])
      outputs.append(sent.SimpleSentence(words=output_actions,
                                         idx=src.idx,
                                         vocab=getattr(self.trg_reader, "vocab", None),
                                         score=score,
                                         output_procs=self.trg_reader.output_procs))
    else:
      scores_arr = scores.npvalue()
      output_actions = list(np.nonzero(scores_arr > 0.5)[0])
      score = np.sum(scores_arr[scores_arr > 0.5])
      outputs.append(sent.SimpleSentence(words=output_actions,
                                         idx=src.idx,
                                         vocab=getattr(self.trg_reader, "vocab", None),
                                         score=score,
                                         output_procs=self.trg_reader.output_procs))
    return outputs
Ejemplo n.º 5
0
 def test_batch_random_no_ties(self):
   src_sents = [sent.SimpleSentence([0] * i, pad_token=1) for i in range(1,7)]
   trg_sents = [sent.SimpleSentence([0] * ((i+3)%6 + 1), pad_token=2) for i in range(1,7)]
   my_batcher = batchers.SrcBatcher(batch_size=3)
   _, trg = my_batcher.pack(src_sents, trg_sents)
   l0 = trg[0].sent_len()
   for _ in range(10):
     _, trg = my_batcher.pack(src_sents, trg_sents)
     l = trg[0].sent_len()
     self.assertTrue(l==l0)
Ejemplo n.º 6
0
    def assert_single_loss_equals_batch_loss(self,
                                             model,
                                             pad_src_to_multiple=1):
        """
    Tests whether single loss equals batch loss.
    Here we don't truncate the target side and use masking.
    """
        batch_size = 5
        src_sents = self.src_data[:batch_size]
        src_min = min([x.sent_len() for x in src_sents])
        src_sents_trunc = [s.words[:src_min] for s in src_sents]
        for single_sent in src_sents_trunc:
            single_sent[src_min - 1] = Vocab.ES
            while len(single_sent) % pad_src_to_multiple != 0:
                single_sent.append(Vocab.ES)
        trg_sents = sorted(self.trg_data[:batch_size],
                           key=lambda x: x.sent_len(),
                           reverse=True)
        trg_max = max([x.sent_len() for x in trg_sents])
        np_arr = np.zeros([batch_size, trg_max])
        for i in range(batch_size):
            for j in range(trg_sents[i].sent_len(), trg_max):
                np_arr[i, j] = 1.0
        trg_masks = Mask(np_arr)
        trg_sents_padded = [[w for w in s] + [Vocab.ES] *
                            (trg_max - s.sent_len()) for s in trg_sents]

        src_sents_trunc = [
            sent.SimpleSentence(words=s) for s in src_sents_trunc
        ]
        trg_sents_padded = [
            sent.SimpleSentence(words=s) for s in trg_sents_padded
        ]

        single_loss = 0.0
        for sent_id in range(batch_size):
            dy.renew_cg()
            train_loss, _ = MLELoss().calc_loss(
                model=model,
                src=src_sents_trunc[sent_id],
                trg=trg_sents[sent_id]).compute()
            single_loss += train_loss.value()

        dy.renew_cg()

        batched_loss, _ = MLELoss().calc_loss(
            model=model,
            src=mark_as_batch(src_sents_trunc),
            trg=mark_as_batch(trg_sents_padded, trg_masks)).compute()
        self.assertAlmostEqual(single_loss,
                               np.sum(batched_loss.value()),
                               places=4)
Ejemplo n.º 7
0
 def test_batch_word_src(self):
   src_sents = [sent.SimpleSentence([0] * i, pad_token=1) for i in range(1,7)]
   trg_sents = [sent.SimpleSentence([0] * ((i+3)%6 + 1), pad_token=2) for i in range(1,7)]
   my_batcher = batchers.WordSrcBatcher(words_per_batch=12)
   src, trg = my_batcher.pack(src_sents, trg_sents)
   self.assertEqual([[0]], [x.words for x in src[0]])
   self.assertEqual([[0, 0, 0, 0, 0]], [x.words for x in trg[0]])
   self.assertEqual([[0, 0]], [x.words for x in src[1]])
   self.assertEqual([[0, 0, 0, 0, 0, 0]], [x.words for x in trg[1]])
   self.assertEqual([[0, 0, 0, 0], [0, 0, 0, 1]], [x.words for x in src[2]])
   self.assertEqual([[0, 0], [0, 2]], [x.words for x in trg[2]])
   self.assertEqual([[0, 0, 0, 0, 0]], [x.words for x in src[3]])
   self.assertEqual([[0, 0, 0]], [x.words for x in trg[3]])
   self.assertEqual([[0, 0, 0, 0, 0, 0]], [x.words for x in src[4]])
   self.assertEqual([[0, 0, 0, 0]], [x.words for x in trg[4]])
Ejemplo n.º 8
0
 def _emit_translation(self, src, output_actions, score):
     return sent.SimpleSentence(idx=src[0].idx,
                                words=output_actions,
                                vocab=getattr(self.trg_reader, "vocab",
                                              None),
                                output_procs=self.trg_reader.output_procs,
                                score=score)
Ejemplo n.º 9
0
    def generate(
            self,
            src: batchers.Batch,
            forced_trg_ids: Sequence[numbers.Integral] = None,
            normalize_scores: bool = False) -> Sequence[sent.ReadableSentence]:
        if not batchers.is_batched(src):
            src = batchers.mark_as_batch([src])
            if forced_trg_ids:
                forced_trg_ids = batchers.mark_as_batch([forced_trg_ids])
        assert src.batch_size() == 1, "batch size > 1 not properly tested"

        batch_size, encodings, outputs, seq_len = self._encode_src(src)
        score_expr = self.scorer.calc_log_softmax(
            outputs) if normalize_scores else self.scorer.calc_scores(outputs)
        scores = score_expr.npvalue()  # vocab_size x seq_len

        if forced_trg_ids:
            output_actions = forced_trg_ids
        else:
            output_actions = [np.argmax(scores[:, j]) for j in range(seq_len)]
        score = np.sum([scores[output_actions[j], j] for j in range(seq_len)])

        outputs = [
            sent.SimpleSentence(
                words=output_actions,
                idx=src[0].idx,
                vocab=self.trg_vocab if hasattr(self, "trg_vocab") else None,
                output_procs=self.trg_reader.output_procs,
                score=score)
        ]

        return outputs
Ejemplo n.º 10
0
    def generate(
            self,
            src: batchers.Batch,
            normalize_scores: bool = False) -> Sequence[sent.ReadableSentence]:
        if not batchers.is_batched(src):
            src = batchers.mark_as_batch([src])
        assert src.batch_size() == 1, "batch size > 1 not properly tested"

        batch_size, encodings, outputs, seq_len = self._encode_src(src)

        best_words, best_scores = self.scorer.best_k(
            outputs, k=1, normalize_scores=normalize_scores)
        best_words = best_words[0, :]
        score = np.sum(best_scores, axis=1)

        outputs = [
            sent.SimpleSentence(
                words=best_words,
                idx=src[0].idx,
                vocab=self.trg_vocab if hasattr(self, "trg_vocab") else None,
                output_procs=self.trg_reader.output_procs,
                score=score)
        ]

        return outputs
Ejemplo n.º 11
0
 def read_sent(self, line: str,
               idx: numbers.Integral) -> sent.SimpleSentence:
     return sent.SimpleSentence(
         idx=idx,
         words=[self.vocab.convert(word)
                for word in line.strip().split()] + [vocabs.Vocab.ES],
         vocab=self.vocab,
         output_procs=self.output_procs)
Ejemplo n.º 12
0
    def assert_single_loss_equals_batch_loss(self,
                                             model,
                                             pad_src_to_multiple=1):
        """
    Tests whether single loss equals batch loss.
    Truncating src / trg sents to same length so no masking is necessary
    """
        batch_size = 5
        src_sents = self.src_data[:batch_size]
        src_min = min([x.sent_len() for x in src_sents])
        src_sents_trunc = [s.words[:src_min] for s in src_sents]
        for single_sent in src_sents_trunc:
            single_sent[src_min - 1] = Vocab.ES
            while len(single_sent) % pad_src_to_multiple != 0:
                single_sent.append(Vocab.ES)
        trg_sents = self.trg_data[:batch_size]
        trg_min = min([x.sent_len() for x in trg_sents])
        trg_sents_trunc = [s.words[:trg_min] for s in trg_sents]
        for single_sent in trg_sents_trunc:
            single_sent[trg_min - 1] = Vocab.ES

        src_sents_trunc = [
            sent.SimpleSentence(words=s) for s in src_sents_trunc
        ]
        trg_sents_trunc = [
            sent.SimpleSentence(words=s) for s in trg_sents_trunc
        ]

        single_loss = 0.0
        for sent_id in range(batch_size):
            dy.renew_cg()
            train_loss, _ = MLELoss().calc_loss(
                model=model,
                src=src_sents_trunc[sent_id],
                trg=trg_sents_trunc[sent_id]).compute()
            single_loss += train_loss.value()

        dy.renew_cg()

        batched_loss, _ = MLELoss().calc_loss(
            model=model,
            src=mark_as_batch(src_sents_trunc),
            trg=mark_as_batch(trg_sents_trunc)).compute()
        self.assertAlmostEqual(single_loss,
                               np.sum(batched_loss.value()),
                               places=4)
Ejemplo n.º 13
0
 def read_sent(self, line: str, idx: numbers.Integral) -> sent.SimpleSentence:
   if self.sample_train and self.train:
     words = self.subword_model.SampleEncodeAsPieces(line.strip(), self.l, self.alpha)
   else:
     words = self.subword_model.EncodeAsPieces(line.strip())
   #words = [w.decode('utf-8') for w in words]
   return sent.SimpleSentence(idx=idx,
                              words=[self.vocab.convert(word) for word in words] + [self.vocab.convert(vocabs.Vocab.ES_STR)],
                              vocab=self.vocab,
                              output_procs=self.output_procs)
Ejemplo n.º 14
0
 def read_sent(self, line: str, idx: numbers.Integral) -> sent.Sentence:
   if self.vocab:
     convert_fct = self.vocab.convert
   else:
     convert_fct = convert_int
   if self.read_sent_len:
     return sent.ScalarSentence(idx=idx, value=len(line.strip().split()))
   else:
     return sent.SimpleSentence(idx=idx,
                                words=[convert_fct(word) for word in line.strip().split()] + [vocabs.Vocab.ES],
                                vocab=self.vocab,
                                output_procs=self.output_procs)
Ejemplo n.º 15
0
    def generate(
            self,
            src: batchers.Batch,
            search_strategy: search_strategies.SearchStrategy,
            forced_trg_ids: batchers.Batch = None) -> Sequence[sent.Sentence]:
        """
    Takes in a batch of source sentences and outputs a list of search outputs.
    Args:
      src: The source sentences
      search_strategy: The strategy with which to perform the search
      forced_trg_ids: The target IDs to generate if performing forced decoding
    Returns:
      A list of search outputs including scores, etc.
    """
        assert src.batch_size() == 1
        search_outputs = self.generate_search_output(src, search_strategy,
                                                     forced_trg_ids)
        if isinstance(src, batchers.CompoundBatch): src = src.batches[0]
        sorted_outputs = sorted(search_outputs,
                                key=lambda x: x.score[0],
                                reverse=True)
        assert len(sorted_outputs) >= 1
        outputs = []
        for curr_output in sorted_outputs:
            output_actions = [x for x in curr_output.word_ids[0]]
            attentions = [x for x in curr_output.attentions[0]]
            score = curr_output.score[0]
            out_sent = sent.SimpleSentence(
                idx=src[0].idx,
                words=output_actions,
                vocab=getattr(self.trg_reader, "vocab", None),
                output_procs=self.trg_reader.output_procs,
                score=score)
            if len(sorted_outputs) == 1:
                outputs.append(out_sent)
            else:
                outputs.append(
                    sent.NbestSentence(base_sent=out_sent,
                                       nbest_id=src[0].idx))

        if self.is_reporting():
            attentions = np.concatenate([x.npvalue() for x in attentions],
                                        axis=1)
            self.report_sent_info({
                "attentions": attentions,
                "src": src[0],
                "output": outputs[0]
            })

        return outputs
Ejemplo n.º 16
0
 def transduce(self, x):
     # some preparations
     output_states = []
     current_state = self._encode_src(x, apply_emb=False)
     if self.mode_transduce == "split":
         first_state = SymmetricDecoderState(
             rnn_state=current_state.rnn_state,
             context=current_state.context)
     batch_size = x.dim()[1]
     done = [False] * batch_size
     out_mask = batchers.Mask(np_arr=np.zeros((batch_size,
                                               self.max_dec_len)))
     out_mask.np_arr.flags.writeable = True
     # teacher / split mode: unfold guided by reference targets
     #  -> feed everything up unto (except) the last token back into the LSTM
     # other modes: unfold until EOS is output or max len is reached
     max_dec_len = self.cur_src.batches[1].sent_len(
     ) if self.mode_transduce in ["teacher", "split"] else self.max_dec_len
     atts_list = []
     generated_word_ids = []
     for pos in range(max_dec_len):
         if self.train and self.mode_transduce in ["teacher", "split"]:
             # unroll RNN guided by reference
             prev_ref_action, ref_action = None, None
             if pos > 0:
                 prev_ref_action = self._batch_ref_action(pos - 1)
             if self.transducer_loss:
                 ref_action = self._batch_ref_action(pos)
             step_loss = self.calc_loss_one_step(
                 dec_state=current_state,
                 batch_size=batch_size,
                 mode=self.mode_transduce,
                 ref_action=ref_action,
                 prev_ref_action=prev_ref_action)
             self.transducer_losses.append(step_loss)
         else:  # inference
             # unroll RNN guided by model predictions
             if self.mode_transduce in ["teacher", "split"]:
                 prev_ref_action = self._batch_max_action(
                     batch_size, current_state, pos)
             else:
                 prev_ref_action = None
             out_scores = self.generate_one_step(
                 dec_state=current_state,
                 mask=out_mask,
                 cur_step=pos,
                 batch_size=batch_size,
                 mode=self.mode_transduce,
                 prev_ref_action=prev_ref_action)
             word_id = np.argmax(out_scores.npvalue(), axis=0)
             word_id = word_id.reshape((word_id.size, ))
             generated_word_ids.append(word_id[0])
             for batch_i in range(batch_size):
                 if self._terminate_rnn(batch_i=batch_i,
                                        pos=pos,
                                        batched_word_id=word_id):
                     done[batch_i] = True
                     out_mask.np_arr[batch_i, pos + 1:] = 1.0
             if pos > 0 and all(done):
                 atts_list.append(self.attender.get_last_attention())
                 output_states.append(current_state.rnn_state.h()[-1])
                 break
         output_states.append(current_state.rnn_state.h()[-1])
         atts_list.append(self.attender.get_last_attention())
     if self.mode_transduce == "split":
         # split mode: use attentions to compute context, then run RNNs over these context inputs
         if self.split_regularizer:
             assert len(atts_list) == len(
                 self._chosen_rnn_inputs
             ), f"{len(atts_list)} != {len(self._chosen_rnn_inputs)}"
         split_output_states = []
         split_rnn_state = first_state.rnn_state
         for pos, att in enumerate(atts_list):
             lstm_input_context = self.attender.curr_sent.as_tensor(
             ) * att  # TODO: better reuse the already computed context vecs
             lstm_input_context = dy.reshape(
                 lstm_input_context, (lstm_input_context.dim()[0][0], ),
                 batch_size=batch_size)
             if self.split_dual:
                 lstm_input_label = self._chosen_rnn_inputs[pos]
                 if self.split_dual[0] > 0.0 and self.train:
                     lstm_input_context = dy.dropout_batch(
                         lstm_input_context, self.split_dual[0])
                 if self.split_dual[1] > 0.0 and self.train:
                     lstm_input_label = dy.dropout_batch(
                         lstm_input_label, self.split_dual[1])
                 if self.split_context_transform:
                     lstm_input_context = self.split_context_transform.transform(
                         lstm_input_context)
                 lstm_input_context = self.split_dual_proj.transform(
                     dy.concatenate([lstm_input_context, lstm_input_label]))
             if self.split_regularizer and pos < len(
                     self._chosen_rnn_inputs):
                 # _chosen_rnn_inputs does not contain first (empty) input, so this is in fact like comparing to pos-1:
                 penalty = dy.squared_norm(lstm_input_context -
                                           self._chosen_rnn_inputs[pos])
                 if self.split_regularizer != 1:
                     penalty = self.split_regularizer * penalty
                 self.split_reg_penalty_expr = penalty
             split_rnn_state = split_rnn_state.add_input(lstm_input_context)
             split_output_states.append(split_rnn_state.h()[-1])
         assert len(output_states) == len(split_output_states)
         output_states = split_output_states
     out_mask.np_arr = out_mask.np_arr[:, :len(output_states)]
     self._final_states = []
     if self.compute_report:
         # for symmetric reporter (this can only be run at inference time)
         assert batch_size == 1
         atts_matrix = np.asarray([att.npvalue() for att in atts_list
                                   ]).reshape(len(atts_list),
                                              atts_list[0].dim()[0][0]).T
         self.report_sent_info({
             "symm_att":
             atts_matrix,
             "symm_out":
             sent.SimpleSentence(
                 words=generated_word_ids,
                 idx=self.cur_src.batches[0][0].idx,
                 vocab=self.cur_src.batches[1][0].vocab,
                 output_procs=self.cur_src.batches[1][0].output_procs),
             "symm_ref":
             self.cur_src.batches[1][0] if isinstance(
                 self.cur_src, batchers.CompoundBatch) else None
         })
     # prepare final outputs
     for layer_i in range(len(current_state.rnn_state.h())):
         self._final_states.append(
             transducers.FinalTransducerState(
                 main_expr=current_state.rnn_state.h()[layer_i],
                 cell_expr=current_state.rnn_state._c[layer_i]))
     out_mask.np_arr.flags.writeable = False
     return expression_seqs.ExpressionSequence(expr_list=output_states,
                                               mask=out_mask)
Ejemplo n.º 17
0
    def generate(self, src, forced_trg_ids=None, **kwargs):
        event_trigger.start_sent(src)
        if isinstance(src, batchers.CompoundBatch):
            src = src.batches[0]

        outputs = []

        batch_size = src.batch_size()
        score = batchers.ListBatch([[] for _ in range(batch_size)])
        words = batchers.ListBatch([[] for _ in range(batch_size)])
        done = [False] * batch_size
        initial_state = self._encode_src(src)
        current_state = initial_state
        attentions = []
        for pos in range(self.max_dec_len):
            prev_ref_action = None
            if pos > 0 and self.mode_translate != "context":
                if forced_trg_ids is not None:
                    prev_ref_action = batchers.mark_as_batch([
                        forced_trg_ids[batch_i][pos - 1]
                        for batch_i in range(batch_size)
                    ])
                elif batch_size > 1:
                    prev_ref_action = batchers.mark_as_batch(
                        np.argmax(current_state.out_prob.npvalue(), axis=0))
                else:
                    prev_ref_action = batchers.mark_as_batch(
                        [np.argmax(current_state.out_prob.npvalue(), axis=0)])

            logsoft = self.generate_one_step(dec_state=current_state,
                                             batch_size=batch_size,
                                             mode=self.mode_translate,
                                             cur_step=pos,
                                             prev_ref_action=prev_ref_action)
            attentions.append(self.attender.get_last_attention().npvalue())
            logsoft = logsoft.npvalue()
            logsoft = logsoft.reshape(logsoft.shape[0], batch_size)
            if forced_trg_ids is None:
                batched_word_id = np.argmax(logsoft, axis=0)
                batched_word_id = batched_word_id.reshape(
                    (batched_word_id.size, ))
            else:
                batched_word_id = [
                    forced_trg_batch_elem[pos]
                    for forced_trg_batch_elem in forced_trg_ids
                ]
            for batch_i in range(batch_size):
                if done[batch_i]:
                    batch_word = vocabs.Vocab.ES
                    batch_score = 0.0
                else:
                    batch_word = batched_word_id[batch_i]
                    batch_score = logsoft[batch_word, batch_i]
                    if self._terminate_rnn(batch_i=batch_i,
                                           pos=pos,
                                           batched_word_id=batched_word_id):
                        done[batch_i] = True
                score[batch_i].append(batch_score)
                words[batch_i].append(batch_word)
            if all(done):
                break
        for batch_i in range(batch_size):
            batch_elem_score = sum(score[batch_i])
            outputs.append(
                sent.SimpleSentence(words=words[batch_i],
                                    idx=src[batch_i].idx,
                                    vocab=getattr(self.trg_reader, "vocab",
                                                  None),
                                    score=batch_elem_score,
                                    output_procs=self.trg_reader.output_procs))
            if self.compute_report:
                if batch_size > 1:
                    cur_attentions = [x[:, :, batch_i] for x in attentions]
                else:
                    cur_attentions = attentions
                attention = np.concatenate(cur_attentions, axis=1)
                self.report_sent_info({
                    "attentions": attention,
                    "src": src[batch_i],
                    "output": outputs[-1]
                })

        return outputs