示例#1
0
 def run_training(self, save_fct: Callable) -> None:
     task_generators = OrderedDict()
     for task in self.tasks:
         task_generators[task] = task.next_minibatch()
     dev_zero = {i: self.dev_zero for i in range(len(self.tasks))}
     if self.tasks[0].run_for_epochs > 0:
         while True:
             tt.reset_graph()
             cur_task_i = np.random.choice(range(len(self.tasks)),
                                           p=self.task_weights)
             cur_task = self.tasks[cur_task_i]
             task_gen = task_generators[cur_task]
             if dev_zero[cur_task_i]:
                 self.checkpoint_and_save(cur_task, cur_task_i, save_fct,
                                          dev_zero)
             with cur_task.train_loss_tracker.time_tracker:
                 for _ in range(self.update_every_within):
                     src, trg = next(task_gen)
                     event_trigger.set_train(True)
                     loss_builder = cur_task.training_step(src, trg)
                     self.backward(loss=loss_builder.compute(
                         comb_method=self.loss_comb_method))
                 self.update(trainer=self.trainer)
             cur_task.train_loss_tracker.report(
                 trg, loss_builder.get_factored_loss_val())
             self.checkpoint_and_save(cur_task, cur_task_i, save_fct,
                                      dev_zero)
             if self.tasks[0].should_stop_training(): break
示例#2
0
 def run_training(self, save_fct: Callable) -> None:
     dev_zero = {i: self.dev_zero for i in range(len(self.tasks))}
     for cur_task_id in range(len(self.tasks)):
         self.train = None
         cur_task = self.tasks[cur_task_id]
         task_gen = cur_task.next_minibatch()
         if cur_task.run_for_epochs > 0:
             while True:
                 tt.reset_graph()
                 src, trg = next(task_gen)
                 if dev_zero[cur_task_id]:
                     self.checkpoint_and_save(cur_task, cur_task_id,
                                              save_fct, dev_zero)
                 with cur_task.train_loss_tracker.time_tracker:
                     event_trigger.set_train(True)
                     loss_builder = cur_task.training_step(src, trg)
                     task_loss = loss_builder.compute(
                         comb_method=self.loss_comb_method)
                     self.backward(task_loss)
                     self.update(self.trainer)
                 cur_task.train_loss_tracker.report(
                     trg, loss_builder.get_factored_loss_val())
                 self.checkpoint_and_save(cur_task, cur_task_id, save_fct,
                                          dev_zero)
                 if cur_task.should_stop_training(): break
示例#3
0
 def test_py_lstm_encoder_len(self):
     layer_dim = 512
     model = DefaultTranslator(
         src_reader=self.src_reader,
         trg_reader=self.trg_reader,
         src_embedder=SimpleWordEmbedder(emb_dim=layer_dim, vocab_size=100),
         encoder=PyramidalLSTMSeqTransducer(input_dim=layer_dim,
                                            hidden_dim=layer_dim,
                                            layers=3),
         attender=MlpAttender(input_dim=layer_dim,
                              state_dim=layer_dim,
                              hidden_dim=layer_dim),
         decoder=AutoRegressiveDecoder(
             input_dim=layer_dim,
             embedder=SimpleWordEmbedder(emb_dim=layer_dim, vocab_size=100),
             rnn=UniLSTMSeqTransducer(input_dim=layer_dim,
                                      hidden_dim=layer_dim,
                                      decoder_input_dim=layer_dim,
                                      yaml_path="model.decoder.rnn"),
             transform=NonLinear(input_dim=layer_dim * 2,
                                 output_dim=layer_dim),
             scorer=Softmax(input_dim=layer_dim, vocab_size=100),
             bridge=CopyBridge(dec_dim=layer_dim, dec_layers=1)),
     )
     event_trigger.set_train(True)
     for sent_i in range(10):
         tt.reset_graph()
         src = self.src_data[sent_i].create_padded_sent(
             4 - (self.src_data[sent_i].sent_len() % 4))
         event_trigger.start_sent(src)
         embeddings = model.src_embedder.embed_sent(src)
         encodings = model.encoder.transduce(embeddings)
         self.assertEqual(int(math.ceil(len(embeddings) / float(4))),
                          len(encodings))
