コード例 #1
0
ファイル: tasks.py プロジェクト: msperber/misc
  def eval(self) -> 'metrics.EvalScore':
    """
    Perform evaluation task.

    Returns:
      Evaluated score
    """
    event_trigger.set_train(False)
    if self.src_data is None:
      self.src_data, self.ref_data, self.src_batches, self.ref_batches = \
        input_readers.read_parallel_corpus(src_reader=self.model.src_reader,
                                           trg_reader=self.model.trg_reader,
                                           src_file=self.src_file,
                                           trg_file=self.ref_file,
                                           batcher=self.batcher,
                                           max_num_sents=self.max_num_sents,
                                           max_src_len=self.max_src_len,
                                           max_trg_len=self.max_trg_len)
    loss_val = losses.FactoredLossVal()
    ref_words_cnt = 0
    for src, trg in zip(self.src_batches, self.ref_batches):
      with utils.ReportOnException({"src": src, "trg": trg, "graph": utils.print_cg_conditional}):
        dy.renew_cg(immediate_compute=settings.IMMEDIATE_COMPUTE, check_validity=settings.CHECK_VALIDITY)

        loss = self.loss_calculator.calc_loss(self.model, src, trg)

        ref_words_cnt += sum([trg_i.len_unpadded() for trg_i in trg])
        loss_val += loss.get_factored_loss_val(comb_method=self.loss_comb_method)

    loss_stats = {k: v/ref_words_cnt for k, v in loss_val.items()}
#
    return metrics.LossScore(sum(loss_stats.values()),
                             loss_stats=loss_stats,
                             num_ref_words=ref_words_cnt,
                             desc=self.desc)
コード例 #2
0
ファイル: regimens.py プロジェクト: msperber/misc
 def run_training(self, save_fct: Callable) -> None:
     """
 Main training loop (overwrites TrainingRegimen.run_training())
 """
     if self.run_for_epochs is None or self.run_for_epochs > 0:
         for src, trg in self.next_minibatch():
             if self.dev_zero:
                 self.checkpoint_and_save(save_fct)
                 self.dev_zero = False
             with utils.ReportOnException({
                     "src":
                     src,
                     "trg":
                     trg,
                     "graph":
                     utils.print_cg_conditional
             }):
                 dy.renew_cg(immediate_compute=settings.IMMEDIATE_COMPUTE,
                             check_validity=settings.CHECK_VALIDITY)
                 with self.train_loss_tracker.time_tracker:
                     event_trigger.set_train(True)
                     loss_builder = self.training_step(src, trg)
                     loss = loss_builder.compute()
                     self.backward(loss, self.dynet_profiling)
                     self.update(self.trainer)
                 self.train_loss_tracker.report(
                     trg,
                     loss_builder.get_factored_loss_val(
                         comb_method=self.loss_comb_method))
             if self.checkpoint_needed():
                 self.checkpoint_and_save(save_fct)
             if self.should_stop_training(): break
コード例 #3
0
 def _generate_one_batch(self,
                         generator: 'models.GeneratorModel',
                         batcher: Optional[batchers.Batcher] = None,
                         src_batch: batchers.Batch = None,
                         max_src_len: Optional[int] = None,
                         fp: TextIO = None):
     """
 Generate outputs for a single batch and write them to the output file.
 """
     batch_size = len(src_batch)
     src_batches = batcher.pack(src_batch, None)
     src_batch = src_batches[0]
     src_len = src_batch.sent_len()
     if max_src_len is not None and src_len > max_src_len:
         output_txt = "\n".join([NO_DECODING_ATTEMPTED] * batch_size)
         fp.write(f"{output_txt}\n")
     else:
         with utils.ReportOnException({
                 "src": src_batch,
                 "graph": utils.print_cg_conditional
         }):
             dy.renew_cg(immediate_compute=settings.IMMEDIATE_COMPUTE,
                         check_validity=settings.CHECK_VALIDITY)
             outputs = self.generate_one(generator, src_batch)
             if self.reporter: self._create_sent_report()
             for i in range(len(outputs)):
                 output_txt = outputs[i].sent_str(
                     custom_output_procs=self.post_processor)
                 fp.write(f"{output_txt}\n")
