예제 #1
0
파일: eval_task.py 프로젝트: bastings/xnmt
  def eval(self):
    if self.src_data == None:
      self.src_data, self.ref_data, self.src_batches, self.ref_batches = \
        xnmt.input_reader.read_parallel_corpus(self.model.src_reader, self.model.trg_reader,
                                        self.src_file, self.ref_file, batcher=self.batcher,
                                        max_src_len=self.max_src_len, max_trg_len=self.max_trg_len)
    loss_val = LossScalarBuilder()
    ref_words_cnt = 0
    for src, trg in zip(self.src_batches, self.ref_batches):
      dy.renew_cg(immediate_compute=settings.IMMEDIATE_COMPUTE, check_validity=settings.CHECK_VALIDITY)

      loss_builder = LossBuilder()
      standard_loss = self.model.calc_loss(src, trg, self.loss_calculator)
      additional_loss = self.model.calc_additional_loss(standard_loss)
      loss_builder.add_loss("standard_loss", standard_loss)
      loss_builder.add_loss("additional_loss", additional_loss)

      ref_words_cnt += self.model.trg_reader.count_words(trg)
      loss_val += loss_builder.get_loss_stats()

    loss_stats = {k: v/ref_words_cnt for k, v in loss_val.items()}

    try:
      return LossScore(loss_stats[self.model.get_primary_loss()], loss_stats=loss_stats, desc=self.desc), ref_words_cnt
    except KeyError:
      raise RuntimeError("Did you wrap your loss calculation with LossBuilder({'primary_loss': loss_value}) ?")
예제 #2
0
 def compute_dev_loss(self):
     loss_builder = LossBuilder()
     trg_words_cnt = 0
     for src, trg in zip(self.dev_src, self.dev_trg):
         dy.renew_cg()
         standard_loss = self.model.calc_loss(src, trg)
         loss_builder.add_loss("loss", standard_loss)
         trg_words_cnt += self.logger.count_trg_words(trg)
         loss_builder.compute()
     return trg_words_cnt, LossScore(loss_builder.sum() / trg_words_cnt)
예제 #3
0
    def eval(self) -> 'EvalScore':
        """
    Perform evaluation task.

    Returns:
      Evaluated score
    """
        self.model.set_train(False)
        if self.src_data is None:
            self.src_data, self.ref_data, self.src_batches, self.ref_batches = \
              xnmt.input_reader.read_parallel_corpus(src_reader=self.model.src_reader,
                                                     trg_reader=self.model.trg_reader,
                                                     src_file=self.src_file,
                                                     trg_file=self.ref_file,
                                                     batcher=self.batcher,
                                                     max_src_len=self.max_src_len,
                                                     max_trg_len=self.max_trg_len)
        loss_val = FactoredLossVal()
        ref_words_cnt = 0
        for src, trg in zip(self.src_batches, self.ref_batches):
            with util.ReportOnException({
                    "src": src,
                    "trg": trg,
                    "graph": dy.print_text_graphviz
            }):
                dy.renew_cg(immediate_compute=settings.IMMEDIATE_COMPUTE,
                            check_validity=settings.CHECK_VALIDITY)

                loss_builder = FactoredLossExpr()
                standard_loss = self.model.calc_loss(src, trg,
                                                     self.loss_calculator)
                additional_loss = self.model.calc_additional_loss(
                    trg, self.model, standard_loss)
                loss_builder.add_factored_loss_expr(standard_loss)
                loss_builder.add_factored_loss_expr(additional_loss)

                ref_words_cnt += sum([trg_i.len_unpadded() for trg_i in trg])
                loss_val += loss_builder.get_factored_loss_val(
                    comb_method=self.loss_comb_method)

        loss_stats = {k: v / ref_words_cnt for k, v in loss_val.items()}

        try:
            return LossScore(loss_stats[self.model.get_primary_loss()],
                             loss_stats=loss_stats,
                             num_ref_words=ref_words_cnt,
                             desc=self.desc)
        except KeyError:
            raise RuntimeError(
                "Did you wrap your loss calculation with FactoredLossExpr({'primary_loss': loss_value}) ?"
            )