示例#4
0
    def _forced_decode_one_batch(self,
                                 generator: 'models.GeneratorModel',
                                 batcher: Optional[batchers.Batcher] = None,
                                 src_batch: batchers.Batch = None,
                                 ref_batch: batchers.Batch = None,
                                 assert_scores: batchers.Batch = None,
                                 max_src_len: Optional[int] = None):
        """
    Performs forced decoding for a single batch.
    """
        batch_size = len(src_batch)
        src_batches, ref_batches = batcher.pack(src_batch, ref_batch)
        src_batch = src_batches[0]
        src_len = src_batch.sent_len()

        if max_src_len is None or src_len <= max_src_len is not None and src_len > max_src_len:
            with utils.ReportOnException({
                    "src": src_batch,
                    "graph": utils.print_cg_conditional
            }):
                tt.reset_graph()
                outputs = self.generate_one(generator, src_batch)
                if self.reporter: self._create_sent_report()
                for i in range(len(outputs)):
                    if assert_scores is not None:
                        # If debugging forced decoding, make sure it matches
                        assert batch_size == len(
                            outputs
                        ), "debug forced decoding not supported with nbest inference"
                        if (abs(outputs[i].score - assert_scores[i]) /
                                abs(assert_scores[i])) > 1e-5:
                            raise ValueError(
                                f'Forced decoding score {outputs[i].score} and loss {assert_scores[i]} do not match at '
                                f'sentence {i}')
示例#5
0
 def _generate_one_batch(self,
                         generator: 'models.GeneratorModel',
                         batcher: Optional[batchers.Batcher] = None,
                         src_batch: batchers.Batch = None,
                         max_src_len: Optional[int] = None,
                         fp: TextIO = None):
     """
 Generate outputs for a single batch and write them to the output file.
 """
     batch_size = len(src_batch)
     src_batches = batcher.pack(src_batch, None)
     src_batch = src_batches[0]
     src_len = src_batch.sent_len()
     if max_src_len is not None and src_len > max_src_len:
         output_txt = "\n".join([NO_DECODING_ATTEMPTED] * batch_size)
         fp.write(f"{output_txt}\n")
     else:
         with utils.ReportOnException({
                 "src": src_batch,
                 "graph": utils.print_cg_conditional
         }):
             tt.reset_graph()
             with torch.no_grad(
             ) if xnmt.backend_torch else utils.dummy_context_mgr():
                 outputs = self.generate_one(generator, src_batch)
             if self.reporter: self._create_sent_report()
             for i in range(len(outputs)):
                 output_txt = outputs[i].sent_str(
                     custom_output_procs=self.post_processor)
                 fp.write(f"{output_txt}\n")
示例#6
0
 def assert_in_out_len_equal(self, model):
     tt.reset_graph()
     event_trigger.set_train(True)
     src = self.src_data[0]
     event_trigger.start_sent(src)
     embeddings = model.src_embedder.embed_sent(src)
     encodings = model.encoder.transduce(embeddings)
     self.assertEqual(len(embeddings), len(encodings))
示例#7
0
 def run_training(self, save_fct: Callable) -> None:
     """
 Main training loop (overwrites TrainingRegimen.run_training())
 """
     tt.reset_graph()
     if self.run_for_epochs is None or self.run_for_epochs > 0:
         total_loss = losses.FactoredLossExpr()
         # Needed for report
         total_trg = []
         for src, trg in self.next_minibatch():
             if self.dev_zero:
                 self.checkpoint_and_save(save_fct)
                 self.dev_zero = False
             with utils.ReportOnException({
                     "src":
                     src,
                     "trg":
                     trg,
                     "graph":
                     utils.print_cg_conditional
             }):
                 with self.train_loss_tracker.time_tracker:
                     event_trigger.set_train(True)
                     total_trg.append(trg[0])
                     loss_builder = self.training_step(src, trg)
                     total_loss.add_factored_loss_expr(loss_builder)
                     # num_updates_skipped is incremented in update but
                     # we need to call backward before update
                     if self.num_updates_skipped == self.update_every - 1:
                         self.backward(
                             total_loss.compute(
                                 comb_method=self.loss_comb_method))
                     self.update(self.trainer)
                 if self.num_updates_skipped == 0:
                     total_loss_val = total_loss.get_factored_loss_val()
                     reported_trg = batchers.ListBatch(total_trg)
                     self.train_loss_tracker.report(reported_trg,
                                                    total_loss_val)
                     total_loss = losses.FactoredLossExpr()
                     total_trg = []
                     tt.reset_graph()
             if self.checkpoint_needed():
                 # Do a last update before checkpoint
                 # Force forward-backward for the last batch even if it's smaller than update_every
                 self.num_updates_skipped = self.update_every - 1
                 self.backward(
                     total_loss.compute(comb_method=self.loss_comb_method))
                 self.update(self.trainer)
                 total_loss_val = total_loss.get_factored_loss_val()
                 reported_trg = batchers.ListBatch(total_trg)
                 self.train_loss_tracker.report(reported_trg,
                                                total_loss_val)
                 total_loss = losses.FactoredLossExpr()
                 total_trg = []
                 self.checkpoint_and_save(save_fct)
             if self.should_stop_training(): break
