Exemplo n.º 1
0
 def embed_sent(self,
                x: sent.Sentence) -> expression_seqs.ExpressionSequence:
     # TODO refactor: seems a bit too many special cases that need to be distinguished
     batched = batchers.is_batched(x)
     first_sent = x[0] if batched else x
     if hasattr(first_sent, "get_array"):
         if not batched:
             return expression_seqs.LazyNumpyExpressionSequence(
                 lazy_data=x.get_array())
         else:
             return expression_seqs.LazyNumpyExpressionSequence(
                 lazy_data=batchers.mark_as_batch([s for s in x]),
                 mask=x.mask)
     else:
         if not batched:
             embeddings = [self.embed(word) for word in x]
         else:
             embeddings = []
             for word_i in range(x.sent_len()):
                 embeddings.append(
                     self.embed(
                         batchers.mark_as_batch(
                             [single_sent[word_i] for single_sent in x])))
         return expression_seqs.ExpressionSequence(expr_list=embeddings,
                                                   mask=x.mask)
Exemplo n.º 2
0
    def generate(
            self,
            src: batchers.Batch,
            forced_trg_ids: Sequence[numbers.Integral] = None,
            normalize_scores: bool = False) -> Sequence[sent.ReadableSentence]:
        if not batchers.is_batched(src):
            src = batchers.mark_as_batch([src])
            if forced_trg_ids:
                forced_trg_ids = batchers.mark_as_batch([forced_trg_ids])
        assert src.batch_size() == 1, "batch size > 1 not properly tested"

        batch_size, encodings, outputs, seq_len = self._encode_src(src)
        score_expr = self.scorer.calc_log_softmax(
            outputs) if normalize_scores else self.scorer.calc_scores(outputs)
        scores = score_expr.npvalue()  # vocab_size x seq_len

        if forced_trg_ids:
            output_actions = forced_trg_ids
        else:
            output_actions = [np.argmax(scores[:, j]) for j in range(seq_len)]
        score = np.sum([scores[output_actions[j], j] for j in range(seq_len)])

        outputs = [
            sent.SimpleSentence(
                words=output_actions,
                idx=src[0].idx,
                vocab=self.trg_vocab if hasattr(self, "trg_vocab") else None,
                output_procs=self.trg_reader.output_procs,
                score=score)
        ]

        return outputs
Exemplo n.º 3
0
 def generate(self,
              src: Union[batchers.Batch, sent.Sentence],
              forced_trg_ids: Optional[Sequence[numbers.Integral]] = None,
              normalize_scores: bool = False):
     if not batchers.is_batched(src):
         src = batchers.mark_as_batch([src])
         if forced_trg_ids:
             forced_trg_ids = batchers.mark_as_batch([forced_trg_ids])
     h = self._encode_src(src)
     scores = self.scorer.calc_log_probs(
         h) if normalize_scores else self.scorer.calc_scores(h)
     np_scores = scores.npvalue()
     if forced_trg_ids:
         output_action = forced_trg_ids
     else:
         output_action = np.argmax(np_scores, axis=0)
     outputs = []
     for batch_i in range(src.batch_size()):
         if src.batch_size() > 1:
             my_action = output_action[batch_i]
             score = np_scores[:, batch_i][my_action]
         else:
             my_action = output_action
             score = np_scores[my_action]
         outputs.append(sent.ScalarSentence(value=my_action, score=score))
     return outputs
Exemplo n.º 4
0
 def assert_forced_decoding(self, sent_id):
     dy.renew_cg()
     outputs = self.model.generate(
         batchers.mark_as_batch([self.src_data[sent_id]]),
         BeamSearch(),
         forced_trg_ids=batchers.mark_as_batch([self.trg_data[sent_id]]))
     self.assertItemsEqual(self.trg_data[sent_id].words, outputs[0].words)
