示例#1
0
    def calc_loss(self, src, trg, loss_calculator):
        assert batcher.is_batched(src) and batcher.is_batched(trg)
        batch_size, encodings, outputs, seq_len = self._encode_src(src)

        if trg.sent_len() != seq_len:
            if self.auto_cut_pad:
                trg = self._cut_or_pad_targets(seq_len, trg)
            else:
                raise ValueError(
                    f"src/trg length do not match: {seq_len} != {len(trg[0])}")

        ref_action = np.asarray([sent.words for sent in trg]).reshape(
            (seq_len * batch_size, ))
        loss_expr_perstep = self.scorer.calc_loss(
            outputs, batcher.mark_as_batch(ref_action))
        # loss_expr_perstep = dy.pickneglogsoftmax_batch(outputs, ref_action)
        loss_expr_perstep = dy.reshape(loss_expr_perstep, (seq_len, ),
                                       batch_size=batch_size)
        if trg.mask:
            loss_expr_perstep = dy.cmult(
                loss_expr_perstep,
                dy.inputTensor(1.0 - trg.mask.np_arr.T, batched=True))
        loss_expr = dy.sum_elems(loss_expr_perstep)

        model_loss = loss.FactoredLossExpr()
        model_loss.add_loss("mle", loss_expr)

        return model_loss
示例#2
0
    def calc_loss(self, src, trg, loss_calculator):
        self.start_sent(src)
        embeddings = self.src_embedder.embed_sent(src)
        encodings = self.encoder(embeddings)
        self.attender.init_sent(encodings)
        # Initialize the hidden state from the encoder
        ss = mark_as_batch([Vocab.SS] *
                           len(src)) if is_batched(src) else Vocab.SS
        dec_state = self.decoder.initial_state(self.encoder.get_final_states(),
                                               self.trg_embedder.embed(ss))
        # Compose losses
        model_loss = LossBuilder()
        model_loss.add_loss("mle", loss_calculator(self, dec_state, src, trg))

        if self.calc_global_fertility or self.calc_attention_entropy:
            # philip30: I assume that attention_vecs is already masked src wisely.
            # Now applying the mask to the target
            masked_attn = self.attender.attention_vecs
            if trg.mask is not None:
                trg_mask = trg.mask.get_active_one_mask().transpose()
                masked_attn = [
                    dy.cmult(attn, dy.inputTensor(mask, batched=True))
                    for attn, mask in zip(masked_attn, trg_mask)
                ]

        if self.calc_global_fertility:
            model_loss.add_loss("fertility",
                                self.global_fertility(masked_attn))
        if self.calc_attention_entropy:
            model_loss.add_loss("H(attn)", self.attention_entropy(masked_attn))

        return model_loss
示例#3
0
    def generate(self, src, idx, forced_trg_ids=None, normalize_scores=False):
        if not batcher.is_batched(src):
            src = batcher.mark_as_batch([src])
            if forced_trg_ids:
                forced_trg_ids = batcher.mark_as_batch([forced_trg_ids])
        assert src.batch_size() == 1, "batch size > 1 not properly tested"

        batch_size, encodings, outputs, seq_len = self._encode_src(src)
        score_expr = self.scorer.calc_log_softmax(
            outputs) if normalize_scores else self.scorer.calc_scores(outputs)
        scores = score_expr.npvalue()  # vocab_size x seq_len

        if forced_trg_ids:
            output_actions = forced_trg_ids
        else:
            output_actions = [np.argmax(scores[:, j]) for j in range(seq_len)]
        score = np.sum([scores[output_actions[j], j] for j in range(seq_len)])

        outputs = [
            output.TextOutput(
                actions=output_actions,
                vocab=self.trg_vocab if hasattr(self, "trg_vocab") else None,
                score=score)
        ]

        return outputs
示例#4
0
 def generate(self, src, idx, src_mask=None, forced_trg_ids=None):
   if not xnmt.batcher.is_batched(src):
     src = xnmt.batcher.mark_as_batch([src])
   else:
     assert src_mask is not None
   outputs = []
   for sents in src:
     self.start_sent(src)
     embeddings = self.src_embedder.embed_sent(src)
     encodings = self.encoder(embeddings)
     self.attender.init_sent(encodings)
     ss = mark_as_batch([Vocab.SS] * len(src)) if is_batched(src) else Vocab.SS
     dec_state = self.decoder.initial_state(self.encoder.get_final_states(), self.trg_embedder.embed(ss))
     output_actions, score = self.search_strategy.generate_output(self.decoder, self.attender, self.trg_embedder, dec_state, src_length=len(sents), forced_trg_ids=forced_trg_ids)
     # In case of reporting
     if self.report_path is not None:
       src_words = [self.reporting_src_vocab[w] for w in sents]
       trg_words = [self.trg_vocab[w] for w in output_actions[1:]]
       attentions = self.attender.attention_vecs
       self.set_report_input(idx, src_words, trg_words, attentions)
       self.set_report_resource("src_words", src_words)
       self.set_report_path('{}.{}'.format(self.report_path, str(idx)))
       self.generate_report(self.report_type)
     # Append output to the outputs
     outputs.append(TextOutput(actions=output_actions,
                               vocab=self.trg_vocab if hasattr(self, "trg_vocab") else None,
                               score=score))
   return outputs
