Пример #1
0
    def calc_nll(self, src, trg):
        if not batchers.is_batched(src):
            src = batchers.ListBatch([src])

        src_inputs = batchers.ListBatch(
            [s[:-1] for s in src],
            mask=batchers.Mask(src.mask.np_arr[:, :-1]) if src.mask else None)
        src_targets = batchers.ListBatch(
            [s[1:] for s in src],
            mask=batchers.Mask(src.mask.np_arr[:, 1:]) if src.mask else None)

        event_trigger.start_sent(src)
        embeddings = self.src_embedder.embed_sent(src_inputs)
        encodings = self.rnn.transduce(embeddings)
        encodings_tensor = encodings.as_tensor()
        ((hidden_dim, seq_len), batch_size) = encodings.dim()
        encoding_reshaped = dy.reshape(encodings_tensor, (hidden_dim, ),
                                       batch_size=batch_size * seq_len)
        outputs = self.transform.transform(encoding_reshaped)

        ref_action = np.asarray([sent.words for sent in src_targets]).reshape(
            (seq_len * batch_size, ))
        loss_expr_perstep = self.scorer.calc_loss(
            outputs, batchers.mark_as_batch(ref_action))
        loss_expr_perstep = dy.reshape(loss_expr_perstep, (seq_len, ),
                                       batch_size=batch_size)
        if src_targets.mask:
            loss_expr_perstep = dy.cmult(
                loss_expr_perstep,
                dy.inputTensor(1.0 - src_targets.mask.np_arr.T, batched=True))
        loss = dy.sum_elems(loss_expr_perstep)

        return loss
Пример #2
0
    def calc_nll(self, src: Union[batchers.Batch, sent.Sentence], trg: Union[batchers.Batch, sent.Sentence]) \
            -> tt.Tensor:
        if not batchers.is_batched(src):
            src = batchers.ListBatch([src])

        src_inputs = batchers.ListBatch(
            [s[:-1] for s in src],
            mask=batchers.Mask(src.mask.np_arr[:, :-1]) if src.mask else None)
        src_targets = batchers.ListBatch(
            [s[1:] for s in src],
            mask=batchers.Mask(src.mask.np_arr[:, 1:]) if src.mask else None)

        event_trigger.start_sent(src)
        embeddings = self.src_embedder.embed_sent(src_inputs)
        encodings = self.rnn.transduce(embeddings)
        encodings_tensor = encodings.as_tensor()

        encoding_reshaped = tt.merge_time_batch_dims(encodings_tensor)
        seq_len = tt.sent_len(encodings_tensor)
        batch_size = tt.batch_size(encodings_tensor)

        outputs = self.transform.transform(encoding_reshaped)

        ref_action = np.asarray([sent.words for sent in src_targets]).reshape(
            (seq_len * batch_size, ))
        loss_expr_perstep = self.scorer.calc_loss(
            outputs, batchers.mark_as_batch(ref_action))

        loss_expr_perstep = tt.unmerge_time_batch_dims(loss_expr_perstep,
                                                       batch_size)

        loss = tt.aggregate_masked_loss(loss_expr_perstep, src_targets.mask)

        return loss