Exemplo n.º 5
0
  def generate(self, src, forced_trg_ids=None, search_strategy=None):
    event_trigger.start_sent(src)
    if not batchers.is_batched(src):
      src = batchers.mark_as_batch([src])
    outputs = []

    trg = sent.SimpleSentence([0])

    if not batchers.is_batched(trg):
      trg = batchers.mark_as_batch([trg])

    output_actions = []
    score = 0.

    # TODO Fix this with generate_one_step and use the appropriate search_strategy
    self.max_len = 100 # This is a temporary hack
    for _ in range(self.max_len):
      dy.renew_cg(immediate_compute=settings.IMMEDIATE_COMPUTE, check_validity=settings.CHECK_VALIDITY)
      log_prob_tail = self.calc_loss(src, trg, loss_cal=None, infer_prediction=True)
      ys = np.argmax(log_prob_tail.npvalue(), axis=0).astype('i')
      if ys == Vocab.ES:
        output_actions.append(ys)
        break
      output_actions.append(ys)
      trg = sent.SimpleSentence(words=output_actions + [0])
      if not batchers.is_batched(trg):
        trg = batchers.mark_as_batch([trg])

    # Append output to the outputs
    if hasattr(self, "trg_vocab") and self.trg_vocab is not None:
      outputs.append(sent.SimpleSentence(words=output_actions, vocab=self.trg_vocab))
    else:
      outputs.append((output_actions, score))

    return outputs
Exemplo n.º 6
0
 def _batch_max_action(self, batch_size, current_state, pos):
     if pos == 0:
         return None
     elif batch_size > 1:
         return batchers.mark_as_batch(
             np.argmax(current_state.out_prob.npvalue(), axis=0))
     else:
         return batchers.mark_as_batch(
             [np.argmax(current_state.out_prob.npvalue(), axis=0)])
Exemplo n.º 7
0
 def test_single(self):
     dy.renew_cg()
     train_loss = self.model.calc_nll(src=self.src_data[0],
                                      trg=self.trg_data[0]).value()
     dy.renew_cg()
     outputs = self.model.generate(
         batchers.mark_as_batch([self.src_data[0]]),
         BeamSearch(beam_size=1),
         forced_trg_ids=batchers.mark_as_batch([self.trg_data[0]]))
     self.assertAlmostEqual(-outputs[0].score, train_loss, places=4)
Exemplo n.º 8
0
    def calc_loss(self, src, trg, infer_prediction=False):
        event_trigger.start_sent(src)
        if not batchers.is_batched(src):
            src = batchers.mark_as_batch([src])
        if not batchers.is_batched(trg):
            trg = batchers.mark_as_batch([trg])
        src_words = np.array([[vocabs.Vocab.SS] + x.words for x in src])
        batch_size, src_len = src_words.shape

        if isinstance(src.mask, type(None)):
            src_mask = np.zeros((batch_size, src_len), dtype=np.int)
        else:
            src_mask = np.concatenate([
                np.zeros((batch_size, 1), dtype=np.int),
                src.mask.np_arr.astype(np.int)
            ],
                                      axis=1)

        src_embeddings = self.sentence_block_embed(
            self.src_embedder.embeddings, src_words, src_mask)
        src_embeddings = self.make_input_embedding(src_embeddings, src_len)

        trg_words = np.array(
            list(map(lambda x: [vocabs.Vocab.SS] + x.words[:-1], trg)))
        batch_size, trg_len = trg_words.shape

        if isinstance(trg.mask, type(None)):
            trg_mask = np.zeros((batch_size, trg_len), dtype=np.int)
        else:
            trg_mask = trg.mask.np_arr.astype(np.int)

        trg_embeddings = self.sentence_block_embed(
            self.trg_embedder.embeddings, trg_words, trg_mask)
        trg_embeddings = self.make_input_embedding(trg_embeddings, trg_len)

        xx_mask = self.make_attention_mask(src_mask, src_mask)
        xy_mask = self.make_attention_mask(trg_mask, src_mask)
        yy_mask = self.make_attention_mask(trg_mask, trg_mask)
        yy_mask *= self.make_history_mask(trg_mask)

        z_blocks = self.encoder.transduce(src_embeddings, xx_mask)
        h_block = self.decoder(trg_embeddings, z_blocks, xy_mask, yy_mask)

        if infer_prediction:
            y_len = h_block.dim()[0][1]
            last_col = dy.pick(h_block, dim=1, index=y_len - 1)
            logits = self.decoder.output(last_col)
            return logits

        ref_list = list(
            itertools.chain.from_iterable(map(lambda x: x.words, trg)))
        concat_t_block = (1 -
                          trg_mask.ravel()).reshape(-1) * np.array(ref_list)
        loss = self.decoder.output_and_loss(h_block, concat_t_block)
        return losses.FactoredLossExpr({"mle": loss})