示例#8
0
    def test_single(self):
        tt.reset_graph()
        outputs = self.model.generate(
            batchers.mark_as_batch([self.src_data[0]]), GreedySearch())
        output_score = outputs[0].score

        tt.reset_graph()
        train_loss = tt.npvalue(
            self.model.calc_nll(src=self.src_data[0], trg=outputs[0]))

        self.assertAlmostEqual(-output_score, train_loss[0], places=3)
示例#9
0
    def test_greedy_vs_beam(self):
        tt.reset_graph()
        outputs = self.model.generate(
            batchers.mark_as_batch([self.src_data[0]]),
            BeamSearch(beam_size=1))
        output_score1 = outputs[0].score

        tt.reset_graph()
        outputs = self.model.generate(
            batchers.mark_as_batch([self.src_data[0]]), GreedySearch())
        output_score2 = outputs[0].score

        self.assertAlmostEqual(output_score1, output_score2)
示例#10
0
    def assert_single_loss_equals_batch_loss(self,
                                             model,
                                             pad_src_to_multiple=1):
        """
    Tests whether single loss equals batch loss.
    Here we don't truncate the target side and use masking.
    """
        batch_size = 5
        src_sents = self.src_data[:batch_size]
        src_min = min([x.sent_len() for x in src_sents])
        src_sents_trunc = [s.words[:src_min] for s in src_sents]
        for single_sent in src_sents_trunc:
            single_sent[src_min - 1] = Vocab.ES
            while len(single_sent) % pad_src_to_multiple != 0:
                single_sent.append(Vocab.ES)
        trg_sents = sorted(self.trg_data[:batch_size],
                           key=lambda x: x.sent_len(),
                           reverse=True)
        trg_max = max([x.sent_len() for x in trg_sents])
        np_arr = np.zeros([batch_size, trg_max])
        for i in range(batch_size):
            for j in range(trg_sents[i].sent_len(), trg_max):
                np_arr[i, j] = 1.0
        trg_masks = Mask(np_arr)
        trg_sents_padded = [[w for w in s] + [Vocab.ES] *
                            (trg_max - s.sent_len()) for s in trg_sents]

        src_sents_trunc = [
            sent.SimpleSentence(words=s) for s in src_sents_trunc
        ]
        trg_sents_padded = [
            sent.SimpleSentence(words=s) for s in trg_sents_padded
        ]

        single_loss = 0.0
        for sent_id in range(batch_size):
            tt.reset_graph()
            train_loss = MLELoss().calc_loss(model=model,
                                             src=src_sents_trunc[sent_id],
                                             trg=trg_sents[sent_id]).value()
            single_loss += train_loss[0]

        tt.reset_graph()

        batched_loss = MLELoss().calc_loss(model=model,
                                           src=mark_as_batch(src_sents_trunc),
                                           trg=mark_as_batch(
                                               trg_sents_padded,
                                               trg_masks)).value()
        self.assertAlmostEqual(single_loss, np.sum(batched_loss), places=4)
示例#11
0
 def _compute_losses(self, generator, ref_corpus, src_corpus,
                     max_num_sents) -> List[numbers.Real]:
     batched_src, batched_ref = self.batcher.pack(src_corpus, ref_corpus)
     ref_scores = []
     for sent_count, (src, ref) in enumerate(zip(batched_src, batched_ref)):
         if max_num_sents and sent_count >= max_num_sents: break
         tt.reset_graph()
         loss = self.compute_losses_one(generator, src, ref)
         if isinstance(loss.value(), collections.abc.Iterable):
             ref_scores.extend(loss.value())
         else:
             ref_scores.append(loss.value())
     ref_scores = [-x for x in ref_scores]
     return ref_scores