コード例 #4
0
ファイル: inferences.py プロジェクト: yzhen-li/xnmt
    def _forced_decode_one_batch(self,
                                 generator: 'models.GeneratorModel',
                                 batcher: Optional[batchers.Batcher] = None,
                                 src_batch: batchers.Batch = None,
                                 ref_batch: batchers.Batch = None,
                                 assert_scores: batchers.Batch = None,
                                 max_src_len: Optional[int] = None):
        """
    Performs forced decoding for a single batch.
    """
        batch_size = len(src_batch)
        src_batches, ref_batches = batcher.pack(src_batch, ref_batch)
        src_batch = src_batches[0]
        src_len = src_batch.sent_len()

        if max_src_len is None or src_len <= max_src_len is not None and src_len > max_src_len:
            with utils.ReportOnException({
                    "src": src_batch,
                    "graph": utils.print_cg_conditional
            }):
                tt.reset_graph()
                outputs = self.generate_one(generator, src_batch)
                if self.reporter: self._create_sent_report()
                for i in range(len(outputs)):
                    if assert_scores is not None:
                        # If debugging forced decoding, make sure it matches
                        assert batch_size == len(
                            outputs
                        ), "debug forced decoding not supported with nbest inference"
                        if (abs(outputs[i].score - assert_scores[i]) /
                                abs(assert_scores[i])) > 1e-5:
                            raise ValueError(
                                f'Forced decoding score {outputs[i].score} and loss {assert_scores[i]} do not match at '
                                f'sentence {i}')
コード例 #5
0
ファイル: inferences.py プロジェクト: yzhen-li/xnmt
 def _generate_one_batch(self,
                         generator: 'models.GeneratorModel',
                         batcher: Optional[batchers.Batcher] = None,
                         src_batch: batchers.Batch = None,
                         max_src_len: Optional[int] = None,
                         fp: TextIO = None):
     """
 Generate outputs for a single batch and write them to the output file.
 """
     batch_size = len(src_batch)
     src_batches = batcher.pack(src_batch, None)
     src_batch = src_batches[0]
     src_len = src_batch.sent_len()
     if max_src_len is not None and src_len > max_src_len:
         output_txt = "\n".join([NO_DECODING_ATTEMPTED] * batch_size)
         fp.write(f"{output_txt}\n")
     else:
         with utils.ReportOnException({
                 "src": src_batch,
                 "graph": utils.print_cg_conditional
         }):
             tt.reset_graph()
             with torch.no_grad(
             ) if xnmt.backend_torch else utils.dummy_context_mgr():
                 outputs = self.generate_one(generator, src_batch)
             if self.reporter: self._create_sent_report()
             for i in range(len(outputs)):
                 output_txt = outputs[i].sent_str(
                     custom_output_procs=self.post_processor)
                 fp.write(f"{output_txt}\n")
コード例 #6
0
ファイル: regimens.py プロジェクト: msperber/misc
 def run_training(self, save_fct: Callable) -> None:
     """
 Main training loop (overwrites TrainingRegimen.run_training())
 """
     dy.renew_cg(immediate_compute=settings.IMMEDIATE_COMPUTE,
                 check_validity=settings.CHECK_VALIDITY)
     if self.run_for_epochs is None or self.run_for_epochs > 0:
         total_loss = losses.FactoredLossExpr()
         # Needed for report
         total_trg = []
         for src, trg in self.next_minibatch():
             if self.dev_zero:
                 self.checkpoint_and_save(save_fct)
                 self.dev_zero = False
             with utils.ReportOnException({
                     "src":
                     src,
                     "trg":
                     trg,
                     "graph":
                     utils.print_cg_conditional
             }):
                 with self.train_loss_tracker.time_tracker:
                     event_trigger.set_train(True)
                     total_trg.append(trg[0])
                     loss_builder = self.training_step(src, trg)
                     total_loss.add_factored_loss_expr(loss_builder)
                     # num_updates_skipped is incremented in update but
                     # we need to call backward before update
                     if self.num_updates_skipped == self.update_every - 1:
                         self.backward(total_loss.compute(),
                                       self.dynet_profiling)
                     self.update(self.trainer)
                 if self.num_updates_skipped == 0:
                     total_loss_val = total_loss.get_factored_loss_val(
                         comb_method=self.loss_comb_method)
                     reported_trg = batchers.ListBatch(total_trg)
                     self.train_loss_tracker.report(reported_trg,
                                                    total_loss_val)
                     total_loss = losses.FactoredLossExpr()
                     total_trg = []
                     dy.renew_cg(
                         immediate_compute=settings.IMMEDIATE_COMPUTE,
                         check_validity=settings.CHECK_VALIDITY)
             if self.checkpoint_needed():
                 # Do a last update before checkpoint
                 # Force forward-backward for the last batch even if it's smaller than update_every
                 self.num_updates_skipped = self.update_every - 1
                 self.backward(total_loss.compute(), self.dynet_profiling)
                 self.update(self.trainer)
                 total_loss_val = total_loss.get_factored_loss_val(
                     comb_method=self.loss_comb_method)
                 reported_trg = batchers.ListBatch(total_trg)
                 self.train_loss_tracker.report(reported_trg,
                                                total_loss_val)
                 total_loss = losses.FactoredLossExpr()
                 total_trg = []
                 self.checkpoint_and_save(save_fct)
             if self.should_stop_training(): break