Пример #3
0
 def run_training(self, save_fct: Callable) -> None:
     """
 Main training loop (overwrites TrainingRegimen.run_training())
 """
     dy.renew_cg(immediate_compute=settings.IMMEDIATE_COMPUTE,
                 check_validity=settings.CHECK_VALIDITY)
     if self.run_for_epochs is None or self.run_for_epochs > 0:
         total_loss = losses.FactoredLossExpr()
         # Needed for report
         total_trg = []
         for src, trg in self.next_minibatch():
             if self.dev_zero:
                 self.checkpoint_and_save(save_fct)
                 self.dev_zero = False
             with utils.ReportOnException({
                     "src":
                     src,
                     "trg":
                     trg,
                     "graph":
                     utils.print_cg_conditional
             }):
                 with self.train_loss_tracker.time_tracker:
                     event_trigger.set_train(True)
                     total_trg.append(trg[0])
                     loss_builder = self.training_step(src, trg)
                     total_loss.add_factored_loss_expr(loss_builder)
                     # num_updates_skipped is incremented in update but
                     # we need to call backward before update
                     if self.num_updates_skipped == self.update_every - 1:
                         self.backward(total_loss.compute(),
                                       self.dynet_profiling)
                     self.update(self.trainer)
                 if self.num_updates_skipped == 0:
                     total_loss_val = total_loss.get_factored_loss_val(
                         comb_method=self.loss_comb_method)
                     reported_trg = batchers.ListBatch(total_trg)
                     self.train_loss_tracker.report(reported_trg,
                                                    total_loss_val)
                     total_loss = losses.FactoredLossExpr()
                     total_trg = []
                     dy.renew_cg(
                         immediate_compute=settings.IMMEDIATE_COMPUTE,
                         check_validity=settings.CHECK_VALIDITY)
             if self.checkpoint_needed():
                 # Do a last update before checkpoint
                 # Force forward-backward for the last batch even if it's smaller than update_every
                 self.num_updates_skipped = self.update_every - 1
                 self.backward(total_loss.compute(), self.dynet_profiling)
                 self.update(self.trainer)
                 total_loss_val = total_loss.get_factored_loss_val(
                     comb_method=self.loss_comb_method)
                 reported_trg = batchers.ListBatch(total_trg)
                 self.train_loss_tracker.report(reported_trg,
                                                total_loss_val)
                 total_loss = losses.FactoredLossExpr()
                 total_trg = []
                 self.checkpoint_and_save(save_fct)
             if self.should_stop_training(): break
Пример #4
0
 def calc_nll(self, src: Union[batchers.Batch, sent.Sentence], trg: Union[batchers.Batch, sent.Sentence]) \
         -> dy.Expression:
     h = self._encode_src(src)
     ids = trg.value if not batchers.is_batched(
         trg) else batchers.ListBatch([trg_i.value for trg_i in trg])
     loss_expr = self.scorer.calc_loss(h, ids)
     return loss_expr
Пример #5
0
  def calc_nll(self, src: Union[batchers.Batch, sent.Sentence], trg: Union[batchers.Batch, sent.Sentence]) \
          -> dy.Expression:
    src_inputs = batchers.ListBatch([s[:-1] for s in src], mask=batchers.Mask(src.mask.np_arr[:, :-1]) if src.mask else None)
    src_targets = batchers.ListBatch([s[1:] for s in src], mask=batchers.Mask(src.mask.np_arr[:, 1:]) if src.mask else None)

    embeddings = self.src_embedder.embed_sent(src_inputs)
    encodings = self.rnn.transduce(embeddings)
    encodings_tensor = encodings.as_tensor()
    ((hidden_dim, seq_len), batch_size) = encodings.dim()
    encoding_reshaped = dy.reshape(encodings_tensor, (hidden_dim,), batch_size=batch_size * seq_len)
    outputs = self.transform.transform(encoding_reshaped)

    ref_action = np.asarray([sent.words for sent in src_targets]).reshape((seq_len * batch_size,))
    loss_expr_perstep = self.scorer.calc_loss(outputs, batchers.mark_as_batch(ref_action))
    loss_expr_perstep = dy.reshape(loss_expr_perstep, (seq_len,), batch_size=batch_size)
    if src_targets.mask:
      loss_expr_perstep = dy.cmult(loss_expr_perstep, dy.inputTensor(1.0-src_targets.mask.np_arr.T, batched=True))
    loss = dy.sum_elems(loss_expr_perstep)
    units = [s.len_unpadded() for s in src]
    return LossExpr(loss, units)
