Exemplo n.º 1
0
  def eval(self):
    if self.src_data == None:
      self.src_data, self.ref_data, self.src_batches, self.ref_batches = \
        xnmt.input_reader.read_parallel_corpus(self.model.src_reader, self.model.trg_reader,
                                        self.src_file, self.ref_file, batcher=self.batcher,
                                        max_src_len=self.max_src_len, max_trg_len=self.max_trg_len)
    loss_val = LossScalarBuilder()
    ref_words_cnt = 0
    for src, trg in zip(self.src_batches, self.ref_batches):
      dy.renew_cg(immediate_compute=settings.IMMEDIATE_COMPUTE, check_validity=settings.CHECK_VALIDITY)

      loss_builder = LossBuilder()
      standard_loss = self.model.calc_loss(src, trg, self.loss_calculator)
      additional_loss = self.model.calc_additional_loss(standard_loss)
      loss_builder.add_loss("standard_loss", standard_loss)
      loss_builder.add_loss("additional_loss", additional_loss)

      ref_words_cnt += self.model.trg_reader.count_words(trg)
      loss_val += loss_builder.get_loss_stats()

    loss_stats = {k: v/ref_words_cnt for k, v in loss_val.items()}

    try:
      return LossScore(loss_stats[self.model.get_primary_loss()], loss_stats=loss_stats, desc=self.desc), ref_words_cnt
    except KeyError:
      raise RuntimeError("Did you wrap your loss calculation with LossBuilder({'primary_loss': loss_value}) ?")
Exemplo n.º 2
0
  def training_step(self, src, trg):
    """
    Performs forward pass, backward pass, parameter update for the given minibatch
    """
    loss_builder = LossBuilder()
    standard_loss = self.model.calc_loss(src, trg, self.loss_calculator)
    additional_loss = self.model.calc_additional_loss(standard_loss)
    loss_builder.add_loss("standard_loss", standard_loss)
    loss_builder.add_loss("additional_loss", additional_loss)

    loss_value = loss_builder.compute()
    self.logger.update_epoch_loss(src, trg, loss_builder.get_loss_stats())
    self.logger.report_train_process()

    return loss_value