コード例 #7
0
  def _generate_output(self, generator: 'models.GeneratorModel', src_corpus: Sequence[sent.Sentence],
                       trg_file: str, batcher: Optional[batchers.Batcher] = None, max_src_len: Optional[int] = None,
                       forced_ref_corpus: Optional[Sequence[sent.Sentence]] = None,
                       assert_scores: Optional[Sequence[float]] = None) -> None:
    """
    Generate outputs and write them to file.

    Args:
      generator: generator model to use
      src_corpus: src-side inputs to generate outputs for
      trg_file: file to write outputs to
      batcher: necessary with some cases of input pre-processing such as padding or truncation
      max_src_len: if given, skip inputs that are too long
      forced_ref_corpus: if given, perform forced decoding with the given trg-side inputs
      assert_scores: if given, raise exception if the scores for generated outputs don't match the given scores
    """
    with open(trg_file, 'wt', encoding='utf-8') as fp:  # Saving the translated output to a trg file
      if forced_ref_corpus:
        src_batches, ref_batches = batcher.pack(src_corpus, forced_ref_corpus)
      else:
        src_batches = batcher.pack(src_corpus, None)
      cur_sent_i = 0
      ref_batch = None
      for batch_i, src_batch in enumerate(src_batches):
        batch_size = src_batch.batch_size()
        src_len = src_batch.sent_len()
        if max_src_len is not None and src_len > max_src_len:
          output_txt = "\n".join([NO_DECODING_ATTEMPTED] * batch_size)
          fp.write(f"{output_txt}\n")
        else:
          if forced_ref_corpus: ref_batch = ref_batches[batch_i]
          with utils.ReportOnException({"batchno":batch_i, "src": src_batch, "graph": utils.print_cg_conditional}):
            dy.renew_cg(immediate_compute=settings.IMMEDIATE_COMPUTE, check_validity=settings.CHECK_VALIDITY)
            outputs = self.generate_one(generator, src_batch, ref_batch)
            if self.reporter: self._create_sent_report()
            for i in range(len(outputs)):
              if assert_scores is not None:
                # If debugging forced decoding, make sure it matches
                assert batch_size == len(outputs), "debug forced decoding not supported with nbest inference"
                if (abs(outputs[i].score - assert_scores[cur_sent_i + i]) / abs(assert_scores[cur_sent_i + i])) > 1e-5:
                  raise ValueError(
                    f'Forced decoding score {outputs[0].score} and loss {assert_scores[cur_sent_i + i]} do not match at '
                    f'sentence {cur_sent_i + i}')
              output_txt = outputs[i].sent_str(custom_output_procs=self.post_processor)
              fp.write(f"{output_txt}\n")
        cur_sent_i += batch_size
        if self.max_num_sents and cur_sent_i >= self.max_num_sents: break
      if self.reporter: self._conclude_report()
コード例 #8
0
    def eval(self) -> 'metrics.EvalScore':
        """
    Perform evaluation task.

    Returns:
      Evaluated score
    """
        event_trigger.set_train(False)
        if self.src_data is None:
            self.src_data, self.ref_data, self.src_batches, self.ref_batches = \
              input_readers.read_parallel_corpus(src_reader=self.model.src_reader,
                                                 trg_reader=self.model.trg_reader,
                                                 src_file=self.src_file,
                                                 trg_file=self.ref_file,
                                                 batcher=self.batcher,
                                                 max_num_sents=self.max_num_sents,
                                                 max_src_len=self.max_src_len,
                                                 max_trg_len=self.max_trg_len)
        ref_words_cnt = 0
        loss_maps = defaultdict(float)
        loss_wrds = defaultdict(float)
        for src, trg in zip(self.src_batches, self.ref_batches):
            with utils.ReportOnException({
                    "src": src,
                    "trg": trg,
                    "graph": utils.print_cg_conditional
            }):
                dy.renew_cg(immediate_compute=settings.IMMEDIATE_COMPUTE,
                            check_validity=settings.CHECK_VALIDITY)

                loss_expr = self.loss_calculator.calc_loss(
                    self.model, src, trg)
                loss, loss_value = loss_expr.compute(
                    comb_method=self.loss_comb_method)

                for k, (value, unit) in loss_value.items():
                    loss_maps[k] += value
                    loss_wrds[k] += unit

        loss_stats = {k: loss_maps[k] / loss_wrds[k] for k in loss_maps.keys()}

        return metrics.LossScore(sum(loss_stats.values()),
                                 loss_stats=loss_stats,
                                 num_ref_words=ref_words_cnt,
                                 desc=self.desc)