Exemplo n.º 9
0
    def calc_loss(
        self, model: 'model_base.ConditionedModel',
        src: Union[sent.Sentence, 'batchers.Batch'],
        trg: Union[sent.Sentence,
                   'batchers.Batch']) -> losses.FactoredLossExpr:
        if not batchers.is_batched(src):
            src = batchers.mark_as_batch([src])
        if not batchers.is_batched(trg):
            trg = batchers.mark_as_batch([trg])

        event_trigger.start_sent(src)
        return self._perform_calc_loss(model, src, trg)
Exemplo n.º 10
0
    def assert_single_loss_equals_batch_loss(self,
                                             model,
                                             pad_src_to_multiple=1):
        """
    Tests whether single loss equals batch loss.
    Here we don't truncate the target side and use masking.
    """
        batch_size = 5
        src_sents = self.src_data[:batch_size]
        src_min = min([x.sent_len() for x in src_sents])
        src_sents_trunc = [s.words[:src_min] for s in src_sents]
        for single_sent in src_sents_trunc:
            single_sent[src_min - 1] = Vocab.ES
            while len(single_sent) % pad_src_to_multiple != 0:
                single_sent.append(Vocab.ES)
        trg_sents = sorted(self.trg_data[:batch_size],
                           key=lambda x: x.sent_len(),
                           reverse=True)
        trg_max = max([x.sent_len() for x in trg_sents])
        np_arr = np.zeros([batch_size, trg_max])
        for i in range(batch_size):
            for j in range(trg_sents[i].sent_len(), trg_max):
                np_arr[i, j] = 1.0
        trg_masks = Mask(np_arr)
        trg_sents_padded = [[w for w in s] + [Vocab.ES] *
                            (trg_max - s.sent_len()) for s in trg_sents]

        src_sents_trunc = [
            sent.SimpleSentence(words=s) for s in src_sents_trunc
        ]
        trg_sents_padded = [
            sent.SimpleSentence(words=s) for s in trg_sents_padded
        ]

        single_loss = 0.0
        for sent_id in range(batch_size):
            dy.renew_cg()
            train_loss, _ = MLELoss().calc_loss(
                model=model,
                src=src_sents_trunc[sent_id],
                trg=trg_sents[sent_id]).compute()
            single_loss += train_loss.value()

        dy.renew_cg()

        batched_loss, _ = MLELoss().calc_loss(
            model=model,
            src=mark_as_batch(src_sents_trunc),
            trg=mark_as_batch(trg_sents_padded, trg_masks)).compute()
        self.assertAlmostEqual(single_loss,
                               np.sum(batched_loss.value()),
                               places=4)
Exemplo n.º 11
0
 def _cut_or_pad_targets(self, seq_len, trg):
   old_mask = trg.mask
   if len(trg[0]) > seq_len:
     trunc_len = len(trg[0]) - seq_len
     trg = batchers.mark_as_batch([trg_sent.get_truncated_sent(trunc_len=trunc_len) for trg_sent in trg])
     if old_mask:
       trg.mask = batchers.Mask(np_arr=old_mask.np_arr[:, :-trunc_len])
   else:
     pad_len = seq_len - len(trg[0])
     trg = batchers.mark_as_batch([trg_sent.create_padded_sent(pad_len=pad_len) for trg_sent in trg])
     if old_mask:
       trg.mask = np.pad(old_mask.np_arr, pad_width=((0, 0), (0, pad_len)), mode="constant", constant_values=1)
   return trg