示例#12
0
 def run_training(self, save_fct: Callable) -> None:
     task_generators = OrderedDict()
     for task in self.tasks:
         task_generators[task] = task.next_minibatch()
     if self.tasks[0].run_for_epochs > 0:
         while True:
             task_losses = []
             task_src_trg = []
             for (task, task_gen), task_n in zip(task_generators.items(),
                                                 self.n_task_steps):
                 for _ in range(task_n):
                     src, trg = next(task_gen)
                     task_src_trg.append((task, src, trg))
             if self.dev_zero:  # True only in first iteration
                 self.checkpoint_and_save(save_fct)
             tt.reset_graph()
             task_trg_loss_stats = {}
             with contextlib.ExitStack(
             ) as stack:  #use exit stack to control whether to use global or per-task time tracking
                 if not self.per_task_backward:
                     stack.enter_context(
                         self.tasks[0].train_loss_tracker.time_tracker)
                 event_trigger.set_train(True)
                 for task, src, trg in task_src_trg:
                     with contextlib.ExitStack() as stack2:
                         if self.per_task_backward:
                             stack2.enter_context(
                                 task.train_loss_tracker.time_tracker)
                         loss_builder = task.training_step(src, trg)
                         task_trg_loss_stats[task] = (
                             trg, loss_builder.get_factored_loss_val())
                         if self.per_task_backward:
                             self.backward(
                                 loss_builder.compute(
                                     comb_method=self.loss_comb_method))
                             tt.reset_graph(zero_grad=False)
                         else:
                             task_losses.append(
                                 loss_builder.compute(
                                     comb_method=self.loss_comb_method))
                 if not self.per_task_backward:
                     self.backward(sum(task_losses))
                 self.update(self.trainer)
             for task, (trg, stats) in task_trg_loss_stats.items():
                 task.train_loss_tracker.report(trg, stats)
             self.checkpoint_and_save(save_fct)
             if self.tasks[0].should_stop_training(): break
示例#13
0
    def test_single(self):
        tt.reset_graph()
        outputs = self.model.generate(
            batchers.mark_as_batch([self.src_data[0]]), BeamSearch())

        # Make sure the output of beam search is the same as the target sentence
        # (this is a very overfit model on exactly this data)
        self.assertEqual(outputs[0].sent_len(), self.trg_data[0].sent_len())
        for i in range(outputs[0].sent_len()):
            self.assertEqual(outputs[0][i], self.trg_data[0][i])

        # Verify that the loss we get from beam search is the same as the loss
        # we get if we call model.calc_nll
        tt.reset_graph()
        train_loss = self.model.calc_nll(src=self.src_data[0],
                                         trg=outputs[0]).value()

        self.assertAlmostEqual(-outputs[0].score, train_loss, places=4)
示例#14
0
    def assert_single_loss_equals_batch_loss(self,
                                             model,
                                             pad_src_to_multiple=1):
        """
    Tests whether single loss equals batch loss.
    Truncating src / trg sents to same length so no masking is necessary
    """
        batch_size = 5
        src_sents = self.src_data[:batch_size]
        src_min = min([x.sent_len() for x in src_sents])
        src_sents_trunc = [s.words[:src_min] for s in src_sents]
        for single_sent in src_sents_trunc:
            single_sent[src_min - 1] = Vocab.ES
            while len(single_sent) % pad_src_to_multiple != 0:
                single_sent.append(Vocab.ES)
        trg_sents = self.trg_data[:batch_size]
        trg_min = min([x.sent_len() for x in trg_sents])
        trg_sents_trunc = [s.words[:trg_min] for s in trg_sents]
        for single_sent in trg_sents_trunc:
            single_sent[trg_min - 1] = Vocab.ES

        src_sents_trunc = [
            sent.SimpleSentence(words=s) for s in src_sents_trunc
        ]
        trg_sents_trunc = [
            sent.SimpleSentence(words=s) for s in trg_sents_trunc
        ]

        single_loss = 0.0
        for sent_id in range(batch_size):
            tt.reset_graph()
            train_loss = MLELoss().calc_loss(
                model=model,
                src=src_sents_trunc[sent_id],
                trg=trg_sents_trunc[sent_id]).value()
            single_loss += train_loss[0]

        tt.reset_graph()

        batched_loss = MLELoss().calc_loss(
            model=model,
            src=mark_as_batch(src_sents_trunc),
            trg=mark_as_batch(trg_sents_trunc)).value()
        self.assertAlmostEqual(single_loss, np.sum(batched_loss), places=4)