コード例 #9
0
  def eval(self) -> 'metrics.EvalScore':
    """
    Perform evaluation task.

    Returns:
      Evaluated score
    """
    event_trigger.set_train(False)
    if self.src_data is None:
      self.src_data, self.ref_data, self.src_batches, self.ref_batches = \
        input_readers.read_parallel_corpus(src_reader=self.model.src_reader,
                                           trg_reader=self.model.trg_reader,
                                           src_file=self.src_file,
                                           trg_file=self.ref_file,
                                           batcher=self.batcher,
                                           max_num_sents=self.max_num_sents,
                                           max_src_len=self.max_src_len,
                                           max_trg_len=self.max_trg_len)
    loss_val = losses.FactoredLossVal()
    ref_words_cnt = 0
    for src, trg in zip(self.src_batches, self.ref_batches):
      with utils.ReportOnException({"src": src, "trg": trg, "graph": utils.print_cg_conditional}):
        tt.reset_graph()
        with torch.no_grad() if xnmt.backend_torch else utils.dummy_context_mgr():
          loss = self.loss_calculator.calc_loss(self.model, src, trg)

          ref_words_cnt += sum([trg_i.len_unpadded() for trg_i in trg])
          loss_val += loss.get_factored_loss_val()
      if settings.PRETEND: break

    loss_stats = {k: v/ref_words_cnt for k, v in loss_val.items()}

    self.src_data, self.trg_data, self.src_batches, self.trg_batches = None, None, None, None

    return metrics.LossScore(sum(loss_stats.values()),
                             loss_stats=loss_stats,
                             num_ref_words=ref_words_cnt,
                             desc=self.desc)
コード例 #10
0
ファイル: inferences.py プロジェクト: msperber/misc
 def _generate_one_batch(self, generator: 'models.GeneratorModel',
                               batcher: Optional[batchers.Batcher] = None,
                               src_batch: batchers.Batch = None,
                               ref_batch: Optional[batchers.Batch] = None,
                               assert_scores: Optional[List[int]] = None,
                               max_src_len: Optional[int] = None,
                               fp: TextIO = None):
   """
   Generate outputs for a single batch and write them to the output file.
   """
   batch_size = len(src_batch)
   if ref_batch[0] is not None:
     src_batches, ref_batches = batcher.pack(src_batch, ref_batch)
     ref_batch = ref_batches[0]
   else:
     src_batches = batcher.pack(src_batch, None)
     ref_batch = None
   src_batch = src_batches[0]
   src_len = src_batch.sent_len()
   if max_src_len is not None and src_len > max_src_len:
     output_txt = "\n".join([NO_DECODING_ATTEMPTED] * batch_size)
     fp.write(f"{output_txt}\n")
   else:
     with utils.ReportOnException({"src": src_batch, "graph": utils.print_cg_conditional}):
       dy.renew_cg(immediate_compute=settings.IMMEDIATE_COMPUTE, check_validity=settings.CHECK_VALIDITY)
       outputs = self.generate_one(generator, src_batch, ref_batch)
       if self.reporter: self._create_sent_report()
       for i in range(len(outputs)):
         if assert_scores[0] is not None:
           # If debugging forced decoding, make sure it matches
           assert batch_size == len(outputs), "debug forced decoding not supported with nbest inference"
           if (abs(outputs[i].score - assert_scores[i]) / abs(assert_scores[i])) > 1e-5:
             raise ValueError(
               f'Forced decoding score {outputs[i].score} and loss {assert_scores[i]} do not match at '
               f'sentence {i}')
         output_txt = outputs[i].sent_str(custom_output_procs=self.post_processor)
         fp.write(f"{output_txt}\n")
コード例 #11
0
 def run_training(self, save_fct: Callable) -> None:
     """
 Main training loop (overwrites TrainingRegimen.run_training())
 """
     if self.run_for_epochs is None or self.run_for_epochs > 0:
         for src, trg in self.next_minibatch():
             if self.dev_zero:
                 self.checkpoint_and_save(save_fct)
                 self.dev_zero = False
             with utils.ReportOnException({"src": src, "trg": trg, "graph": utils.print_cg_conditional}), \
                  self.skip_out_of_memory:
                 tt.reset_graph()
                 with self.train_loss_tracker.time_tracker:
                     event_trigger.set_train(True)
                     loss_builder = self.training_step(src, trg)
                     loss = loss_builder.compute(
                         comb_method=self.loss_comb_method)
                     self.backward(loss)
                     self.update(self.trainer)
                 self.train_loss_tracker.report(
                     trg, loss_builder.get_factored_loss_val())
             if self.checkpoint_needed():
                 self.checkpoint_and_save(save_fct)
             if self.should_stop_training(): break