Exemplo n.º 12
0
    def test_greedy_vs_beam(self):
        dy.renew_cg()
        outputs = self.model.generate(
            batchers.mark_as_batch([self.src_data[0]]),
            BeamSearch(beam_size=1))
        output_score1 = outputs[0].score

        dy.renew_cg()
        outputs = self.model.generate(
            batchers.mark_as_batch([self.src_data[0]]), GreedySearch())
        output_score2 = outputs[0].score

        self.assertAlmostEqual(output_score1, output_score2)
Exemplo n.º 13
0
    def embed_sent(self, x) -> expression_seqs.ExpressionSequence:
        """Embed a full sentence worth of words. By default, just do a for loop.

    Args:
      x: This will generally be a list of word IDs, but could also be a list of strings or some other format.
         It could also be batched, in which case it will be a (possibly masked) :class:`xnmt.batcher.Batch` object

    Returns:
      An expression sequence representing vectors of each word in the input.
    """
        # single mode
        if not batchers.is_batched(x):
            embeddings = [self.embed(word) for word in x]
        # minibatch mode
        else:
            embeddings = []
            seq_len = x.sent_len()
            for single_sent in x:
                assert single_sent.sent_len() == seq_len
            for word_i in range(seq_len):
                batch = batchers.mark_as_batch(
                    [single_sent[word_i] for single_sent in x])
                embeddings.append(self.embed(batch))

        return expression_seqs.ExpressionSequence(
            expr_list=embeddings,
            mask=x.mask if batchers.is_batched(x) else None)
Exemplo n.º 14
0
    def calc_nll(self, src: Union[batchers.Batch, sent.Sentence], trg: Union[batchers.Batch, sent.Sentence]) \
            -> dy.Expression:
        assert batchers.is_batched(src) and batchers.is_batched(trg)
        batch_size, encodings, outputs, seq_len = self._encode_src(src)

        if trg.sent_len() != seq_len:
            if self.auto_cut_pad:
                trg = self._cut_or_pad_targets(seq_len, trg)
            else:
                raise ValueError(
                    f"src/trg length do not match: {seq_len} != {len(trg[0])}")

        ref_action = np.asarray([trg_sent.words for trg_sent in trg]).reshape(
            (seq_len * batch_size, ))
        loss_expr_perstep = self.scorer.calc_loss(
            outputs, batchers.mark_as_batch(ref_action))
        # loss_expr_perstep = dy.pickneglogsoftmax_batch(outputs, ref_action)
        loss_expr_perstep = dy.reshape(loss_expr_perstep, (seq_len, ),
                                       batch_size=batch_size)
        if trg.mask:
            loss_expr_perstep = dy.cmult(
                loss_expr_perstep,
                dy.inputTensor(1.0 - trg.mask.np_arr.T, batched=True))
        loss_expr = dy.sum_elems(loss_expr_perstep)

        return loss_expr
Exemplo n.º 15
0
    def calc_nll(self, src: Union[batchers.Batch, sent.Sentence], trg: Union[batchers.Batch, sent.Sentence]) \
            -> tt.Tensor:
        if not batchers.is_batched(src):
            src = batchers.ListBatch([src])

        src_inputs = batchers.ListBatch(
            [s[:-1] for s in src],
            mask=batchers.Mask(src.mask.np_arr[:, :-1]) if src.mask else None)
        src_targets = batchers.ListBatch(
            [s[1:] for s in src],
            mask=batchers.Mask(src.mask.np_arr[:, 1:]) if src.mask else None)

        event_trigger.start_sent(src)
        embeddings = self.src_embedder.embed_sent(src_inputs)
        encodings = self.rnn.transduce(embeddings)
        encodings_tensor = encodings.as_tensor()

        encoding_reshaped = tt.merge_time_batch_dims(encodings_tensor)
        seq_len = tt.sent_len(encodings_tensor)
        batch_size = tt.batch_size(encodings_tensor)

        outputs = self.transform.transform(encoding_reshaped)

        ref_action = np.asarray([sent.words for sent in src_targets]).reshape(
            (seq_len * batch_size, ))
        loss_expr_perstep = self.scorer.calc_loss(
            outputs, batchers.mark_as_batch(ref_action))

        loss_expr_perstep = tt.unmerge_time_batch_dims(loss_expr_perstep,
                                                       batch_size)

        loss = tt.aggregate_masked_loss(loss_expr_perstep, src_targets.mask)

        return loss