Пример #6
0
    def calc_nll(self, src: Union[batchers.Batch, sent.Sentence], trg: Union[batchers.Batch, sent.Sentence]) \
            -> LossExpr:
        if batchers.is_batched(trg):
            units = [t.len_unpadded() for t in trg]
            ids = batchers.ListBatch([t.value for t in trg])
        else:
            units = trg.len_unpadded()
            ids = trg.value

        h = self._encode_src(src)
        loss_expr = self.scorer.calc_loss(h, ids)
        return LossExpr(loss_expr, units)
Пример #7
0
    def generate(self, src, forced_trg_ids=None, **kwargs):
        event_trigger.start_sent(src)
        if isinstance(src, batchers.CompoundBatch):
            src = src.batches[0]

        outputs = []

        batch_size = src.batch_size()
        score = batchers.ListBatch([[] for _ in range(batch_size)])
        words = batchers.ListBatch([[] for _ in range(batch_size)])
        done = [False] * batch_size
        initial_state = self._encode_src(src)
        current_state = initial_state
        attentions = []
        for pos in range(self.max_dec_len):
            prev_ref_action = None
            if pos > 0 and self.mode_translate != "context":
                if forced_trg_ids is not None:
                    prev_ref_action = batchers.mark_as_batch([
                        forced_trg_ids[batch_i][pos - 1]
                        for batch_i in range(batch_size)
                    ])
                elif batch_size > 1:
                    prev_ref_action = batchers.mark_as_batch(
                        np.argmax(current_state.out_prob.npvalue(), axis=0))
                else:
                    prev_ref_action = batchers.mark_as_batch(
                        [np.argmax(current_state.out_prob.npvalue(), axis=0)])

            logsoft = self.generate_one_step(dec_state=current_state,
                                             batch_size=batch_size,
                                             mode=self.mode_translate,
                                             cur_step=pos,
                                             prev_ref_action=prev_ref_action)
            attentions.append(self.attender.get_last_attention().npvalue())
            logsoft = logsoft.npvalue()
            logsoft = logsoft.reshape(logsoft.shape[0], batch_size)
            if forced_trg_ids is None:
                batched_word_id = np.argmax(logsoft, axis=0)
                batched_word_id = batched_word_id.reshape(
                    (batched_word_id.size, ))
            else:
                batched_word_id = [
                    forced_trg_batch_elem[pos]
                    for forced_trg_batch_elem in forced_trg_ids
                ]
            for batch_i in range(batch_size):
                if done[batch_i]:
                    batch_word = vocabs.Vocab.ES
                    batch_score = 0.0
                else:
                    batch_word = batched_word_id[batch_i]
                    batch_score = logsoft[batch_word, batch_i]
                    if self._terminate_rnn(batch_i=batch_i,
                                           pos=pos,
                                           batched_word_id=batched_word_id):
                        done[batch_i] = True
                score[batch_i].append(batch_score)
                words[batch_i].append(batch_word)
            if all(done):
                break
        for batch_i in range(batch_size):
            batch_elem_score = sum(score[batch_i])
            outputs.append(
                sent.SimpleSentence(words=words[batch_i],
                                    idx=src[batch_i].idx,
                                    vocab=getattr(self.trg_reader, "vocab",
                                                  None),
                                    score=batch_elem_score,
                                    output_procs=self.trg_reader.output_procs))
            if self.compute_report:
                if batch_size > 1:
                    cur_attentions = [x[:, :, batch_i] for x in attentions]
                else:
                    cur_attentions = attentions
                attention = np.concatenate(cur_attentions, axis=1)
                self.report_sent_info({
                    "attentions": attention,
                    "src": src[batch_i],
                    "output": outputs[-1]
                })

        return outputs
Пример #8
0
 def calc_nll(self, src, trg):
     h = self._encode_src(src)
     ids = trg.value if not batchers.is_batched(
         trg) else batchers.ListBatch([trg_i.value for trg_i in trg])
     loss_expr = self.scorer.calc_loss(h, ids)
     return loss_expr