示例#5
0
  def calc_loss(self, src, trg, loss_calculator):
    if not batcher.is_batched(src):
      src = batcher.ListBatch([src])

    src_inputs = batcher.ListBatch([s[:-1] for s in src], mask=batcher.Mask(src.mask.np_arr[:,:-1]) if src.mask else None)
    src_targets = batcher.ListBatch([s[1:] for s in src], mask=batcher.Mask(src.mask.np_arr[:,1:]) if src.mask else None)

    self.start_sent(src)
    embeddings = self.src_embedder.embed_sent(src_inputs)
    encodings = self.rnn.transduce(embeddings)
    encodings_tensor = encodings.as_tensor()
    ((hidden_dim, seq_len), batch_size) = encodings.dim()
    encoding_reshaped = dy.reshape(encodings_tensor, (hidden_dim,), batch_size=batch_size * seq_len)
    outputs = self.transform(encoding_reshaped)

    ref_action = np.asarray([sent.words for sent in src_targets]).reshape((seq_len * batch_size,))
    loss_expr_perstep = self.scorer.calc_loss(outputs, batcher.mark_as_batch(ref_action))
    loss_expr_perstep = dy.reshape(loss_expr_perstep, (seq_len,), batch_size=batch_size)
    if src_targets.mask:
      loss_expr_perstep = dy.cmult(loss_expr_perstep, dy.inputTensor(1.0-src_targets.mask.np_arr.T, batched=True))
    loss_expr = dy.sum_elems(loss_expr_perstep)

    model_loss = loss.FactoredLossExpr()
    model_loss.add_loss("mle", loss_expr)

    return model_loss
示例#6
0
    def calc_loss(self, x: dy.Expression,
                  y: Union[int, List[int]]) -> dy.Expression:

        scores = self.calc_scores(x)

        if self.label_smoothing == 0.0:
            # single mode
            if not batcher.is_batched(y):
                loss = dy.pickneglogsoftmax(scores, y)
            # minibatch mode
            else:
                loss = dy.pickneglogsoftmax_batch(scores, y)
        else:
            log_prob = dy.log_softmax(scores)
            if not batcher.is_batched(y):
                pre_loss = -dy.pick(log_prob, y)
            else:
                pre_loss = -dy.pick_batch(log_prob, y)

            ls_loss = -dy.mean_elems(log_prob)
            loss = ((1 - self.label_smoothing) *
                    pre_loss) + (self.label_smoothing * ls_loss)

        return loss
示例#7
0
 def calc_loss(self, src, trg, loss_calculator):
   """
   :param src: source sequence (unbatched, or batched + padded)
   :param trg: target sequence (unbatched, or batched + padded); losses will be accumulated only if trg_mask[batch,pos]==0, or no mask is set
   :param loss_calculator:
   :returns: (possibly batched) loss expression
   """
   self.start_sent(src)
   embeddings = self.src_embedder.embed_sent(src)
   encodings = self.encoder(embeddings)
   self.attender.init_sent(encodings)
   # Initialize the hidden state from the encoder
   ss = mark_as_batch([Vocab.SS] * len(src)) if is_batched(src) else Vocab.SS
   dec_state = self.decoder.initial_state(self.encoder.get_final_states(), self.trg_embedder.embed(ss))
   return loss_calculator(self, dec_state, src, trg)
示例#8
0
 def generate(self, src, idx, search_strategy, src_mask=None, forced_trg_ids=None):
   if not xnmt.batcher.is_batched(src):
     src = xnmt.batcher.mark_as_batch([src])
   else:
     assert src_mask is not None
   # Generating outputs
   outputs = []
   for sents in src:
     self.start_sent(src)
     embeddings = self.src_embedder.embed_sent(src)
     encodings = self.encoder(embeddings)
     self.attender.init_sent(encodings)
     ss = mark_as_batch([Vocab.SS] * len(src)) if is_batched(src) else Vocab.SS
     initial_state = self.decoder.initial_state(self.encoder.get_final_states(), self.trg_embedder.embed(ss))
     search_outputs = search_strategy.generate_output(self, initial_state,
                                                      src_length=[len(sents)],
                                                      forced_trg_ids=forced_trg_ids)
     best_output = sorted(search_outputs, key=lambda x: x.score[0], reverse=True)[0]
     output_actions = [x for x in best_output.word_ids[0]]
     attentions = [x for x in best_output.attentions[0]]
     score = best_output.score[0]
     # In case of reporting
     if self.report_path is not None:
       if self.reporting_src_vocab:
         src_words = [self.reporting_src_vocab[w] for w in sents]
       else:
         src_words = ['' for w in sents]
       trg_words = [self.trg_vocab[w] for w in output_actions]
       # Attentions
       attentions = np.concatenate([x.npvalue() for x in attentions], axis=1)
       # Segmentation
       segment = self.get_report_resource("segmentation")
       if segment is not None:
         segment = [int(x[0]) for x in segment]
         src_inp = [x[0] for x in self.encoder.apply_segmentation(src_words, segment)]
       else:
         src_inp = src_words
       # Other Resources
       self.set_report_input(idx, src_inp, trg_words, attentions)
       self.set_report_resource("src_words", src_words)
       self.set_report_path('{}.{}'.format(self.report_path, str(idx)))
       self.generate_report(self.report_type)
     # Append output to the outputs
     outputs.append(TextOutput(actions=output_actions,
                               vocab=self.trg_vocab if hasattr(self, "trg_vocab") else None,
                               score=score))
   self.outputs = outputs
   return outputs