Exemplo n.º 16
0
    def generate(
            self,
            src: batchers.Batch,
            normalize_scores: bool = False) -> Sequence[sent.ReadableSentence]:
        if not batchers.is_batched(src):
            src = batchers.mark_as_batch([src])
        assert src.batch_size() == 1, "batch size > 1 not properly tested"

        batch_size, encodings, outputs, seq_len = self._encode_src(src)

        best_words, best_scores = self.scorer.best_k(
            outputs, k=1, normalize_scores=normalize_scores)
        best_words = best_words[0, :]
        score = np.sum(best_scores, axis=1)

        outputs = [
            sent.SimpleSentence(
                words=best_words,
                idx=src[0].idx,
                vocab=self.trg_vocab if hasattr(self, "trg_vocab") else None,
                output_procs=self.trg_reader.output_procs,
                score=score)
        ]

        return outputs
Exemplo n.º 17
0
  def generate_search_output(self,
                             src: batchers.Batch,
                             search_strategy: search_strategies.SearchStrategy,
                             forced_trg_ids: batchers.Batch=None) -> List[search_strategies.SearchOutput]:
    """
    Takes in a batch of source sentences and outputs a list of search outputs.
    Args:
      src: The source sentences
      search_strategy: The strategy with which to perform the search
      forced_trg_ids: The target IDs to generate if performing forced decoding
    Returns:
      A list of search outputs including scores, etc.
    """
    if src.batch_size()!=1:
      raise NotImplementedError("batched decoding not implemented for DefaultTranslator. "
                                "Specify inference batcher with batch size 1.")
    event_trigger.start_sent(src)
    all_src = src
    if isinstance(src, batchers.CompoundBatch): src = src.batches[0]
    # Generating outputs
    cur_forced_trg = None
    src_sent = src[0]#checkme
    sent_mask = None
    if src.mask: sent_mask = batchers.Mask(np_arr=src.mask.np_arr[0:1])
    sent_batch = batchers.mark_as_batch([sent], mask=sent_mask)

    # Encode the sentence
    initial_state = self._encode_src(all_src)

    if forced_trg_ids is  not None: cur_forced_trg = forced_trg_ids[0]
    search_outputs = search_strategy.generate_output(self, initial_state,
                                                     src_length=[src_sent.sent_len()],
                                                     forced_trg_ids=cur_forced_trg)
    return search_outputs
Exemplo n.º 18
0
    def calc_nll(self, src, trg):
        if not batchers.is_batched(src):
            src = batchers.ListBatch([src])

        src_inputs = batchers.ListBatch(
            [s[:-1] for s in src],
            mask=batchers.Mask(src.mask.np_arr[:, :-1]) if src.mask else None)
        src_targets = batchers.ListBatch(
            [s[1:] for s in src],
            mask=batchers.Mask(src.mask.np_arr[:, 1:]) if src.mask else None)

        event_trigger.start_sent(src)
        embeddings = self.src_embedder.embed_sent(src_inputs)
        encodings = self.rnn.transduce(embeddings)
        encodings_tensor = encodings.as_tensor()
        ((hidden_dim, seq_len), batch_size) = encodings.dim()
        encoding_reshaped = dy.reshape(encodings_tensor, (hidden_dim, ),
                                       batch_size=batch_size * seq_len)
        outputs = self.transform.transform(encoding_reshaped)

        ref_action = np.asarray([sent.words for sent in src_targets]).reshape(
            (seq_len * batch_size, ))
        loss_expr_perstep = self.scorer.calc_loss(
            outputs, batchers.mark_as_batch(ref_action))
        loss_expr_perstep = dy.reshape(loss_expr_perstep, (seq_len, ),
                                       batch_size=batch_size)
        if src_targets.mask:
            loss_expr_perstep = dy.cmult(
                loss_expr_perstep,
                dy.inputTensor(1.0 - src_targets.mask.np_arr.T, batched=True))
        loss = dy.sum_elems(loss_expr_perstep)

        return loss