示例#15
0
    def test_py_lstm_mask(self):
        layer_dim = 512
        model = DefaultTranslator(
            src_reader=self.src_reader,
            trg_reader=self.trg_reader,
            src_embedder=SimpleWordEmbedder(emb_dim=layer_dim, vocab_size=100),
            encoder=PyramidalLSTMSeqTransducer(input_dim=layer_dim,
                                               hidden_dim=layer_dim,
                                               layers=1),
            attender=MlpAttender(input_dim=layer_dim,
                                 state_dim=layer_dim,
                                 hidden_dim=layer_dim),
            decoder=AutoRegressiveDecoder(
                input_dim=layer_dim,
                embedder=SimpleWordEmbedder(emb_dim=layer_dim, vocab_size=100),
                rnn=UniLSTMSeqTransducer(input_dim=layer_dim,
                                         hidden_dim=layer_dim,
                                         decoder_input_dim=layer_dim,
                                         yaml_path="model.decoder.rnn"),
                transform=NonLinear(input_dim=layer_dim * 2,
                                    output_dim=layer_dim),
                scorer=Softmax(input_dim=layer_dim, vocab_size=100),
                bridge=CopyBridge(dec_dim=layer_dim, dec_layers=1)),
        )

        batcher = batchers.TrgBatcher(batch_size=3)
        train_src, _ = \
          batcher.pack(self.src_data, self.trg_data)

        event_trigger.set_train(True)
        for sent_i in range(3):
            tt.reset_graph()
            src = train_src[sent_i]
            event_trigger.start_sent(src)
            embeddings = model.src_embedder.embed_sent(src)
            encodings = model.encoder.transduce(embeddings)
            if train_src[sent_i].mask is None:
                assert encodings.mask is None
            else:
                np.testing.assert_array_almost_equal(
                    train_src[sent_i].mask.np_arr, encodings.mask.np_arr)
示例#16
0
  def eval(self) -> 'metrics.EvalScore':
    """
    Perform evaluation task.

    Returns:
      Evaluated score
    """
    event_trigger.set_train(False)
    if self.src_data is None:
      self.src_data, self.ref_data, self.src_batches, self.ref_batches = \
        input_readers.read_parallel_corpus(src_reader=self.model.src_reader,
                                           trg_reader=self.model.trg_reader,
                                           src_file=self.src_file,
                                           trg_file=self.ref_file,
                                           batcher=self.batcher,
                                           max_num_sents=self.max_num_sents,
                                           max_src_len=self.max_src_len,
                                           max_trg_len=self.max_trg_len)
    loss_val = losses.FactoredLossVal()
    ref_words_cnt = 0
    for src, trg in zip(self.src_batches, self.ref_batches):
      with utils.ReportOnException({"src": src, "trg": trg, "graph": utils.print_cg_conditional}):
        tt.reset_graph()
        with torch.no_grad() if xnmt.backend_torch else utils.dummy_context_mgr():
          loss = self.loss_calculator.calc_loss(self.model, src, trg)

          ref_words_cnt += sum([trg_i.len_unpadded() for trg_i in trg])
          loss_val += loss.get_factored_loss_val()
      if settings.PRETEND: break

    loss_stats = {k: v/ref_words_cnt for k, v in loss_val.items()}

    self.src_data, self.trg_data, self.src_batches, self.trg_batches = None, None, None, None

    return metrics.LossScore(sum(loss_stats.values()),
                             loss_stats=loss_stats,
                             num_ref_words=ref_words_cnt,
                             desc=self.desc)
示例#17
0
 def run_training(self, save_fct: Callable) -> None:
     """
 Main training loop (overwrites TrainingRegimen.run_training())
 """
     if self.run_for_epochs is None or self.run_for_epochs > 0:
         for src, trg in self.next_minibatch():
             if self.dev_zero:
                 self.checkpoint_and_save(save_fct)
                 self.dev_zero = False
             with utils.ReportOnException({"src": src, "trg": trg, "graph": utils.print_cg_conditional}), \
                  self.skip_out_of_memory:
                 tt.reset_graph()
                 with self.train_loss_tracker.time_tracker:
                     event_trigger.set_train(True)
                     loss_builder = self.training_step(src, trg)
                     loss = loss_builder.compute(
                         comb_method=self.loss_comb_method)
                     self.backward(loss)
                     self.update(self.trainer)
                 self.train_loss_tracker.report(
                     trg, loss_builder.get_factored_loss_val())
             if self.checkpoint_needed():
                 self.checkpoint_and_save(save_fct)
             if self.should_stop_training(): break