示例#9
0
 def generate(self, src, idx, src_mask=None, forced_trg_ids=None):
   if not xnmt.batcher.is_batched(src):
     src = xnmt.batcher.mark_as_batch([src])
   else:
     assert src_mask is not None
   outputs = []
   for sents in src:
     self.start_sent(src)
     embeddings = self.src_embedder.embed_sent(src)
     encodings = self.encoder(embeddings)
     self.attender.init_sent(encodings)
     ss = mark_as_batch([Vocab.SS] * len(src)) if is_batched(src) else Vocab.SS
     dec_state = self.decoder.initial_state(self.encoder.get_final_states(), self.trg_embedder.embed(ss))
     output_actions, score = self.search_strategy.generate_output(self.decoder, self.attender, self.trg_embedder, dec_state, src_length=len(sents), forced_trg_ids=forced_trg_ids)
     # In case of reporting
     if self.report_path is not None:
       if self.reporting_src_vocab:
         src_words = [self.reporting_src_vocab[w] for w in sents]
       else:
         src_words = ['' for w in sents]
       trg_words = [self.trg_vocab[w] for w in output_actions.word_ids]
       # Attentions
       attentions = output_actions.attentions
       if type(attentions) == dy.Expression:
         attentions = attentions.npvalue()
       elif type(attentions) == list:
         attentions = np.concatenate([x.npvalue() for x in attentions], axis=1)
       elif type(attentions) != np.ndarray:
         raise RuntimeError("Illegal type for attentions in translator report: {}".format(type(attentions)))
       # Segmentation
       segment = self.get_report_resource("segmentation")
       if segment is not None:
         segment = [int(x[0]) for x in segment]
         src_inp = [x[0] for x in self.encoder.apply_segmentation(src_words, segment)]
       else:
         src_inp = src_words
       # Other Resources
       self.set_report_input(idx, src_inp, trg_words, attentions)
       self.set_report_resource("src_words", src_words)
       self.set_report_path('{}.{}'.format(self.report_path, str(idx)))
       self.generate_report(self.report_type)
     # Append output to the outputs
     outputs.append(TextOutput(actions=output_actions.word_ids,
                               vocab=self.trg_vocab if hasattr(self, "trg_vocab") else None,
                               score=score))
   self.outputs = outputs
   return outputs
示例#10
0
 def generate(self, src, idx, forced_trg_ids=None, normalize_scores=False):
   if not batcher.is_batched(src):
     src = batcher.mark_as_batch([src])
     if forced_trg_ids:
       forced_trg_ids = batcher.mark_as_batch([forced_trg_ids])
   h = self._encode_src(src)
   scores = self.scorer.calc_log_probs(h) if normalize_scores else self.scorer.calc_scores(h)
   np_scores = scores.npvalue()
   if forced_trg_ids:
     output_action = forced_trg_ids
   else:
     output_action = np.argmax(np_scores, axis=0)
   outputs = []
   for batch_i in range(src.batch_size()):
     score = np_scores[:, batch_i][output_action[batch_i]]
     outputs.append(output.ScalarOutput(actions=[output_action],
                                        vocab=None,
                                        score=score))
   return outputs
示例#11
0
 def _encode_src(self, src):
     embeddings = self.src_embedder.embed_sent(src)
     # We assume that the encoder can generate multiple possible encodings
     encodings = self.encoder.transduce(embeddings)
     # Most cases, it falls here where the encoder just generate 1 encodings
     if type(encodings) != CompoundSeqExpression:
         encodings = CompoundSeqExpression([encodings])
         final_states = [self.encoder.get_final_states()]
     else:
         final_states = self.encoder.get_final_states()
     initial_states = []
     for encoding, final_state in zip(encodings, final_states):
         self.attender.init_sent(encoding)
         ss = mark_as_batch(
             [Vocab.SS] * src.batch_size()) if is_batched(src) else Vocab.SS
         initial_states.append(
             self.decoder.initial_state(final_state,
                                        self.trg_embedder.embed(ss)))
     return CompoundSeqExpression(initial_states)
示例#12
0
 def calc_loss(self, src, trg, loss_calculator):
   h = self._encode_src(src)
   ids = trg.value if not batcher.is_batched(trg) else batcher.ListBatch([trg_i.value for trg_i in trg])
   loss_expr = self.scorer.calc_loss(h, ids)
   classifier_loss = loss.FactoredLossExpr({"mle" : loss_expr})
   return classifier_loss