Exemplo n.º 19
0
 def test_composite(self):
   event_trigger.set_train(True)
   composite_loss = loss_calculators.CompositeLoss([loss_calculators.MLELoss(), loss_calculators.PolicyMLELoss()])
   composite_loss.calc_loss(self.model, self.src[0], self.trg[0])
   
   event_trigger.set_train(False)
   self.model.generate(batchers.mark_as_batch([self.src_data[0]]), GreedySearch())
Exemplo n.º 20
0
    def _embed_word(self, word: sent.SegmentedWord, is_batched: bool = False):
        char_embeds = self.embeddings.batch(batchers.mark_as_batch(word.chars))

        char_embeds = [
            dy.pick_batch_elem(char_embeds, i) for i in range(len(word.chars))
        ]
        return self.composer.compose(char_embeds)
Exemplo n.º 21
0
 def test_train_mle_only(self):
   self.model.policy_network = None
   event_trigger.set_train(True)
   mle_loss = loss_calculators.MLELoss()
   mle_loss.calc_loss(self.model, self.src[0], self.trg[0])
   
   event_trigger.set_train(False)
   self.model.generate(batchers.mark_as_batch([self.src_data[0]]), GreedySearch())
Exemplo n.º 22
0
 def _encode_src(self, src: Union[batchers.Batch, sent.Sentence]):
   embeddings = self.src_embedder.embed_sent(src)
   encoding = self.encoder.transduce(embeddings)
   final_state = self.encoder.get_final_states()
   self.attender.init_sent(encoding)
   ss = batchers.mark_as_batch([Vocab.SS] * src.batch_size()) if batchers.is_batched(src) else Vocab.SS
   initial_state = self.decoder.initial_state(final_state, self.trg_embedder.embed(ss))
   return initial_state
Exemplo n.º 23
0
    def assert_single_loss_equals_batch_loss(self,
                                             model,
                                             pad_src_to_multiple=1):
        """
    Tests whether single loss equals batch loss.
    Truncating src / trg sents to same length so no masking is necessary
    """
        batch_size = 5
        src_sents = self.src_data[:batch_size]
        src_min = min([x.sent_len() for x in src_sents])
        src_sents_trunc = [s.words[:src_min] for s in src_sents]
        for single_sent in src_sents_trunc:
            single_sent[src_min - 1] = Vocab.ES
            while len(single_sent) % pad_src_to_multiple != 0:
                single_sent.append(Vocab.ES)
        trg_sents = self.trg_data[:batch_size]
        trg_min = min([x.sent_len() for x in trg_sents])
        trg_sents_trunc = [s.words[:trg_min] for s in trg_sents]
        for single_sent in trg_sents_trunc:
            single_sent[trg_min - 1] = Vocab.ES

        src_sents_trunc = [
            sent.SimpleSentence(words=s) for s in src_sents_trunc
        ]
        trg_sents_trunc = [
            sent.SimpleSentence(words=s) for s in trg_sents_trunc
        ]

        single_loss = 0.0
        for sent_id in range(batch_size):
            dy.renew_cg()
            train_loss, _ = MLELoss().calc_loss(
                model=model,
                src=src_sents_trunc[sent_id],
                trg=trg_sents_trunc[sent_id]).compute()
            single_loss += train_loss.value()

        dy.renew_cg()

        batched_loss, _ = MLELoss().calc_loss(
            model=model,
            src=mark_as_batch(src_sents_trunc),
            trg=mark_as_batch(trg_sents_trunc)).compute()
        self.assertAlmostEqual(single_loss,
                               np.sum(batched_loss.value()),
                               places=4)
Exemplo n.º 24
0
 def _select_ref_words(sent, index, truncate_masked = False):
   if truncate_masked:
     mask = sent.mask if batchers.is_batched(sent) else None
     if not batchers.is_batched(sent):
       return sent[index]
     else:
       ret = []
       found_masked = False
       for (j, single_trg) in enumerate(sent):
         if mask is None or mask.np_arr[j, index] == 0 or np.sum(mask.np_arr[:, index]) == mask.np_arr.shape[0]:
           assert not found_masked, "sentences must be sorted by decreasing target length"
           ret.append(single_trg[index])
         else:
           found_masked = True
       return batchers.mark_as_batch(ret)
   else:
     if not batchers.is_batched(sent): return sent[index]
     else: return batchers.mark_as_batch([single_trg[index] for single_trg in sent])
Exemplo n.º 25
0
 def _batch_ref_action(self, pos):
     ref_action = []
     for src_sent in self.cur_src.batches[1]:
         if src_sent[pos] is None:
             ref_action.append(vocabs.Vocab.ES)
         else:
             ref_action.append(src_sent[pos])
     ref_action = batchers.mark_as_batch(ref_action)
     return ref_action
Exemplo n.º 26
0
  def test_train_nll(self):
    event_trigger.set_train(True)
    mle_loss = loss_calculators.MLELoss()
    mle_loss.calc_loss(self.model, self.src[0], self.trg[0])
    
    pol_loss = loss_calculators.PolicyMLELoss()
    pol_loss._perform_calc_loss(self.model, self.src[0], self.trg[0])

    event_trigger.set_train(False)
    self.model.generate(batchers.mark_as_batch([self.src_data[0]]), GreedySearch())
Exemplo n.º 27
0
  def test_single(self):
    dy.renew_cg()
    outputs = self.model.generate(batchers.mark_as_batch([self.src_data[0]]), GreedySearch())
    output_score = outputs[0].score

    dy.renew_cg()
    train_loss = self.model.calc_nll(src=self.src_data[0],
                                     trg=outputs[0]).value()

    self.assertAlmostEqual(-output_score, train_loss, places=3)
Exemplo n.º 28
0
    def test_single(self):
        tt.reset_graph()
        outputs = self.model.generate(
            batchers.mark_as_batch([self.src_data[0]]), GreedySearch())
        output_score = outputs[0].score

        tt.reset_graph()
        train_loss = tt.npvalue(
            self.model.calc_nll(src=self.src_data[0], trg=outputs[0]))

        self.assertAlmostEqual(-output_score, train_loss[0], places=3)
Exemplo n.º 29
0
 def calc_loss(self, dec_state, ref_action):
     state = self._calc_transform(dec_state)
     action_batch = batchers.mark_as_batch(
         [x.action_type.value for x in ref_action])
     action_type = ref_action[0].action_type
     loss = self.action_scorer.calc_loss(state, action_batch)
     # Aux Losses based on action content
     if action_type == sent.RNNGAction.Type.NT:
         nt_batch = batchers.mark_as_batch(
             [x.action_content for x in ref_action])
         loss += self.nt_scorer.calc_loss(state, nt_batch)
     elif action_type == sent.RNNGAction.Type.GEN:
         term_batch = batchers.mark_as_batch(
             [x.action_content for x in ref_action])
         loss += self.term_scorer.calc_loss(state, term_batch)
     elif action_type == sent.RNNGAction.Type.REDUCE_LEFT or \
          action_type == sent.RNNGAction.Type.REDUCE_RIGHT:
         edge_batch = batchers.mark_as_batch(
             [x.action_content for x in ref_action])
         loss += self.edge_scorer.calc_loss(state, edge_batch)
     # Total Loss
     return loss
Exemplo n.º 30
0
 def generate_one_step(self, current_word: Any, current_state: AutoRegressiveDecoderState) -> TranslatorOutput:
   if current_word is not None:
     if type(current_word) == int:
       current_word = [current_word]
     if type(current_word) == list or type(current_word) == np.ndarray:
       current_word = batchers.mark_as_batch(current_word)
     current_word_embed = self.trg_embedder.embed(current_word)
     next_state = self.decoder.add_input(current_state, current_word_embed)
   else:
     next_state = current_state
   next_state.context = self.attender.calc_context(next_state.rnn_state.output())
   next_logsoftmax = self.decoder.calc_log_probs(next_state)
   return TranslatorOutput(next_state, next_logsoftmax, self.attender.get_last_attention())